From 06a5130b2558bfcd2c081133be59f0fbe6954166 Mon Sep 17 00:00:00 2001 From: Steffen Illium Date: Thu, 9 Nov 2023 17:50:20 +0100 Subject: [PATCH] new rules, new spawn logic, small fixes, default and narrow corridor debugged --- README.md | 2 +- marl_factory_grid/__init__.py | 7 +- .../algorithms/static/TSP_target_agent.py | 1 - marl_factory_grid/algorithms/static/utils.py | 20 +- marl_factory_grid/configs/default_config.yaml | 76 +++++--- .../configs/narrow_corridor.yaml | 50 +++-- marl_factory_grid/environment/actions.py | 4 +- marl_factory_grid/environment/constants.py | 5 +- marl_factory_grid/environment/entity/agent.py | 31 +-- .../environment/entity/entity.py | 17 +- marl_factory_grid/environment/entity/mixin.py | 24 --- .../environment/entity/object.py | 31 ++- marl_factory_grid/environment/entity/util.py | 5 +- marl_factory_grid/environment/entity/wall.py | 17 +- marl_factory_grid/environment/factory.py | 7 +- .../environment/groups/agents.py | 5 + .../environment/groups/collection.py | 77 ++++++-- .../environment/groups/global_entities.py | 32 ++- .../environment/groups/mixins.py | 4 - .../environment/groups/objects.py | 9 +- marl_factory_grid/environment/groups/utils.py | 25 +-- marl_factory_grid/environment/groups/walls.py | 9 +- marl_factory_grid/environment/rules.py | 55 ++++-- .../modules/batteries/__init__.py | 2 +- .../modules/batteries/actions.py | 5 +- .../modules/batteries/chargepods.png | Bin 0 -> 8135 bytes .../modules/batteries/entitites.py | 4 +- marl_factory_grid/modules/batteries/groups.py | 44 ++--- marl_factory_grid/modules/batteries/rules.py | 14 +- .../modules/clean_up/__init__.py | 2 +- .../modules/clean_up/entitites.py | 16 -- marl_factory_grid/modules/clean_up/groups.py | 71 +++---- marl_factory_grid/modules/clean_up/rules.py | 47 ++--- .../modules/destinations/__init__.py | 5 +- .../modules/destinations/entitites.py | 24 --- .../modules/destinations/groups.py | 33 +--- .../modules/destinations/rules.py | 62 ++---- marl_factory_grid/modules/doors/entitites.py | 68 ++++--- marl_factory_grid/modules/doors/groups.py | 10 +- marl_factory_grid/modules/doors/rules.py | 8 +- marl_factory_grid/modules/items/__init__.py | 1 - marl_factory_grid/modules/items/actions.py | 2 +- marl_factory_grid/modules/items/entitites.py | 23 --- marl_factory_grid/modules/items/groups.py | 48 +++-- marl_factory_grid/modules/items/rules.py | 40 +--- .../modules/machines/__init__.py | 1 - marl_factory_grid/modules/machines/actions.py | 6 +- .../modules/machines/entitites.py | 24 +-- marl_factory_grid/modules/machines/rules.py | 28 --- .../modules/maintenance/entities.py | 63 +++--- .../modules/maintenance/groups.py | 30 ++- .../modules/maintenance/rules.py | 23 +-- marl_factory_grid/modules/zones/rules.py | 2 +- marl_factory_grid/utils/__init__.py | 3 + marl_factory_grid/utils/config_parser.py | 156 ++++++++++----- marl_factory_grid/utils/helpers.py | 12 +- marl_factory_grid/utils/level_parser.py | 1 + .../utils/observation_builder.py | 183 +++--------------- marl_factory_grid/utils/ray_caster.py | 6 +- marl_factory_grid/utils/renderer.py | 36 ++-- marl_factory_grid/utils/results.py | 5 +- marl_factory_grid/utils/states.py | 16 +- marl_factory_grid/utils/tools.py | 1 - marl_factory_grid/utils/utility_classes.py | 3 + reload_agent.py | 3 +- studies/normalization_study.py | 2 +- transform_wg_to_json_no_priv.py | 43 ++++ 67 files changed, 768 insertions(+), 921 deletions(-) delete mode 100644 marl_factory_grid/environment/entity/mixin.py create mode 100644 marl_factory_grid/modules/batteries/chargepods.png create mode 100644 transform_wg_to_json_no_priv.py diff --git a/README.md b/README.md index a1d2740..07d67bf 100644 --- a/README.md +++ b/README.md @@ -56,7 +56,7 @@ Just define what your environment needs in a *yaml*-configfile like: - Items Rules: Defaults: {} - Collision: + WatchCollisions: done_at_collisions: !!bool True ItemRespawn: spawn_freq: 5 diff --git a/marl_factory_grid/__init__.py b/marl_factory_grid/__init__.py index b2bbfa3..49f5635 100644 --- a/marl_factory_grid/__init__.py +++ b/marl_factory_grid/__init__.py @@ -1,6 +1 @@ -from .environment import * -from .modules import * -from .utils import * - -from .quickstart import init - +from .quickstart import init \ No newline at end of file diff --git a/marl_factory_grid/algorithms/static/TSP_target_agent.py b/marl_factory_grid/algorithms/static/TSP_target_agent.py index 0c5de3a..5e0f989 100644 --- a/marl_factory_grid/algorithms/static/TSP_target_agent.py +++ b/marl_factory_grid/algorithms/static/TSP_target_agent.py @@ -30,4 +30,3 @@ class TSPTargetAgent(TSPBaseAgent): except (StopIteration, UnboundLocalError): print('Will not happen') return action_obj - diff --git a/marl_factory_grid/algorithms/static/utils.py b/marl_factory_grid/algorithms/static/utils.py index d5119db..60cba30 100644 --- a/marl_factory_grid/algorithms/static/utils.py +++ b/marl_factory_grid/algorithms/static/utils.py @@ -26,12 +26,16 @@ def points_to_graph(coordiniates, allow_euclidean_connections=True, allow_manhat assert allow_euclidean_connections or allow_manhattan_connections possible_connections = itertools.combinations(coordiniates, 2) graph = nx.Graph() - for a, b in possible_connections: - diff = np.linalg.norm(np.asarray(a)-np.asarray(b)) - if allow_manhattan_connections and allow_euclidean_connections and diff <= np.sqrt(2): - graph.add_edge(a, b) - elif not allow_manhattan_connections and allow_euclidean_connections and diff == np.sqrt(2): - graph.add_edge(a, b) - elif allow_manhattan_connections and not allow_euclidean_connections and diff == 1: - graph.add_edge(a, b) + if allow_manhattan_connections and allow_euclidean_connections: + graph.add_edges_from( + (a, b) for a, b in possible_connections if np.linalg.norm(np.asarray(a) - np.asarray(b)) <= np.sqrt(2) + ) + elif not allow_manhattan_connections and allow_euclidean_connections: + graph.add_edges_from( + (a, b) for a, b in possible_connections if np.linalg.norm(np.asarray(a) - np.asarray(b)) == np.sqrt(2) + ) + elif allow_manhattan_connections and not allow_euclidean_connections: + graph.add_edges_from( + (a, b) for a, b in possible_connections if np.linalg.norm(np.asarray(a) - np.asarray(b)) == 1 + ) return graph diff --git a/marl_factory_grid/configs/default_config.yaml b/marl_factory_grid/configs/default_config.yaml index 44a0977..d3015c9 100644 --- a/marl_factory_grid/configs/default_config.yaml +++ b/marl_factory_grid/configs/default_config.yaml @@ -22,26 +22,41 @@ Agents: - Inventory - DropOffLocations - Maintainers + # This is special for agents, as each one is differten and can act as an adversary e.g. + Positions: + - (16, 7) + - (16, 6) + - (16, 3) + - (16, 4) + - (16, 5) Entities: Batteries: initial_charge: 0.8 per_action_costs: 0.02 - ChargePods: {} - Destinations: {} + ChargePods: + coords_or_quantity: 2 + Destinations: + coords_or_quantity: 1 + spawn_mode: GROUPED DirtPiles: + coords_or_quantity: 10 + initial_amount: 2 clean_amount: 1 dirt_spawn_r_var: 0.1 - initial_amount: 2 - initial_dirt_ratio: 0.05 max_global_amount: 20 max_local_amount: 5 - Doors: {} - DropOffLocations: {} + Doors: + DropOffLocations: + coords_or_quantity: 1 + max_dropoff_storage_size: 0 GlobalPositions: {} Inventories: {} - Items: {} - Machines: {} - Maintainers: {} + Items: + coords_or_quantity: 5 + Machines: + coords_or_quantity: 2 + Maintainers: + coords_or_quantity: 1 Zones: {} General: @@ -49,32 +64,31 @@ General: individual_rewards: true level_name: large pomdp_r: 3 - verbose: false + verbose: True + tests: false Rules: - SpawnAgents: {} - DoneAtBatteryDischarge: {} - Collision: - done_at_collisions: false - AssignGlobalPositions: {} - DoneAtDestinationReachAny: {} - DestinationReachReward: {} - SpawnDestinations: - n_dests: 1 - spawn_mode: GROUPED - DoneOnAllDirtCleaned: {} - SpawnDirt: - spawn_freq: 15 + # Environment Dynamics EntitiesSmearDirtOnMove: smear_ratio: 0.2 DoorAutoClose: close_frequency: 10 - ItemRules: - max_dropoff_storage_size: 0 - n_items: 5 - n_locations: 5 - spawn_frequency: 15 - MaxStepsReached: + MoveMaintainers: + + # Respawn Stuff + RespawnDirt: + respawn_freq: 15 + RespawnItems: + respawn_freq: 15 + + # Utilities + WatchCollisions: + done_at_collisions: false + + # Done Conditions + DoneAtDestinationReachAny: + DoneOnAllDirtCleaned: + DoneAtBatteryDischarge: + DoneAtMaintainerCollision: + DoneAtMaxStepsReached: max_steps: 500 -# AgentSingleZonePlacement: -# n_zones: 4 diff --git a/marl_factory_grid/configs/narrow_corridor.yaml b/marl_factory_grid/configs/narrow_corridor.yaml index 0006513..ddfeebd 100644 --- a/marl_factory_grid/configs/narrow_corridor.yaml +++ b/marl_factory_grid/configs/narrow_corridor.yaml @@ -1,3 +1,10 @@ +General: + env_seed: 69 + individual_rewards: true + level_name: narrow_corridor + pomdp_r: 0 + verbose: true + Agents: Wolfgang: Actions: @@ -10,6 +17,7 @@ Agents: Positions: - (2, 1) - (2, 5) + is_blocking_pos: true Karl-Heinz: Actions: - Noop @@ -21,26 +29,30 @@ Agents: Positions: - (2, 1) - (2, 5) -Entities: - Destinations: {} + is_blocking_pos: true -General: - env_seed: 69 - individual_rewards: true - level_name: narrow_corridor - pomdp_r: 0 - verbose: true +Entities: + Destinations: + ignore_blocking: true + spawnrule: + SpawnDestinationsPerAgent: + coords_or_quantity: + Wolfgang: + - (2, 1) + - (2, 5) + Karl-Heinz: + - (2, 1) + - (2, 5) + # Whether you want to provide a numeric Position observation. + # GlobalPositions: + # normalized: false Rules: - SpawnAgents: {} - Collision: + # Utilities + WatchCollisions: done_at_collisions: false - FixedDestinationSpawn: - per_agent_positions: - Wolfgang: - - (2, 1) - - (2, 5) - Karl-Heinz: - - (2, 1) - - (2, 5) - DestinationReachAll: {} + # Done Conditions + # DoneAtDestinationReachAny: + DoneAtDestinationReachAll: + DoneAtMaxStepsReached: + max_steps: 500 diff --git a/marl_factory_grid/environment/actions.py b/marl_factory_grid/environment/actions.py index 4edfe24..606832c 100644 --- a/marl_factory_grid/environment/actions.py +++ b/marl_factory_grid/environment/actions.py @@ -48,9 +48,9 @@ class Move(Action, abc.ABC): reward = r.MOVEMENTS_VALID if move_validity else r.MOVEMENTS_FAIL return ActionResult(entity=entity, identifier=self._identifier, validity=move_validity, reward=reward) else: # There is no place to go, propably collision - # This is currently handeld by the Collision rule, so that it can be switched on and off by conf.yml + # This is currently handeld by the WatchCollisions rule, so that it can be switched on and off by conf.yml # return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=r.COLLISION) - return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=0) + return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID) def _calc_new_pos(self, pos): x_diff, y_diff = MOVEMAP[self._identifier] diff --git a/marl_factory_grid/environment/constants.py b/marl_factory_grid/environment/constants.py index 1fdf639..6ddb19a 100644 --- a/marl_factory_grid/environment/constants.py +++ b/marl_factory_grid/environment/constants.py @@ -10,6 +10,7 @@ AGENT = 'Agent' # Identifier of Agent-objects an OTHERS = 'Other' COMBINED = 'Combined' GLOBALPOSITIONS = 'GlobalPositions' # Identifier of the global position slice +SPAWN_ENTITY_RULE = 'SpawnEntity' # Attributes IS_BLOCKING_LIGHT = 'var_is_blocking_light' @@ -29,7 +30,7 @@ VALUE_NO_POS = (-9999, -9999) # Invalid Position value used in the e ACTION = 'action' # Identifier of Action-objects and groups (groups). -COLLISION = 'Collision' # Identifier to use in the context of collitions. +COLLISION = 'Collisions' # Identifier to use in the context of collitions. # LAST_POS = 'LAST_POS' # Identifiert for retrieving an enitites last pos. VALIDITY = 'VALIDITY' # Identifiert for retrieving the Validity of Action, Tick, etc. ... @@ -54,3 +55,5 @@ NOOP = 'Noop' # Result Identifier MOVEMENTS_VALID = 'motion_valid' MOVEMENTS_FAIL = 'motion_not_valid' +DEFAULT_PATH = 'environment' +MODULE_PATH = 'modules' diff --git a/marl_factory_grid/environment/entity/agent.py b/marl_factory_grid/environment/entity/agent.py index 285c8d2..0920604 100644 --- a/marl_factory_grid/environment/entity/agent.py +++ b/marl_factory_grid/environment/entity/agent.py @@ -12,14 +12,6 @@ from marl_factory_grid.environment import constants as c class Agent(Entity): - @property - def var_is_blocking_light(self): - return False - - @property - def var_can_move(self): - return True - @property def var_is_paralyzed(self): return len(self._paralyzed) @@ -28,14 +20,6 @@ class Agent(Entity): def paralyze_reasons(self): return [x for x in self._paralyzed] - @property - def var_is_blocking_pos(self): - return False - - @property - def var_has_position(self): - return True - @property def obs_tag(self): return self.name @@ -48,10 +32,6 @@ class Agent(Entity): def observations(self): return self._observations - @property - def var_can_collide(self): - return True - def step_result(self): pass @@ -60,16 +40,21 @@ class Agent(Entity): return self._collection @property - def state(self): - return self._state or ActionResult(entity=self, identifier=c.NOOP, validity=c.VALID, reward=0) + def var_is_blocking_pos(self): + return self._is_blocking_pos - def __init__(self, actions: List[Action], observations: List[str], *args, **kwargs): + @property + def state(self): + return self._state or ActionResult(entity=self, identifier=c.NOOP, validity=c.VALID) + + def __init__(self, actions: List[Action], observations: List[str], *args, is_blocking_pos=False, **kwargs): super(Agent, self).__init__(*args, **kwargs) self._paralyzed = set() self.step_result = dict() self._actions = actions self._observations = observations self._state: Union[Result, None] = None + self._is_blocking_pos = is_blocking_pos # noinspection PyAttributeOutsideInit def clear_temp_state(self): diff --git a/marl_factory_grid/environment/entity/entity.py b/marl_factory_grid/environment/entity/entity.py index 637827f..4abf2af 100644 --- a/marl_factory_grid/environment/entity/entity.py +++ b/marl_factory_grid/environment/entity/entity.py @@ -14,7 +14,7 @@ class Entity(_Object, abc.ABC): @property def state(self): - return self._status or ActionResult(entity=self, identifier=c.NOOP, validity=c.VALID, reward=0) + return self._status or ActionResult(entity=self, identifier=c.NOOP, validity=c.VALID) @property def var_has_position(self): @@ -60,6 +60,10 @@ class Entity(_Object, abc.ABC): def pos(self): return self._pos + def set_pos(self, pos): + assert isinstance(pos, tuple) and len(pos) == 2 + self._pos = pos + @property def last_pos(self): try: @@ -84,7 +88,7 @@ class Entity(_Object, abc.ABC): for observer in self.observers: observer.notify_del_entity(self) self._view_directory = curr_pos[0] - next_pos[0], curr_pos[1] - next_pos[1] - self._pos = next_pos + self.set_pos(next_pos) for observer in self.observers: observer.notify_add_entity(self) return valid @@ -93,7 +97,7 @@ class Entity(_Object, abc.ABC): def __init__(self, pos, bind_to=None, **kwargs): super().__init__(**kwargs) self._status = None - self._pos = pos + self.set_pos(pos) self._last_pos = pos if bind_to: try: @@ -109,8 +113,9 @@ class Entity(_Object, abc.ABC): def render(self): return RenderEntity(self.__class__.__name__.lower(), self.pos) - def __repr__(self): - return super(Entity, self).__repr__() + f'(@{self.pos})' + @abc.abstractmethod + def render(self): + return RenderEntity(self.__class__.__name__.lower(), self.pos) @property def obs_tag(self): @@ -149,4 +154,4 @@ class Entity(_Object, abc.ABC): except StopIteration: pass except ValueError: - print() + pass diff --git a/marl_factory_grid/environment/entity/mixin.py b/marl_factory_grid/environment/entity/mixin.py deleted file mode 100644 index bab6343..0000000 --- a/marl_factory_grid/environment/entity/mixin.py +++ /dev/null @@ -1,24 +0,0 @@ - - -# noinspection PyAttributeOutsideInit -class BoundEntityMixin: - - @property - def bound_entity(self): - return self._bound_entity - - @property - def name(self): - if self.bound_entity: - return f'{self.__class__.__name__}({self.bound_entity.name})' - else: - pass - - def belongs_to_entity(self, entity): - return entity == self.bound_entity - - def bind_to(self, entity): - self._bound_entity = entity - - def unbind(self): - self._bound_entity = None diff --git a/marl_factory_grid/environment/entity/object.py b/marl_factory_grid/environment/entity/object.py index 8810baf..768f8b5 100644 --- a/marl_factory_grid/environment/entity/object.py +++ b/marl_factory_grid/environment/entity/object.py @@ -13,10 +13,6 @@ class _Object: def __bool__(self): return True - @property - def var_has_position(self): - return False - @property def var_can_be_bound(self): try: @@ -30,22 +26,14 @@ class _Object: @property def name(self): - if self._str_ident is not None: - name = f'{self.__class__.__name__}[{self._str_ident}]' - else: - name = f'{self.__class__.__name__}#{self.u_int}' - if self.bound_entity: - name = h.add_bound_name(name, self.bound_entity) - if self.var_has_position: - name = h.add_pos_name(name, self) - return name + return f'{self.__class__.__name__}[{self.identifier}]' @property def identifier(self): if self._str_ident is not None: return self._str_ident else: - return self.name + return self.u_int def reset_uid(self): self._u_idx = defaultdict(lambda: 0) @@ -62,7 +50,15 @@ class _Object: print(f'Following kwargs were passed, but ignored: {kwargs}') def __repr__(self): - return f'{self.name}' + name = self.name + if self.bound_entity: + name = h.add_bound_name(name, self.bound_entity) + try: + if self.var_has_position: + name = h.add_pos_name(name, self) + except (AttributeError): + pass + return name def __eq__(self, other) -> bool: return other == self.identifier @@ -88,7 +84,7 @@ class _Object: def summarize_state(self): return dict() - def bind(self, entity): + def bind_to(self, entity): # noinspection PyAttributeOutsideInit self._bound_entity = entity return c.VALID @@ -100,9 +96,6 @@ class _Object: def bound_entity(self): return self._bound_entity - def bind_to(self, entity): - self._bound_entity = entity - def unbind(self): self._bound_entity = None diff --git a/marl_factory_grid/environment/entity/util.py b/marl_factory_grid/environment/entity/util.py index 1a5cbe3..d43c53a 100644 --- a/marl_factory_grid/environment/entity/util.py +++ b/marl_factory_grid/environment/entity/util.py @@ -24,7 +24,7 @@ class PlaceHolder(_Object): @property def name(self): - return "PlaceHolder" + return self.__class__.__name__ class GlobalPosition(_Object): @@ -36,7 +36,8 @@ class GlobalPosition(_Object): else: return self.bound_entity.pos - def __init__(self, level_shape, *args, normalized: bool = True, **kwargs): + def __init__(self, agent, level_shape, *args, normalized: bool = True, **kwargs): super(GlobalPosition, self).__init__(*args, **kwargs) + self.bind_to(agent) self._normalized = normalized self._shape = level_shape diff --git a/marl_factory_grid/environment/entity/wall.py b/marl_factory_grid/environment/entity/wall.py index 3f0fb7c..83044cd 100644 --- a/marl_factory_grid/environment/entity/wall.py +++ b/marl_factory_grid/environment/entity/wall.py @@ -5,13 +5,8 @@ from marl_factory_grid.utils.utility_classes import RenderEntity class Wall(Entity): - @property - def var_has_position(self): - return True - - @property - def var_can_collide(self): - return True + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) @property def encoding(self): @@ -19,11 +14,3 @@ class Wall(Entity): def render(self): return RenderEntity(c.WALL, self.pos) - - @property - def var_is_blocking_pos(self): - return True - - @property - def var_is_blocking_light(self): - return True diff --git a/marl_factory_grid/environment/factory.py b/marl_factory_grid/environment/factory.py index d840178..3c5f7f6 100644 --- a/marl_factory_grid/environment/factory.py +++ b/marl_factory_grid/environment/factory.py @@ -87,11 +87,14 @@ class Factory(gym.Env): entities = self.map.do_init() # Init rules - rules = self.conf.load_rules() + env_rules = self.conf.load_env_rules() + entity_rules = self.conf.load_entity_spawn_rules(entities) + env_rules.extend(entity_rules) # Parse the agent conf parsed_agents_conf = self.conf.parse_agents_conf() - self.state = Gamestate(entities, parsed_agents_conf, rules, self.conf.env_seed, self.conf.verbose) + self.state = Gamestate(entities, parsed_agents_conf, env_rules, self.map.level_shape, + self.conf.env_seed, self.conf.verbose) # All is set up, trigger entity init with variable pos # All is set up, trigger additional init (after agent entity spawn etc) diff --git a/marl_factory_grid/environment/groups/agents.py b/marl_factory_grid/environment/groups/agents.py index f4a6ac6..d549384 100644 --- a/marl_factory_grid/environment/groups/agents.py +++ b/marl_factory_grid/environment/groups/agents.py @@ -1,10 +1,15 @@ from marl_factory_grid.environment.entity.agent import Agent from marl_factory_grid.environment.groups.collection import Collection +from marl_factory_grid.environment.rules import SpawnAgents class Agents(Collection): _entity = Agent + @property + def spawn_rule(self): + return {SpawnAgents.__name__: {}} + @property def var_is_blocking_light(self): return False diff --git a/marl_factory_grid/environment/groups/collection.py b/marl_factory_grid/environment/groups/collection.py index 640c3b4..140c941 100644 --- a/marl_factory_grid/environment/groups/collection.py +++ b/marl_factory_grid/environment/groups/collection.py @@ -1,18 +1,25 @@ -from typing import List, Tuple, Union +from typing import List, Tuple, Union, Dict from marl_factory_grid.environment.entity.entity import Entity from marl_factory_grid.environment.groups.objects import _Objects +# noinspection PyProtectedMember from marl_factory_grid.environment.entity.object import _Object import marl_factory_grid.environment.constants as c +from marl_factory_grid.utils.results import Result class Collection(_Objects): _entity = _Object # entity? + symbol = None @property def var_is_blocking_light(self): return False + @property + def var_is_blocking_pos(self): + return False + @property def var_can_collide(self): return False @@ -23,29 +30,61 @@ class Collection(_Objects): @property def var_has_position(self): - return False - - # @property - # def var_has_bound(self): - # return False # batteries, globalpos, inventories true - - @property - def var_can_be_bound(self): - return False + return True @property def encodings(self): return [x.encoding for x in self] - def __init__(self, size, *args, **kwargs): - super(Collection, self).__init__(*args, **kwargs) - self.size = size - - def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args): # woihn mit den args - if isinstance(coords_or_quantity, int): - self.add_items([self._entity() for _ in range(coords_or_quantity)]) + @property + def spawn_rule(self): + """Prevent SpawnRule creation if Objects are spawned by map, Doors e.g.""" + if self.symbol: + return None + elif self._spawnrule: + return self._spawnrule else: - self.add_items([self._entity(pos) for pos in coords_or_quantity]) + return {c.SPAWN_ENTITY_RULE: dict(collection=self, coords_or_quantity=self._coords_or_quantity)} + + def __init__(self, size, *args, coords_or_quantity: int = None, ignore_blocking=False, + spawnrule: Union[None, Dict[str, dict]] = None, + **kwargs): + super(Collection, self).__init__(*args, **kwargs) + self._coords_or_quantity = coords_or_quantity + self.size = size + self._spawnrule = spawnrule + self._ignore_blocking = ignore_blocking + + def trigger_spawn(self, state, *entity_args, coords_or_quantity=None, ignore_blocking=False, **entity_kwargs): + coords_or_quantity = coords_or_quantity if coords_or_quantity else self._coords_or_quantity + if self.var_has_position: + if isinstance(coords_or_quantity, int): + if ignore_blocking or self._ignore_blocking: + coords_or_quantity = state.entities.floorlist[:coords_or_quantity] + else: + coords_or_quantity = state.get_n_random_free_positions(coords_or_quantity) + self.spawn(coords_or_quantity, *entity_args, **entity_kwargs) + state.print(f'{len(coords_or_quantity)} new {self.name} have been spawned at {coords_or_quantity}') + return Result(identifier=f'{self.name}_spawn', validity=c.VALID, value=len(coords_or_quantity)) + else: + if isinstance(coords_or_quantity, int): + self.spawn(coords_or_quantity, *entity_args, **entity_kwargs) + state.print(f'{coords_or_quantity} new {self.name} have been spawned randomly.') + return Result(identifier=f'{self.name}_spawn', validity=c.VALID, value=coords_or_quantity) + else: + raise ValueError(f'{self._entity.__name__} has no position!') + + def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args, **entity_kwargs): + if self.var_has_position: + if isinstance(coords_or_quantity, int): + raise ValueError(f'{self._entity.__name__} should have a position!') + else: + self.add_items([self._entity(pos, *entity_args, **entity_kwargs) for pos in coords_or_quantity]) + else: + if isinstance(coords_or_quantity, int): + self.add_items([self._entity(*entity_args, **entity_kwargs) for _ in range(coords_or_quantity)]) + else: + raise ValueError(f'{self._entity.__name__} has no position!') return c.VALID def despawn(self, items: List[_Object]): @@ -115,7 +154,7 @@ class Collection(_Objects): except StopIteration: pass except ValueError: - print() + pass @property def positions(self): diff --git a/marl_factory_grid/environment/groups/global_entities.py b/marl_factory_grid/environment/groups/global_entities.py index 8bfc9fe..7a50de4 100644 --- a/marl_factory_grid/environment/groups/global_entities.py +++ b/marl_factory_grid/environment/groups/global_entities.py @@ -1,6 +1,6 @@ from collections import defaultdict from operator import itemgetter -from random import shuffle, random +from random import shuffle from typing import Dict from marl_factory_grid.environment.groups.objects import _Objects @@ -12,10 +12,10 @@ class Entities(_Objects): @staticmethod def neighboring_positions(pos): - return (POS_MASK + pos).reshape(-1, 2) + return [tuple(x) for x in (POS_MASK + pos).reshape(-1, 2)] def get_entities_near_pos(self, pos): - return [y for x in itemgetter(*(tuple(x) for x in self.neighboring_positions(pos)))(self.pos_dict) for y in x] + return [y for x in itemgetter(*self.neighboring_positions(pos))(self.pos_dict) for y in x] def render(self): return [y for x in self for y in x.render() if x is not None] @@ -35,8 +35,9 @@ class Entities(_Objects): super().__init__() def guests_that_can_collide(self, pos): - return[x for val in self.pos_dict[pos] for x in val if x.var_can_collide] + return [x for val in self.pos_dict[pos] for x in val if x.var_can_collide] + @property def empty_positions(self): empty_positions = [key for key in self.floorlist if not self.pos_dict[key]] shuffle(empty_positions) @@ -48,11 +49,23 @@ class Entities(_Objects): shuffle(empty_positions) return empty_positions - def is_blocked(self): - return[key for key, val in self.pos_dict.items() if any([x.var_is_blocking_pos for x in val])] + @property + def blocked_positions(self): + blocked_positions = [key for key, val in self.pos_dict.items() if any([x.var_is_blocking_pos for x in val])] + shuffle(blocked_positions) + return blocked_positions - def is_not_blocked(self): - return[key for key, val in self.pos_dict.items() if not all([x.var_is_blocking_pos for x in val])] + @property + def free_positions_generator(self): + generator = ( + key for key in self.floorlist if all(not x.var_can_collide and not x.var_is_blocking_pos + for x in self.pos_dict[key]) + ) + return generator + + @property + def free_positions_list(self): + return [x for x in self.free_positions_generator] def iter_entities(self): return iter((x for sublist in self.values() for x in sublist)) @@ -92,3 +105,6 @@ class Entities(_Objects): @property def positions(self): return [k for k, v in self.pos_dict.items() for _ in v] + + def is_occupied(self, pos): + return len([x for x in self.pos_dict[pos] if x.var_can_collide or x.var_is_blocking_pos]) >= 1 diff --git a/marl_factory_grid/environment/groups/mixins.py b/marl_factory_grid/environment/groups/mixins.py index 48333ca..acfac7e 100644 --- a/marl_factory_grid/environment/groups/mixins.py +++ b/marl_factory_grid/environment/groups/mixins.py @@ -4,10 +4,6 @@ from marl_factory_grid.environment import constants as c # noinspection PyUnresolvedReferences,PyTypeChecker class IsBoundMixin: - @property - def name(self): - return f'{self.__class__.__name__}({self._bound_entity.name})' - def __repr__(self): return f'{self.__class__.__name__}#{self._bound_entity.name}({self._data})' diff --git a/marl_factory_grid/environment/groups/objects.py b/marl_factory_grid/environment/groups/objects.py index d3f32af..d29cc2c 100644 --- a/marl_factory_grid/environment/groups/objects.py +++ b/marl_factory_grid/environment/groups/objects.py @@ -5,11 +5,16 @@ import numpy as np from marl_factory_grid.environment.entity.object import _Object import marl_factory_grid.environment.constants as c +from marl_factory_grid.utils import helpers as h class _Objects: _entity = _Object + @property + def var_can_be_bound(self): + return False + @property def observers(self): return self._observers @@ -148,12 +153,12 @@ class _Objects: def by_entity(self, entity): try: - return next((x for x in self if x.belongs_to_entity(entity))) + return h.get_first(self, filter_by=lambda x: x.belongs_to_entity(entity)) except (StopIteration, AttributeError): return None def idx_by_entity(self, entity): try: - return next((idx for idx, x in enumerate(self) if x.belongs_to_entity(entity))) + return h.get_first_index(self, filter_by=lambda x: x.belongs_to_entity(entity)) except (StopIteration, AttributeError): return None diff --git a/marl_factory_grid/environment/groups/utils.py b/marl_factory_grid/environment/groups/utils.py index 5619041..d272152 100644 --- a/marl_factory_grid/environment/groups/utils.py +++ b/marl_factory_grid/environment/groups/utils.py @@ -1,7 +1,10 @@ from typing import List, Union +from marl_factory_grid.environment import constants as c from marl_factory_grid.environment.entity.util import GlobalPosition from marl_factory_grid.environment.groups.collection import Collection +from marl_factory_grid.utils.results import Result +from marl_factory_grid.utils.states import Gamestate class Combined(Collection): @@ -36,17 +39,17 @@ class GlobalPositions(Collection): _entity = GlobalPosition - @property - def var_is_blocking_light(self): - return False - - @property - def var_can_collide(self): - return False - - @property - def var_can_be_bound(self): - return True + var_is_blocking_light = False + var_can_be_bound = True + var_can_collide = False + var_has_position = False def __init__(self, *args, **kwargs): super(GlobalPositions, self).__init__(*args, **kwargs) + + def spawn(self, agents, level_shape, *args, **kwargs): + self.add_items([self._entity(agent, level_shape, *args, **kwargs) for agent in agents]) + return [Result(identifier=f'{self.name}_spawn', validity=c.VALID, value=len(self))] + + def trigger_spawn(self, state: Gamestate, *args, **kwargs) -> [Result]: + return self.spawn(state[c.AGENT], state.lvl_shape, *args, **kwargs) diff --git a/marl_factory_grid/environment/groups/walls.py b/marl_factory_grid/environment/groups/walls.py index 2d85362..776bbca 100644 --- a/marl_factory_grid/environment/groups/walls.py +++ b/marl_factory_grid/environment/groups/walls.py @@ -7,9 +7,12 @@ class Walls(Collection): _entity = Wall symbol = c.SYMBOL_WALL - @property - def var_has_position(self): - return True + var_can_collide = True + var_is_blocking_light = True + var_can_move = False + var_has_position = True + var_can_be_bound = False + var_is_blocking_pos = True def __init__(self, *args, **kwargs): super(Walls, self).__init__(*args, **kwargs) diff --git a/marl_factory_grid/environment/rules.py b/marl_factory_grid/environment/rules.py index f9678b0..5f96424 100644 --- a/marl_factory_grid/environment/rules.py +++ b/marl_factory_grid/environment/rules.py @@ -1,6 +1,6 @@ import abc from random import shuffle -from typing import List +from typing import List, Collection, Union from marl_factory_grid.environment.entity.agent import Agent from marl_factory_grid.utils import helpers as h @@ -39,6 +39,29 @@ class Rule(abc.ABC): return [] +class SpawnEntity(Rule): + + @property + def _collection(self) -> Collection: + return Collection() + + @property + def name(self): + return f'{self.__class__.__name__}({self.collection.name})' + + def __init__(self, collection, coords_or_quantity, ignore_blocking=False): + super().__init__() + self.coords_or_quantity = coords_or_quantity + self.collection = collection + self.ignore_blocking = ignore_blocking + + def on_init(self, state, lvl_map) -> [TickResult]: + results = self.collection.trigger_spawn(state, ignore_blocking=self.ignore_blocking) + pos_str = f' on: {[x.pos for x in self.collection]}' if self.collection.var_has_position else '' + state.print(f'Initial {self.collection.__class__.__name__} were spawned{pos_str}') + return results + + class SpawnAgents(Rule): def __init__(self): @@ -46,14 +69,14 @@ class SpawnAgents(Rule): pass def on_init(self, state, lvl_map): - agent_conf = state.agents_conf # agents = Agents(lvl_map.size) agents = state[c.AGENT] - empty_positions = state.entities.empty_positions()[:len(agent_conf)] - for agent_name in agent_conf: - actions = agent_conf[agent_name]['actions'].copy() - observations = agent_conf[agent_name]['observations'].copy() - positions = agent_conf[agent_name]['positions'].copy() + empty_positions = state.entities.empty_positions[:len(state.agents_conf)] + for agent_name, agent_conf in state.agents_conf.items(): + actions = agent_conf['actions'].copy() + observations = agent_conf['observations'].copy() + positions = agent_conf['positions'].copy() + other = agent_conf['other'].copy() if positions: shuffle(positions) while True: @@ -61,18 +84,18 @@ class SpawnAgents(Rule): pos = positions.pop() except IndexError: raise ValueError(f'It was not possible to spawn an Agent on the available position: ' - f'\n{agent_name[agent_name]["positions"].copy()}') - if agents.by_pos(pos) and state.check_pos_validity(pos): + f'\n{agent_conf["positions"].copy()}') + if bool(agents.by_pos(pos)) or not state.check_pos_validity(pos): continue else: - agents.add_item(Agent(actions, observations, pos, str_ident=agent_name)) + agents.add_item(Agent(actions, observations, pos, str_ident=agent_name, **other)) break else: - agents.add_item(Agent(actions, observations, empty_positions.pop(), str_ident=agent_name)) + agents.add_item(Agent(actions, observations, empty_positions.pop(), str_ident=agent_name, **other)) pass -class MaxStepsReached(Rule): +class DoneAtMaxStepsReached(Rule): def __init__(self, max_steps: int = 500): super().__init__() @@ -83,8 +106,8 @@ class MaxStepsReached(Rule): def on_check_done(self, state): if self.max_steps <= state.curr_step: - return [DoneResult(validity=c.VALID, identifier=self.name, reward=0)] - return [DoneResult(validity=c.NOT_VALID, identifier=self.name, reward=0)] + return [DoneResult(validity=c.VALID, identifier=self.name)] + return [DoneResult(validity=c.NOT_VALID, identifier=self.name)] class AssignGlobalPositions(Rule): @@ -101,7 +124,7 @@ class AssignGlobalPositions(Rule): return [] -class Collision(Rule): +class WatchCollisions(Rule): def __init__(self, done_at_collisions: bool = False): super().__init__() @@ -132,4 +155,4 @@ class Collision(Rule): move_failed = any(h.is_move(x.state.identifier) and not x.state.validity for x in state[c.AGENT]) if inter_entity_collision_detected or move_failed: return [DoneResult(validity=c.VALID, identifier=c.COLLISION, reward=r.COLLISION)] - return [DoneResult(validity=c.NOT_VALID, identifier=self.name, reward=0)] + return [DoneResult(validity=c.NOT_VALID, identifier=self.name)] diff --git a/marl_factory_grid/modules/batteries/__init__.py b/marl_factory_grid/modules/batteries/__init__.py index 0218021..80671fd 100644 --- a/marl_factory_grid/modules/batteries/__init__.py +++ b/marl_factory_grid/modules/batteries/__init__.py @@ -1,4 +1,4 @@ from .actions import BtryCharge -from .entitites import Pod, Battery +from .entitites import ChargePod, Battery from .groups import ChargePods, Batteries from .rules import DoneAtBatteryDischarge, BatteryDecharge diff --git a/marl_factory_grid/modules/batteries/actions.py b/marl_factory_grid/modules/batteries/actions.py index 343bbcc..bd755a2 100644 --- a/marl_factory_grid/modules/batteries/actions.py +++ b/marl_factory_grid/modules/batteries/actions.py @@ -6,6 +6,7 @@ from marl_factory_grid.utils.results import ActionResult from marl_factory_grid.modules.batteries import constants as b from marl_factory_grid.environment import constants as c +from marl_factory_grid.utils import helpers as h class BtryCharge(Action): @@ -14,8 +15,8 @@ class BtryCharge(Action): super().__init__(b.ACTION_CHARGE) def do(self, entity, state) -> Union[None, ActionResult]: - if charge_pod := state[b.CHARGE_PODS].by_pos(entity.pos): - valid = charge_pod.charge_battery(state[b.BATTERIES].by_entity(entity)) + if charge_pod := h.get_first(state[b.CHARGE_PODS].by_pos(entity.pos)): + valid = h.get_first(charge_pod.charge_battery(state[b.BATTERIES].by_entity(entity))) if valid: state.print(f'{entity.name} just charged batteries at {charge_pod.name}.') else: diff --git a/marl_factory_grid/modules/batteries/chargepods.png b/marl_factory_grid/modules/batteries/chargepods.png new file mode 100644 index 0000000000000000000000000000000000000000..7221daa82694b49d42a4b8fec084514c45b93362 GIT binary patch literal 8135 zcmd6MS5y?=(r-^QBg~Mqs?Qc47I5->=*z5sC9KT@c@8; zUl9Ps8Tesyzu@*iU!Hhvb)c+|^CtlC1nFvCHubSvODFd+H9HOubP|rmVZ|aSbm%F0 z$AHTgRKX?uk++yBd7p;Xl$3=cZZL96(I*LdchH1H^hCN?xnj9EO2%U6?M3mLO2}Z{ zkG)5i9%5f!>06VRZt>0VH{Vj8IqEw}_pSF_npOt=UI1vpbBGre;!7WNAtVk#WQ@gu zj9CCf-I0wF+6b5=6O1Jxu@KTtEy)mM0uEsM3kVEvP-=f(9DpLJ;y|&YZh!G9UV=;kKp*k zh;ApqhFe5cKmJxPbZt`f5mTvNJgT9xuZ{&ui}OSKB_p$*au5~sbwRuvSvvvd*N96| zXvBaX;(%iLeiW4ap|6V!p=k6vp^=+bd~2M29uF_OwgOxncB!z6VFR}xy)fT)mu6FV z#~)P32&JXJJkyfabGwTVA$p!p^%+b1T(^xi-UimE)p%1#@(eM@MN_dd9T}iredy}J zw?C$+%fApUiRdjo6?<|vO^4c`u12r%+#{BmMB>Uj=T;8^K(S$;dGpyr*ZKBpJN{Wl zf6bVj6g#oPHzN#}z-0RqeVxJ5H*&#VEBqJqUo+dkc0q=S3(>KHu{$pIh)*r2jCW%w z2Tg2@WF)gkT-zMJYXF{cGhw6zxy>7T{S2{;n%AFtUn92{rugcE;8fUYX+x6Hl%^oP zpH|Q3-o5>Ey72uXCBbdYg|D5%oL55cZZtLZNiO!MkNB61+KQlP9><{HXy6O{`&NFr zoLU|6SNDwHJ~3b<2Xz9dz=M>ZcE)%Uzrka!)xxO}1NLY$L8J1ULVdQLLU_GE6Ha5Z z)-=bPhWvA?pfX(U3wV~Uz23!OM-C)}1;R1m3 zoA4oJ*Nf6=j#nA#{JvF&No&q&XYN)0F$T|w;sIF$Aa$gygCj=GID1OMP?61I$}7C} z-t6=RX>g!uy(!?NbzT%fcJDQCHyQAD9e+OB`KL;~c;bNyb<&jiPT8|cBS3{A3)p2W zaOOdHZ;S;IFPPrQ&VQiB!8#+Oa|uY&wmb*u;hZk{B4-EaNku}J2z#{pgC}7t58f5H z`WWBfvtv&%^42b2`J7Jq`0oN*p+?_H=H3)$FZ%u>zcL3Bhh84L!30HtW0k(r z=8{KaAT<`pB_V;lci@<#EU-d;ivbpu!J7VTwwYBH( zS%0j1Lz(SBv$aLF!qLsPnj)%dT_7I9{PsDt1r#N}#m#`{j&`Mc)Autm_m)Qv9MU{L zbwsWes;s6w=7HjGB~9)1`wukP+qe0)9N!df_Sp!*d><;SXU{)Sm?2L8Ix~(1^sMeIHV_a#sHbu0(qV}^%V^SKfg$uJ#E|> z7Z{i$&_`>3x40AK9I-w&9K0mLUKeBp?$ERu}^la&GM{ zJ3QL2tE&wt-`kI~+DfkHZvQEex!o;LKlaD^8;A3Gb|13w%%NeUKKI4mj{J?^wI9a* z)Y57d@2^oUm~#Dyg{6tHt0Ts+K-5uJT>flVc{Mrn-z}+=i?{|lod(|J&Cxvh{M!d|Nl2I;zf%+oCm95gGhmBu>2DFEe$tOGC5-;&Jt7 zi~%`H-1pL^8}h$Prm7FSR)UMmKd?Eg?l#+>i}@SSnxA~IXwAGm9WRTa=8%8q*Sw9k zr2~YBmW)_{99iKv`1gWmcT4x`sbi{q*IpY>#VNJ<6w+jeyCZS$R_1``mlUTg_s-Z~ zVb`~3Y#Ct2tjA0c88~2ht7(G z)3H&vzWXQK#E`O;RH*kH9$$LhY*9)Q;vxf_cV7_Efn@Y3qfnqa>SV9v+Z)Mo#kw8X zwOC-U4rf&Mdg7gsK592E&m!JpA8a-w*}8{ti8f$q+%q(IG%nO#NG<+%e5eG_yk5C1 zduyX69;)QhLYVSH_7-Kq`(dZIj_&WWJ;)VuA8SH27}s9*ceK4vH1!H}EKbbv-^unT z%V~;s`6lH}tC@S7ndjYIcpr*b#Bg7HR4&ksCVYOoko@O^{w=JteYmszv#q=_5=RBRX1KFo;^8lJx^!baO%!UPCq1@;>O;b_93Hjm?6IFFg$)*<7> z?5mcU(oAA-)PNiG#qCMa>aZtCc6m2{$5k0rNKWP<|0>UgGFlfP&8Jjt9qk^72^FL_ z4)vio%@l+3?ERj8GZDYUS3Xr(FjZ;!=hhxQyo~!pJdluv{O!H2a?a;uN>4Xq^sm%| z^(lRx1!TnZ+W=KHAED>LYqvJAUlRjfJ28(50Rfyw6Azeh%=TpFUZH7Wd)wKyIeX4r zUnxX%;`e3v)(W*3?DNz7Sf7~2}lN+Zm86Ft6LD%`s@%Q@q|Kf`rfB0?f z{evsv>r=5WyLp%tDXu)8&fafo*>^>(4ObWo*zDCR8?~_Oj{74?o*)y69$lb9zdWlY zVQ42{&q&Kzy#tKb$IXoV4u&%d`BlR^#%oe(3_uK0%pDa2Fm0d)xL19*t(?povKduz zkL#A*T}HC5c17!9)8IofsuW^^a)5Ym>5vDE#OrXdePb~On4op}(gd(NFn;la)uV%bxXR9%iTpQRtKb zE3?zUAp6G`hJ&cA-&fq!c2vK9+Z7MsRxtH7@CuT(0mirQRwJ){k@A<&J-|54^J{ ze^==w_aRq>%<>hiI0FCr-S>tUAd%BI1uY688Aw~6PFbfrzRcc}ectJtUYi~k-u$Rp%)Uk( z9zgR$?@UCphLy`L#_n}!%SYROw|}teQMRgyAu+EqkuHN`v?#Ayt;O313GwY}8}di@ z?ed;Bf9h?Ej&@2(ZEwGWAt6;L!?r1Sqho+yrMAo`0WH(7GR3FF#KWdz>>u0SZIzSC z2nH|o^+mp5!lhwm$FwLN+yep!Hxkz4?91x7#og5;H!?H@E zAd0|eKvsx_*ma5e6oIVU|@rc&vnhO?1kyOsBFMhqc}bp+^hhaGgQelo0M$} zUznk8t=mNPbu?@Z#VS{pl^?j7y>OExJ{tO(kBi31T)Qa zf3cS1`V&+Q@}wo8PYofPN8D0on%`PiIEITRyvG4!hZ}|yc%eo>Zge9gRK&FL%Cy`( z=ILnJA^nEhL9qJkM<0hD&V(+C)R^(i;%?S@$>)4pKkes{IhGO!?oHd7?Ej(B(GS?f z=}k&G8T~@$u6R7M`9O&y`{NjAd|0?O!#C&oDll*7bmRNA&4m-5>(_gZEcUysSof=C z@&|8%#%Yavx@qryO9q+3JFWPjX_IG$Pp{-kQixs!8OUzamMDMT&EbSYmJ_o5Rs1aP zDo&`|v4HkKuBFDQiSh+S*Bus zMxWwq>`U*TiPI>la%E42*I-R^M^}rja&aJcD36_$bMu6Qa#sx1- z2nVFzZqFpdo|;9{<$by;taiVj<-Z8NNPFh1N357eZfa|-Sa5UiKhrB=b;JUa-l4Lh zj1?LHTkejq4!VX?T+157|f?l9E1wqGrort(-SQ$zFA)DkFS(!nZ3#g1Tu)Ge3xImr$ z-WCRRG1VX~u-^cKS|2f~N^%edB)vjvmj9)B@DbNd>mG7Eg&-Vv0k_(g!TiU-Ja@ zlMl8_7>$3n`6mz5-G@8qbCgcxhC~0wUd{idYTF!f$Hf2gTP!ELUH-9x`xPT#dX=&m zFG$D*)IegT9@`*ApO2U8?BfK>@($SQD-XfD@=fgXEIA`<=;N6y=x{V_C;7##kKivBkMfu zCJ%yS&kDm5XwoRs$Z%^Ob5cseY0~_6d@i1+dJWh>eWIlTgcjgF53*e#lDT|(>6)6S z-ev~a;WOHX1xBq&U@sjJ2dDzs0hW>rs&i>`*HW$G;v7W*5ogcyJ?JqBh;U)qno)N+ zV7aPZZwo(Ltotpf3|2b=R(k+7LP&pT(< zoX0K2kI;WS6kJjtc0GBSbLkp7+OnM8Ud?T}pesglE9q`vuDr?Txjjz=69i-SC0&qB z+H-egulUK~b&7hMLrotC;+T{w`T-9L;CZ4E6!2(~D3u5jU&<}LO@j~!PCE^RAUEX_ zva^8Y_75x$bMX#-ze2RgQ-Mtzp`e@UX=O-y4bXfQNDR|xQz~L{n#IlrmY#-Ag|re( zy=d~`Q_k1%#r`OSXGzZ*Uj38v>;POltFx?cZ^*a?F5V7KfRc&$yO)HE z_O}O&VZ=Nlb;UOx%Fqwu%bNOPeX2kIfoMo@+Y1WRy!pJ3LsDqni1T~OKdV(bQtb#M zhEs#SNbfuv7lU5F8R_M32AsbTo?8)<1;oGa;4Np^j09xVG_EW~F*67-Tb1sdoH{i# zGzK;}G)lntEfSU~UqwMX22z%B3{U0{Nn46pWZoX7hr7b#u{ga-ITZW{9(3!v6NQg1 zFdjK1y#vyHG`rr^K023%bIP+!4aNegB&kwqq>2bM+s+j4W=jz?#Eh;xM&5uosQ=Ff z%h9~v=Xs?MX!zQFv>bRsHG-NWD-*+wfkc5@>LO?#V?dsk$Kuk<)e}SDo~h-b;IY9+ zbg${t8l^C>;uUsBQIGTtj7X+?nnNeGHpLfE)YK zvS{<`j}IXoWBz&8wYJt9;xGvnO!DNDX&)m@;;;n1y>HJzS6gKCf+1IU0cjc-YAKY5 zr$S{td3xDSX!aK3=FNpSx)e3~k)f7MecU6To%@)ep?1a0JY?fJXsjM8QCSn_%hQh1 zooj%doUG~`%JixfhNL^5)^i`sV1=T27+>Jk)I^;QPZKqi)YbnpWycC-ScjrEJKBwA zReJc00Ym-;8$*W8_dsxKmitvfFLLONn9crC9ETx#+9Zp?MT;s1D9QTBSw;`j>}HS? zkMrB%boTDN$h%+8PY1;yI)J<4&?|Zx)Oq8GD#0gU1yUfcPAlsSnm`Si4|{$&P!1hlQofHw z5fXX38uYR{DF_e#u*SY^084_2*a0KHzL#_u62o(Yd6pn{0jOlM7U7YH$%oOwb~T|w zQ)tAM>OZV+C~*CbZMk_T>@CM13+@?0c|clLucE7fkqI(3BP9tH=ZoCjtj=0m z(3<3b;#F1$sJ%R^CCl<+s4bg3si5Q4vmBsz>-(J~^j9YMgcj|DLzBybd&tw|`IrvX z|BoecRzRzEg8+lsBNH@X6E?N1IBQ}?LMlzfQ$5AgX3w@(@@jjwx2IZ&vFK#RFhslMFhX%Ede{P1+@9)l}`1l}UR*-FN^L}5%f>`Bk7_d|!6{pR>BCt+T zq0CBVej>}Xgfu$m>JnHXpep|NG0q#GGP4vTmr2LDRjgO??JS0L?sD|Q1gHo&v$T+R zOj)5_GbR|g`y?quV5GQ{NA$evp^l5@OY`EzNEuwyryyMR-NqCJ42eoL+AAIAs#YJW zx-U;QZ<&`P6c+9Y`s!5Xfe@Vzw>oERDN=dxT^5Xo0QBHb7UM4(0|#AM<1 za^U?nDpWPkiOlc!iiaPc4(fc-8%BaYUJlp`nzGZboc}%_TnrHQyl+KBaC*piGiJ^x zD?R!Ull{}mcL%X+^qDVcMK~06FX&8#Z;N^ve)w3Rm+46ZgbFv^*X!&HF^KVQ1C=ZR z+b*)pw75xbXDHFo|Ig?mT#=`m(0q^(rFyd)MH3#sP7tHFsO-_;f&RbadGEef$Bhs8mzY z%C0+gw2SfeUkwPb{=FyJV*&aHsVZ`tAwiX4;YL8%;6FupnHjf=<&~CX1rgu(DA$xu zJ8%y&AZ!Auo^F}uX%+QS{71M-g>i7Z6M0btI?iz45QrW@6|kHsaY{eKPvMa!V$Qel zCgRegz8@E9%e>xmV@S^PY5jQc3Nnj1_M(!chMqx1_1WA{PdMrNxYa%{6Q#dn0qI?Z z5ijrxs^V`h>0kBNf^}{_)!tejHGB6!vEB;2jht;pwsXx5R4jd#x*(kaUeJhfv=t$k z0~gUZ7iMTaE^*pEf4V|hQW6JNycG6SG$dgV?SdDS4zBCW^)mZ#`|j)rK;I`47_VO? zi6J;&;+U>Y?5{!bsZXvwPfO-9dVC>bxI*kxmjFP{>xns`uoMMKyHD-ukN=)QtOfBN zOR^OC6+Y(Ma&71F?Tjivj@HZsE8N|ss>qNfETq<#2lf}S8(=#Cbe*-1oCIety@>D` z`6>R}bR`ck)xFpk_YZF5qyGWBF=n8jd^Zkv35bfOQhz-3%o#r$a?^e;jWwjTDFgX1 z3yUH3Ox?IQ%urzsEKm+)oHH%HFg{=r8J4H3Tba43Cnhn)eDy<_e_^w&;q)g93C>Ws zjv*K{8udxe2%n@UmEBmtaZ;gM0KOgajLQ0V+#3Ol?6`k>2s|bWffu@Ba&3j-F&6+u zi0095fP0J=M&TOEc`4dE1s*Gd(TrjM=b97(L$V=*h5c3-zXdWu^=DvVhaS~E|13aF zGys04_^R?<=(4Hjyb7XmxPKMfGDh+LO;$QMS_B41E7AE9mqwoj%$H=_(4d$3aERsX z4oijzX%N0Tkd5!*L#V>yRIAVGgC18PFrq5#zM^1OI9$uaSsZnF`zVY6j}fEu;mBP& zxQ#O4a1Pzh2KPw&*t-{EA#QkA+}EEA1t|#N5bahP^_&a?8t~$Q#Oz~&McN@Dit!2Af-_!Ize0@|?V=i~_!YAIzSWWxrXt#H3zk7NoZM}tYr(;Uv zayrzhr~@H)>4~d_q~mS}8FVbf2x~5a^`v3g#KNY|jk;kaIE#ee(A{YvNEJN@+`Qvx z$%0n641vU8sU==S4BSdm89$q-1scNE_@^Z3kaA%mF>+zIUi^PA?hlXcRz^+f740;* PGY{x$8ETfP6N3K-iRyx+ literal 0 HcmV?d00001 diff --git a/marl_factory_grid/modules/batteries/entitites.py b/marl_factory_grid/modules/batteries/entitites.py index b51f2dd..751d57b 100644 --- a/marl_factory_grid/modules/batteries/entitites.py +++ b/marl_factory_grid/modules/batteries/entitites.py @@ -50,7 +50,7 @@ class Battery(_Object): return summary -class Pod(Entity): +class ChargePod(Entity): @property def encoding(self): @@ -58,7 +58,7 @@ class Pod(Entity): def __init__(self, *args, charge_rate: float = 0.4, multi_charge: bool = False, **kwargs): - super(Pod, self).__init__(*args, **kwargs) + super(ChargePod, self).__init__(*args, **kwargs) self.charge_rate = charge_rate self.multi_charge = multi_charge diff --git a/marl_factory_grid/modules/batteries/groups.py b/marl_factory_grid/modules/batteries/groups.py index 8d9e060..7db43bd 100644 --- a/marl_factory_grid/modules/batteries/groups.py +++ b/marl_factory_grid/modules/batteries/groups.py @@ -1,52 +1,36 @@ from typing import Union, List, Tuple +from marl_factory_grid.environment import constants as c from marl_factory_grid.environment.groups.collection import Collection -from marl_factory_grid.modules.batteries.entitites import Pod, Battery +from marl_factory_grid.modules.batteries.entitites import ChargePod, Battery +from marl_factory_grid.utils.results import Result class Batteries(Collection): _entity = Battery - @property - def var_is_blocking_light(self): - return False - - @property - def var_can_collide(self): - return False - - @property - def var_can_move(self): - return False - - @property - def var_has_position(self): - return False - - @property - def var_can_be_bound(self): - return True + var_has_position = False + var_can_be_bound = True @property def obs_tag(self): return self.__class__.__name__ - def __init__(self, *args, **kwargs): - super(Batteries, self).__init__(*args, **kwargs) + def __init__(self, size, initial_charge_level: float=1.0, *args, **kwargs): + super(Batteries, self).__init__(size, *args, **kwargs) + self.initial_charge_level = initial_charge_level - def spawn(self, agents, initial_charge_level): - batteries = [self._entity(initial_charge_level, agent) for _, agent in enumerate(agents)] + def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], agents, *entity_args, **entity_kwargs): + batteries = [self._entity(self.initial_charge_level, agent) for _, agent in enumerate(agents)] self.add_items(batteries) - # def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args): hat keine pos - # agents = entity_args[0] - # initial_charge_level = entity_args[1] - # batteries = [self._entity(initial_charge_level, agent) for _, agent in enumerate(agents)] - # self.add_items(batteries) + def trigger_spawn(self, state, *entity_args, coords_or_quantity=None, **entity_kwargs): + self.spawn(0, state[c.AGENT]) + return Result(identifier=f'{self.name}_spawn', validity=c.VALID, value=len(self)) class ChargePods(Collection): - _entity = Pod + _entity = ChargePod def __init__(self, *args, **kwargs): super(ChargePods, self).__init__(*args, **kwargs) diff --git a/marl_factory_grid/modules/batteries/rules.py b/marl_factory_grid/modules/batteries/rules.py index e060629..84b7ef2 100644 --- a/marl_factory_grid/modules/batteries/rules.py +++ b/marl_factory_grid/modules/batteries/rules.py @@ -49,10 +49,6 @@ class BatteryDecharge(Rule): self.per_action_costs = per_action_costs self.initial_charge = initial_charge - def on_init(self, state, lvl_map): # on reset? - assert len(state[c.AGENT]), "There are no agents, did you already spawn them?" - state[b.BATTERIES].spawn(state[c.AGENT], self.initial_charge) - def tick_step(self, state) -> List[TickResult]: # Decharge batteries = state[b.BATTERIES] @@ -66,7 +62,7 @@ class BatteryDecharge(Rule): batteries.by_entity(agent).decharge(energy_consumption) - results.append(TickResult(self.name, reward=0, entity=agent, validity=c.VALID)) + results.append(TickResult(self.name, entity=agent, validity=c.VALID)) return results @@ -82,13 +78,13 @@ class BatteryDecharge(Rule): if self.paralyze_agents_on_discharge: btry.bound_entity.paralyze(self.name) results.append( - TickResult("Paralyzed", entity=btry.bound_entity, reward=0, validity=c.VALID) + TickResult("Paralyzed", entity=btry.bound_entity, validity=c.VALID) ) state.print(f'{btry.bound_entity.name} has just been paralyzed!') if btry.bound_entity.var_is_paralyzed and not btry.is_discharged: btry.bound_entity.de_paralyze(self.name) results.append( - TickResult("De-Paralyzed", entity=btry.bound_entity, reward=0, validity=c.VALID) + TickResult("De-Paralyzed", entity=btry.bound_entity, validity=c.VALID) ) state.print(f'{btry.bound_entity.name} has just been de-paralyzed!') return results @@ -132,7 +128,7 @@ class DoneAtBatteryDischarge(BatteryDecharge): if any_discharged or all_discharged: return [DoneResult(self.name, validity=c.VALID, reward=self.reward_discharge_done)] else: - return [DoneResult(self.name, validity=c.NOT_VALID, reward=0)] + return [DoneResult(self.name, validity=c.NOT_VALID)] class SpawnChargePods(Rule): @@ -155,7 +151,7 @@ class SpawnChargePods(Rule): def on_init(self, state, lvl_map): pod_collection = state[b.CHARGE_PODS] - empty_positions = state.entities.empty_positions() + empty_positions = state.entities.empty_positions pods = pod_collection.from_coordinates(empty_positions, entity_kwargs=dict( multi_charge=self.multi_charge, charge_rate=self.charge_rate) ) diff --git a/marl_factory_grid/modules/clean_up/__init__.py b/marl_factory_grid/modules/clean_up/__init__.py index 31cb841..ec4d1e7 100644 --- a/marl_factory_grid/modules/clean_up/__init__.py +++ b/marl_factory_grid/modules/clean_up/__init__.py @@ -1,4 +1,4 @@ from .actions import CleanUp from .entitites import DirtPile from .groups import DirtPiles -from .rules import SpawnDirt, EntitiesSmearDirtOnMove, DoneOnAllDirtCleaned +from .rules import EntitiesSmearDirtOnMove, DoneOnAllDirtCleaned diff --git a/marl_factory_grid/modules/clean_up/entitites.py b/marl_factory_grid/modules/clean_up/entitites.py index 8ac8a0c..19e703c 100644 --- a/marl_factory_grid/modules/clean_up/entitites.py +++ b/marl_factory_grid/modules/clean_up/entitites.py @@ -7,22 +7,6 @@ from marl_factory_grid.modules.clean_up import constants as d class DirtPile(Entity): - @property - def var_can_collide(self): - return False - - @property - def var_can_move(self): - return False - - @property - def var_is_blocking_light(self): - return False - - @property - def var_has_position(self): - return True - @property def amount(self): return self._amount diff --git a/marl_factory_grid/modules/clean_up/groups.py b/marl_factory_grid/modules/clean_up/groups.py index 63e5898..2029171 100644 --- a/marl_factory_grid/modules/clean_up/groups.py +++ b/marl_factory_grid/modules/clean_up/groups.py @@ -9,68 +9,55 @@ from marl_factory_grid.modules.clean_up.entitites import DirtPile class DirtPiles(Collection): _entity = DirtPile - @property - def var_is_blocking_light(self): - return False + var_is_blocking_light = False + var_can_collide = False + var_can_move = False + var_has_position = True @property - def var_can_collide(self): - return False - - @property - def var_can_move(self): - return False - - @property - def var_has_position(self): - return True - - @property - def amount(self): + def global_amount(self): return sum([dirt.amount for dirt in self]) def __init__(self, *args, max_local_amount=5, clean_amount=1, - max_global_amount: int = 20, **kwargs): + max_global_amount: int = 20, + coords_or_quantity=10, + initial_amount=2, + amount_var=0.2, + n_var=0.2, + **kwargs): super(DirtPiles, self).__init__(*args, **kwargs) + self.amount_var = amount_var + self.n_var = n_var self.clean_amount = clean_amount self.max_global_amount = max_global_amount self.max_local_amount = max_local_amount + self.coords_or_quantity = coords_or_quantity + self.initial_amount = initial_amount - def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args): - amount_s = entity_args[0] + def trigger_spawn(self, state, coords_or_quantity=0, amount=0) -> [Result]: + coords_or_quantity = coords_or_quantity if coords_or_quantity else self.coords_or_quantity + n_new = int(abs(coords_or_quantity + (state.rng.uniform(-self.n_var, self.n_var)))) + n_new = state.get_n_random_free_positions(n_new) + + amounts = [amount if amount else (self.initial_amount + state.rng.uniform(-self.amount_var, self.amount_var)) + for _ in range(coords_or_quantity)] spawn_counter = 0 - for idx, pos in enumerate(coords_or_quantity): - if not self.amount > self.max_global_amount: - amount = amount_s[idx] if isinstance(amount_s, list) else amount_s + for idx, (pos, a) in enumerate(zip(n_new, amounts)): + if not self.global_amount > self.max_global_amount: if dirt := self.by_pos(pos): dirt = next(dirt.iter()) - new_value = dirt.amount + amount + new_value = dirt.amount + a dirt.set_new_amount(new_value) else: - dirt = DirtPile(pos, amount=amount) - self.add_item(dirt) + super().spawn([pos], amount=a) spawn_counter += 1 else: - return Result(identifier=f'{self.name}_spawn', validity=c.NOT_VALID, reward=0, - value=spawn_counter) - return Result(identifier=f'{self.name}_spawn', validity=c.VALID, reward=0, value=spawn_counter) + return Result(identifier=f'{self.name}_spawn', validity=c.NOT_VALID, value=spawn_counter) - def trigger_dirt_spawn(self, n, amount, state, n_var=0.2, amount_var=0.2) -> Result: - free_for_dirt = [x for x in state.entities.floorlist if len(state.entities.pos_dict[x]) == 0 or ( - len(state.entities.pos_dict[x]) >= 1 and isinstance(next(y for y in x), DirtPile))] - # free_for_dirt = [x for x in state[c.FLOOR] - # if len(x.guests) == 0 or ( - # len(x.guests) == 1 and - # isinstance(next(y for y in x.guests), DirtPile))] - state.rng.shuffle(free_for_dirt) - - new_spawn = int(abs(n + (state.rng.uniform(-n_var, n_var)))) - new_amount_s = [abs(amount + (amount*state.rng.uniform(-amount_var, amount_var))) for _ in range(new_spawn)] - n_dirty_positions = free_for_dirt[:new_spawn] - return self.spawn(n_dirty_positions, new_amount_s) + return Result(identifier=f'{self.name}_spawn', validity=c.VALID, value=spawn_counter) def __repr__(self): s = super(DirtPiles, self).__repr__() - return f'{s[:-1]}, {self.amount})' + return f'{s[:-1]}, {self.global_amount}]' diff --git a/marl_factory_grid/modules/clean_up/rules.py b/marl_factory_grid/modules/clean_up/rules.py index 3f58cdb..be2f9b9 100644 --- a/marl_factory_grid/modules/clean_up/rules.py +++ b/marl_factory_grid/modules/clean_up/rules.py @@ -22,58 +22,37 @@ class DoneOnAllDirtCleaned(Rule): def on_check_done(self, state) -> [DoneResult]: if len(state[d.DIRT]) == 0 and state.curr_step: return [DoneResult(validity=c.VALID, identifier=self.name, reward=self.reward)] - return [DoneResult(validity=c.NOT_VALID, identifier=self.name, reward=0)] + return [DoneResult(validity=c.NOT_VALID, identifier=self.name)] -class SpawnDirt(Rule): +class RespawnDirt(Rule): - def __init__(self, initial_n: int = 5, initial_amount: float = 1.3, - respawn_n: int = 3, respawn_amount: float = 0.8, - n_var: float = 0.2, amount_var: float = 0.2, spawn_freq: int = 15): + def __init__(self, respawn_freq: int = 15, respawn_n: int = 5, respawn_amount: float = 1.0): """ Defines the spawn pattern of intial and additional 'Dirt'-entitites. First chooses positions, then trys to spawn dirt until 'respawn_n' or the maximal global amount is reached. If there is allready some, it is topped up to min(max_local_amount, amount). - :type spawn_freq: int - :parameter spawn_freq: In which frequency should this Rule try to spawn new 'Dirt'? + :type respawn_freq: int + :parameter respawn_freq: In which frequency should this Rule try to spawn new 'Dirt'? :type respawn_n: int :parameter respawn_n: How many respawn positions are considered. - :type initial_n: int - :parameter initial_n: How much initial positions are considered. - :type amount_var: float - :parameter amount_var: Variance of amount to spawn. - :type n_var: float - :parameter n_var: Variance of n to spawn. :type respawn_amount: float :parameter respawn_amount: Defines how much dirt 'amount' is placed every 'spawn_freq' ticks. - :type initial_amount: float - :parameter initial_amount: Defines how much dirt 'amount' is initially placed. - """ super().__init__() - self.amount_var = amount_var - self.n_var = n_var - self.respawn_amount = respawn_amount self.respawn_n = respawn_n - self.initial_amount = initial_amount - self.initial_n = initial_n - self.spawn_freq = spawn_freq - self._next_dirt_spawn = spawn_freq - - def on_init(self, state, lvl_map) -> str: - result = state[d.DIRT].trigger_dirt_spawn(self.initial_n, self.initial_amount, state, - n_var=self.n_var, amount_var=self.amount_var) - state.print(f'Initial Dirt was spawned on: {[x.pos for x in state[d.DIRT]]}') - return result + self.respawn_amount = respawn_amount + self.respawn_freq = respawn_freq + self._next_dirt_spawn = respawn_freq def tick_step(self, state): + collection = state[d.DIRT] if self._next_dirt_spawn < 0: pass # No DirtPile Spawn elif not self._next_dirt_spawn: - result = [state[d.DIRT].trigger_dirt_spawn(self.respawn_n, self.respawn_amount, state, - n_var=self.n_var, amount_var=self.amount_var)] - self._next_dirt_spawn = self.spawn_freq + result = [collection.trigger_spawn(state, coords_or_quantity=self.respawn_n, amount=self.respawn_amount)] + self._next_dirt_spawn = self.respawn_freq else: self._next_dirt_spawn -= 1 result = [] @@ -99,8 +78,8 @@ class EntitiesSmearDirtOnMove(Rule): for entity in state.moving_entites: if is_move(entity.state.identifier) and entity.state.validity == c.VALID: if old_pos_dirt := state[d.DIRT].by_pos(entity.last_pos): + old_pos_dirt = next(iter(old_pos_dirt)) if smeared_dirt := round(old_pos_dirt.amount * self.smear_ratio, 2): if state[d.DIRT].spawn(entity.pos, amount=smeared_dirt): - results.append(TickResult(identifier=self.name, entity=entity, - reward=0, validity=c.VALID)) + results.append(TickResult(identifier=self.name, entity=entity, validity=c.VALID)) return results diff --git a/marl_factory_grid/modules/destinations/__init__.py b/marl_factory_grid/modules/destinations/__init__.py index 83e5988..4614dd7 100644 --- a/marl_factory_grid/modules/destinations/__init__.py +++ b/marl_factory_grid/modules/destinations/__init__.py @@ -1,4 +1,7 @@ from .actions import DestAction from .entitites import Destination from .groups import Destinations -from .rules import DoneAtDestinationReachAll, SpawnDestinations +from .rules import (DoneAtDestinationReachAll, + DoneAtDestinationReachAny, + SpawnDestinationsPerAgent, + DestinationReachReward) diff --git a/marl_factory_grid/modules/destinations/entitites.py b/marl_factory_grid/modules/destinations/entitites.py index 7b866b7..d75f9e0 100644 --- a/marl_factory_grid/modules/destinations/entitites.py +++ b/marl_factory_grid/modules/destinations/entitites.py @@ -9,30 +9,6 @@ from marl_factory_grid.utils.utility_classes import RenderEntity class Destination(Entity): - @property - def var_can_move(self): - return False - - @property - def var_can_collide(self): - return False - - @property - def var_has_position(self): - return True - - @property - def var_is_blocking_pos(self): - return False - - @property - def var_is_blocking_light(self): - return False - - @property - def var_can_be_bound(self): - return True - def was_reached(self): return self._was_reached diff --git a/marl_factory_grid/modules/destinations/groups.py b/marl_factory_grid/modules/destinations/groups.py index 5f91bb4..5f0b654 100644 --- a/marl_factory_grid/modules/destinations/groups.py +++ b/marl_factory_grid/modules/destinations/groups.py @@ -7,37 +7,14 @@ from marl_factory_grid.modules.destinations import constants as d class Destinations(Collection): _entity = Destination - @property - def var_is_blocking_light(self): - return False - - @property - def var_can_collide(self): - return False - - @property - def var_can_move(self): - return False - - @property - def var_has_position(self): - return True + var_is_blocking_light = False + var_can_collide = False + var_can_move = False + var_has_position = True + var_can_be_bound = True def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def __repr__(self): return super(Destinations, self).__repr__() - - @staticmethod - def trigger_destination_spawn(n_dests, state): - coordinates = state.entities.floorlist[:n_dests] - if destinations := [Destination(pos) for pos in coordinates]: - state[d.DESTINATION].add_items(destinations) - state.print(f'{n_dests} new destinations have been spawned') - return c.VALID - else: - state.print('No Destiantions are spawning, limit is reached.') - return c.NOT_VALID - - diff --git a/marl_factory_grid/modules/destinations/rules.py b/marl_factory_grid/modules/destinations/rules.py index afb8575..8e72141 100644 --- a/marl_factory_grid/modules/destinations/rules.py +++ b/marl_factory_grid/modules/destinations/rules.py @@ -2,8 +2,8 @@ import ast from random import shuffle from typing import List, Dict, Tuple -import marl_factory_grid.modules.destinations.constants from marl_factory_grid.environment.rules import Rule +from marl_factory_grid.utils import helpers as h from marl_factory_grid.utils.results import TickResult, DoneResult from marl_factory_grid.environment import constants as c @@ -54,7 +54,7 @@ class DoneAtDestinationReachAll(DestinationReachReward): """ This rule triggers and sets the done flag if ALL Destinations have been reached. - :type reward_at_done: object + :type reward_at_done: float :param reward_at_done: Specifies the reward, agent get, whenn all destinations are reached. :type dest_reach_reward: float :param dest_reach_reward: Specify the reward, agents get when reaching a single destination. @@ -65,7 +65,7 @@ class DoneAtDestinationReachAll(DestinationReachReward): def on_check_done(self, state) -> List[DoneResult]: if all(x.was_reached() for x in state[d.DESTINATION]): return [DoneResult(self.name, validity=c.VALID, reward=self.reward)] - return [DoneResult(self.name, validity=c.NOT_VALID, reward=0)] + return [DoneResult(self.name, validity=c.NOT_VALID)] class DoneAtDestinationReachAny(DestinationReachReward): @@ -75,7 +75,7 @@ class DoneAtDestinationReachAny(DestinationReachReward): This rule triggers and sets the done flag if ANY Destinations has been reached. !!! IMPORTANT: 'reward_at_done' is shared between the agents; 'dest_reach_reward' is bound to a specific one. - :type reward_at_done: object + :type reward_at_done: float :param reward_at_done: Specifies the reward, all agent get, when any destinations has been reached. Default {d.REWARD_DEST_DONE} :type dest_reach_reward: float @@ -87,67 +87,29 @@ class DoneAtDestinationReachAny(DestinationReachReward): def on_check_done(self, state) -> List[DoneResult]: if any(x.was_reached() for x in state[d.DESTINATION]): - return [DoneResult(self.name, validity=c.VALID, reward=marl_factory_grid.modules.destinations.constants.REWARD_DEST_REACHED)] + return [DoneResult(self.name, validity=c.VALID, reward=d.REWARD_DEST_REACHED)] return [] -class SpawnDestinations(Rule): - - def __init__(self, n_dests: int = 1, spawn_mode: str = d.MODE_GROUPED): - f""" - Defines how destinations are initially spawned and respawned in addition. - !!! This rule introduces no kind of reward or Env.-Done condition! - - :type n_dests: int - :param n_dests: How many destiantions should be maintained (and initally spawnewd) on the map? - :type spawn_mode: str - :param spawn_mode: One of {d.SPAWN_MODES}. {d.MODE_GROUPED}: Always wait for all Dstiantions do be gone, - then respawn after the given time. {d.MODE_SINGLE}: Just spawn every destination, - that has been reached, after the given time - - """ - super(SpawnDestinations, self).__init__() - self.n_dests = n_dests - self.spawn_mode = spawn_mode - - def on_init(self, state, lvl_map): - # noinspection PyAttributeOutsideInit - state[d.DESTINATION].trigger_destination_spawn(self.n_dests, state) - pass - - def tick_pre_step(self, state) -> List[TickResult]: - pass - - def tick_step(self, state) -> List[TickResult]: - if n_dest_spawn := max(0, self.n_dests - len(state[d.DESTINATION])): - if self.spawn_mode == d.MODE_GROUPED and n_dest_spawn == self.n_dests: - validity = state[d.DESTINATION].trigger_destination_spawn(n_dest_spawn, state) - return [TickResult(self.name, validity=validity, entity=None, value=n_dest_spawn)] - elif self.spawn_mode == d.MODE_SINGLE and n_dest_spawn: - validity = state[d.DESTINATION].trigger_destination_spawn(n_dest_spawn, state) - return [TickResult(self.name, validity=validity, entity=None, value=n_dest_spawn)] - else: - pass - - class SpawnDestinationsPerAgent(Rule): - def __init__(self, per_agent_positions: Dict[str, List[Tuple[int, int]]]): + def __init__(self, coords_or_quantity: Dict[str, List[Tuple[int, int]]]): """ Special rule, that spawn distinations, that are bound to a single agent a fixed set of positions. Usefull for introducing specialists, etc. .. !!! This rule does not introduce any reward or done condition. - :type per_agent_positions: Dict[str, List[Tuple[int, int]] - :param per_agent_positions: Please provide a dictionary with agent names as keys; and a list of possible + :type coords_or_quantity: Dict[str, List[Tuple[int, int]] + :param coords_or_quantity: Please provide a dictionary with agent names as keys; and a list of possible destiantion coords as value. Example: {Wolfgang: [(0, 0), (1, 1), ...]} """ super(Rule, self).__init__() - self.per_agent_positions = {key: [ast.literal_eval(x) for x in val] for key, val in per_agent_positions.items()} + self.per_agent_positions = {key: [ast.literal_eval(x) for x in val] for key, val in coords_or_quantity.items()} def on_init(self, state, lvl_map): for (agent_name, position_list) in self.per_agent_positions.items(): - agent = next(x for x in state[c.AGENT] if agent_name in x.name) # Fixme: Ugly AF + agent = h.get_first(state[c.AGENT], lambda x: agent_name in x.name) + assert agent position_list = position_list.copy() shuffle(position_list) while True: @@ -155,7 +117,7 @@ class SpawnDestinationsPerAgent(Rule): pos = position_list.pop() except IndexError: print(f"Could not spawn Destinations at: {self.per_agent_positions[agent_name]}") - print(f'Check your agent palcement: {state[c.AGENT]} ... Exit ...') + print(f'Check your agent placement: {state[c.AGENT]} ... Exit ...') exit(9999) if (not pos == agent.pos) and (not state[d.DESTINATION].by_pos(pos)): destination = Destination(pos, bind_to=agent) diff --git a/marl_factory_grid/modules/doors/entitites.py b/marl_factory_grid/modules/doors/entitites.py index 669f74e..1c33d7b 100644 --- a/marl_factory_grid/modules/doors/entitites.py +++ b/marl_factory_grid/modules/doors/entitites.py @@ -1,4 +1,5 @@ from marl_factory_grid.environment.entity.entity import Entity +from marl_factory_grid.utils import Result from marl_factory_grid.utils.utility_classes import RenderEntity from marl_factory_grid.environment import constants as c @@ -41,21 +42,6 @@ class Door(Entity): def str_state(self): return 'open' if self.is_open else 'closed' - def __init__(self, *args, closed_on_init=True, auto_close_interval=10, **kwargs): - self._status = d.STATE_CLOSED - super(Door, self).__init__(*args, **kwargs) - self.auto_close_interval = auto_close_interval - self.time_to_close = 0 - if not closed_on_init: - self._open() - else: - self._close() - - def summarize_state(self): - state_dict = super().summarize_state() - state_dict.update(state=str(self.str_state), time_to_close=int(self.time_to_close)) - return state_dict - @property def is_closed(self): return self._status == d.STATE_CLOSED @@ -68,6 +54,25 @@ class Door(Entity): def status(self): return self._status + @property + def time_to_close(self): + return self._time_to_close + + def __init__(self, *args, closed_on_init=True, auto_close_interval=10, **kwargs): + self._status = d.STATE_CLOSED + super(Door, self).__init__(*args, **kwargs) + self._auto_close_interval = auto_close_interval + self._time_to_close = 0 + if not closed_on_init: + self._open() + else: + self._close() + + def summarize_state(self): + state_dict = super().summarize_state() + state_dict.update(state=str(self.str_state), time_to_close=self.time_to_close) + return state_dict + def render(self): name, state = 'door_open' if self.is_open else 'door_closed', 'blank' return RenderEntity(name, self.pos, 1, 'none', state, self.u_int + 1) @@ -80,18 +85,35 @@ class Door(Entity): return c.VALID def tick(self, state): - if self.is_open and len(state.entities.pos_dict[self.pos]) == 2 and self.time_to_close: - self.time_to_close -= 1 - return c.NOT_VALID - elif self.is_open and not self.time_to_close and len(state.entities.pos_dict[self.pos]) == 2: - self.use() - return c.VALID + # Check if no entity is standing in the door + if len(state.entities.pos_dict[self.pos]) <= 2: + if self.is_open and self.time_to_close: + self._decrement_timer() + return Result(f"{d.DOOR}_tick", c.VALID, entity=self) + elif self.is_open and not self.time_to_close: + self.use() + return Result(f"{d.DOOR}_closed", c.VALID, entity=self) + else: + # No one is in door, but it is closed... Nothing to do.... + return None else: - return c.NOT_VALID + # Entity is standing in the door, reset timer + self._reset_timer() + return Result(f"{d.DOOR}_reset", c.VALID, entity=self) def _open(self): self._status = d.STATE_OPEN - self.time_to_close = self.auto_close_interval + self._reset_timer() + return True def _close(self): self._status = d.STATE_CLOSED + return True + + def _decrement_timer(self): + self._time_to_close -= 1 + return True + + def _reset_timer(self): + self._time_to_close = self._auto_close_interval + return True diff --git a/marl_factory_grid/modules/doors/groups.py b/marl_factory_grid/modules/doors/groups.py index 687846e..a27d598 100644 --- a/marl_factory_grid/modules/doors/groups.py +++ b/marl_factory_grid/modules/doors/groups.py @@ -18,8 +18,10 @@ class Doors(Collection): super(Doors, self).__init__(*args, can_collide=True, **kwargs) def tick_doors(self, state): - result_dict = dict() + results = list() for door in self: - did_tick = door.tick(state) - result_dict.update({door.name: did_tick}) - return result_dict + tick_result = door.tick(state) + if tick_result is not None: + results.append(tick_result) + # TODO: Should return a Result object, not a random dict. + return results diff --git a/marl_factory_grid/modules/doors/rules.py b/marl_factory_grid/modules/doors/rules.py index da312cd..599d975 100644 --- a/marl_factory_grid/modules/doors/rules.py +++ b/marl_factory_grid/modules/doors/rules.py @@ -19,10 +19,10 @@ class DoorAutoClose(Rule): def tick_step(self, state): if doors := state[d.DOORS]: - doors_tick_result = doors.tick_doors(state) - doors_that_ticked = [key for key, val in doors_tick_result.items() if val] - state.print(f'{doors_that_ticked} were auto-closed' - if doors_that_ticked else 'No Doors were auto-closed') + doors_tick_results = doors.tick_doors(state) + doors_that_closed = [x.entity.name for x in doors_tick_results if 'closed' in x.identifier] + door_str = doors_that_closed if doors_that_closed else "No Doors" + state.print(f'{door_str} were auto-closed') return [TickResult(self.name, validity=c.VALID, value=1)] state.print('There are no doors, but you loaded the corresponding Module') return [] diff --git a/marl_factory_grid/modules/items/__init__.py b/marl_factory_grid/modules/items/__init__.py index 157c385..cb9b69b 100644 --- a/marl_factory_grid/modules/items/__init__.py +++ b/marl_factory_grid/modules/items/__init__.py @@ -1,4 +1,3 @@ from .actions import ItemAction from .entitites import Item, DropOffLocation from .groups import DropOffLocations, Items, Inventory, Inventories -from .rules import ItemRules diff --git a/marl_factory_grid/modules/items/actions.py b/marl_factory_grid/modules/items/actions.py index f9e4f6f..ef6aa99 100644 --- a/marl_factory_grid/modules/items/actions.py +++ b/marl_factory_grid/modules/items/actions.py @@ -29,7 +29,7 @@ class ItemAction(Action): elif items := state[i.ITEM].by_pos(entity.pos): item = items[0] item.change_parent_collection(inventory) - item.set_pos_to(c.VALUE_NO_POS) + item.set_pos(c.VALUE_NO_POS) state.print(f'{entity.name} just picked up an item at {entity.pos}') return ActionResult(entity=entity, identifier=self._identifier, validity=c.VALID, reward=r.PICK_UP_VALID) diff --git a/marl_factory_grid/modules/items/entitites.py b/marl_factory_grid/modules/items/entitites.py index b710282..ff34e23 100644 --- a/marl_factory_grid/modules/items/entitites.py +++ b/marl_factory_grid/modules/items/entitites.py @@ -8,16 +8,11 @@ from marl_factory_grid.modules.items import constants as i class Item(Entity): - @property - def var_can_collide(self): - return False - def render(self): return RenderEntity(i.ITEM, self.pos) if self.pos != c.VALUE_NO_POS else None def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self._auto_despawn = -1 @property def auto_despawn(self): @@ -31,9 +26,6 @@ class Item(Entity): def set_auto_despawn(self, auto_despawn): self._auto_despawn = auto_despawn - def set_pos_to(self, no_pos): - self._pos = no_pos - def summarize_state(self) -> dict: super_summarization = super(Item, self).summarize_state() super_summarization.update(dict(auto_despawn=self.auto_despawn)) @@ -42,21 +34,6 @@ class Item(Entity): class DropOffLocation(Entity): - @property - def var_can_collide(self): - return False - - @property - def var_can_move(self): - return False - - @property - def var_is_blocking_light(self): - return False - - @property - def var_has_position(self): - return True def render(self): return RenderEntity(i.DROP_OFF, self.pos) diff --git a/marl_factory_grid/modules/items/groups.py b/marl_factory_grid/modules/items/groups.py index 707f743..deb1812 100644 --- a/marl_factory_grid/modules/items/groups.py +++ b/marl_factory_grid/modules/items/groups.py @@ -8,6 +8,7 @@ from marl_factory_grid.environment.groups.objects import _Objects from marl_factory_grid.environment.groups.mixins import IsBoundMixin from marl_factory_grid.environment.entity.agent import Agent from marl_factory_grid.modules.items.entitites import Item, DropOffLocation +from marl_factory_grid.utils.results import Result class Items(Collection): @@ -15,7 +16,7 @@ class Items(Collection): @property def var_has_position(self): - return False + return True @property def is_blocking_light(self): @@ -28,18 +29,18 @@ class Items(Collection): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - @staticmethod - def trigger_item_spawn(state, n_items, spawn_frequency): - if item_to_spawns := max(0, (n_items - len(state[i.ITEM]))): - position_list = [x for x in state.entities.floorlist] - shuffle(position_list) - position_list = state.entities.floorlist[:item_to_spawns] - state[i.ITEM].spawn(position_list) - state.print(f'{item_to_spawns} new items have been spawned; next spawn in {spawn_frequency}') - return len(position_list) + def trigger_spawn(self, state, *entity_args, coords_or_quantity=None, **entity_kwargs) -> [Result]: + coords_or_quantity = coords_or_quantity if coords_or_quantity else self._coords_or_quantity + assert coords_or_quantity + + if item_to_spawns := max(0, (coords_or_quantity - len(self))): + return super().trigger_spawn(state, + *entity_args, + coords_or_quantity=item_to_spawns, + **entity_kwargs) else: state.print('No Items are spawning, limit is reached.') - return 0 + return Result(identifier=f'{self.name}_spawn', validity=c.NOT_VALID, value=coords_or_quantity) class Inventory(IsBoundMixin, Collection): @@ -76,9 +77,15 @@ class Inventory(IsBoundMixin, Collection): class Inventories(_Objects): _entity = Inventory + var_can_move = False + var_has_position = False + + + symbol = None + @property - def var_can_move(self): - return False + def spawn_rule(self): + return {c.SPAWN_ENTITY_RULE: dict(collection=self, coords_or_quantity=None)} def __init__(self, size: int, *args, **kwargs): super(Inventories, self).__init__(*args, **kwargs) @@ -86,10 +93,12 @@ class Inventories(_Objects): self._obs = None self._lazy_eval_transforms = [] - def spawn(self, agents): - inventories = [self._entity(agent, self.size, ) - for _, agent in enumerate(agents)] - self.add_items(inventories) + def spawn(self, agents, *args, **kwargs): + self.add_items([self._entity(agent, self.size, *args, **kwargs) for _, agent in enumerate(agents)]) + return [Result(identifier=f'{self.name}_spawn', validity=c.VALID, value=len(self))] + + def trigger_spawn(self, state, *args, **kwargs) -> [Result]: + return self.spawn(state[c.AGENT], *args, **kwargs) def idx_by_entity(self, entity): try: @@ -106,9 +115,6 @@ class Inventories(_Objects): def summarize_states(self, **kwargs): return [val.summarize_states(**kwargs) for key, val in self.items()] - @staticmethod - def trigger_inventory_spawn(state): - state[i.INVENTORY].spawn(state[c.AGENT]) class DropOffLocations(Collection): @@ -135,7 +141,7 @@ class DropOffLocations(Collection): @staticmethod def trigger_drop_off_location_spawn(state, n_locations): - empty_positions = state.entities.empty_positions()[:n_locations] + empty_positions = state.entities.empty_positions[:n_locations] do_entites = state[i.DROP_OFF] drop_offs = [DropOffLocation(pos) for pos in empty_positions] do_entites.add_items(drop_offs) diff --git a/marl_factory_grid/modules/items/rules.py b/marl_factory_grid/modules/items/rules.py index 9f8a0cc..a655956 100644 --- a/marl_factory_grid/modules/items/rules.py +++ b/marl_factory_grid/modules/items/rules.py @@ -6,52 +6,28 @@ from marl_factory_grid.utils.results import TickResult from marl_factory_grid.modules.items import constants as i -class ItemRules(Rule): +class RespawnItems(Rule): - def __init__(self, n_items: int = 5, spawn_frequency: int = 15, - n_locations: int = 5, max_dropoff_storage_size: int = 0): + def __init__(self, n_items: int = 5, respawn_freq: int = 15, n_locations: int = 5): super().__init__() - self.spawn_frequency = spawn_frequency - self._next_item_spawn = spawn_frequency + self.spawn_frequency = respawn_freq + self._next_item_spawn = respawn_freq self.n_items = n_items - self.max_dropoff_storage_size = max_dropoff_storage_size self.n_locations = n_locations - def on_init(self, state, lvl_map): - state[i.DROP_OFF].trigger_drop_off_location_spawn(state, self.n_locations) - self._next_item_spawn = self.spawn_frequency - state[i.INVENTORY].trigger_inventory_spawn(state) - state[i.ITEM].trigger_item_spawn(state, self.n_items, self.spawn_frequency) - def tick_step(self, state): - for item in list(state[i.ITEM].values()): - if item.auto_despawn >= 1: - item.set_auto_despawn(item.auto_despawn - 1) - elif not item.auto_despawn: - state[i.ITEM].delete_env_object(item) - else: - pass - if not self._next_item_spawn: - state[i.ITEM].trigger_item_spawn(state, self.n_items, self.spawn_frequency) + state[i.ITEM].trigger_spawn(state, self.n_items, self.spawn_frequency) else: self._next_item_spawn = max(0, self._next_item_spawn - 1) return [] def tick_post_step(self, state) -> List[TickResult]: - for item in list(state[i.ITEM].values()): - if item.auto_despawn >= 1: - item.set_auto_despawn(item.auto_despawn-1) - elif not item.auto_despawn: - state[i.ITEM].delete_env_object(item) - else: - pass - if not self._next_item_spawn: - if spawned_items := state[i.ITEM].trigger_item_spawn(state, self.n_items, self.spawn_frequency): - return [TickResult(self.name, validity=c.VALID, value=spawned_items, entity=None)] + if spawned_items := state[i.ITEM].trigger_spawn(state, self.n_items, self.spawn_frequency): + return [TickResult(self.name, validity=c.VALID, value=spawned_items.value)] else: - return [TickResult(self.name, validity=c.NOT_VALID, value=0, entity=None)] + return [TickResult(self.name, validity=c.NOT_VALID, value=0)] else: self._next_item_spawn = max(0, self._next_item_spawn-1) return [] diff --git a/marl_factory_grid/modules/machines/__init__.py b/marl_factory_grid/modules/machines/__init__.py index 36ba51d..233efbb 100644 --- a/marl_factory_grid/modules/machines/__init__.py +++ b/marl_factory_grid/modules/machines/__init__.py @@ -1,3 +1,2 @@ from .entitites import Machine from .groups import Machines -from .rules import MachineRule diff --git a/marl_factory_grid/modules/machines/actions.py b/marl_factory_grid/modules/machines/actions.py index 8f4eaaa..970f85f 100644 --- a/marl_factory_grid/modules/machines/actions.py +++ b/marl_factory_grid/modules/machines/actions.py @@ -5,6 +5,7 @@ from marl_factory_grid.utils.results import ActionResult from marl_factory_grid.modules.machines import constants as m, rewards as r from marl_factory_grid.environment import constants as c +from marl_factory_grid.utils import helpers as h class MachineAction(Action): @@ -13,13 +14,10 @@ class MachineAction(Action): super().__init__(m.MACHINE_ACTION) def do(self, entity, state) -> Union[None, ActionResult]: - if machine := state[m.MACHINES].by_pos(entity.pos): + if machine := h.get_first(state[m.MACHINES].by_pos(entity.pos)): if valid := machine.maintain(): return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=r.MAINTAIN_VALID) else: return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=r.MAINTAIN_FAIL) else: return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=r.MAINTAIN_FAIL) - - - diff --git a/marl_factory_grid/modules/machines/entitites.py b/marl_factory_grid/modules/machines/entitites.py index 36a87cc..f5775e1 100644 --- a/marl_factory_grid/modules/machines/entitites.py +++ b/marl_factory_grid/modules/machines/entitites.py @@ -8,22 +8,6 @@ from . import constants as m class Machine(Entity): - @property - def var_can_collide(self): - return False - - @property - def var_can_move(self): - return False - - @property - def var_is_blocking_light(self): - return False - - @property - def var_has_position(self): - return True - @property def encoding(self): return self._encodings[self.status] @@ -46,12 +30,12 @@ class Machine(Entity): else: return c.NOT_VALID - def tick(self): + def tick(self, state): # if self.status == m.STATE_MAINTAIN and any([c.AGENT in x.name for x in self.tile.guests]): - if self.status == m.STATE_MAINTAIN and any([c.AGENT in x.name for x in self.state.entities.pos_dict[self.pos]]): - return TickResult(identifier=self.name, validity=c.VALID, reward=0, entity=self) + if self.status == m.STATE_MAINTAIN and any([c.AGENT in x.name for x in state.entities.pos_dict[self.pos]]): + return TickResult(identifier=self.name, validity=c.VALID, entity=self) # elif self.status == m.STATE_MAINTAIN and not any([c.AGENT in x.name for x in self.tile.guests]): - elif self.status == m.STATE_MAINTAIN and not any([c.AGENT in x.name for x in self.state.entities.pos_dict[self.pos]]): + elif self.status == m.STATE_MAINTAIN and not any([c.AGENT in x.name for x in state.entities.pos_dict[self.pos]]): self.status = m.STATE_WORK self.reset_counter() return None diff --git a/marl_factory_grid/modules/machines/rules.py b/marl_factory_grid/modules/machines/rules.py index 84e3410..e69de29 100644 --- a/marl_factory_grid/modules/machines/rules.py +++ b/marl_factory_grid/modules/machines/rules.py @@ -1,28 +0,0 @@ -from typing import List -from marl_factory_grid.environment.rules import Rule -from marl_factory_grid.utils.results import TickResult, DoneResult -from marl_factory_grid.environment import constants as c -from marl_factory_grid.modules.machines import constants as m -from marl_factory_grid.modules.machines.entitites import Machine - - -class MachineRule(Rule): - - def __init__(self, n_machines: int = 2): - super(MachineRule, self).__init__() - self.n_machines = n_machines - - def on_init(self, state, lvl_map): - state[m.MACHINES].spawn(state.entities.empty_positions()) - - def tick_pre_step(self, state) -> List[TickResult]: - pass - - def tick_step(self, state) -> List[TickResult]: - pass - - def tick_post_step(self, state) -> List[TickResult]: - pass - - def on_check_done(self, state) -> List[DoneResult]: - pass diff --git a/marl_factory_grid/modules/maintenance/entities.py b/marl_factory_grid/modules/maintenance/entities.py index e084b0c..1a043c8 100644 --- a/marl_factory_grid/modules/maintenance/entities.py +++ b/marl_factory_grid/modules/maintenance/entities.py @@ -1,48 +1,35 @@ +from random import shuffle + import networkx as nx import numpy as np + from ...algorithms.static.utils import points_to_graph from ...environment import constants as c from ...environment.actions import Action, ALL_BASEACTIONS from ...environment.entity.entity import Entity from ..doors import constants as do from ..maintenance import constants as mi -from ...utils.helpers import MOVEMAP -from ...utils.utility_classes import RenderEntity -from ...utils.states import Gamestate +from ...utils import helpers as h +from ...utils.utility_classes import RenderEntity, Floor +from ..doors import DoorUse class Maintainer(Entity): - @property - def var_can_collide(self): - return True - - @property - def var_can_move(self): - return False - - @property - def var_is_blocking_light(self): - return False - - @property - def var_has_position(self): - return True - - def __init__(self, state: Gamestate, objective: str, action: Action, *args, **kwargs): + def __init__(self, objective: str, action: Action, *args, **kwargs): super().__init__(*args, **kwargs) self.action = action - self.actions = [x() for x in ALL_BASEACTIONS] + self.actions = [x() for x in ALL_BASEACTIONS] + [DoorUse()] self.objective = objective self._path = None self._next = [] self._last = [] self._last_serviced = 'None' - self._floortile_graph = points_to_graph(state.entities.floorlist) + self._floortile_graph = None def tick(self, state): - if found_objective := state[self.objective].by_pos(self.pos): + if found_objective := h.get_first(state[self.objective].by_pos(self.pos)): if found_objective.name != self._last_serviced: self.action.do(self, state) self._last_serviced = found_objective.name @@ -54,24 +41,27 @@ class Maintainer(Entity): return action.do(self, state) def get_move_action(self, state) -> Action: + if not self._floortile_graph: + state.print("Generating Floorgraph....") + self._floortile_graph = points_to_graph(state.entities.floorlist) if self._path is None or not self._path: if not self._next: - self._next = list(state[self.objective].values()) + self._next = list(state[self.objective].values()) + [Floor(*state.random_free_position)] + shuffle(self._next) self._last = [] self._last.append(self._next.pop()) + state.print("Calculating shortest path....") self._path = self.calculate_route(self._last[-1]) - if door := self._door_is_close(state): - if door.is_closed: - # Translate the action_object to an integer to have the same output as any other model - action = do.ACTION_DOOR_USE - else: - action = self._predict_move(state) + if door := self._closed_door_in_path(state): + state.print(f"{self} found {door} that is closed. Attempt to open.") + # Translate the action_object to an integer to have the same output as any other model + action = do.ACTION_DOOR_USE else: action = self._predict_move(state) # Translate the action_object to an integer to have the same output as any other model try: - action_obj = next(x for x in self.actions if x.name == action) + action_obj = h.get_first(self.actions, lambda x: x.name == action) except (StopIteration, UnboundLocalError): print('Will not happen') raise EnvironmentError @@ -81,11 +71,10 @@ class Maintainer(Entity): route = nx.shortest_path(self._floortile_graph, self.pos, entity.pos) return route[1:] - def _door_is_close(self, state): - state.print("Found a door that is close.") - try: - return next(y for x in state.entities.neighboring_positions(self.state.pos) for y in state.entities.pos_dict[x] if do.DOOR in y.name) - except StopIteration: + def _closed_door_in_path(self, state): + if self._path: + return h.get_first(state[do.DOORS].by_pos(self._path[0]), lambda x: x.is_closed) + else: return None def _predict_move(self, state): @@ -96,7 +85,7 @@ class Maintainer(Entity): next_pos = self._path.pop(0) diff = np.subtract(next_pos, self.pos) # Retrieve action based on the pos dif (like in: What do I have to do to get there?) - action = next(action for action, pos_diff in MOVEMAP.items() if np.all(diff == pos_diff)) + action = next(action for action, pos_diff in h.MOVEMAP.items() if np.all(diff == pos_diff)) return action def render(self): diff --git a/marl_factory_grid/modules/maintenance/groups.py b/marl_factory_grid/modules/maintenance/groups.py index 2df70cb..79f7480 100644 --- a/marl_factory_grid/modules/maintenance/groups.py +++ b/marl_factory_grid/modules/maintenance/groups.py @@ -1,4 +1,4 @@ -from typing import Union, List, Tuple +from typing import Union, List, Tuple, Dict from marl_factory_grid.environment.groups.collection import Collection from .entities import Maintainer @@ -10,25 +10,21 @@ from ...utils.states import Gamestate class Maintainers(Collection): _entity = Maintainer - @property - def var_can_collide(self): - return True + var_can_collide = True + var_can_move = True + var_is_blocking_light = False + var_has_position = True - @property - def var_can_move(self): - return True - - @property - def var_is_blocking_light(self): - return False - - @property - def var_has_position(self): - return True + def __init__(self, size, *args, coords_or_quantity: int = None, + spawnrule: Union[None, Dict[str, dict]] = None, + **kwargs): + super(Collection, self).__init__(*args, **kwargs) + self._coords_or_quantity = coords_or_quantity + self.size = size + self._spawnrule = spawnrule def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args): - state = entity_args[0] - self.add_items([self._entity(state, mc.MACHINES, MachineAction(), pos) for pos in coords_or_quantity]) + self.add_items([self._entity(mc.MACHINES, MachineAction(), pos) for pos in coords_or_quantity]) diff --git a/marl_factory_grid/modules/maintenance/rules.py b/marl_factory_grid/modules/maintenance/rules.py index 820183e..fdefe42 100644 --- a/marl_factory_grid/modules/maintenance/rules.py +++ b/marl_factory_grid/modules/maintenance/rules.py @@ -4,29 +4,24 @@ from marl_factory_grid.utils.results import TickResult, DoneResult from marl_factory_grid.environment import constants as c from . import rewards as r from . import constants as M -from marl_factory_grid.utils.states import Gamestate -class MaintenanceRule(Rule): +class MoveMaintainers(Rule): - def __init__(self, n_maintainer: int = 1, *args, **kwargs): - super(MaintenanceRule, self).__init__(*args, **kwargs) - self.n_maintainer = n_maintainer - - def on_init(self, state: Gamestate, lvl_map): - state[M.MAINTAINERS].spawn(state.entities.empty_positions[:self.n_maintainer], state) - pass - - def tick_pre_step(self, state) -> List[TickResult]: - pass + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) def tick_step(self, state) -> List[TickResult]: for maintainer in state[M.MAINTAINERS]: maintainer.tick(state) + # Todo: Return a Result Object. return [] - def tick_post_step(self, state) -> List[TickResult]: - pass + +class DoneAtMaintainerCollision(Rule): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) def on_check_done(self, state) -> List[DoneResult]: agents = list(state[c.AGENT].values()) diff --git a/marl_factory_grid/modules/zones/rules.py b/marl_factory_grid/modules/zones/rules.py index 2969186..f9b5c11 100644 --- a/marl_factory_grid/modules/zones/rules.py +++ b/marl_factory_grid/modules/zones/rules.py @@ -1,8 +1,8 @@ from random import choices, choice from . import constants as z, Zone +from .. import Destination from ..destinations import constants as d -from ... import Destination from ...environment.rules import Rule from ...environment import constants as c diff --git a/marl_factory_grid/utils/__init__.py b/marl_factory_grid/utils/__init__.py index e69de29..23848e0 100644 --- a/marl_factory_grid/utils/__init__.py +++ b/marl_factory_grid/utils/__init__.py @@ -0,0 +1,3 @@ +from . import helpers as h +from . import helpers +from .results import Result, DoneResult, ActionResult, TickResult diff --git a/marl_factory_grid/utils/config_parser.py b/marl_factory_grid/utils/config_parser.py index 093f1d0..7cdc9e6 100644 --- a/marl_factory_grid/utils/config_parser.py +++ b/marl_factory_grid/utils/config_parser.py @@ -1,28 +1,24 @@ import ast -from collections import defaultdict + from os import PathLike from pathlib import Path -from typing import Union +from typing import Union, List import yaml -from marl_factory_grid.environment.groups.agents import Agents -from marl_factory_grid.environment.entity.agent import Agent from marl_factory_grid.environment.rules import Rule from marl_factory_grid.utils.helpers import locate_and_import_class +from marl_factory_grid.environment.constants import DEFAULT_PATH, MODULE_PATH from marl_factory_grid.environment import constants as c -DEFAULT_PATH = 'environment' -MODULE_PATH = 'modules' - class FactoryConfigParser(object): default_entites = [] - default_rules = ['MaxStepsReached', 'Collision'] + default_rules = ['DoneAtMaxStepsReached', 'WatchCollision'] default_actions = [c.MOVE8, c.NOOP] default_observations = [c.WALLS, c.AGENT] - def __init__(self, config_path, custom_modules_path: Union[None, PathLike] = None): + def __init__(self, config_path, custom_modules_path: Union[PathLike] = None): self.config_path = Path(config_path) self.custom_modules_path = Path(custom_modules_path) if custom_modules_path is not None else custom_modules_path self.config = yaml.safe_load(self.config_path.open()) @@ -46,6 +42,10 @@ class FactoryConfigParser(object): def rules(self): return self.config['Rules'] + @property + def tests(self): + return self.config.get('Tests', []) + @property def agents(self): return self.config['Agents'] @@ -61,7 +61,6 @@ class FactoryConfigParser(object): return self.config[item] def load_entities(self): - # entites = Entities() entity_classes = dict() entities = [] if c.DEFAULTS in self.entities: @@ -69,28 +68,40 @@ class FactoryConfigParser(object): entities.extend(x for x in self.entities if x != c.DEFAULTS) for entity in entities: + e1 = e2 = e3 = None try: folder_path = Path(__file__).parent.parent / DEFAULT_PATH entity_class = locate_and_import_class(entity, folder_path) - except AttributeError as e1: + except AttributeError as e: + e1 = e try: - folder_path = Path(__file__).parent.parent / MODULE_PATH - entity_class = locate_and_import_class(entity, folder_path) - except AttributeError as e2: - try: - folder_path = self.custom_modules_path - entity_class = locate_and_import_class(entity, folder_path) - except AttributeError as e3: - ents = [y for x in [e1.argss[1], e2.argss[1], e3.argss[1]] for y in x] - print('### Error ### Error ### Error ### Error ### Error ###') - print() - print(f'Class "{entity}" was not found in "{folder_path.name}"') - print('Possible Entitys are:', str(ents)) - print() - print('Goodbye') - print() - exit() - # raise AttributeError(e1.argss[0], e2.argss[0], e3.argss[0], 'Possible Entitys are:', str(ents)) + module_path = Path(__file__).parent.parent / MODULE_PATH + entity_class = locate_and_import_class(entity, module_path) + except AttributeError as e: + e2 = e + if self.custom_modules_path: + try: + entity_class = locate_and_import_class(entity, self.custom_modules_path) + except AttributeError as e: + e3 = e + pass + if (e1 and e2) or e3: + ents = [y for x in [e1, e2, e3] if x is not None for y in x.args[1]] + print('##############################################################') + print('### Error ### Error ### Error ### Error ### Error ###') + print('##############################################################') + print(f'Class "{entity}" was not found in "{module_path.name}"') + print(f'Class "{entity}" was not found in "{folder_path.name}"') + print('##############################################################') + if self.custom_modules_path: + print(f'Class "{entity}" was not found in "{self.custom_modules_path}"') + print('Possible Entitys are:', str(ents)) + print('##############################################################') + print('Goodbye') + print('##############################################################') + print('### Error ### Error ### Error ### Error ### Error ###') + print('##############################################################') + exit(-99999) entity_kwargs = self.entities.get(entity, {}) entity_symbol = entity_class.symbol if hasattr(entity_class, 'symbol') else None @@ -128,31 +139,86 @@ class FactoryConfigParser(object): observations.extend(self.default_observations) observations.extend(x for x in self.agents[name]['Observations'] if x != c.DEFAULTS) positions = [ast.literal_eval(x) for x in self.agents[name].get('Positions', [])] - parsed_agents_conf[name] = dict(actions=parsed_actions, observations=observations, positions=positions) + other_kwargs = {k: v for k, v in self.agents[name].items() if k not in + ['Actions', 'Observations', 'Positions']} + parsed_agents_conf[name] = dict( + actions=parsed_actions, observations=observations, positions=positions, other=other_kwargs + ) + return parsed_agents_conf - def load_rules(self): - # entites = Entities() - rules_classes = dict() - rules = [] + def load_env_rules(self) -> List[Rule]: + rules = self.rules.copy() if c.DEFAULTS in self.rules: for rule in self.default_rules: if rule not in rules: - rules.append(rule) - rules.extend(x for x in self.rules if x != c.DEFAULTS) + rules.append({rule: {}}) - for rule in rules: + return self._load_smth(rules, Rule) + + def load_env_tests(self) -> List[Rule]: + return self._load_smth(self.tests, None) # Test + + def _load_smth(self, config, class_obj): + rules = list() + rules_names = list() + for rule in config: + e1 = e2 = e3 = None try: folder_path = (Path(__file__).parent.parent / DEFAULT_PATH) rule_class = locate_and_import_class(rule, folder_path) - except AttributeError: + except AttributeError as e: + e1 = e try: - folder_path = (Path(__file__).parent.parent / MODULE_PATH) - rule_class = locate_and_import_class(rule, folder_path) + module_path = (Path(__file__).parent.parent / MODULE_PATH) + rule_class = locate_and_import_class(rule, module_path) + except AttributeError as e: + e2 = e + if self.custom_modules_path: + try: + rule_class = locate_and_import_class(rule, self.custom_modules_path) + except AttributeError as e: + e3 = e + pass + if (e1 and e2) or e3: + ents = [y for x in [e1, e2, e3] if x is not None for y in x.args[1]] + print('### Error ### Error ### Error ### Error ### Error ###') + print('') + print(f'Class "{rule}" was not found in "{module_path.name}"') + print(f'Class "{rule}" was not found in "{folder_path.name}"') + if self.custom_modules_path: + print(f'Class "{rule}" was not found in "{self.custom_modules_path}"') + print('Possible Entitys are:', str(ents)) + print('') + print('Goodbye') + print('') + exit(-99999) + + if issubclass(rule_class, class_obj): + rule_kwargs = config.get(rule, {}) + rules.append(rule_class(**(rule_kwargs or {}))) + return rules + + def load_entity_spawn_rules(self, entities) -> List[Rule]: + rules = list() + rules_dicts = list() + for e in entities: + try: + if spawn_rule := e.spawn_rule: + rules_dicts.append(spawn_rule) + except AttributeError: + pass + + for rule_dict in rules_dicts: + for rule_name, rule_kwargs in rule_dict.items(): + try: + folder_path = (Path(__file__).parent.parent / DEFAULT_PATH) + rule_class = locate_and_import_class(rule_name, folder_path) except AttributeError: - rule_class = locate_and_import_class(rule, self.custom_modules_path) - # Fixme This check does not work! - # assert isinstance(rule_class, Rule), f'{rule_class.__name__} is no valid "Rule".' - rule_kwargs = self.rules.get(rule, {}) - rules_classes.update({rule: {'class': rule_class, 'kwargs': rule_kwargs}}) - return rules_classes + try: + folder_path = (Path(__file__).parent.parent / MODULE_PATH) + rule_class = locate_and_import_class(rule_name, folder_path) + except AttributeError: + rule_class = locate_and_import_class(rule_name, self.custom_modules_path) + rules.append(rule_class(**rule_kwargs)) + return rules diff --git a/marl_factory_grid/utils/helpers.py b/marl_factory_grid/utils/helpers.py index e2f3c9a..ae68bf7 100644 --- a/marl_factory_grid/utils/helpers.py +++ b/marl_factory_grid/utils/helpers.py @@ -2,7 +2,7 @@ import importlib from collections import defaultdict from pathlib import PurePath, Path -from typing import Union, Dict, List +from typing import Union, Dict, List, Iterable, Callable import numpy as np from numpy.typing import ArrayLike @@ -222,7 +222,7 @@ def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''): mod = importlib.import_module('.'.join(module_parts)) all_found_modules.extend([x for x in dir(mod) if (not(x.startswith('__') or len(x) <= 2) and x.istitle()) and x not in ['Entity', 'NamedTuple', 'List', 'Rule', 'Union', - 'TickResult', 'ActionResult', 'Action', 'Agent', 'BoundEntityMixin', + 'TickResult', 'ActionResult', 'Action', 'Agent', 'RenderEntity', 'TemplateRule', 'Objects', 'PositionMixin', 'IsBoundMixin', 'EnvObject', 'EnvObjects', 'Dict', 'Any' ]]) @@ -240,7 +240,13 @@ def add_bound_name(name_str, bound_e): def add_pos_name(name_str, bound_e): if bound_e.var_has_position: - return f'{name_str}({bound_e.pos})' + return f'{name_str}@{bound_e.pos}' return name_str +def get_first(iterable: Iterable, filter_by: Callable[[any], bool] = lambda _: True): + return next((x for x in iterable if filter_by(x)), None) + + +def get_first_index(iterable: Iterable, filter_by: Callable[[any], bool] = lambda _: True): + return next((idx for idx, x in enumerate(iterable) if filter_by(x)), None) diff --git a/marl_factory_grid/utils/level_parser.py b/marl_factory_grid/utils/level_parser.py index fc8b948..24a05df 100644 --- a/marl_factory_grid/utils/level_parser.py +++ b/marl_factory_grid/utils/level_parser.py @@ -47,6 +47,7 @@ class LevelParser(object): # All other for es_name in self.e_p_dict: e_class, e_kwargs = self.e_p_dict[es_name]['class'], self.e_p_dict[es_name]['kwargs'] + e_kwargs = e_kwargs if e_kwargs else {} if hasattr(e_class, 'symbol') and e_class.symbol is not None: symbols = e_class.symbol diff --git a/marl_factory_grid/utils/observation_builder.py b/marl_factory_grid/utils/observation_builder.py index 9fd1d26..df10ae9 100644 --- a/marl_factory_grid/utils/observation_builder.py +++ b/marl_factory_grid/utils/observation_builder.py @@ -1,17 +1,17 @@ -import math import re from collections import defaultdict -from itertools import product from typing import Dict, List import numpy as np -from numba import njit from marl_factory_grid.environment import constants as c +from marl_factory_grid.environment.entity.object import _Object from marl_factory_grid.environment.groups.utils import Combined -import marl_factory_grid.utils.helpers as h -from marl_factory_grid.utils.states import Gamestate from marl_factory_grid.utils.utility_classes import Floor +from marl_factory_grid.utils.ray_caster import RayCaster +from marl_factory_grid.utils.states import Gamestate +from marl_factory_grid.utils import helpers as h + class OBSBuilder(object): @@ -77,11 +77,13 @@ class OBSBuilder(object): def place_entity_in_observation(self, obs_array, agent, e): x, y = (e.x - agent.x) + self.pomdp_r, (e.y - agent.y) + self.pomdp_r - try: - obs_array[x, y] += e.encoding - except IndexError: - # Seemded to be visible but is out of range - pass + if not min([y, x]) < 0: + try: + obs_array[x, y] += e.encoding + except IndexError: + # Seemded to be visible but is out of range + pass + pass def build_for_agent(self, agent, state) -> (List[str], np.ndarray): assert self._curr_env_step == state.curr_step, ( @@ -121,18 +123,24 @@ class OBSBuilder(object): e = self.all_obs[l_name] except KeyError: try: - # Look for bound entity names! - pattern = re.compile(f'{re.escape(l_name)}(.*){re.escape(agent.name)}') - name = next((x for x in self.all_obs if pattern.search(x)), None) + # Look for bound entity REPRs! + pattern = re.compile(f'{re.escape(l_name)}' + f'{re.escape("[")}(.*){re.escape("]")}' + f'{re.escape("(")}{re.escape(agent.name)}{re.escape(")")}') + name = next((key for key, val in self.all_obs.items() + if pattern.search(str(val)) and isinstance(val, _Object)), None) e = self.all_obs[name] except KeyError: try: e = next(v for k, v in self.all_obs.items() if l_name in k and agent.name in k) except StopIteration: - raise KeyError( - f'Check for spelling errors! \n ' - f'No combination of "{l_name} and {agent.name}" could not be found in:\n ' - f'{list(dict(self.all_obs).keys())}') + print(f'# Check for spelling errors!') + print(f'# No combination of "{l_name}" and "{agent.name}" could not be found in:') + print(f'# {list(dict(self.all_obs).keys())}') + print('#') + print('# exiting...') + print('#') + exit(-99999) try: positional = e.var_has_position @@ -161,15 +169,14 @@ class OBSBuilder(object): try: light_map = np.zeros(self.obs_shape) visible_floor = self.ray_caster[agent.name].visible_entities(self._floortiles, reset_cache=False) - if self.pomdp_r: - for f in set(visible_floor): - self.place_entity_in_observation(light_map, agent, f) - else: - for f in set(visible_floor): - light_map[f.x, f.y] += f.encoding + + for f in set(visible_floor): + self.place_entity_in_observation(light_map, agent, f) + # else: + # for f in set(visible_floor): + # light_map[f.x, f.y] += f.encoding self.curr_lightmaps[agent.name] = light_map except (KeyError, ValueError): - print() pass return obs, self.obs_layers[agent.name] @@ -185,7 +192,7 @@ class OBSBuilder(object): for obs_str in agent.observations: if isinstance(obs_str, dict): - obs_str, vals = next(obs_str.items().__iter__()) + obs_str, vals = h.get_first(obs_str.items()) else: vals = None if obs_str == c.SELF: @@ -214,129 +221,3 @@ class OBSBuilder(object): obs_layers.append(obs_str) self.obs_layers[agent.name] = obs_layers self.curr_lightmaps[agent.name] = np.zeros(self.obs_shape) - - -class RayCaster: - def __init__(self, agent, pomdp_r, degs=360): - self.agent = agent - self.pomdp_r = pomdp_r - self.n_rays = (self.pomdp_r + 1) * 8 - self.degs = degs - self.ray_targets = self.build_ray_targets() - self.obs_shape_cube = np.array([self.pomdp_r, self.pomdp_r]) - self._cache_dict = {} - - def __repr__(self): - return f'{self.__class__.__name__}({self.agent.name})' - - def build_ray_targets(self): - north = np.array([0, -1]) * self.pomdp_r - thetas = [np.deg2rad(deg) for deg in np.linspace(-self.degs // 2, self.degs // 2, self.n_rays)[::-1]] - rot_M = [ - [[math.cos(theta), -math.sin(theta)], - [math.sin(theta), math.cos(theta)]] for theta in thetas - ] - rot_M = np.stack(rot_M, 0) - rot_M = np.unique(np.round(rot_M @ north), axis=0) - return rot_M.astype(int) - - def ray_block_cache(self, key, callback): - if key not in self._cache_dict: - self._cache_dict[key] = callback() - return self._cache_dict[key] - - def visible_entities(self, pos_dict, reset_cache=True): - visible = list() - if reset_cache: - self._cache_dict = {} - - for ray in self.get_rays(): - rx, ry = ray[0] - for x, y in ray: - cx, cy = x - rx, y - ry - - entities_hit = pos_dict[(x, y)] - hits = self.ray_block_cache((x, y), - lambda: any(True for e in entities_hit if e.var_is_blocking_light) - ) - - diag_hits = all([ - self.ray_block_cache( - key, - lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light) and bool( - pos_dict[key])) - for key in ((x, y - cy), (x - cx, y)) - ]) if (cx != 0 and cy != 0) else False - - visible += entities_hit if not diag_hits else [] - if hits or diag_hits: - break - rx, ry = x, y - return visible - - def get_rays(self): - a_pos = self.agent.pos - outline = self.ray_targets + a_pos - return self.bresenham_loop(a_pos, outline) - - # todo do this once and cache the points! - def get_fov_outline(self) -> np.ndarray: - return self.ray_targets + self.agent.pos - - def get_square_outline(self): - agent = self.agent - x_coords = range(agent.x - self.pomdp_r, agent.x + self.pomdp_r + 1) - y_coords = range(agent.y - self.pomdp_r, agent.y + self.pomdp_r + 1) - outline = list(product(x_coords, [agent.y - self.pomdp_r, agent.y + self.pomdp_r])) \ - + list(product([agent.x - self.pomdp_r, agent.x + self.pomdp_r], y_coords)) - return outline - - @staticmethod - @njit - def bresenham_loop(a_pos, points): - results = [] - for end in points: - x1, y1 = a_pos - x2, y2 = end - dx = x2 - x1 - dy = y2 - y1 - - # Determine how steep the line is - is_steep = abs(dy) > abs(dx) - - # Rotate line - if is_steep: - x1, y1 = y1, x1 - x2, y2 = y2, x2 - - # Swap start and end points if necessary and store swap state - swapped = False - if x1 > x2: - x1, x2 = x2, x1 - y1, y2 = y2, y1 - swapped = True - - # Recalculate differentials - dx = x2 - x1 - dy = y2 - y1 - - # Calculate error - error = int(dx / 2.0) - ystep = 1 if y1 < y2 else -1 - - # Iterate over bounding box generating points between start and end - y = y1 - points = [] - for x in range(int(x1), int(x2) + 1): - coord = [y, x] if is_steep else [x, y] - points.append(coord) - error -= abs(dy) - if error < 0: - y += ystep - error += dx - - # Reverse the list if the coordinates were swapped - if swapped: - points.reverse() - results.append(points) - return results diff --git a/marl_factory_grid/utils/ray_caster.py b/marl_factory_grid/utils/ray_caster.py index cf17bd1..ecbac6d 100644 --- a/marl_factory_grid/utils/ray_caster.py +++ b/marl_factory_grid/utils/ray_caster.py @@ -39,8 +39,9 @@ class RayCaster: if reset_cache: self._cache_dict = dict() - for ray in self.get_rays(): + for ray in self.get_rays(): # Do not check, just trust. rx, ry = ray[0] + # self.ray_block_cache(ray[0], lambda: False) We do not do that, because of doors etc... for x, y in ray: cx, cy = x - rx, y - ry @@ -52,7 +53,8 @@ class RayCaster: diag_hits = all([ self.ray_block_cache( key, - lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light)) + lambda: any(True for e in pos_dict[key] if e.var_is_blocking_light)) + # lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light)) for key in ((x, y-cy), (x-cx, y)) ]) if (cx != 0 and cy != 0) else False diff --git a/marl_factory_grid/utils/renderer.py b/marl_factory_grid/utils/renderer.py index db6a93f..1976974 100644 --- a/marl_factory_grid/utils/renderer.py +++ b/marl_factory_grid/utils/renderer.py @@ -31,7 +31,7 @@ class Renderer: def __init__(self, lvl_shape: Tuple[int, int] = (16, 16), lvl_padded_shape: Union[Tuple[int, int], None] = None, - cell_size: int = 40, fps: int = 7, + cell_size: int = 40, fps: int = 7, factor: float = 0.9, grid_lines: bool = True, view_radius: int = 2): # TODO: Customn_assets paths self.grid_h, self.grid_w = lvl_shape @@ -45,7 +45,7 @@ class Renderer: self.screen = pygame.display.set_mode(self.screen_size) self.clock = pygame.time.Clock() assets = list(self.ASSETS.rglob('*.png')) - self.assets = {path.stem: self.load_asset(str(path), 1) for path in assets} + self.assets = {path.stem: self.load_asset(str(path), factor) for path in assets} self.fill_bg() now = time.time() @@ -110,22 +110,22 @@ class Renderer: pygame.quit() sys.exit() self.fill_bg() - blits = deque() - for entity in [x for x in entities]: - bp = self.blit_params(entity) - blits.append(bp) - if entity.name.lower() == AGENT: - if self.view_radius > 0: - vis_rects = self.visibility_rects(bp, entity.aux) - blits.extendleft(vis_rects) - if entity.state != BLANK: - agent_state_blits = self.blit_params( - RenderEntity(entity.state, (entity.pos[0] + 0.12, entity.pos[1]), 0.48, SCALE) - ) - textsurface = self.font.render(str(entity.id), False, (0, 0, 0)) - text_blit = dict(source=textsurface, dest=(bp['dest'].center[0]-.07*self.cell_size, - bp['dest'].center[1])) - blits += [agent_state_blits, text_blit] + # First all others + blits = deque(self.blit_params(x) for x in entities if not x.name.lower() == AGENT) + # Then Agents, so that agents are rendered on top. + for agent in (x for x in entities if x.name.lower() == AGENT): + agent_blit = self.blit_params(agent) + if self.view_radius > 0: + vis_rects = self.visibility_rects(agent_blit, agent.aux) + blits.extendleft(vis_rects) + if agent.state != BLANK: + state_blit = self.blit_params( + RenderEntity(agent.state, (agent.pos[0] + 0.12, agent.pos[1]), 0.48, SCALE) + ) + textsurface = self.font.render(str(agent.id), False, (0, 0, 0)) + text_blit = dict(source=textsurface, dest=(agent_blit['dest'].center[0]-.07*self.cell_size, + agent_blit['dest'].center[1])) + blits += [agent_blit, state_blit, text_blit] for blit in blits: self.screen.blit(**blit) diff --git a/marl_factory_grid/utils/results.py b/marl_factory_grid/utils/results.py index 9f0fa38..6abf11c 100644 --- a/marl_factory_grid/utils/results.py +++ b/marl_factory_grid/utils/results.py @@ -28,7 +28,10 @@ class Result: def __repr__(self): valid = "not " if not self.validity else "" - return f'{self.__class__.__name__}({self.identifier.capitalize()} {valid}valid: {self.reward})' + reward = f" | Reward: {self.reward}" if self.reward is not None else "" + value = f" | Value: {self.value}" if self.value is not None else "" + entity = f" | by: {self.entity.name}" if self.entity is not None else "" + return f'{self.__class__.__name__}({self.identifier.capitalize()} {valid}valid{reward}{value})' @dataclass diff --git a/marl_factory_grid/utils/states.py b/marl_factory_grid/utils/states.py index 4c1f7f2..fc07b95 100644 --- a/marl_factory_grid/utils/states.py +++ b/marl_factory_grid/utils/states.py @@ -1,3 +1,4 @@ +from itertools import islice from typing import List, Dict, Tuple import numpy as np @@ -59,14 +60,15 @@ class Gamestate(object): def moving_entites(self): return [y for x in self.entities for y in x if x.var_can_move] - def __init__(self, entities, agents_conf, rules: Dict[str, dict], env_seed=69, verbose=False): + def __init__(self, entities, agents_conf, rules: List[Rule], lvl_shape, env_seed=69, verbose=False): + self.lvl_shape = lvl_shape self.entities = entities self.curr_step = 0 self.curr_actions = None self.agents_conf = agents_conf self.verbose = verbose self.rng = np.random.default_rng(env_seed) - self.rules = StepRules(*(v['class'](**v['kwargs']) for v in rules.values())) + self.rules = StepRules(*rules) def __getitem__(self, item): return self.entities[item] @@ -80,6 +82,13 @@ class Gamestate(object): def __repr__(self): return f'{self.__class__.__name__}({len(self.entities)} Entitites @ Step {self.curr_step})' + @property + def random_free_position(self): + return self.get_n_random_free_positions(1)[0] + + def get_n_random_free_positions(self, n): + return list(islice(self.entities.free_positions_generator, n)) + def tick(self, actions) -> List[Result]: results = list() self.curr_step += 1 @@ -115,8 +124,7 @@ class Gamestate(object): return results def get_all_pos_with_collisions(self) -> List[Tuple[(int, int)]]: - positions = [pos for pos, entity_list_for_position in self.entities.pos_dict.items() - if any([e.var_can_collide for e in entity_list_for_position])] + positions = [pos for pos, entities in self.entities.pos_dict.items() if len(entities) >= 2 and (len([e for e in entities if e.var_can_collide]) >= 2)] return positions def check_move_validity(self, moving_entity, position): diff --git a/marl_factory_grid/utils/tools.py b/marl_factory_grid/utils/tools.py index d2f9bd1..63c9f69 100644 --- a/marl_factory_grid/utils/tools.py +++ b/marl_factory_grid/utils/tools.py @@ -135,4 +135,3 @@ if __name__ == '__main__': ce.get_observations() ce.get_assets() all_conf = ce.get_all() - print() diff --git a/marl_factory_grid/utils/utility_classes.py b/marl_factory_grid/utils/utility_classes.py index 4844133..4d1cfe1 100644 --- a/marl_factory_grid/utils/utility_classes.py +++ b/marl_factory_grid/utils/utility_classes.py @@ -52,3 +52,6 @@ class Floor: def __hash__(self): return hash(self.name) + + def __repr__(self): + return f"Floor{self.pos}" diff --git a/reload_agent.py b/reload_agent.py index 8c16069..f0ed389 100644 --- a/reload_agent.py +++ b/reload_agent.py @@ -6,6 +6,7 @@ import yaml from marl_factory_grid.environment.factory import Factory from marl_factory_grid.utils.logging.envmonitor import EnvMonitor from marl_factory_grid.utils.logging.recorder import EnvRecorder +from marl_factory_grid.utils import helpers as h from marl_factory_grid.modules.doors import constants as d @@ -61,7 +62,7 @@ if __name__ == '__main__': if render: env.render() try: - door = next(x for x in env.unwrapped.unwrapped[d.DOORS] if x.is_open) + door = h.get_first([x for x in env.unwrapped.unwrapped[d.DOORS] if x.is_open]) print('openDoor found') except StopIteration: pass diff --git a/studies/normalization_study.py b/studies/normalization_study.py index 37e10c4..7c72982 100644 --- a/studies/normalization_study.py +++ b/studies/normalization_study.py @@ -1,8 +1,8 @@ from algorithms.utils import Checkpointer from pathlib import Path from algorithms.utils import load_yaml_file, add_env_props, instantiate_class, load_class -#from algorithms.marl import LoopSNAC, LoopIAC, LoopSEAC +# from algorithms.marl import LoopSNAC, LoopIAC, LoopSEAC for i in range(0, 5): diff --git a/transform_wg_to_json_no_priv.py b/transform_wg_to_json_no_priv.py new file mode 100644 index 0000000..d9bc8e1 --- /dev/null +++ b/transform_wg_to_json_no_priv.py @@ -0,0 +1,43 @@ +import configparser +import json +from datetime import datetime +from pathlib import Path + +if __name__ == '__main__': + + + conf_path = Path('wg0') + wg0_conf = configparser.ConfigParser() + wg0_conf.read(conf_path/'wg0.conf') + interface = wg0_conf['Interface'] + # Iterate all pears + for client_name in wg0_conf.sections(): + if client_name == 'Interface': + continue + # Delete any old conf.json for the current peer + (conf_path / f'{client_name}.json').unlink(missing_ok=True) + + + peer = wg0_conf[client_name] + + date_time = datetime.now().strftime('%Y-%m-%dT%H:%M:%S.%f000Z') + + jdict = dict( + id=client_name, + private_key=peer['PublicKey'], + public_key=peer['PublicKey'], + # preshared_key=wg0_conf[client_name_wg0]['PresharedKey'], + name=client_name, + email=f"sysadmin@mobile.ifi.lmu.de", + allocated_ips=[interface['Address'].replace('/24', '')], + allowed_ips=['10.4.0.0/24', '10.153.199.0/24'], + extra_allowed_ips=[], + use_server_dns=True, + enabled=True, + created_at=date_time, + updated_at=date_time + ) + + with (conf_path / f'{client_name}.json').open('w+') as f: + json.dump(jdict, f, indent='\t', separators=(',', ': ')) + print(client_name, ' written...')