From f5c6317158a0214bee060c08ef52ab88d9a1b4df Mon Sep 17 00:00:00 2001 From: Steffen Illium Date: Thu, 12 Oct 2023 17:14:32 +0200 Subject: [PATCH] Remove BoundDestination Object New Variable 'var_can_be_bound' Observations adjusted accordingly --- .../configs/narrow_corridor.yaml | 10 +++-- .../configs/two_rooms_one_door.yaml | 6 +-- .../environment/entity/entity.py | 10 ++++- marl_factory_grid/environment/entity/mixin.py | 8 +++- .../environment/entity/object.py | 7 +++ marl_factory_grid/environment/factory.py | 6 +-- .../environment/groups/agents.py | 4 -- .../environment/groups/env_objects.py | 1 + .../environment/groups/mixins.py | 4 +- .../environment/groups/objects.py | 4 +- marl_factory_grid/environment/rules.py | 11 ++--- marl_factory_grid/modules/aomas/__init__.py | 0 .../modules/aomas/narrow_corridor/__init__.py | 0 .../modules/aomas/narrow_corridor/rules.py | 28 ------------ .../modules/destinations/__init__.py | 2 +- .../modules/destinations/actions.py | 4 +- .../modules/destinations/constants.py | 1 - .../modules/destinations/entitites.py | 43 ++++++++++--------- .../modules/destinations/groups.py | 11 +---- .../modules/destinations/rules.py | 37 +++++++++------- .../modules/levels/narrow_corridor.txt | 2 +- marl_factory_grid/modules/zones/rules.py | 9 ++-- 22 files changed, 98 insertions(+), 110 deletions(-) delete mode 100644 marl_factory_grid/modules/aomas/__init__.py delete mode 100644 marl_factory_grid/modules/aomas/narrow_corridor/__init__.py delete mode 100644 marl_factory_grid/modules/aomas/narrow_corridor/rules.py diff --git a/marl_factory_grid/configs/narrow_corridor.yaml b/marl_factory_grid/configs/narrow_corridor.yaml index 446cf4f..0006513 100644 --- a/marl_factory_grid/configs/narrow_corridor.yaml +++ b/marl_factory_grid/configs/narrow_corridor.yaml @@ -5,7 +5,8 @@ Agents: - Move8 Observations: - Walls - - BoundDestination + - Other + - Destination Positions: - (2, 1) - (2, 5) @@ -15,12 +16,13 @@ Agents: - Move8 Observations: - Walls - - BoundDestination + - Other + - Destination Positions: - (2, 1) - (2, 5) Entities: - BoundDestinations: {} + Destinations: {} General: env_seed: 69 @@ -32,7 +34,7 @@ General: Rules: SpawnAgents: {} Collision: - done_at_collisions: true + done_at_collisions: false FixedDestinationSpawn: per_agent_positions: Wolfgang: diff --git a/marl_factory_grid/configs/two_rooms_one_door.yaml b/marl_factory_grid/configs/two_rooms_one_door.yaml index f1b9a32..cf6a074 100644 --- a/marl_factory_grid/configs/two_rooms_one_door.yaml +++ b/marl_factory_grid/configs/two_rooms_one_door.yaml @@ -6,7 +6,7 @@ General: verbose: false Entities: - BoundDestinations: {} + Destinations: {} Doors: {} GlobalPositions: {} Zones: {} @@ -36,7 +36,7 @@ Agents: - Walls - Other - Doors - - BoundDestination + - Destination Sigmund: Actions: - Move8 @@ -47,5 +47,5 @@ Agents: - Combined: - Other - Walls - - BoundDestination + - Destination - Doors \ No newline at end of file diff --git a/marl_factory_grid/environment/entity/entity.py b/marl_factory_grid/environment/entity/entity.py index dfcb504..c995eeb 100644 --- a/marl_factory_grid/environment/entity/entity.py +++ b/marl_factory_grid/environment/entity/entity.py @@ -56,6 +56,7 @@ class Entity(EnvObject, abc.ABC): return last_x - curr_x, last_y - curr_y def destroy(self): + if valid = self._collection.remove_item(self) for observer in self.observers: observer.notify_del_entity(self) @@ -73,10 +74,17 @@ class Entity(EnvObject, abc.ABC): return valid return not_same_tile - def __init__(self, tile, **kwargs): + def __init__(self, tile, bind_to=None, **kwargs): super().__init__(**kwargs) self._status = None self._tile = tile + if bind_to: + try: + self.bind_to(bind_to) + except AttributeError: + print(f'Objects of {self.__class__.__name__} can not be bound to other entities.') + exit() + assert tile.enter(self, spawn=True), "Positions was not valid!" def summarize_state(self) -> dict: diff --git a/marl_factory_grid/environment/entity/mixin.py b/marl_factory_grid/environment/entity/mixin.py index 7148504..2a318c7 100644 --- a/marl_factory_grid/environment/entity/mixin.py +++ b/marl_factory_grid/environment/entity/mixin.py @@ -9,10 +9,16 @@ class BoundEntityMixin: @property def name(self): - return f'{self.__class__.__name__}({self.bound_entity.name})' + if self.bound_entity: + return f'{self.__class__.__name__}({self.bound_entity.name})' + else: + print() def belongs_to_entity(self, entity): return entity == self.bound_entity def bind_to(self, entity): self._bound_entity = entity + + def unbind(self): + self._bound_entity = None diff --git a/marl_factory_grid/environment/entity/object.py b/marl_factory_grid/environment/entity/object.py index 2de33cd..1e023b6 100644 --- a/marl_factory_grid/environment/entity/object.py +++ b/marl_factory_grid/environment/entity/object.py @@ -91,6 +91,13 @@ class EnvObject(Object): except AttributeError: return False + @property + def var_can_be_bound(self): + try: + return self._collection.var_can_be_bound or False + except AttributeError: + return False + @property def var_can_move(self): try: diff --git a/marl_factory_grid/environment/factory.py b/marl_factory_grid/environment/factory.py index ba636fc..1f20ffb 100644 --- a/marl_factory_grid/environment/factory.py +++ b/marl_factory_grid/environment/factory.py @@ -90,7 +90,7 @@ class Factory(gym.Env): # Parse the agent conf parsed_agents_conf = self.conf.parse_agents_conf() - self.state = Gamestate(entities, parsed_agents_conf, rules, self.conf.env_seed) + self.state = Gamestate(entities, parsed_agents_conf, rules, self.conf.env_seed, self.conf.verbose) # All is set up, trigger entity init with variable pos self.state.rules.do_all_init(self.state, self.map) @@ -235,10 +235,6 @@ class Factory(gym.Env): del summary[key] return summary - def print(self, string): - if self.conf.verbose: - print(string) - def save_params(self, filepath: Path): # noinspection PyProtectedMember filepath = Path(filepath) diff --git a/marl_factory_grid/environment/groups/agents.py b/marl_factory_grid/environment/groups/agents.py index f7839ba..0169f88 100644 --- a/marl_factory_grid/environment/groups/agents.py +++ b/marl_factory_grid/environment/groups/agents.py @@ -11,10 +11,6 @@ class Agents(PositionMixin, EnvObjects): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - @property - def obs_pairs(self): - return [(a.name, a) for a in self] - @property def action_space(self): from gymnasium import spaces diff --git a/marl_factory_grid/environment/groups/env_objects.py b/marl_factory_grid/environment/groups/env_objects.py index dcbcd26..1113833 100644 --- a/marl_factory_grid/environment/groups/env_objects.py +++ b/marl_factory_grid/environment/groups/env_objects.py @@ -9,6 +9,7 @@ class EnvObjects(Objects): var_can_collide: bool = False var_has_position: bool = False var_can_move: bool = False + var_can_be_bound: bool = False @property def encodings(self): diff --git a/marl_factory_grid/environment/groups/mixins.py b/marl_factory_grid/environment/groups/mixins.py index 88d2841..bab2b84 100644 --- a/marl_factory_grid/environment/groups/mixins.py +++ b/marl_factory_grid/environment/groups/mixins.py @@ -92,11 +92,11 @@ class HasBoundMixin: def by_entity(self, entity): try: return next((x for x in self if x.belongs_to_entity(entity))) - except StopIteration: + except (StopIteration, AttributeError): return None def idx_by_entity(self, entity): try: return next((idx for idx, x in enumerate(self) if x.belongs_to_entity(entity))) - except StopIteration: + except (StopIteration, AttributeError): return None diff --git a/marl_factory_grid/environment/groups/objects.py b/marl_factory_grid/environment/groups/objects.py index d9c8f4c..57e0106 100644 --- a/marl_factory_grid/environment/groups/objects.py +++ b/marl_factory_grid/environment/groups/objects.py @@ -24,7 +24,9 @@ class Objects: @property def obs_pairs(self): - return [(self.name, self)] + pair_list = [(self.name, self)] + pair_list.extend([(a.name, a) for a in self]) + return pair_list @property def names(self): diff --git a/marl_factory_grid/environment/rules.py b/marl_factory_grid/environment/rules.py index 11abd42..99530a8 100644 --- a/marl_factory_grid/environment/rules.py +++ b/marl_factory_grid/environment/rules.py @@ -124,12 +124,13 @@ class Collision(Rule): pass results.append(TickResult(entity=guest, identifier=c.COLLISION, reward=r.COLLISION, validity=c.VALID)) - self.curr_done = True + self.curr_done = True if self.done_at_collisions else False return results def on_check_done(self, state) -> List[DoneResult]: - inter_entity_collision_detected = self.curr_done and self.done_at_collisions - move_failed = any(h.is_move(x.state.identifier) and not x.state.validity for x in state[c.AGENT]) - if inter_entity_collision_detected or move_failed: - return [DoneResult(validity=c.VALID, identifier=c.COLLISION, reward=r.COLLISION)] + if self.done_at_collisions: + inter_entity_collision_detected = self.curr_done + move_failed = any(h.is_move(x.state.identifier) and not x.state.validity for x in state[c.AGENT]) + if inter_entity_collision_detected or move_failed: + return [DoneResult(validity=c.VALID, identifier=c.COLLISION, reward=r.COLLISION)] return [DoneResult(validity=c.NOT_VALID, identifier=self.name, reward=0)] diff --git a/marl_factory_grid/modules/aomas/__init__.py b/marl_factory_grid/modules/aomas/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/marl_factory_grid/modules/aomas/narrow_corridor/__init__.py b/marl_factory_grid/modules/aomas/narrow_corridor/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/marl_factory_grid/modules/aomas/narrow_corridor/rules.py b/marl_factory_grid/modules/aomas/narrow_corridor/rules.py deleted file mode 100644 index 01a12cc..0000000 --- a/marl_factory_grid/modules/aomas/narrow_corridor/rules.py +++ /dev/null @@ -1,28 +0,0 @@ -from random import shuffle -from typing import List, Tuple - -from marl_factory_grid.environment.rules import Rule -from marl_factory_grid.environment import constants as c -from marl_factory_grid.modules.destinations import constants as d -from marl_factory_grid.modules.destinations.entitites import BoundDestination - - -class NarrowCorridorSpawn(Rule): - def __init__(self, positions: List[Tuple[int, int]], fixed: bool = False): - super().__init__() - self.fixed = fixed - self.positions = positions - - def on_init(self, state, lvl_map): - if not self.fixed: - shuffle(self.positions) - for agent in state[c.AGENT]: - pass - - def trigger_destination_spawn(self, state): - for (agent_name, position_list) in self.per_agent_positions.items(): - agent = state[c.AGENT][agent_name] - destinations = [BoundDestination(agent, pos) for pos in position_list] - state[d.DESTINATION].add_items(destinations) - return c.VALID - diff --git a/marl_factory_grid/modules/destinations/__init__.py b/marl_factory_grid/modules/destinations/__init__.py index 565056c..3bbf997 100644 --- a/marl_factory_grid/modules/destinations/__init__.py +++ b/marl_factory_grid/modules/destinations/__init__.py @@ -1,4 +1,4 @@ from .actions import DestAction from .entitites import Destination -from .groups import Destinations, BoundDestinations +from .groups import Destinations from .rules import DestinationReachAll, DestinationSpawn diff --git a/marl_factory_grid/modules/destinations/actions.py b/marl_factory_grid/modules/destinations/actions.py index a0529b5..05c4955 100644 --- a/marl_factory_grid/modules/destinations/actions.py +++ b/marl_factory_grid/modules/destinations/actions.py @@ -13,9 +13,7 @@ class DestAction(Action): super().__init__(d.DESTINATION) def do(self, entity, state) -> Union[None, ActionResult]: - dest_entities = d.DESTINATION if d.DESTINATION in state else d.BOUNDDESTINATION - assert dest_entities - if destination := state[dest_entities].by_pos(entity.pos): + if destination := state[d.DESTINATION].by_pos(entity.pos): valid = destination.do_wait_action(entity) state.print(f'{entity.name} just waited at {entity.pos}') else: diff --git a/marl_factory_grid/modules/destinations/constants.py b/marl_factory_grid/modules/destinations/constants.py index 85ee07d..5202b35 100644 --- a/marl_factory_grid/modules/destinations/constants.py +++ b/marl_factory_grid/modules/destinations/constants.py @@ -1,7 +1,6 @@ # Destination Env DESTINATION = 'Destinations' -BOUNDDESTINATION = 'BoundDestinations' DEST_SYMBOL = 1 WAIT_ON_DEST = 'WAIT' diff --git a/marl_factory_grid/modules/destinations/entitites.py b/marl_factory_grid/modules/destinations/entitites.py index 7c292b5..179f38d 100644 --- a/marl_factory_grid/modules/destinations/entitites.py +++ b/marl_factory_grid/modules/destinations/entitites.py @@ -8,13 +8,18 @@ from marl_factory_grid.utils.render import RenderEntity from marl_factory_grid.modules.destinations import constants as d -class Destination(Entity): +class Destination(BoundEntityMixin, Entity): var_can_move = False var_can_collide = False var_has_position = True var_is_blocking_pos = False var_is_blocking_light = False + var_can_be_bound = True # Introduce this globally! + + @property + def was_reached(self): + return self._was_reached @property def encoding(self): @@ -22,6 +27,7 @@ class Destination(Entity): def __init__(self, *args, action_counts=0, **kwargs): super(Destination, self).__init__(*args, **kwargs) + self._was_reached = False self.action_counts = action_counts self._per_agent_actions = defaultdict(lambda: 0) @@ -30,9 +36,15 @@ class Destination(Entity): return c.VALID @property - def is_considered_reached(self): - agent_at_position = any(c.AGENT.lower() in x.name.lower() for x in self.tile.guests_that_can_collide) - return agent_at_position or any(x >= self.action_counts for x in self._per_agent_actions.values()) + def has_just_been_reached(self): + if self.was_reached: + return False + agent_at_position = any(self.bound_entity == x for x in self.tile.guests_that_can_collide) + if self.bound_entity: + return ((agent_at_position and not self.action_counts) + or self._per_agent_actions[self.bound_entity.name] >= self.action_counts >= 1) + else: + return agent_at_position or any(x >= self.action_counts for x in self._per_agent_actions.values()) def agent_did_action(self, agent: Agent): return self._per_agent_actions[agent.name] >= self.action_counts @@ -44,21 +56,10 @@ class Destination(Entity): return state_summary def render(self): - return RenderEntity(d.DESTINATION, self.pos) + if self.was_reached: + return None + else: + return RenderEntity(d.DESTINATION, self.pos) - -class BoundDestination(BoundEntityMixin, Destination): - - @property - def encoding(self): - return d.DEST_SYMBOL - - def __init__(self, entity, *args, **kwargs): - self.bind_to(entity) - super().__init__(*args, **kwargs) - - @property - def is_considered_reached(self): - agent_at_position = any(self.bound_entity == x for x in self.tile.guests_that_can_collide) - return ((agent_at_position and not self.action_counts) - or self._per_agent_actions[self.bound_entity.name] >= self.action_counts >= 1) + def mark_as_reached(self): + self._was_reached = True diff --git a/marl_factory_grid/modules/destinations/groups.py b/marl_factory_grid/modules/destinations/groups.py index 3e296ff..b0e55c9 100644 --- a/marl_factory_grid/modules/destinations/groups.py +++ b/marl_factory_grid/modules/destinations/groups.py @@ -1,6 +1,6 @@ from marl_factory_grid.environment.groups.env_objects import EnvObjects from marl_factory_grid.environment.groups.mixins import PositionMixin, HasBoundMixin -from marl_factory_grid.modules.destinations.entitites import Destination, BoundDestination +from marl_factory_grid.modules.destinations.entitites import Destination class Destinations(PositionMixin, EnvObjects): @@ -14,12 +14,3 @@ class Destinations(PositionMixin, EnvObjects): def __repr__(self): return super(Destinations, self).__repr__() - - -class BoundDestinations(HasBoundMixin, Destinations): - - _entity = BoundDestination - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - diff --git a/marl_factory_grid/modules/destinations/rules.py b/marl_factory_grid/modules/destinations/rules.py index 7ea320e..fd77da1 100644 --- a/marl_factory_grid/modules/destinations/rules.py +++ b/marl_factory_grid/modules/destinations/rules.py @@ -1,12 +1,12 @@ import ast from random import shuffle -from typing import List, Union, Dict, Tuple +from typing import List, Dict, Tuple from marl_factory_grid.environment.rules import Rule from marl_factory_grid.utils.results import TickResult, DoneResult from marl_factory_grid.environment import constants as c from marl_factory_grid.modules.destinations import constants as d, rewards as r -from marl_factory_grid.modules.destinations.entitites import Destination, BoundDestination +from marl_factory_grid.modules.destinations.entitites import Destination class DestinationReachAll(Rule): @@ -16,21 +16,28 @@ class DestinationReachAll(Rule): def tick_step(self, state) -> List[TickResult]: results = [] - for dest in list(state[next(key for key in state.entities.names if d.DESTINATION in key)]): - if dest.is_considered_reached: - agent = state[c.AGENT].by_pos(dest.pos) - results.append(TickResult(self.name, validity=c.VALID, reward=r.DEST_REACHED, entity=agent)) - state.print(f'{dest.name} is reached now, removing...') - assert dest.destroy(), f'{dest.name} could not be destroyed. Critical Error.' + for dest in state[d.DESTINATION]: + if dest.has_just_been_reached and not dest.was_reached: + # Dest has just been reached, some agent needs to stand here, grab any first. + for agent in state[c.AGENT].by_pos(dest.pos): + if dest.bound_entity: + if dest.bound_entity == agent: + results.append(TickResult(self.name, validity=c.VALID, reward=r.DEST_REACHED, entity=agent)) + else: + pass + else: + results.append(TickResult(self.name, validity=c.VALID, reward=r.DEST_REACHED, entity=agent)) + state.print(f'{dest.name} is reached now, mark as reached...') + dest.mark_as_reached() else: - pass - return [TickResult(self.name, validity=c.VALID, reward=0, entity=None)] + pass + return results def tick_post_step(self, state) -> List[TickResult]: return [] def on_check_done(self, state) -> List[DoneResult]: - if not len(state[next(key for key in state.entities.names if d.DESTINATION in key)]): + if all(x.was_reached for x in state[d.DESTINATION]): return [DoneResult(self.name, validity=c.VALID, reward=r.DEST_REACHED)] return [DoneResult(self.name, validity=c.NOT_VALID, reward=0)] @@ -41,7 +48,7 @@ class DestinationReachAny(DestinationReachAll): super(DestinationReachAny, self).__init__() def on_check_done(self, state) -> List[DoneResult]: - if not len(state[next(key for key in state.entities.names if d.DESTINATION in key)]): + if any(x.was_reached for x in state[d.DESTINATION]): return [DoneResult(self.name, validity=c.VALID, reward=r.DEST_REACHED)] return [] @@ -95,10 +102,10 @@ class FixedDestinationSpawn(Rule): shuffle(position_list) while True: pos = position_list.pop() - if pos != agent.pos and not state[d.BOUNDDESTINATION].by_pos(pos): - destination = BoundDestination(agent, state[c.FLOORS].by_pos(pos)) + if pos != agent.pos and not state[d.DESTINATION].by_pos(pos): + destination = Destination(state[c.FLOORS].by_pos(pos), bind_to=agent) break else: continue - state[d.BOUNDDESTINATION].add_item(destination) + state[d.DESTINATION].add_item(destination) pass diff --git a/marl_factory_grid/modules/levels/narrow_corridor.txt b/marl_factory_grid/modules/levels/narrow_corridor.txt index 978fbd6..3f56673 100644 --- a/marl_factory_grid/modules/levels/narrow_corridor.txt +++ b/marl_factory_grid/modules/levels/narrow_corridor.txt @@ -1,5 +1,5 @@ ####### ###-### -#1---2# +#-----# ###-### ####### diff --git a/marl_factory_grid/modules/zones/rules.py b/marl_factory_grid/modules/zones/rules.py index 3416e59..8921c9d 100644 --- a/marl_factory_grid/modules/zones/rules.py +++ b/marl_factory_grid/modules/zones/rules.py @@ -2,7 +2,7 @@ from random import choices, choice from . import constants as z, Zone from ..destinations import constants as d -from ..destinations.entitites import BoundDestination +from ... import Destination from ...environment.rules import Rule from ...environment import constants as c @@ -66,9 +66,10 @@ class IndividualDestinationZonePlacement(Rule): already_has_destination = True while already_has_destination: tile = choice(other_zones).random_tile - if state[d.BOUNDDESTINATION].by_pos(tile.pos) is None: + if state[d.DESTINATION].by_pos(tile.pos) is None: already_has_destination = False - destination = BoundDestination(agent, tile) - state[d.BOUNDDESTINATION].add_item(destination) + destination = Destination(tile, bind_to=agent) + + state[d.DESTINATION].add_item(destination) continue return c.VALID