Merge branch 'remove-tiles'

# Conflicts:
#	marl_factory_grid/environment/actions.py
#	marl_factory_grid/environment/entity/entity.py
#	marl_factory_grid/environment/factory.py
#	marl_factory_grid/modules/batteries/rules.py
#	marl_factory_grid/modules/clean_up/groups.py
#	marl_factory_grid/modules/destinations/entitites.py
#	marl_factory_grid/modules/destinations/groups.py
#	marl_factory_grid/modules/destinations/rules.py
#	marl_factory_grid/modules/items/rules.py
#	marl_factory_grid/modules/maintenance/entities.py
#	marl_factory_grid/utils/config_parser.py
#	marl_factory_grid/utils/level_parser.py
#	marl_factory_grid/utils/states.py
This commit is contained in:
Steffen Illium
2023-10-17 14:03:59 +02:00
40 changed files with 285 additions and 273 deletions

View File

@@ -54,9 +54,10 @@ class TSPBaseAgent(ABC):
nodes=nodes, cycle=True, method=tsp.greedy_tsp) nodes=nodes, cycle=True, method=tsp.greedy_tsp)
return route return route
def _door_is_close(self): def _door_is_close(self, state):
try: try:
return next(y for x in self.state.tile.neighboring_floor for y in x.guests if do.DOOR in y.name) # return next(y for x in self.state.tile.neighboring_floor for y in x.guests if do.DOOR in y.name)
return next(y for x in state.entities.neighboring_positions(self.state.pos) for y in state.entities.pos_dict[x] if do.DOOR in y.name)
except StopIteration: except StopIteration:
return None return None

View File

@@ -14,7 +14,7 @@ class TSPDirtAgent(TSPBaseAgent):
if self._env.state[di.DIRT].by_pos(self.state.pos) is not None: if self._env.state[di.DIRT].by_pos(self.state.pos) is not None:
# Translate the action_object to an integer to have the same output as any other model # Translate the action_object to an integer to have the same output as any other model
action = di.CLEAN_UP action = di.CLEAN_UP
elif door := self._door_is_close(): elif door := self._door_is_close(self._env):
action = self._use_door_or_move(door, di.DIRT) action = self._use_door_or_move(door, di.DIRT)
else: else:
action = self._predict_move(di.DIRT) action = self._predict_move(di.DIRT)

View File

@@ -24,7 +24,7 @@ class TSPItemAgent(TSPBaseAgent):
elif self._env.state[i.DROP_OFF].by_pos(self.state.pos) is not None: elif self._env.state[i.DROP_OFF].by_pos(self.state.pos) is not None:
# Translate the action_object to an integer to have the same output as any other model # Translate the action_object to an integer to have the same output as any other model
action = i.ITEM_ACTION action = i.ITEM_ACTION
elif door := self._door_is_close(): elif door := self._door_is_close(self._env):
action = self._use_door_or_move(door, i.DROP_OFF if self.mode == MODE_BRING else i.ITEM) action = self._use_door_or_move(door, i.DROP_OFF if self.mode == MODE_BRING else i.ITEM)
else: else:
action = self._choose() action = self._choose()

View File

@@ -11,15 +11,16 @@ class TSPTargetAgent(TSPBaseAgent):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(TSPTargetAgent, self).__init__(*args, **kwargs) super(TSPTargetAgent, self).__init__(*args, **kwargs)
def _handle_doors(self): def _handle_doors(self, state):
try: try:
return next(y for x in self.state.tile.neighboring_floor for y in x.guests if do.DOOR in y.name) # return next(y for x in self.state.tile.neighboring_floor for y in x.guests if do.DOOR in y.name)
return next(y for x in state.entities.neighboring_positions(self.state.pos) for y in state.entities.pos_dict[x] if do.DOOR in y.name)
except StopIteration: except StopIteration:
return None return None
def predict(self, *_, **__): def predict(self, *_, **__):
if door := self._door_is_close(): if door := self._door_is_close(self._env):
action = self._use_door_or_move(door, d.DESTINATION) action = self._use_door_or_move(door, d.DESTINATION)
else: else:
action = self._predict_move(d.DESTINATION) action = self._predict_move(d.DESTINATION)

View File

@@ -40,11 +40,11 @@ class Move(Action, abc.ABC):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def do(self, entity, env): def do(self, entity, state):
new_pos = self._calc_new_pos(entity.pos) new_pos = self._calc_new_pos(entity.pos)
if next_tile := env[c.FLOORS].by_pos(new_pos): if state.check_move_validity(entity, new_pos): # next_tile := state[c.FLOOR].by_pos(new_pos):
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences
move_validity = entity.move(next_tile) move_validity = entity.move(new_pos, state)
reward = r.MOVEMENTS_VALID if move_validity else r.MOVEMENTS_FAIL reward = r.MOVEMENTS_VALID if move_validity else r.MOVEMENTS_FAIL
return ActionResult(entity=entity, identifier=self._identifier, validity=move_validity, reward=reward) return ActionResult(entity=entity, identifier=self._identifier, validity=move_validity, reward=reward)
else: # There is no floor, propably collision else: # There is no floor, propably collision

View File

@@ -27,57 +27,43 @@ class Entity(EnvObject, abc.ABC):
@property @property
def pos(self): def pos(self):
return self._tile.pos return self._pos
@property @property
def tile(self): def tile(self):
return self._tile return self._tile # wall_n_floors funktionalität
@property # @property
def last_tile(self): # def last_tile(self):
try: # try:
return self._last_tile # return self._last_tile
except AttributeError: # except AttributeError:
# noinspection PyAttributeOutsideInit # # noinspection PyAttributeOutsideInit
self._last_tile = None # self._last_tile = None
return self._last_tile # return self._last_tile
@property
def last_pos(self):
try:
return self.last_tile.pos
except AttributeError:
return c.VALUE_NO_POS
@property @property
def direction_of_view(self): def direction_of_view(self):
last_x, last_y = self.last_pos last_x, last_y = self._last_pos
curr_x, curr_y = self.pos curr_x, curr_y = self.pos
return last_x - curr_x, last_y - curr_y return last_x - curr_x, last_y - curr_y
def destroy(self): def move(self, next_pos, state):
if next_pos = next_pos
valid = self._collection.remove_item(self) curr_pos = self._pos
for observer in self.observers: if not_same_pos := curr_pos != next_pos:
observer.notify_del_entity(self) if valid := state.check_move_validity(self, next_pos):
return valid self._pos = next_pos
self._last_pos = curr_pos
def move(self, next_tile):
curr_tile = self.tile
if not_same_tile := curr_tile != next_tile:
if valid := next_tile.enter(self):
curr_tile.leave(self)
self._tile = next_tile
self._last_tile = curr_tile
for observer in self.observers: for observer in self.observers:
observer.notify_change_pos(self) observer.notify_change_pos(self)
return valid return valid
return not_same_tile return not_same_pos
def __init__(self, tile, bind_to=None, **kwargs): def __init__(self, pos, bind_to=None, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self._status = None self._status = None
self._tile = tile self._pos = pos
if bind_to: if bind_to:
try: try:
self.bind_to(bind_to) self.bind_to(bind_to)
@@ -85,11 +71,8 @@ class Entity(EnvObject, abc.ABC):
print(f'Objects of {self.__class__.__name__} can not be bound to other entities.') print(f'Objects of {self.__class__.__name__} can not be bound to other entities.')
exit() exit()
assert tile.enter(self, spawn=True), "Positions was not valid!" def summarize_state(self) -> dict: # tile=str(self.tile.name)
return dict(name=str(self.name), x=int(self.x), y=int(self.y), can_collide=bool(self.var_can_collide))
def summarize_state(self) -> dict:
return dict(name=str(self.name), x=int(self.x), y=int(self.y),
tile=str(self.tile.name), can_collide=bool(self.var_can_collide))
@abc.abstractmethod @abc.abstractmethod
def render(self): def render(self):

View File

@@ -45,9 +45,9 @@ class Floor(EnvObject):
def encoding(self): def encoding(self):
return c.VALUE_OCCUPIED_CELL return c.VALUE_OCCUPIED_CELL
@property # @property
def guests_that_can_collide(self): # def guests_that_can_collide(self):
return [x for x in self.guests if x.var_can_collide] # return [x for x in self.guests if x.var_can_collide]
@property @property
def guests(self): def guests(self):

View File

@@ -66,7 +66,8 @@ class Factory(gym.Env):
self.map: LevelParser self.map: LevelParser
self.obs_builder: OBSBuilder self.obs_builder: OBSBuilder
# TODO: Reset ---> document this # reset env to initial state, preparing env for new episode.
# returns tuple where the first dict contains initial observation for each agent in the env
self.reset() self.reset()
def __getitem__(self, item): def __getitem__(self, item):
@@ -82,10 +83,10 @@ class Factory(gym.Env):
self.state = None self.state = None
# Init entity: # Init entities
entities = self.map.do_init() entities = self.map.do_init()
# Grab all env-rules: # Init rules
rules = self.conf.load_rules() rules = self.conf.load_rules()
# Parse the agent conf # Parse the agent conf
@@ -93,9 +94,10 @@ class Factory(gym.Env):
self.state = Gamestate(entities, parsed_agents_conf, rules, self.conf.env_seed, self.conf.verbose) self.state = Gamestate(entities, parsed_agents_conf, rules, self.conf.env_seed, self.conf.verbose)
# All is set up, trigger entity init with variable pos # All is set up, trigger entity init with variable pos
# All is set up, trigger additional init (after agent entity spawn etc)
self.state.rules.do_all_init(self.state, self.map) self.state.rules.do_all_init(self.state, self.map)
# Observations # Build initial observations for all agents
# noinspection PyAttributeOutsideInit # noinspection PyAttributeOutsideInit
self.obs_builder = OBSBuilder(self.map.level_shape, self.state, self.map.pomdp_r) self.obs_builder = OBSBuilder(self.map.level_shape, self.state, self.map.pomdp_r)
return self.obs_builder.refresh_and_build_for_all(self.state) return self.obs_builder.refresh_and_build_for_all(self.state)

View File

@@ -23,10 +23,33 @@ class Entities(Objects):
def names(self): def names(self):
return list(self._data.keys()) return list(self._data.keys())
def __init__(self): @property
def floorlist(self):
return self._floor_positions
def __init__(self, floor_positions):
self._floor_positions = floor_positions
self.pos_dict = defaultdict(list) self.pos_dict = defaultdict(list)
super().__init__() super().__init__()
# def all_floors(self):
# return[key for key, val in self.pos_dict.items() if any('floor' in x.name.lower() for x in val)]
def guests_that_can_collide(self, pos):
return[x for val in self.pos_dict[pos] for x in val if x.var_can_collide]
def empty_tiles(self):
return[key for key in self.floorlist if not any(self.pos_dict[key])]
def occupied_tiles(self): # positions that are not empty
return[key for key in self.floorlist if any(self.pos_dict[key])]
def is_blocked(self):
return[key for key, val in self.pos_dict.items() if any([x.var_is_blocking_pos for x in val])]
def is_not_blocked(self):
return[key for key, val in self.pos_dict.items() if not all([x.var_is_blocking_pos for x in val])]
def iter_entities(self): def iter_entities(self):
return iter((x for sublist in self.values() for x in sublist)) return iter((x for sublist in self.values() for x in sublist))

View File

@@ -1,4 +1,6 @@
from typing import List from typing import List, Tuple
import numpy as np
from marl_factory_grid.environment import constants as c from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.entity.entity import Entity from marl_factory_grid.environment.entity.entity import Entity
@@ -6,40 +8,36 @@ from marl_factory_grid.environment.entity.wall_floor import Floor
class PositionMixin: class PositionMixin:
_entity = Entity _entity = Entity
var_is_blocking_light: bool = True var_is_blocking_light: bool = True
var_can_collide: bool = True var_can_collide: bool = True
var_has_position: bool = True var_has_position: bool = True
def spawn(self, tiles: List[Floor]): def spawn(self, coords: List[Tuple[(int, int)]]):
self.add_items([self._entity(tile) for tile in tiles]) self.add_items([self._entity(pos) for pos in coords])
def render(self): def render(self):
return [y for y in [x.render() for x in self] if y is not None] return [y for y in [x.render() for x in self] if y is not None]
# @classmethod
# def from_tiles(cls, tiles, *args, entity_kwargs=None, **kwargs):
# collection = cls(*args, **kwargs)
# entities = [cls._entity(tile, str_ident=i,
# **entity_kwargs if entity_kwargs is not None else {})
# for i, tile in enumerate(tiles)]
# collection.add_items(entities)
# return collection
@classmethod @classmethod
def from_tiles(cls, tiles, *args, entity_kwargs=None, **kwargs): def from_coordinates(cls, positions: [(int, int)], *args, entity_kwargs=None, **kwargs, ):
collection = cls(*args, **kwargs) collection = cls(*args, **kwargs)
entities = [cls._entity(tile, str_ident=i, collection.add_items(
**entity_kwargs if entity_kwargs is not None else {}) [cls._entity(tuple(pos), **entity_kwargs if entity_kwargs is not None else {}) for pos in positions])
for i, tile in enumerate(tiles)]
collection.add_items(entities)
return collection return collection
@classmethod
def from_coordinates(cls, positions: [(int, int)], tiles, *args, entity_kwargs=None, **kwargs, ):
return cls.from_tiles([tiles.by_pos(position) for position in positions], tiles.size, *args,
entity_kwargs=entity_kwargs,
**kwargs)
@property
def tiles(self):
return [entity.tile for entity in self]
def __delitem__(self, name): def __delitem__(self, name):
idx, obj = next((i, obj) for i, obj in enumerate(self) if obj.name == name) idx, obj = next((i, obj) for i, obj in enumerate(self) if obj.name == name)
obj.tile.leave(obj) obj.tile.leave(obj) # observer notify?
super().__delitem__(name) super().__delitem__(name)
def by_pos(self, pos: (int, int)): def by_pos(self, pos: (int, int)):

View File

@@ -1,5 +1,5 @@
import random import random
from typing import List from typing import List, Tuple
from marl_factory_grid.environment import constants as c from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.groups.env_objects import EnvObjects from marl_factory_grid.environment.groups.env_objects import EnvObjects
@@ -15,16 +15,12 @@ class Walls(PositionMixin, EnvObjects):
super(Walls, self).__init__(*args, **kwargs) super(Walls, self).__init__(*args, **kwargs)
self._value = c.VALUE_OCCUPIED_CELL self._value = c.VALUE_OCCUPIED_CELL
@classmethod # @classmethod
def from_coordinates(cls, argwhere_coordinates, *args, **kwargs): # def from_coordinates(cls, argwhere_coordinates, *args, **kwargs):
tiles = cls(*args, **kwargs) # tiles = cls(*args, **kwargs)
# noinspection PyTypeChecker # # noinspection PyTypeChecker
tiles.add_items([cls._entity(pos) for pos in argwhere_coordinates]) # tiles.add_items([cls._entity(pos) for pos in argwhere_coordinates])
return tiles # return tiles
@classmethod
def from_tiles(cls, tiles, *args, **kwargs):
raise RuntimeError()
def by_pos(self, pos: (int, int)): def by_pos(self, pos: (int, int)):
try: try:
@@ -43,18 +39,4 @@ class Floors(Walls):
super(Floors, self).__init__(*args, **kwargs) super(Floors, self).__init__(*args, **kwargs)
self._value = c.VALUE_FREE_CELL self._value = c.VALUE_FREE_CELL
@property
def occupied_tiles(self):
tiles = [tile for tile in self if tile.is_occupied()]
random.shuffle(tiles)
return tiles
@property
def empty_tiles(self) -> List[Floor]:
tiles = [tile for tile in self if tile.is_empty()]
random.shuffle(tiles)
return tiles
@classmethod
def from_tiles(cls, tiles, *args, **kwargs):
raise RuntimeError()

View File

@@ -111,10 +111,10 @@ class Collision(Rule):
def tick_post_step(self, state) -> List[TickResult]: def tick_post_step(self, state) -> List[TickResult]:
self.curr_done = False self.curr_done = False
tiles_with_collisions = state.get_all_tiles_with_collisions() pos_with_collisions = state.get_all_pos_with_collisions()
results = list() results = list()
for tile in tiles_with_collisions: for pos in pos_with_collisions:
guests = tile.guests_that_can_collide guests = [x for x in state.entities.pos_dict[pos] if x.var_can_collide]
if len(guests) >= 2: if len(guests) >= 2:
for i, guest in enumerate(guests): for i, guest in enumerate(guests):
try: try:

View File

@@ -66,7 +66,8 @@ class Pod(Entity):
def charge_battery(self, battery: Battery): def charge_battery(self, battery: Battery):
if battery.charge_level == 1.0: if battery.charge_level == 1.0:
return c.NOT_VALID return c.NOT_VALID
if sum(guest for guest in self.tile.guests if 'agent' in guest.name.lower()) > 1: # if sum(guest for guest in self.tile.guests if 'agent' in guest.name.lower()) > 1:
if sum(1 for key, val in self.state.entities.pos_dict[self.pos] for guest in val if 'agent' in guest.name.lower()) > 1:
return c.NOT_VALID return c.NOT_VALID
valid = battery.do_charge_action(self.charge_rate) valid = battery.do_charge_action(self.charge_rate)
return valid return valid

View File

@@ -71,7 +71,7 @@ class PodRules(Rule):
def on_init(self, state, lvl_map): def on_init(self, state, lvl_map):
pod_collection = state[b.CHARGE_PODS] pod_collection = state[b.CHARGE_PODS]
empty_tiles = state[c.FLOORS].empty_tiles[:self.n_pods] empty_tiles = state[c.FLOORS].empty_tiles[:self.n_pods]
pods = pod_collection.from_tiles(empty_tiles, entity_kwargs=dict( pods = pod_collection.from_coordinates(empty_tiles, entity_kwargs=dict(
multi_charge=self.multi_charge, charge_rate=self.charge_rate) multi_charge=self.multi_charge, charge_rate=self.charge_rate)
) )
pod_collection.add_items(pods) pod_collection.add_items(pods)

View File

@@ -1,6 +1,4 @@
from .actions import CleanUp from .actions import CleanUp
from .entitites import DirtPile from .entitites import DirtPile
from .groups import DirtPiles from .groups import DirtPiles
from .rule_respawn import DirtRespawnRule from .rules import DirtRespawnRule, DirtSmearOnMove, DirtAllCleanDone
from .rule_smear_on_move import DirtSmearOnMove
from .rule_done_on_all_clean import DirtAllCleanDone

View File

@@ -14,7 +14,7 @@ class CleanUp(Action):
super().__init__(d.CLEAN_UP) super().__init__(d.CLEAN_UP)
def do(self, entity, state) -> Union[None, ActionResult]: def do(self, entity, state) -> Union[None, ActionResult]:
if dirt := state[d.DIRT].by_pos(entity.pos): if dirt := next((x for x in state.entities.pos_dict[entity.pos] if "dirt" in x.name.lower()), None):
new_dirt_amount = dirt.amount - state[d.DIRT].clean_amount new_dirt_amount = dirt.amount - state[d.DIRT].clean_amount
if new_dirt_amount <= 0: if new_dirt_amount <= 0:

View File

@@ -48,4 +48,4 @@ class DirtPile(Entity):
return state_dict return state_dict
def render(self): def render(self):
return RenderEntity(d.DIRT, self.tile.pos, min(0.15 + self.amount, 1.5), 'scale') return RenderEntity(d.DIRT, self.pos, min(0.15 + self.amount, 1.5), 'scale')

View File

@@ -7,7 +7,6 @@ from marl_factory_grid.environment import constants as c
class DirtPiles(PositionMixin, EnvObjects): class DirtPiles(PositionMixin, EnvObjects):
_entity = DirtPile _entity = DirtPile
is_blocking_light: bool = False is_blocking_light: bool = False
can_collide: bool = False can_collide: bool = False
@@ -31,27 +30,28 @@ class DirtPiles(PositionMixin, EnvObjects):
self.max_global_amount = max_global_amount self.max_global_amount = max_global_amount
self.max_local_amount = max_local_amount self.max_local_amount = max_local_amount
def spawn(self, then_dirty_tiles, amount) -> bool: def spawn(self, then_dirty_positions, amount) -> bool:
if isinstance(then_dirty_tiles, Floor): # if isinstance(then_dirty_tiles, Floor):
then_dirty_tiles = [then_dirty_tiles] # then_dirty_tiles = [then_dirty_tiles]
for tile in then_dirty_tiles: for pos in then_dirty_positions:
if not self.amount > self.max_global_amount: if not self.amount > self.max_global_amount:
if dirt := self.by_pos(tile.pos): if dirt := self.by_pos(pos):
new_value = dirt.amount + amount new_value = dirt.amount + amount
dirt.set_new_amount(new_value) dirt.set_new_amount(new_value)
else: else:
dirt = DirtPile(tile, initial_amount=amount, spawn_variation=self.dirt_spawn_r_var) dirt = DirtPile(pos, initial_amount=amount, spawn_variation=self.dirt_spawn_r_var)
self.add_item(dirt) self.add_item(dirt)
else: else:
return c.NOT_VALID return c.NOT_VALID
return c.VALID return c.VALID
def trigger_dirt_spawn(self, state, initial_spawn=False) -> bool: def trigger_dirt_spawn(self, state, initial_spawn=False) -> bool:
free_for_dirt = [x for x in state[c.FLOORS] free_for_dirt = [x for x in state.entities.floorlist if len(state.entities.pos_dict[x]) == 1 or (
if len(x.guests) == 0 or ( len(state.entities.pos_dict[x]) == 2 and isinstance(next(y for y in x), DirtPile))]
len(x.guests) == 1 and # free_for_dirt = [x for x in state[c.FLOOR]
isinstance(next(y for y in x.guests), DirtPile)) # if len(x.guests) == 0 or (
] # len(x.guests) == 1 and
# isinstance(next(y for y in x.guests), DirtPile))]
state.rng.shuffle(free_for_dirt) state.rng.shuffle(free_for_dirt)
var = self.dirt_spawn_r_var var = self.dirt_spawn_r_var

View File

@@ -1,15 +0,0 @@
from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.utils.results import DoneResult
from marl_factory_grid.modules.clean_up import constants as d, rewards as r
class DirtAllCleanDone(Rule):
def __init__(self):
super().__init__()
def on_check_done(self, state) -> [DoneResult]:
if len(state[d.DIRT]) == 0 and state.curr_step:
return [DoneResult(validity=c.VALID, identifier=self.name, reward=r.CLEAN_UP_ALL)]
return [DoneResult(validity=c.NOT_VALID, identifier=self.name, reward=0)]

View File

@@ -1,28 +0,0 @@
from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.utils.results import TickResult
from marl_factory_grid.modules.clean_up import constants as d
class DirtRespawnRule(Rule):
def __init__(self, spawn_freq=15):
super().__init__()
self.spawn_freq = spawn_freq
self._next_dirt_spawn = spawn_freq
def on_init(self, state, lvl_map) -> str:
state[d.DIRT].trigger_dirt_spawn(state, initial_spawn=True)
return f'Initial Dirt was spawned on: {[x.pos for x in state[d.DIRT]]}'
def tick_step(self, state):
if self._next_dirt_spawn < 0:
pass # No DirtPile Spawn
elif not self._next_dirt_spawn:
validity = state[d.DIRT].trigger_dirt_spawn(state)
return [TickResult(entity=None, validity=validity, identifier=self.name, reward=0)]
self._next_dirt_spawn = self.spawn_freq
else:
self._next_dirt_spawn -= 1
return []

View File

@@ -1,24 +0,0 @@
from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.utils.helpers import is_move
from marl_factory_grid.utils.results import TickResult
from marl_factory_grid.environment import constants as c
from marl_factory_grid.modules.clean_up import constants as d
class DirtSmearOnMove(Rule):
def __init__(self, smear_amount: float = 0.2):
super().__init__()
self.smear_amount = smear_amount
def tick_post_step(self, state):
results = list()
for entity in state.moving_entites:
if is_move(entity.state.identifier) and entity.state.validity == c.VALID:
if old_pos_dirt := state[d.DIRT].by_pos(entity.last_pos):
if smeared_dirt := round(old_pos_dirt.amount * self.smear_amount, 2):
if state[d.DIRT].spawn(entity.tile, amount=smeared_dirt):
results.append(TickResult(identifier=self.name, entity=entity,
reward=0, validity=c.VALID))
return results

View File

@@ -0,0 +1,60 @@
from marl_factory_grid.modules.clean_up import constants as d, rewards as r
from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.utils.helpers import is_move
from marl_factory_grid.utils.results import TickResult
from marl_factory_grid.utils.results import DoneResult
class DirtAllCleanDone(Rule):
def __init__(self):
super().__init__()
def on_check_done(self, state) -> [DoneResult]:
if len(state[d.DIRT]) == 0 and state.curr_step:
return [DoneResult(validity=c.VALID, identifier=self.name, reward=r.CLEAN_UP_ALL)]
return [DoneResult(validity=c.NOT_VALID, identifier=self.name, reward=0)]
class DirtRespawnRule(Rule):
def __init__(self, spawn_freq=15):
super().__init__()
self.spawn_freq = spawn_freq
self._next_dirt_spawn = spawn_freq
def on_init(self, state, lvl_map) -> str:
state[d.DIRT].trigger_dirt_spawn(state, initial_spawn=True)
return f'Initial Dirt was spawned on: {[x.pos for x in state[d.DIRT]]}'
def tick_step(self, state):
if self._next_dirt_spawn < 0:
pass # No DirtPile Spawn
elif not self._next_dirt_spawn:
validity = state[d.DIRT].trigger_dirt_spawn(state)
return [TickResult(entity=None, validity=validity, identifier=self.name, reward=0)]
self._next_dirt_spawn = self.spawn_freq
else:
self._next_dirt_spawn -= 1
return []
class DirtSmearOnMove(Rule):
def __init__(self, smear_amount: float = 0.2):
super().__init__()
self.smear_amount = smear_amount
def tick_post_step(self, state):
results = list()
for entity in state.moving_entites:
if is_move(entity.state.identifier) and entity.state.validity == c.VALID:
if old_pos_dirt := state[d.DIRT].by_pos(entity.last_pos):
if smeared_dirt := round(old_pos_dirt.amount * self.smear_amount, 2):
if state[d.DIRT].spawn(entity.pos, amount=smeared_dirt): # pos statt tile
results.append(TickResult(identifier=self.name, entity=entity,
reward=0, validity=c.VALID))
return results

View File

@@ -39,7 +39,8 @@ class Destination(BoundEntityMixin, Entity):
def has_just_been_reached(self): def has_just_been_reached(self):
if self.was_reached: if self.was_reached:
return False return False
agent_at_position = any(self.bound_entity == x for x in self.tile.guests_that_can_collide) agent_at_position = any(c.AGENT.lower() in x.name.lower() for x in state.entities.pos_dict[self.pos] if x.var_can_collide)
if self.bound_entity: if self.bound_entity:
return ((agent_at_position and not self.action_counts) return ((agent_at_position and not self.action_counts)
or self._per_agent_actions[self.bound_entity.name] >= self.action_counts >= 1) or self._per_agent_actions[self.bound_entity.name] >= self.action_counts >= 1)

View File

@@ -1,10 +1,11 @@
from marl_factory_grid.environment.groups.env_objects import EnvObjects from marl_factory_grid.environment.groups.env_objects import EnvObjects
from marl_factory_grid.environment.groups.mixins import PositionMixin, HasBoundMixin from marl_factory_grid.environment.groups.mixins import PositionMixin
from marl_factory_grid.modules.destinations.entitites import Destination from marl_factory_grid.modules.destinations.entitites import Destination
from marl_factory_grid.environment import constants as c
from marl_factory_grid.modules.destinations import constants as d
class Destinations(PositionMixin, EnvObjects): class Destinations(PositionMixin, EnvObjects):
_entity = Destination _entity = Destination
is_blocking_light: bool = False is_blocking_light: bool = False
can_collide: bool = False can_collide: bool = False
@@ -14,3 +15,16 @@ class Destinations(PositionMixin, EnvObjects):
def __repr__(self): def __repr__(self):
return super(Destinations, self).__repr__() return super(Destinations, self).__repr__()
@staticmethod
def trigger_destination_spawn(n_dests, state):
coordinates = state.entities.floorlist[:n_dests]
if destinations := [Destination(pos) for pos in coordinates]:
state[d.DESTINATION].add_items(destinations)
state.print(f'{n_dests} new destinations have been spawned')
return c.VALID
else:
state.print('No Destiantions are spawning, limit is reached.')
return c.NOT_VALID

View File

@@ -81,11 +81,11 @@ class Door(Entity):
self._open() self._open()
return c.VALID return c.VALID
def tick(self): def tick(self, state):
if self.is_open and len(self.tile) == 1 and self.time_to_close: if self.is_open and len(state.entities.pos_dict[self.pos]) == 2 and self.time_to_close:
self.time_to_close -= 1 self.time_to_close -= 1
return c.NOT_VALID return c.NOT_VALID
elif self.is_open and not self.time_to_close and len(self.tile) == 1: elif self.is_open and not self.time_to_close and len(state.entities.pos_dict[self.pos]) == 2:
self.use() self.use()
return c.VALID return c.VALID
else: else:

View File

@@ -14,9 +14,9 @@ class Doors(PositionMixin, EnvObjects):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(Doors, self).__init__(*args, can_collide=True, **kwargs) super(Doors, self).__init__(*args, can_collide=True, **kwargs)
def tick_doors(self): def tick_doors(self, state):
result_dict = dict() result_dict = dict()
for door in self: for door in self:
did_tick = door.tick() did_tick = door.tick(state)
result_dict.update({door.name: did_tick}) result_dict.update({door.name: did_tick})
return result_dict return result_dict

View File

@@ -12,7 +12,7 @@ class DoorAutoClose(Rule):
def tick_step(self, state): def tick_step(self, state):
if doors := state[d.DOORS]: if doors := state[d.DOORS]:
doors_tick_result = doors.tick_doors() doors_tick_result = doors.tick_doors(state)
doors_that_ticked = [key for key, val in doors_tick_result.items() if val] doors_that_ticked = [key for key, val in doors_tick_result.items() if val]
state.print(f'{doors_that_ticked} were auto-closed' state.print(f'{doors_that_ticked} were auto-closed'
if doors_that_ticked else 'No Doors were auto-closed') if doors_that_ticked else 'No Doors were auto-closed')

View File

@@ -23,7 +23,7 @@ class AgentSingleZonePlacementBeta(Rule):
raise ValueError raise ValueError
tiles = [state[c.FLOORS].by_pos(pos) for pos in coordinates] tiles = [state[c.FLOORS].by_pos(pos) for pos in coordinates]
for agent, tile in zip(agents, tiles): for agent, tile in zip(agents, tiles):
agent.move(tile) agent.move(tile, state)
def tick_step(self, state): def tick_step(self, state):
return [] return []

View File

@@ -29,7 +29,7 @@ class ItemAction(Action):
elif items := state[i.ITEM].by_pos(entity.pos): elif items := state[i.ITEM].by_pos(entity.pos):
item = items[0] item = items[0]
item.change_parent_collection(inventory) item.change_parent_collection(inventory)
item.set_tile_to(state.NO_POS_TILE) item.set_pos_to(c.VALUE_NO_POS)
state.print(f'{entity.name} just picked up an item at {entity.pos}') state.print(f'{entity.name} just picked up an item at {entity.pos}')
return ActionResult(entity=entity, identifier=self._identifier, validity=c.VALID, reward=r.PICK_UP_VALID) return ActionResult(entity=entity, identifier=self._identifier, validity=c.VALID, reward=r.PICK_UP_VALID)

View File

@@ -11,7 +11,7 @@ class Item(Entity):
var_can_collide = False var_can_collide = False
def render(self): def render(self):
return RenderEntity(i.ITEM, self.tile.pos) if self.pos != c.VALUE_NO_POS else None return RenderEntity(i.ITEM, self.pos) if self.pos != c.VALUE_NO_POS else None
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@@ -29,8 +29,8 @@ class Item(Entity):
def set_auto_despawn(self, auto_despawn): def set_auto_despawn(self, auto_despawn):
self._auto_despawn = auto_despawn self._auto_despawn = auto_despawn
def set_tile_to(self, no_pos_tile): def set_pos_to(self, no_pos):
self._tile = no_pos_tile self._pos = no_pos
def summarize_state(self) -> dict: def summarize_state(self) -> dict:
super_summarization = super(Item, self).summarize_state() super_summarization = super(Item, self).summarize_state()
@@ -57,7 +57,7 @@ class DropOffLocation(Entity):
return True return True
def render(self): def render(self):
return RenderEntity(i.DROP_OFF, self.tile.pos) return RenderEntity(i.DROP_OFF, self.pos)
@property @property
def encoding(self): def encoding(self):

View File

@@ -1,15 +1,14 @@
from typing import List from marl_factory_grid.modules.items import constants as i
from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.groups.env_objects import EnvObjects from marl_factory_grid.environment.groups.env_objects import EnvObjects
from marl_factory_grid.environment.groups.objects import Objects from marl_factory_grid.environment.groups.objects import Objects
from marl_factory_grid.environment.groups.mixins import PositionMixin, IsBoundMixin, HasBoundMixin from marl_factory_grid.environment.groups.mixins import PositionMixin, IsBoundMixin, HasBoundMixin
from marl_factory_grid.environment.entity.wall_floor import Floor
from marl_factory_grid.environment.entity.agent import Agent from marl_factory_grid.environment.entity.agent import Agent
from marl_factory_grid.modules.items.entitites import Item, DropOffLocation from marl_factory_grid.modules.items.entitites import Item, DropOffLocation
class Items(PositionMixin, EnvObjects): class Items(PositionMixin, EnvObjects):
_entity = Item _entity = Item
is_blocking_light: bool = False is_blocking_light: bool = False
can_collide: bool = False can_collide: bool = False
@@ -17,9 +16,19 @@ class Items(PositionMixin, EnvObjects):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@staticmethod
def trigger_item_spawn(state, n_items, spawn_frequency):
if item_to_spawns := max(0, (n_items - len(state[i.ITEM]))):
floor_list = state.entities.floorlist[:item_to_spawns]
state[i.ITEM].spawn(floor_list)
state.print(f'{item_to_spawns} new items have been spawned; next spawn in {spawn_frequency}') # spawn in self._next_item_spawn ?
return len(floor_list)
else:
state.print('No Items are spawning, limit is reached.')
return 0
class Inventory(IsBoundMixin, EnvObjects): class Inventory(IsBoundMixin, EnvObjects):
_accepted_objects = Item _accepted_objects = Item
@property @property
@@ -27,7 +36,7 @@ class Inventory(IsBoundMixin, EnvObjects):
return self.name return self.name
def __init__(self, agent: Agent, *args, **kwargs): def __init__(self, agent: Agent, *args, **kwargs):
super(Inventory, self).__init__(*args, **kwargs) super(Inventory, self).__init__(*args, **kwargs)
self._collection = None self._collection = None
self.bind(agent) self.bind(agent)
@@ -47,7 +56,6 @@ class Inventory(IsBoundMixin, EnvObjects):
class Inventories(HasBoundMixin, Objects): class Inventories(HasBoundMixin, Objects):
_entity = Inventory _entity = Inventory
var_can_move = False var_can_move = False
@@ -58,7 +66,7 @@ class Inventories(HasBoundMixin, Objects):
self._lazy_eval_transforms = [] self._lazy_eval_transforms = []
def spawn(self, agents): def spawn(self, agents):
inventories = [self._entity(agent, self.size,) inventories = [self._entity(agent, self.size, )
for _, agent in enumerate(agents)] for _, agent in enumerate(agents)]
self.add_items(inventories) self.add_items(inventories)
@@ -77,12 +85,22 @@ class Inventories(HasBoundMixin, Objects):
def summarize_states(self, **kwargs): def summarize_states(self, **kwargs):
return [val.summarize_states(**kwargs) for key, val in self.items()] return [val.summarize_states(**kwargs) for key, val in self.items()]
@staticmethod
def trigger_inventory_spawn(state):
state[i.INVENTORY].spawn(state[c.AGENT])
class DropOffLocations(PositionMixin, EnvObjects): class DropOffLocations(PositionMixin, EnvObjects):
_entity = DropOffLocation _entity = DropOffLocation
is_blocking_light: bool = False is_blocking_light: bool = False
can_collide: bool = False can_collide: bool = False
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(DropOffLocations, self).__init__(*args, **kwargs) super(DropOffLocations, self).__init__(*args, **kwargs)
@staticmethod
def trigger_drop_off_location_spawn(state, n_locations):
empty_tiles = state.entities.floorlist[:n_locations]
do_entites = state[i.DROP_OFF]
drop_offs = [DropOffLocation(tile) for tile in empty_tiles]
do_entites.add_items(drop_offs)

View File

@@ -4,7 +4,6 @@ from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.environment import constants as c from marl_factory_grid.environment import constants as c
from marl_factory_grid.utils.results import TickResult from marl_factory_grid.utils.results import TickResult
from marl_factory_grid.modules.items import constants as i from marl_factory_grid.modules.items import constants as i
from marl_factory_grid.modules.items.entitites import DropOffLocation
class ItemRules(Rule): class ItemRules(Rule):
@@ -19,10 +18,10 @@ class ItemRules(Rule):
self.n_locations = n_locations self.n_locations = n_locations
def on_init(self, state, lvl_map): def on_init(self, state, lvl_map):
self.trigger_drop_off_location_spawn(state) state[i.DROP_OFF].trigger_drop_off_location_spawn(state, self.n_locations)
self._next_item_spawn = self.spawn_frequency self._next_item_spawn = self.spawn_frequency
self.trigger_inventory_spawn(state) state[i.INVENTORY].trigger_inventory_spawn(state)
self.trigger_item_spawn(state) state[i.ITEM].trigger_item_spawn(state, self.n_items, self.spawn_frequency)
def tick_step(self, state): def tick_step(self, state):
for item in list(state[i.ITEM].values()): for item in list(state[i.ITEM].values()):
@@ -34,26 +33,11 @@ class ItemRules(Rule):
pass pass
if not self._next_item_spawn: if not self._next_item_spawn:
self.trigger_item_spawn(state) state[i.ITEM].trigger_item_spawn(state, self.n_items, self.spawn_frequency)
else: else:
self._next_item_spawn = max(0, self._next_item_spawn - 1) self._next_item_spawn = max(0, self._next_item_spawn - 1)
return [] return []
def trigger_item_spawn(self, state):
if item_to_spawns := max(0, (self.n_items - len(state[i.ITEM]))):
empty_tiles = state[c.FLOORS].empty_tiles[:item_to_spawns]
state[i.ITEM].spawn(empty_tiles)
self._next_item_spawn = self.spawn_frequency
state.print(f'{item_to_spawns} new items have been spawned; next spawn in {self._next_item_spawn}')
return len(empty_tiles)
else:
state.print('No Items are spawning, limit is reached.')
return 0
@staticmethod
def trigger_inventory_spawn(state):
state[i.INVENTORY].spawn(state[c.AGENT])
def tick_post_step(self, state) -> List[TickResult]: def tick_post_step(self, state) -> List[TickResult]:
for item in list(state[i.ITEM].values()): for item in list(state[i.ITEM].values()):
if item.auto_despawn >= 1: if item.auto_despawn >= 1:
@@ -64,7 +48,7 @@ class ItemRules(Rule):
pass pass
if not self._next_item_spawn: if not self._next_item_spawn:
if spawned_items := self.trigger_item_spawn(state): if spawned_items := state[i.ITEM].trigger_item_spawn(state, self.n_items, self.spawn_frequency):
return [TickResult(self.name, validity=c.VALID, value=spawned_items, entity=None)] return [TickResult(self.name, validity=c.VALID, value=spawned_items, entity=None)]
else: else:
return [TickResult(self.name, validity=c.NOT_VALID, value=0, entity=None)] return [TickResult(self.name, validity=c.NOT_VALID, value=0, entity=None)]
@@ -72,8 +56,3 @@ class ItemRules(Rule):
self._next_item_spawn = max(0, self._next_item_spawn-1) self._next_item_spawn = max(0, self._next_item_spawn-1)
return [] return []
def trigger_drop_off_location_spawn(self, state):
empty_tiles = state[c.FLOORS].empty_tiles[:self.n_locations]
do_entites = state[i.DROP_OFF]
drop_offs = [DropOffLocation(tile) for tile in empty_tiles]
do_entites.add_items(drop_offs)

View File

@@ -47,9 +47,11 @@ class Machine(Entity):
return c.NOT_VALID return c.NOT_VALID
def tick(self): def tick(self):
if self.status == m.STATE_MAINTAIN and any([c.AGENT in x.name for x in self.tile.guests]): # if self.status == m.STATE_MAINTAIN and any([c.AGENT in x.name for x in self.tile.guests]):
if self.status == m.STATE_MAINTAIN and any([c.AGENT in x.name for x in self.state.entities.pos_dict[self.pos]]):
return TickResult(identifier=self.name, validity=c.VALID, reward=0, entity=self) return TickResult(identifier=self.name, validity=c.VALID, reward=0, entity=self)
elif self.status == m.STATE_MAINTAIN and not any([c.AGENT in x.name for x in self.tile.guests]): # elif self.status == m.STATE_MAINTAIN and not any([c.AGENT in x.name for x in self.tile.guests]):
elif self.status == m.STATE_MAINTAIN and not any([c.AGENT in x.name for x in self.state.entities.pos_dict[self.pos]]):
self.status = m.STATE_WORK self.status = m.STATE_WORK
self.reset_counter() self.reset_counter()
return None return None

View File

@@ -61,7 +61,7 @@ class Maintainer(Entity):
self._last.append(self._next.pop()) self._last.append(self._next.pop())
self._path = self.calculate_route(self._last[-1]) self._path = self.calculate_route(self._last[-1])
if door := self._door_is_close(): if door := self._door_is_close(state):
if door.is_closed: if door.is_closed:
# Translate the action_object to an integer to have the same output as any other model # Translate the action_object to an integer to have the same output as any other model
action = do.ACTION_DOOR_USE action = do.ACTION_DOOR_USE
@@ -81,15 +81,18 @@ class Maintainer(Entity):
route = nx.shortest_path(self._floortile_graph, self.pos, entity.pos) route = nx.shortest_path(self._floortile_graph, self.pos, entity.pos)
return route[1:] return route[1:]
def _door_is_close(self): def _door_is_close(self, state):
state.print("Found a door that is close.")
try: try:
return next(y for x in self.tile.neighboring_floor for y in x.guests if do.DOOR in y.name) # return next(y for x in self.tile.neighboring_floor for y in x.guests if do.DOOR in y.name)
return next(y for x in state.entities.neighboring_positions(self.state.pos) for y in state.entities.pos_dict[x] if do.DOOR in y.name)
except StopIteration: except StopIteration:
return None return None
def _predict_move(self, state): def _predict_move(self, state):
next_pos = self._path[0] next_pos = self._path[0]
if len(state[c.FLOORS].by_pos(next_pos).guests_that_can_collide) > 0: # if len(state[c.FLOORS].by_pos(next_pos).guests_that_can_collide) > 0:
if any(x for x in state.entities.pos_dict[next_pos] if x.var_can_collide) > 0:
action = c.NOOP action = c.NOOP
else: else:
next_pos = self._path.pop(0) next_pos = self._path.pop(0)

View File

@@ -21,5 +21,5 @@ class Maintainers(PositionMixin, EnvObjects):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def spawn(self, tiles: List[Floor], state: Gamestate): def spawn(self, position, state: Gamestate):
self.add_items([self._entity(state, mc.MACHINES, MachineAction(), tile) for tile in tiles]) self.add_items([self._entity(state, mc.MACHINES, MachineAction(), pos) for pos in position])

View File

@@ -1,5 +1,5 @@
import random import random
from typing import List from typing import List, Tuple
from marl_factory_grid.environment.entity.entity import Entity from marl_factory_grid.environment.entity.entity import Entity
from marl_factory_grid.environment.entity.object import Object from marl_factory_grid.environment.entity.object import Object
@@ -14,12 +14,12 @@ class Zone(Object):
@property @property
def positions(self): def positions(self):
return [x.pos for x in self.tiles] return self.coords
def __init__(self, tiles: List[Floor], *args, **kwargs): def __init__(self, coords: List[Tuple[(int, int)]], *args, **kwargs):
super(Zone, self).__init__(*args, **kwargs) super(Zone, self).__init__(*args, **kwargs)
self.tiles = tiles self.coords = coords
@property @property
def random_tile(self): def random_tile(self):
return random.choice(self.tiles) return random.choice(self.coords)

View File

@@ -38,7 +38,7 @@ class AgentSingleZonePlacement(Rule):
z_idxs = choices(list(range(len(state[z.ZONES]))), k=n_agents) z_idxs = choices(list(range(len(state[z.ZONES]))), k=n_agents)
for agent in state[c.AGENT]: for agent in state[c.AGENT]:
agent.move(state[z.ZONES][z_idxs.pop()].random_tile) agent.move(state[z.ZONES][z_idxs.pop()].random_tile, state)
return [] return []
def tick_step(self, state): def tick_step(self, state):

View File

@@ -16,7 +16,6 @@ MODULE_PATH = 'modules'
class FactoryConfigParser(object): class FactoryConfigParser(object):
default_entites = [] default_entites = []
default_rules = ['MaxStepsReached', 'Collision'] default_rules = ['MaxStepsReached', 'Collision']
default_actions = [c.MOVE8, c.NOOP] default_actions = [c.MOVE8, c.NOOP]

View File

@@ -23,7 +23,7 @@ class LevelParser(object):
self._parsed_level = h.parse_level(Path(level_file_path)) self._parsed_level = h.parse_level(Path(level_file_path))
level_array = h.one_hot_level(self._parsed_level, c.SYMBOL_WALL) level_array = h.one_hot_level(self._parsed_level, c.SYMBOL_WALL)
self.level_shape = level_array.shape self.level_shape = level_array.shape
self.size = self.pomdp_r**2 if self.pomdp_r else np.prod(self.level_shape) self.size = self.pomdp_r ** 2 if self.pomdp_r else np.prod(self.level_shape)
def get_coordinates_for_symbol(self, symbol, negate=False): def get_coordinates_for_symbol(self, symbol, negate=False):
level_array = h.one_hot_level(self._parsed_level, symbol) level_array = h.one_hot_level(self._parsed_level, symbol)
@@ -33,14 +33,17 @@ class LevelParser(object):
return np.argwhere(level_array == c.VALUE_OCCUPIED_CELL) return np.argwhere(level_array == c.VALUE_OCCUPIED_CELL)
def do_init(self): def do_init(self):
entities = Entities() # Global Entities
list_of_all_floors = ([tuple(floor) for floor in self.get_coordinates_for_symbol(c.SYMBOL_WALL, negate=True)])
entities = Entities(list_of_all_floors)
# Walls # Walls
walls = Walls.from_coordinates(self.get_coordinates_for_symbol(c.SYMBOL_WALL), self.size) walls = Walls.from_coordinates(self.get_coordinates_for_symbol(c.SYMBOL_WALL), self.size)
entities.add_items({c.WALLS: walls}) entities.add_items({c.WALLS: walls})
# Floor # Floor
floor = Floors.from_coordinates(self.get_coordinates_for_symbol(c.SYMBOL_WALL, negate=True), self.size) floor = Floors.from_coordinates(list_of_all_floors, self.size)
entities.add_items({c.FLOORS: floor}) entities.add_items({c.FLOOR: floor})
entities.add_items({c.AGENT: Agents(self.size)}) entities.add_items({c.AGENT: Agents(self.size)})
# All other # All other
@@ -56,8 +59,7 @@ class LevelParser(object):
if np.any(level_array): if np.any(level_array):
# TODO: Get rid of this! # TODO: Get rid of this!
e = e_class.from_coordinates(np.argwhere(level_array == c.VALUE_OCCUPIED_CELL).tolist(), e = e_class.from_coordinates(np.argwhere(level_array == c.VALUE_OCCUPIED_CELL).tolist(),
entities[c.FLOORS], self.size, entity_kwargs=e_kwargs self.size, entity_kwargs=e_kwargs)
)
else: else:
raise ValueError(f'No {e_class} (Symbol: {e_class.symbol}) could be found!\n' raise ValueError(f'No {e_class} (Symbol: {e_class.symbol}) could be found!\n'
f'Check your level file!') f'Check your level file!')

View File

@@ -1,8 +1,7 @@
from typing import List, Dict from typing import List, Dict, Tuple
import numpy as np import numpy as np
from marl_factory_grid.environment import constants as c from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.entity.wall_floor import Floor from marl_factory_grid.environment.entity.wall_floor import Floor
from marl_factory_grid.environment.groups.global_entities import Entities from marl_factory_grid.environment.groups.global_entities import Entities
@@ -60,11 +59,10 @@ class Gamestate(object):
@property @property
def moving_entites(self): def moving_entites(self):
return [y for x in self.entities for y in x if x.var_can_move] return [y for x in self.entities for y in x if x.var_can_move] # wird das aus dem String gelesen?
def __init__(self, entitites, agents_conf, rules: Dict[str, dict], env_seed=69, verbose=False): def __init__(self, entities, agents_conf, rules: Dict[str, dict], env_seed=69, verbose=False):
self.entities: Entities = entitites self.entities = entities
self.NO_POS_TILE = Floor(c.VALUE_NO_POS)
self.curr_step = 0 self.curr_step = 0
self.curr_actions = None self.curr_actions = None
self.agents_conf = agents_conf self.agents_conf = agents_conf
@@ -114,8 +112,22 @@ class Gamestate(object):
results.extend(on_check_done_result) results.extend(on_check_done_result)
return results return results
def get_all_tiles_with_collisions(self) -> List[Floor]: # def get_all_tiles_with_collisions(self) -> List[Floor]:
tiles = [self[c.FLOORS].by_pos(pos) for pos, e in self.entities.pos_dict.items() # tiles = [self[c.FLOORS].by_pos(pos) for pos, e in self.entities.pos_dict.items()
if sum([x.var_can_collide for x in e]) > 1] # if sum([x.var_can_collide for x in e]) > 1]
# tiles = [x for x in self[c.FLOOR] if len(x.guests_that_can_collide) > 1] # # tiles = [x for x in self[c.FLOOR] if len(x.guests_that_can_collide) > 1]
return tiles # return tiles
def get_all_pos_with_collisions(self) -> List[Tuple[(int, int)]]:
positions = [pos for pos, e in self.entities.pos_dict.items()
if sum([x.var_can_collide for x in e]) > 1]
return positions
def check_move_validity(self, moving_entity, position):
# if (guest.name not in self._guests and not self.is_blocked)
# and not (guest.var_is_blocking_pos and self.is_occupied()):
if moving_entity.pos != position and not any(
entity.var_is_blocking_pos for entity in self.entities.pos_dict[position]) and not (
moving_entity.var_is_blocking_pos and moving_entity.is_occupied()):
return True
return False