diff --git a/marl_factory_grid/algorithms/static/TSP_base_agent.py b/marl_factory_grid/algorithms/static/TSP_base_agent.py index 919d57e..bc48f7c 100644 --- a/marl_factory_grid/algorithms/static/TSP_base_agent.py +++ b/marl_factory_grid/algorithms/static/TSP_base_agent.py @@ -21,7 +21,7 @@ class TSPBaseAgent(ABC): self.local_optimization = True self._env = state self.state = self._env.state[c.AGENT][agent_i] - self._floortile_graph = points_to_graph(self._env[c.FLOORS].positions) + self._position_graph = points_to_graph(self._env.entities.floorlist) self._static_route = None @abstractmethod @@ -50,7 +50,7 @@ class TSPBaseAgent(ABC): else: nodes = [self.state.pos] + positions - route = tsp.traveling_salesman_problem(self._floortile_graph, + route = tsp.traveling_salesman_problem(self._position_graph, nodes=nodes, cycle=True, method=tsp.greedy_tsp) return route diff --git a/marl_factory_grid/algorithms/static/utils.py b/marl_factory_grid/algorithms/static/utils.py index 2543152..d5119db 100644 --- a/marl_factory_grid/algorithms/static/utils.py +++ b/marl_factory_grid/algorithms/static/utils.py @@ -4,17 +4,17 @@ import networkx as nx import numpy as np -def points_to_graph(coordiniates_or_tiles, allow_euclidean_connections=True, allow_manhattan_connections=True): +def points_to_graph(coordiniates, allow_euclidean_connections=True, allow_manhattan_connections=True): """ Given a set of coordinates, this function contructs a non-directed graph, by conncting adjected points. There are three combinations of settings: Allow all neigbors: Distance(a, b) <= sqrt(2) Allow only manhattan: Distance(a, b) == 1 - Allow only euclidean: Distance(a, b) == sqrt(2) + Allow only Euclidean: Distance(a, b) == sqrt(2) - :param coordiniates_or_tiles: A set of coordinates. - :type coordiniates_or_tiles: Tiles + :param coordiniates: A set of coordinates. + :type coordiniates: Tuple[int, int] :param allow_euclidean_connections: Whether to regard diagonal adjected cells as neighbors :type: bool :param allow_manhattan_connections: Whether to regard directly adjected cells as neighbors @@ -24,9 +24,7 @@ def points_to_graph(coordiniates_or_tiles, allow_euclidean_connections=True, all :rtype: nx.Graph """ assert allow_euclidean_connections or allow_manhattan_connections - if hasattr(coordiniates_or_tiles, 'positions'): - coordiniates_or_tiles = coordiniates_or_tiles.positions - possible_connections = itertools.combinations(coordiniates_or_tiles, 2) + possible_connections = itertools.combinations(coordiniates, 2) graph = nx.Graph() for a, b in possible_connections: diff = np.linalg.norm(np.asarray(a)-np.asarray(b)) diff --git a/marl_factory_grid/configs/default_config.yaml b/marl_factory_grid/configs/default_config.yaml index 9c9bb6c..e8687ff 100644 --- a/marl_factory_grid/configs/default_config.yaml +++ b/marl_factory_grid/configs/default_config.yaml @@ -66,7 +66,6 @@ Rules: DestinationDone: {} DestinationReach: n_dests: 1 - tiles: null DestinationSpawn: n_dests: 1 spawn_frequency: 5 diff --git a/marl_factory_grid/environment/actions.py b/marl_factory_grid/environment/actions.py index eaa8781..4edfe24 100644 --- a/marl_factory_grid/environment/actions.py +++ b/marl_factory_grid/environment/actions.py @@ -42,12 +42,12 @@ class Move(Action, abc.ABC): def do(self, entity, state): new_pos = self._calc_new_pos(entity.pos) - if state.check_move_validity(entity, new_pos): # next_tile := state[c.FLOOR].by_pos(new_pos): + if state.check_move_validity(entity, new_pos): # noinspection PyUnresolvedReferences move_validity = entity.move(new_pos, state) reward = r.MOVEMENTS_VALID if move_validity else r.MOVEMENTS_FAIL return ActionResult(entity=entity, identifier=self._identifier, validity=move_validity, reward=reward) - else: # There is no floor, propably collision + else: # There is no place to go, propably collision # This is currently handeld by the Collision rule, so that it can be switched on and off by conf.yml # return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=r.COLLISION) return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=0) diff --git a/marl_factory_grid/environment/constants.py b/marl_factory_grid/environment/constants.py index d229766..1fdf639 100644 --- a/marl_factory_grid/environment/constants.py +++ b/marl_factory_grid/environment/constants.py @@ -3,15 +3,13 @@ DANGER_ZONE = 'x' # Dange Zone tile _identifier fo DEFAULTS = 'Defaults' SELF = 'Self' PLACEHOLDER = 'Placeholder' -FLOOR = 'Floor' # Identifier of Floor-objects and groups (groups). -FLOORS = 'Floors' # Identifier of Floor-objects and groups (groups). WALL = 'Wall' # Identifier of Wall-objects and groups (groups). WALLS = 'Walls' # Identifier of Wall-objects and groups (groups). LEVEL = 'Level' # Identifier of Level-objects and groups (groups). AGENT = 'Agent' # Identifier of Agent-objects and groups (groups). OTHERS = 'Other' COMBINED = 'Combined' -GLOBALPOSITIONS = 'GlobalPositions' # Identifier of the global position slice +GLOBALPOSITIONS = 'GlobalPositions' # Identifier of the global position slice # Attributes IS_BLOCKING_LIGHT = 'var_is_blocking_light' @@ -32,7 +30,7 @@ VALUE_NO_POS = (-9999, -9999) # Invalid Position value used in the e ACTION = 'action' # Identifier of Action-objects and groups (groups). COLLISION = 'Collision' # Identifier to use in the context of collitions. -LAST_POS = 'LAST_POS' # Identifiert for retrieving an enitites last pos. +# LAST_POS = 'LAST_POS' # Identifiert for retrieving an enitites last pos. VALIDITY = 'VALIDITY' # Identifiert for retrieving the Validity of Action, Tick, etc. ... # Actions diff --git a/marl_factory_grid/environment/entity/agent.py b/marl_factory_grid/environment/entity/agent.py index 206be49..61e33d3 100644 --- a/marl_factory_grid/environment/entity/agent.py +++ b/marl_factory_grid/environment/entity/agent.py @@ -2,7 +2,7 @@ from typing import List, Union from marl_factory_grid.environment.actions import Action from marl_factory_grid.environment.entity.entity import Entity -from marl_factory_grid.utils.render import RenderEntity +from marl_factory_grid.utils.utility_classes import RenderEntity from marl_factory_grid.utils import renderer from marl_factory_grid.utils.helpers import is_move from marl_factory_grid.utils.results import ActionResult, Result diff --git a/marl_factory_grid/environment/entity/entity.py b/marl_factory_grid/environment/entity/entity.py index 67bf665..f24c86c 100644 --- a/marl_factory_grid/environment/entity/entity.py +++ b/marl_factory_grid/environment/entity/entity.py @@ -1,8 +1,10 @@ import abc +import numpy as np + from .. import constants as c from .object import EnvObject -from ...utils.render import RenderEntity +from ...utils.utility_classes import RenderEntity from ...utils.results import ActionResult @@ -30,33 +32,32 @@ class Entity(EnvObject, abc.ABC): return self._pos @property - def tile(self): - return self._tile # wall_n_floors funktionalität - - # @property - # def last_tile(self): - # try: - # return self._last_tile - # except AttributeError: - # # noinspection PyAttributeOutsideInit - # self._last_tile = None - # return self._last_tile + def last_pos(self): + try: + return self._last_pos + except AttributeError: + # noinspection PyAttributeOutsideInit + self._last_pos = c.VALUE_NO_POS + return self._last_pos @property def direction_of_view(self): - last_x, last_y = self._last_pos - curr_x, curr_y = self.pos - return last_x - curr_x, last_y - curr_y + if self._last_pos != c.VALUE_NO_POS: + return 0, 0 + else: + return np.subtract(self._last_pos, self.pos) def move(self, next_pos, state): next_pos = next_pos curr_pos = self._pos if not_same_pos := curr_pos != next_pos: if valid := state.check_move_validity(self, next_pos): - self._pos = next_pos - self._last_pos = curr_pos for observer in self.observers: - observer.notify_change_pos(self) + observer.notify_del_entity(self) + self._view_directory = curr_pos[0]-next_pos[0], curr_pos[1]-next_pos[1] + self._pos = next_pos + for observer in self.observers: + observer.notify_add_entity(self) return valid return not_same_pos @@ -64,6 +65,7 @@ class Entity(EnvObject, abc.ABC): super().__init__(**kwargs) self._status = None self._pos = pos + self._last_pos = pos if bind_to: try: self.bind_to(bind_to) diff --git a/marl_factory_grid/environment/entity/wall_floor.py b/marl_factory_grid/environment/entity/wall_floor.py index c5bb09a..e8b153e 100644 --- a/marl_factory_grid/environment/entity/wall_floor.py +++ b/marl_factory_grid/environment/entity/wall_floor.py @@ -4,7 +4,7 @@ import numpy as np from marl_factory_grid.environment import constants as c from marl_factory_grid.environment.entity.object import EnvObject -from marl_factory_grid.utils.render import RenderEntity +from marl_factory_grid.utils.utility_classes import RenderEntity from marl_factory_grid.utils import helpers as h @@ -30,17 +30,6 @@ class Floor(EnvObject): def var_is_blocking_light(self): return False - @property - def neighboring_floor(self): - if self._neighboring_floor: - pass - else: - self._neighboring_floor = [x for x in [self._collection.by_pos(np.add(self.pos, pos)) - for pos in h.POS_MASK.reshape(-1, 2) - if not np.all(pos == [0, 0])] - if x] - return self._neighboring_floor - @property def encoding(self): return c.VALUE_OCCUPIED_CELL diff --git a/marl_factory_grid/environment/factory.py b/marl_factory_grid/environment/factory.py index 2054d99..d840178 100644 --- a/marl_factory_grid/environment/factory.py +++ b/marl_factory_grid/environment/factory.py @@ -197,7 +197,7 @@ class Factory(gym.Env): del rewards['global'] reward = [rewards[agent.name] for agent in self.state[c.AGENT]] reward = [x + global_rewards for x in reward] - self.state.print(f"rewards are {rewards}") + self.state.print(f"Individual rewards are {dict(rewards)}") return reward, combined_info_dict, done else: reward = sum(rewards.values()) @@ -220,7 +220,7 @@ class Factory(gym.Env): def summarize_header(self): header = {'rec_step': self.state.curr_step} - for entity_group in (x for x in self.state if x.name in ['Walls', 'Floors', 'DropOffLocations', 'ChargePods']): + for entity_group in (x for x in self.state if x.name in ['Walls', 'DropOffLocations', 'ChargePods']): header.update({f'rec{entity_group.name}': entity_group.summarize_states()}) return header @@ -229,7 +229,7 @@ class Factory(gym.Env): # Todo: Protobuff Compatibility Section ####### # for entity_group in (x for x in self.state if x.name not in [c.WALLS, c.FLOORS]): - for entity_group in (x for x in self.state if x.name not in [c.FLOORS]): + for entity_group in self.state: summary.update({entity_group.name.lower(): entity_group.summarize_states()}) # TODO Section End ######## for key in list(summary.keys()): diff --git a/marl_factory_grid/environment/groups/global_entities.py b/marl_factory_grid/environment/groups/global_entities.py index 3051504..d403a41 100644 --- a/marl_factory_grid/environment/groups/global_entities.py +++ b/marl_factory_grid/environment/groups/global_entities.py @@ -1,5 +1,6 @@ from collections import defaultdict from operator import itemgetter +from random import shuffle from typing import Dict from marl_factory_grid.environment.groups.objects import Objects @@ -13,7 +14,7 @@ class Entities(Objects): def neighboring_positions(pos): return (POS_MASK + pos).reshape(-1, 2) - def get_near_pos(self, pos): + def get_entities_near_pos(self, pos): return [y for x in itemgetter(*(tuple(x) for x in self.neighboring_positions(pos)))(self.pos_dict) for y in x] def render(self): @@ -38,11 +39,17 @@ class Entities(Objects): def guests_that_can_collide(self, pos): return[x for val in self.pos_dict[pos] for x in val if x.var_can_collide] - def empty_tiles(self): - return[key for key in self.floorlist if not any(self.pos_dict[key])] + @property + def empty_positions(self): + empty_positions= [key for key in self.floorlist if self.pos_dict[key]] + shuffle(empty_positions) + return empty_positions - def occupied_tiles(self): # positions that are not empty - return[key for key in self.floorlist if any(self.pos_dict[key])] + @property + def occupied_positions(self): # positions that are not empty + empty_positions = [key for key in self.floorlist if self.pos_dict[key]] + shuffle(empty_positions) + return empty_positions def is_blocked(self): return[key for key, val in self.pos_dict.items() if any([x.var_is_blocking_pos for x in val])] diff --git a/marl_factory_grid/environment/groups/mixins.py b/marl_factory_grid/environment/groups/mixins.py index 34bab16..7171d43 100644 --- a/marl_factory_grid/environment/groups/mixins.py +++ b/marl_factory_grid/environment/groups/mixins.py @@ -37,7 +37,11 @@ class PositionMixin: def __delitem__(self, name): idx, obj = next((i, obj) for i, obj in enumerate(self) if obj.name == name) - obj.tile.leave(obj) # observer notify? + try: + for observer in obj.observers: + observer.notify_del_entity(obj) + except AttributeError: + pass super().__delitem__(name) def by_pos(self, pos: (int, int)): diff --git a/marl_factory_grid/environment/groups/objects.py b/marl_factory_grid/environment/groups/objects.py index 57e0106..0e3a49e 100644 --- a/marl_factory_grid/environment/groups/objects.py +++ b/marl_factory_grid/environment/groups/objects.py @@ -103,6 +103,9 @@ class Objects: except StopIteration: return None + def by_name(self, name): + return next(x for x in self if x.name == name) + def __getitem__(self, item): if isinstance(item, (int, np.int64, np.int32)): if item < 0: @@ -120,7 +123,7 @@ class Objects: raise TypeError def __repr__(self): - repr_dict = { key: val for key, val in self._data.items() if key not in [c.WALLS, c.FLOORS]} + repr_dict = { key: val for key, val in self._data.items() if key not in [c.WALLS]} return f'{self.__class__.__name__}[{repr_dict}]' def spawn(self, n: int): @@ -132,22 +135,25 @@ class Objects: for item in items: del self[item] - def notify_change_pos(self, entity: object): - try: - self.pos_dict[entity.last_pos].remove(entity) - except (ValueError, AttributeError): - pass - if entity.var_has_position: - try: - self.pos_dict[entity.pos].append(entity) - except (ValueError, AttributeError): - pass + # def notify_change_pos(self, entity: object): + # try: + # self.pos_dict[entity.last_pos].remove(entity) + # except (ValueError, AttributeError): + # pass + # if entity.var_has_position: + # try: + # self.pos_dict[entity.pos].append(entity) + # except (ValueError, AttributeError): + # pass def notify_del_entity(self, entity: Object): try: entity.del_observer(self) + except AttributeError: + pass + try: self.pos_dict[entity.pos].remove(entity) - except (ValueError, AttributeError): + except (AttributeError, ValueError, IndexError): pass def notify_add_entity(self, entity: Object): diff --git a/marl_factory_grid/environment/groups/wall_n_floors.py b/marl_factory_grid/environment/groups/walls.py similarity index 96% rename from marl_factory_grid/environment/groups/wall_n_floors.py rename to marl_factory_grid/environment/groups/walls.py index ffeb033..7fcd939 100644 --- a/marl_factory_grid/environment/groups/wall_n_floors.py +++ b/marl_factory_grid/environment/groups/walls.py @@ -15,6 +15,7 @@ class Walls(PositionMixin, EnvObjects): super(Walls, self).__init__(*args, **kwargs) self._value = c.VALUE_OCCUPIED_CELL + #ToDo: Do we need this? Move to spawn methode? # @classmethod # def from_coordinates(cls, argwhere_coordinates, *args, **kwargs): # tiles = cls(*args, **kwargs) diff --git a/marl_factory_grid/environment/rules.py b/marl_factory_grid/environment/rules.py index efdd69e..1fadc6a 100644 --- a/marl_factory_grid/environment/rules.py +++ b/marl_factory_grid/environment/rules.py @@ -49,7 +49,7 @@ class SpawnAgents(Rule): agent_conf = state.agents_conf # agents = Agents(lvl_map.size) agents = state[c.AGENT] - empty_tiles = state[c.FLOORS].empty_tiles[:len(agent_conf)] + empty_positions = state.entities.empty_positions[:len(agent_conf)] for agent_name in agent_conf: actions = agent_conf[agent_name]['actions'].copy() observations = agent_conf[agent_name]['observations'].copy() @@ -58,18 +58,17 @@ class SpawnAgents(Rule): shuffle(positions) while True: try: - tile = state[c.FLOORS].by_pos(positions.pop()) + pos = positions.pop() except IndexError as e: raise ValueError(f'It was not possible to spawn an Agent on the available position: ' f'\n{agent_name[agent_name]["positions"].copy()}') - try: - agents.add_item(Agent(actions, observations, tile, str_ident=agent_name)) - except AssertionError: - state.print(f'No valid pos:{tile.pos} for {agent_name}') + if agents.by_pos(pos) and state.check_pos_validity(pos): continue + else: + agents.add_item(Agent(actions, observations, pos, str_ident=agent_name)) break else: - agents.add_item(Agent(actions, observations, empty_tiles.pop(), str_ident=agent_name)) + agents.add_item(Agent(actions, observations, empty_positions.pop(), str_ident=agent_name)) pass diff --git a/marl_factory_grid/modules/batteries/entitites.py b/marl_factory_grid/modules/batteries/entitites.py index fd8f40a..0e8153e 100644 --- a/marl_factory_grid/modules/batteries/entitites.py +++ b/marl_factory_grid/modules/batteries/entitites.py @@ -2,7 +2,7 @@ from marl_factory_grid.environment.entity.mixin import BoundEntityMixin from marl_factory_grid.environment.entity.object import EnvObject from marl_factory_grid.environment.entity.entity import Entity from marl_factory_grid.environment import constants as c -from marl_factory_grid.utils.render import RenderEntity +from marl_factory_grid.utils.utility_classes import RenderEntity from marl_factory_grid.modules.batteries import constants as b diff --git a/marl_factory_grid/modules/batteries/rules.py b/marl_factory_grid/modules/batteries/rules.py index 27f5f59..d81dae6 100644 --- a/marl_factory_grid/modules/batteries/rules.py +++ b/marl_factory_grid/modules/batteries/rules.py @@ -70,8 +70,8 @@ class PodRules(Rule): def on_init(self, state, lvl_map): pod_collection = state[b.CHARGE_PODS] - empty_tiles = state[c.FLOORS].empty_tiles[:self.n_pods] - pods = pod_collection.from_coordinates(empty_tiles, entity_kwargs=dict( + empty_positions = state.entities.empty_positions() + pods = pod_collection.from_coordinates(empty_positions, entity_kwargs=dict( multi_charge=self.multi_charge, charge_rate=self.charge_rate) ) pod_collection.add_items(pods) diff --git a/marl_factory_grid/modules/clean_up/entitites.py b/marl_factory_grid/modules/clean_up/entitites.py index 53d84e8..9d5e4cf 100644 --- a/marl_factory_grid/modules/clean_up/entitites.py +++ b/marl_factory_grid/modules/clean_up/entitites.py @@ -1,7 +1,7 @@ from numpy import random from marl_factory_grid.environment.entity.entity import Entity -from marl_factory_grid.utils.render import RenderEntity +from marl_factory_grid.utils.utility_classes import RenderEntity from marl_factory_grid.modules.clean_up import constants as d diff --git a/marl_factory_grid/modules/clean_up/groups.py b/marl_factory_grid/modules/clean_up/groups.py index ad24bd7..0ee6893 100644 --- a/marl_factory_grid/modules/clean_up/groups.py +++ b/marl_factory_grid/modules/clean_up/groups.py @@ -1,6 +1,5 @@ from marl_factory_grid.environment.groups.env_objects import EnvObjects from marl_factory_grid.environment.groups.mixins import PositionMixin -from marl_factory_grid.environment.entity.wall_floor import Floor from marl_factory_grid.modules.clean_up.entitites import DirtPile from marl_factory_grid.environment import constants as c @@ -31,8 +30,6 @@ class DirtPiles(PositionMixin, EnvObjects): self.max_local_amount = max_local_amount def spawn(self, then_dirty_positions, amount) -> bool: - # if isinstance(then_dirty_tiles, Floor): - # then_dirty_tiles = [then_dirty_tiles] for pos in then_dirty_positions: if not self.amount > self.max_global_amount: if dirt := self.by_pos(pos): @@ -56,8 +53,8 @@ class DirtPiles(PositionMixin, EnvObjects): var = self.dirt_spawn_r_var new_spawn = abs(self.initial_dirt_ratio + (state.rng.uniform(-var, var) if initial_spawn else 0)) - n_dirt_tiles = max(0, int(new_spawn * len(free_for_dirt))) - return self.spawn(free_for_dirt[:n_dirt_tiles], self.initial_amount) + n_dirty_positions = max(0, int(new_spawn * len(free_for_dirt))) + return self.spawn(free_for_dirt[:n_dirty_positions], self.initial_amount) def __repr__(self): s = super(DirtPiles, self).__repr__() diff --git a/marl_factory_grid/modules/destinations/entitites.py b/marl_factory_grid/modules/destinations/entitites.py index 99cd2c4..42669fd 100644 --- a/marl_factory_grid/modules/destinations/entitites.py +++ b/marl_factory_grid/modules/destinations/entitites.py @@ -4,7 +4,7 @@ from marl_factory_grid.environment.entity.agent import Agent from marl_factory_grid.environment.entity.entity import Entity from marl_factory_grid.environment import constants as c from marl_factory_grid.environment.entity.mixin import BoundEntityMixin -from marl_factory_grid.utils.render import RenderEntity +from marl_factory_grid.utils.utility_classes import RenderEntity from marl_factory_grid.modules.destinations import constants as d @@ -17,7 +17,6 @@ class Destination(BoundEntityMixin, Entity): var_is_blocking_light = False var_can_be_bound = True # Introduce this globally! - @property def was_reached(self): return self._was_reached @@ -35,11 +34,10 @@ class Destination(BoundEntityMixin, Entity): self._per_agent_actions[agent.name] += 1 return c.VALID - @property - def has_just_been_reached(self): - if self.was_reached: + def has_just_been_reached(self, state): + if self.was_reached(): return False - agent_at_position = any(c.AGENT.lower() in x.name.lower() for x in state.entities.pos_dict[self.pos] if x.var_can_collide) + agent_at_position = any(state[c.AGENT].by_pos(self.pos)) if self.bound_entity: return ((agent_at_position and not self.action_counts) @@ -57,7 +55,7 @@ class Destination(BoundEntityMixin, Entity): return state_summary def render(self): - if self.was_reached: + if self.was_reached(): return None else: return RenderEntity(d.DESTINATION, self.pos) diff --git a/marl_factory_grid/modules/destinations/rules.py b/marl_factory_grid/modules/destinations/rules.py index fd77da1..8773f2d 100644 --- a/marl_factory_grid/modules/destinations/rules.py +++ b/marl_factory_grid/modules/destinations/rules.py @@ -16,28 +16,29 @@ class DestinationReachAll(Rule): def tick_step(self, state) -> List[TickResult]: results = [] + reached = False for dest in state[d.DESTINATION]: - if dest.has_just_been_reached and not dest.was_reached: - # Dest has just been reached, some agent needs to stand here, grab any first. + if dest.has_just_been_reached(state) and not dest.was_reached(): + # Dest has just been reached, some agent needs to stand here for agent in state[c.AGENT].by_pos(dest.pos): if dest.bound_entity: if dest.bound_entity == agent: - results.append(TickResult(self.name, validity=c.VALID, reward=r.DEST_REACHED, entity=agent)) + reached = True else: pass else: - results.append(TickResult(self.name, validity=c.VALID, reward=r.DEST_REACHED, entity=agent)) - state.print(f'{dest.name} is reached now, mark as reached...') - dest.mark_as_reached() + reached = True else: pass + if reached: + state.print(f'{dest.name} is reached now, mark as reached...') + dest.mark_as_reached() + results.append(TickResult(self.name, validity=c.VALID, reward=r.DEST_REACHED, entity=agent)) return results - def tick_post_step(self, state) -> List[TickResult]: - return [] def on_check_done(self, state) -> List[DoneResult]: - if all(x.was_reached for x in state[d.DESTINATION]): + if all(x.was_reached() for x in state[d.DESTINATION]): return [DoneResult(self.name, validity=c.VALID, reward=r.DEST_REACHED)] return [DoneResult(self.name, validity=c.NOT_VALID, reward=0)] @@ -48,7 +49,7 @@ class DestinationReachAny(DestinationReachAll): super(DestinationReachAny, self).__init__() def on_check_done(self, state) -> List[DoneResult]: - if any(x.was_reached for x in state[d.DESTINATION]): + if any(x.was_reached() for x in state[d.DESTINATION]): return [DoneResult(self.name, validity=c.VALID, reward=r.DEST_REACHED)] return [] @@ -63,7 +64,7 @@ class DestinationSpawn(Rule): def on_init(self, state, lvl_map): # noinspection PyAttributeOutsideInit - self.trigger_destination_spawn(self.n_dests, state) + state[d.DESTINATION].trigger_destination_spawn(self.n_dests, state) pass def tick_pre_step(self, state) -> List[TickResult]: @@ -72,24 +73,14 @@ class DestinationSpawn(Rule): def tick_step(self, state) -> List[TickResult]: if n_dest_spawn := max(0, self.n_dests - len(state[d.DESTINATION])): if self.spawn_mode == d.MODE_GROUPED and n_dest_spawn == self.n_dests: - validity = self.trigger_destination_spawn(n_dest_spawn, state) + validity = state[d.DESTINATION].trigger_destination_spawn(n_dest_spawn, state) return [TickResult(self.name, validity=validity, entity=None, value=n_dest_spawn)] elif self.spawn_mode == d.MODE_SINGLE and n_dest_spawn: - validity = self.trigger_destination_spawn(n_dest_spawn, state) + validity = state[d.DESTINATION].trigger_destination_spawn(n_dest_spawn, state) return [TickResult(self.name, validity=validity, entity=None, value=n_dest_spawn)] else: pass - def trigger_destination_spawn(self, n_dests, state): - empty_positions = state[c.FLOORS].empty_tiles[:n_dests] - if destinations := [Destination(pos) for pos in empty_positions]: - state[d.DESTINATION].add_items(destinations) - state.print(f'{n_dests} new destinations have been spawned') - return c.VALID - else: - state.print('No Destiantions are spawning, limit is reached.') - return c.NOT_VALID - class FixedDestinationSpawn(Rule): def __init__(self, per_agent_positions: Dict[str, List[Tuple[int, int]]]): @@ -99,11 +90,17 @@ class FixedDestinationSpawn(Rule): def on_init(self, state, lvl_map): for (agent_name, position_list) in self.per_agent_positions.items(): agent = next(x for x in state[c.AGENT] if agent_name in x.name) # Fixme: Ugly AF + position_list = position_list.copy() shuffle(position_list) while True: - pos = position_list.pop() - if pos != agent.pos and not state[d.DESTINATION].by_pos(pos): - destination = Destination(state[c.FLOORS].by_pos(pos), bind_to=agent) + try: + pos = position_list.pop() + except IndexError: + print(f"Could not spawn Destinations at: {self.per_agent_positions[agent_name]}") + print(f'Check your agent palcement: {state[c.AGENT]} ... Exit ...') + exit(9999) + if (not pos == agent.pos) and (not state[d.DESTINATION].by_pos(pos)): + destination = Destination(pos, bind_to=agent) break else: continue diff --git a/marl_factory_grid/modules/doors/__init__.py b/marl_factory_grid/modules/doors/__init__.py index e5dc1cf..4f1d0a2 100644 --- a/marl_factory_grid/modules/doors/__init__.py +++ b/marl_factory_grid/modules/doors/__init__.py @@ -1,4 +1,4 @@ from .actions import DoorUse from .entitites import Door, DoorIndicator from .groups import Doors -from .rule_door_auto_close import DoorAutoClose +from .rules import DoorAutoClose, DoorIndicateArea diff --git a/marl_factory_grid/modules/doors/actions.py b/marl_factory_grid/modules/doors/actions.py index 31a5bf8..c7d06ed 100644 --- a/marl_factory_grid/modules/doors/actions.py +++ b/marl_factory_grid/modules/doors/actions.py @@ -13,8 +13,9 @@ class DoorUse(Action): def do(self, entity, state) -> Union[None, ActionResult]: # Check if agent really is standing on a door: - e = state.entities.get_near_pos(entity.pos) + e = state.entities.get_entities_near_pos(entity.pos) try: + # Only one door opens TODO introcude loop door = next(x for x in e if x.name.startswith(d.DOOR)) valid = door.use() state.print(f'{entity.name} just used a {door.name} at {door.pos}') diff --git a/marl_factory_grid/modules/doors/entitites.py b/marl_factory_grid/modules/doors/entitites.py index 6f386ca..0eea655 100644 --- a/marl_factory_grid/modules/doors/entitites.py +++ b/marl_factory_grid/modules/doors/entitites.py @@ -1,5 +1,5 @@ from marl_factory_grid.environment.entity.entity import Entity -from marl_factory_grid.utils.render import RenderEntity +from marl_factory_grid.utils.utility_classes import RenderEntity from marl_factory_grid.environment import constants as c from marl_factory_grid.modules.doors import constants as d @@ -41,7 +41,7 @@ class Door(Entity): def str_state(self): return 'open' if self.is_open else 'closed' - def __init__(self, *args, closed_on_init=True, auto_close_interval=10, indicate_area=False, **kwargs): + def __init__(self, *args, closed_on_init=True, auto_close_interval=10, **kwargs): self._status = d.STATE_CLOSED super(Door, self).__init__(*args, **kwargs) self.auto_close_interval = auto_close_interval @@ -50,8 +50,6 @@ class Door(Entity): self._open() else: self._close() - if indicate_area: - self._collection.add_items([DoorIndicator(x) for x in self.tile.neighboring_floor]) def summarize_state(self): state_dict = super().summarize_state() diff --git a/marl_factory_grid/modules/doors/rule_door_auto_close.py b/marl_factory_grid/modules/doors/rules.py similarity index 71% rename from marl_factory_grid/modules/doors/rule_door_auto_close.py rename to marl_factory_grid/modules/doors/rules.py index 8792bf6..282cb57 100644 --- a/marl_factory_grid/modules/doors/rule_door_auto_close.py +++ b/marl_factory_grid/modules/doors/rules.py @@ -1,7 +1,8 @@ from marl_factory_grid.environment.rules import Rule from marl_factory_grid.environment import constants as c from marl_factory_grid.utils.results import TickResult -from marl_factory_grid.modules.doors import constants as d +from . import constants as d +from .entitites import DoorIndicator class DoorAutoClose(Rule): @@ -19,3 +20,13 @@ class DoorAutoClose(Rule): return [TickResult(self.name, validity=c.VALID, value=0)] state.print('There are no doors, but you loaded the corresponding Module') return [] + + +class DoorIndicateArea(Rule): + + def __init__(self): + super().__init__() + + def on_init(self, state, lvl_map): + for door in state[d.DOORS]: + state[d.DOORS].add_items([DoorIndicator(x) for x in state.entities.neighboring_positions(door.pos)]) diff --git a/marl_factory_grid/modules/factory/rules.py b/marl_factory_grid/modules/factory/rules.py index 38068b0..d736f7a 100644 --- a/marl_factory_grid/modules/factory/rules.py +++ b/marl_factory_grid/modules/factory/rules.py @@ -9,6 +9,8 @@ from marl_factory_grid.utils.results import TickResult class AgentSingleZonePlacementBeta(Rule): def __init__(self): + raise NotImplementedError() + # TODO!!!! Is this concept needed any more? super().__init__() def on_init(self, state, lvl_map): @@ -21,9 +23,9 @@ class AgentSingleZonePlacementBeta(Rule): coordinates = random.choices(self.coordinates, k=len(agents)) else: raise ValueError - tiles = [state[c.FLOORS].by_pos(pos) for pos in coordinates] - for agent, tile in zip(agents, tiles): - agent.move(tile, state) + + for agent, pos in zip(agents, coordinates): + agent.move(pos, state) def tick_step(self, state): return [] diff --git a/marl_factory_grid/modules/items/entitites.py b/marl_factory_grid/modules/items/entitites.py index 4372d3a..0e5def5 100644 --- a/marl_factory_grid/modules/items/entitites.py +++ b/marl_factory_grid/modules/items/entitites.py @@ -2,7 +2,7 @@ from collections import deque from marl_factory_grid.environment.entity.entity import Entity from marl_factory_grid.environment import constants as c -from marl_factory_grid.utils.render import RenderEntity +from marl_factory_grid.utils.utility_classes import RenderEntity from marl_factory_grid.modules.items import constants as i diff --git a/marl_factory_grid/modules/items/groups.py b/marl_factory_grid/modules/items/groups.py index 6f5adc0..295bc08 100644 --- a/marl_factory_grid/modules/items/groups.py +++ b/marl_factory_grid/modules/items/groups.py @@ -1,3 +1,5 @@ +from random import shuffle + from marl_factory_grid.modules.items import constants as i from marl_factory_grid.environment import constants as c @@ -19,10 +21,12 @@ class Items(PositionMixin, EnvObjects): @staticmethod def trigger_item_spawn(state, n_items, spawn_frequency): if item_to_spawns := max(0, (n_items - len(state[i.ITEM]))): - floor_list = state.entities.floorlist[:item_to_spawns] - state[i.ITEM].spawn(floor_list) - state.print(f'{item_to_spawns} new items have been spawned; next spawn in {spawn_frequency}') # spawn in self._next_item_spawn ? - return len(floor_list) + position_list = [x for x in state.entities.floorlist] + shuffle(position_list) + position_list = state.entities.floorlist[:item_to_spawns] + state[i.ITEM].spawn(position_list) + state.print(f'{item_to_spawns} new items have been spawned; next spawn in {spawn_frequency}') + return len(position_list) else: state.print('No Items are spawning, limit is reached.') return 0 @@ -100,7 +104,7 @@ class DropOffLocations(PositionMixin, EnvObjects): @staticmethod def trigger_drop_off_location_spawn(state, n_locations): - empty_tiles = state.entities.floorlist[:n_locations] + empty_positions = state.entities.empty_positions()[:n_locations] do_entites = state[i.DROP_OFF] - drop_offs = [DropOffLocation(tile) for tile in empty_tiles] + drop_offs = [DropOffLocation(pos) for pos in empty_positions] do_entites.add_items(drop_offs) diff --git a/marl_factory_grid/modules/machines/entitites.py b/marl_factory_grid/modules/machines/entitites.py index e0585a9..36a87cc 100644 --- a/marl_factory_grid/modules/machines/entitites.py +++ b/marl_factory_grid/modules/machines/entitites.py @@ -1,5 +1,5 @@ from marl_factory_grid.environment.entity.entity import Entity -from marl_factory_grid.utils.render import RenderEntity +from ...utils.utility_classes import RenderEntity from marl_factory_grid.environment import constants as c from marl_factory_grid.utils.results import TickResult diff --git a/marl_factory_grid/modules/machines/rules.py b/marl_factory_grid/modules/machines/rules.py index e9402c7..18709dc 100644 --- a/marl_factory_grid/modules/machines/rules.py +++ b/marl_factory_grid/modules/machines/rules.py @@ -13,8 +13,8 @@ class MachineRule(Rule): self.n_machines = n_machines def on_init(self, state, lvl_map): - empty_tiles = state[c.FLOORS].empty_tiles[:self.n_machines] - state[m.MACHINES].add_items(Machine(tile) for tile in empty_tiles) + # TODO Move to spawn!!! + state[m.MACHINES].add_items(Machine(pos) for pos in state.entities.empty_positions()) def tick_pre_step(self, state) -> List[TickResult]: pass diff --git a/marl_factory_grid/modules/maintenance/entities.py b/marl_factory_grid/modules/maintenance/entities.py index 45af947..3582197 100644 --- a/marl_factory_grid/modules/maintenance/entities.py +++ b/marl_factory_grid/modules/maintenance/entities.py @@ -8,7 +8,7 @@ from ...environment.entity.entity import Entity from ..doors import constants as do from ..maintenance import constants as mi from ...utils.helpers import MOVEMAP -from ...utils.render import RenderEntity +from ...utils.utility_classes import RenderEntity from ...utils.states import Gamestate @@ -39,7 +39,7 @@ class Maintainer(Entity): self._next = [] self._last = [] self._last_serviced = 'None' - self._floortile_graph = points_to_graph(state[c.FLOORS].positions) + self._floortile_graph = points_to_graph(state.entities.floorlist) def tick(self, state): if found_objective := state[self.objective].by_pos(self.pos): diff --git a/marl_factory_grid/modules/maintenance/rules.py b/marl_factory_grid/modules/maintenance/rules.py index e82673e..bf73d67 100644 --- a/marl_factory_grid/modules/maintenance/rules.py +++ b/marl_factory_grid/modules/maintenance/rules.py @@ -14,7 +14,8 @@ class MaintenanceRule(Rule): self.n_maintainer = n_maintainer def on_init(self, state: Gamestate, lvl_map): - state[M.MAINTAINERS].spawn(state[c.FLOORS].empty_tiles[:self.n_maintainer], state) + # Move to spawn? : #TODO + state[M.MAINTAINERS].spawn(state.entities.empty_positions[:self.n_maintainer], state) pass def tick_pre_step(self, state) -> List[TickResult]: diff --git a/marl_factory_grid/modules/zones/entitites.py b/marl_factory_grid/modules/zones/entitites.py index c6632f6..f4923b7 100644 --- a/marl_factory_grid/modules/zones/entitites.py +++ b/marl_factory_grid/modules/zones/entitites.py @@ -3,8 +3,7 @@ from typing import List, Tuple from marl_factory_grid.environment.entity.entity import Entity from marl_factory_grid.environment.entity.object import Object -from marl_factory_grid.environment.entity.wall_floor import Floor -from marl_factory_grid.utils.render import RenderEntity +from marl_factory_grid.utils.utility_classes import RenderEntity from marl_factory_grid.environment import constants as c from marl_factory_grid.modules.doors import constants as d @@ -21,5 +20,5 @@ class Zone(Object): self.coords = coords @property - def random_tile(self): + def random_pos(self): return random.choice(self.coords) diff --git a/marl_factory_grid/modules/zones/rules.py b/marl_factory_grid/modules/zones/rules.py index 16f6085..2969186 100644 --- a/marl_factory_grid/modules/zones/rules.py +++ b/marl_factory_grid/modules/zones/rules.py @@ -19,7 +19,7 @@ class ZoneInit(Rule): while z_idx: zone_positions = lvl_map.get_coordinates_for_symbol(z_idx) if len(zone_positions): - zones.append(Zone([state[c.FLOORS].by_pos(pos) for pos in zone_positions])) + zones.append(Zone(zone_positions)) z_idx += 1 else: z_idx = 0 @@ -38,7 +38,7 @@ class AgentSingleZonePlacement(Rule): z_idxs = choices(list(range(len(state[z.ZONES]))), k=n_agents) for agent in state[c.AGENT]: - agent.move(state[z.ZONES][z_idxs.pop()].random_tile, state) + agent.move(state[z.ZONES][z_idxs.pop()].random_pos, state) return [] def tick_step(self, state): @@ -65,10 +65,10 @@ class IndividualDestinationZonePlacement(Rule): other_zones = [x for x in state[z.ZONES] if x not in agent_zones] already_has_destination = True while already_has_destination: - tile = choice(other_zones).random_tile - if state[d.DESTINATION].by_pos(tile.pos) is None: + pos = choice(other_zones).random_pos + if state[d.DESTINATION].by_pos(pos) is None: already_has_destination = False - destination = Destination(tile, bind_to=agent) + destination = Destination(pos, bind_to=agent) state[d.DESTINATION].add_item(destination) continue diff --git a/marl_factory_grid/utils/helpers.py b/marl_factory_grid/utils/helpers.py index ca3a20c..15ced7b 100644 --- a/marl_factory_grid/utils/helpers.py +++ b/marl_factory_grid/utils/helpers.py @@ -25,10 +25,8 @@ This file is used for: LEVELS_DIR = 'modules/levels' # for use in studies and experiments STEPS_START = 1 # Define where to the stepcount; which is the first step -# Not used anymore? Clean! -# TO_BE_AVERAGED = ['dirt_amount', 'dirty_tiles'] IGNORED_DF_COLUMNS = ['Episode', 'Run', # For plotting, which values are ignored when loading monitor files - 'train_step', 'step', 'index', 'dirt_amount', 'dirty_tile_count', 'terminal_observation', + 'train_step', 'step', 'index', 'dirt_amount', 'dirty_pos_count', 'terminal_observation', 'episode'] POS_MASK = np.asarray([[[-1, -1], [0, -1], [1, -1]], @@ -223,7 +221,7 @@ def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''): module_parts = [x.replace('.py', '') for idx, x in enumerate(module_path.parts) if idx >= package_pos] mod = importlib.import_module('.'.join(module_parts)) all_found_modules.extend([x for x in dir(mod) if (not(x.startswith('__') or len(x) <= 2) and x.istitle()) - and x not in ['Entity', 'NamedTuple', 'List', 'Rule', 'Union', 'Floor' + and x not in ['Entity', 'NamedTuple', 'List', 'Rule', 'Union', 'TickResult', 'ActionResult', 'Action', 'Agent', 'BoundEntityMixin', 'RenderEntity', 'TemplateRule', 'Objects', 'PositionMixin', 'IsBoundMixin', 'EnvObject', 'EnvObjects', 'Dict', 'Any' diff --git a/marl_factory_grid/utils/level_parser.py b/marl_factory_grid/utils/level_parser.py index 0acc31d..fc8b948 100644 --- a/marl_factory_grid/utils/level_parser.py +++ b/marl_factory_grid/utils/level_parser.py @@ -6,7 +6,7 @@ import numpy as np from marl_factory_grid.environment.groups.agents import Agents from marl_factory_grid.environment.groups.global_entities import Entities -from marl_factory_grid.environment.groups.wall_n_floors import Walls, Floors +from marl_factory_grid.environment.groups.walls import Walls from marl_factory_grid.utils import helpers as h from marl_factory_grid.environment import constants as c @@ -34,16 +34,14 @@ class LevelParser(object): def do_init(self): # Global Entities - list_of_all_floors = ([tuple(floor) for floor in self.get_coordinates_for_symbol(c.SYMBOL_WALL, negate=True)]) - entities = Entities(list_of_all_floors) + list_of_all_positions = ([tuple(f) for f in self.get_coordinates_for_symbol(c.SYMBOL_WALL, negate=True)]) + entities = Entities(list_of_all_positions) # Walls walls = Walls.from_coordinates(self.get_coordinates_for_symbol(c.SYMBOL_WALL), self.size) entities.add_items({c.WALLS: walls}) - # Floor - floor = Floors.from_coordinates(list_of_all_floors, self.size) - entities.add_items({c.FLOOR: floor}) + # Agents entities.add_items({c.AGENT: Agents(self.size)}) # All other diff --git a/marl_factory_grid/utils/observation_builder.py b/marl_factory_grid/utils/observation_builder.py index 4ae4fe8..1377a92 100644 --- a/marl_factory_grid/utils/observation_builder.py +++ b/marl_factory_grid/utils/observation_builder.py @@ -9,6 +9,7 @@ from numba import njit from marl_factory_grid.environment import constants as c from marl_factory_grid.environment.groups.utils import Combined from marl_factory_grid.utils.states import Gamestate +from marl_factory_grid.utils.utility_classes import Floor class OBSBuilder(object): @@ -39,6 +40,7 @@ class OBSBuilder(object): self.reset_struc_obs_block(state) self.curr_lightmaps = dict() + self._floortiles = defaultdict(list, {pos: [Floor(*pos)] for pos in state.entities.floorlist}) def reset_struc_obs_block(self, state): self._curr_env_step = state.curr_step @@ -82,19 +84,23 @@ class OBSBuilder(object): self._sort_and_name_observation_conf(agent) agent_want_obs = self.obs_layers[agent.name] - # Handle in-grid observations aka visible observations - visible_entitites = self.ray_caster[agent.name].visible_entities(state.entities) - pre_sort_obs = defaultdict(lambda: np.zeros((self.pomdp_d, self.pomdp_d))) - for e in set(visible_entitites): - x, y = (e.x - agent.x) + self.pomdp_r, (e.y - agent.y) + self.pomdp_r - try: - pre_sort_obs[e.obs_tag][x, y] += e.encoding - except IndexError: - # Seemded to be visible but is out or range - pass + # Handle in-grid observations aka visible observations (Things on the map, with pos) + visible_entitites = self.ray_caster[agent.name].visible_entities(state.entities.pos_dict) + pre_sort_obs = defaultdict(lambda: np.zeros(self.obs_shape)) + if self.pomdp_r: + for e in set(visible_entitites): + x, y = (e.x - agent.x) + self.pomdp_r, (e.y - agent.y) + self.pomdp_r + try: + pre_sort_obs[e.obs_tag][x, y] += e.encoding + except IndexError: + # Seemded to be visible but is out or range + pass + else: + for e in set(visible_entitites): + pre_sort_obs[e.obs_tag][e.x, e.y] += e.encoding pre_sort_obs = dict(pre_sort_obs) - obs = np.zeros((len(agent_want_obs), self.pomdp_d, self.pomdp_d)) + obs = np.zeros((len(agent_want_obs), self.obs_shape[0], self.obs_shape[1])) for idx, l_name in enumerate(agent_want_obs): try: @@ -144,13 +150,26 @@ class OBSBuilder(object): raise ValueError(f'Max(obs.size) for {e.name}: {obs[idx].size}, but was: {len(v)}.') try: - self.curr_lightmaps[agent.name] = pre_sort_obs[c.FLOORS].astype(bool) + light_map = np.zeros(self.obs_shape) + visible_floor = set(self.ray_caster[agent.name].visible_entities(self._floortiles, reset_cache=False)) + if self.pomdp_r: + coords = [((f.x - agent.x) + self.pomdp_r, (f.y - agent.y) + self.pomdp_r) for f in visible_floor] + else: + coords = [x.pos for x in visible_floor] + np.put(light_map, np.ravel_multi_index(np.asarray(coords).T, light_map.shape), 1) + self.curr_lightmaps[agent.name] = light_map except KeyError: print() return obs, self.obs_layers[agent.name] def _sort_and_name_observation_conf(self, agent): - self.ray_caster[agent.name] = RayCaster(agent, self.pomdp_r) + ''' + Builds the useable observation scheme per agent from conf.yaml. + :param agent: + :return: + ''' + # Fixme: no asymetric shapes possible. + self.ray_caster[agent.name] = RayCaster(agent, min(self.obs_shape)) obs_layers = [] for obs_str in agent.observations: @@ -173,7 +192,7 @@ class OBSBuilder(object): names.extend([x.name for x in agent.collection if x.name != agent.name]) else: names.append(val) - combined = Combined(names, self.pomdp_r, identifier=agent.name) + combined = Combined(names, self.size, identifier=agent.name) self.all_obs[combined.name] = combined obs_layers.append(combined.name) elif obs_str == c.OTHERS: @@ -183,19 +202,18 @@ class OBSBuilder(object): else: obs_layers.append(obs_str) self.obs_layers[agent.name] = obs_layers - self.curr_lightmaps[agent.name] = np.zeros((self.pomdp_d or self.level_shape[0], - self.pomdp_d or self.level_shape[1] - )) + self.curr_lightmaps[agent.name] = np.zeros(self.obs_shape) class RayCaster: def __init__(self, agent, pomdp_r, degs=360): self.agent = agent self.pomdp_r = pomdp_r - self.n_rays = 100 # (self.pomdp_r + 1) * 8 + self.n_rays = (self.pomdp_r + 1) * 8 self.degs = degs self.ray_targets = self.build_ray_targets() self.obs_shape_cube = np.array([self.pomdp_r, self.pomdp_r]) + self._cache_dict = {} def __repr__(self): return f'{self.__class__.__name__}({self.agent.name})' @@ -211,30 +229,30 @@ class RayCaster: rot_M = np.unique(np.round(rot_M @ north), axis=0) return rot_M.astype(int) - def ray_block_cache(self, cache_dict, key, callback): - if key not in cache_dict: - cache_dict[key] = callback() - return cache_dict[key] + def ray_block_cache(self, key, callback): + if key not in self._cache_dict: + self._cache_dict[key] = callback() + return self._cache_dict[key] - def visible_entities(self, entities): + def visible_entities(self, pos_dict, reset_cache=True): visible = list() - cache_blocking = {} + if reset_cache: + self._cache_dict = {} for ray in self.get_rays(): rx, ry = ray[0] for x, y in ray: cx, cy = x - rx, y - ry - entities_hit = entities.pos_dict[(x, y)] - hits = self.ray_block_cache(cache_blocking, - (x, y), - lambda: any(True for e in entities_hit if e.var_is_blocking_light)) + entities_hit = pos_dict[(x, y)] + hits = self.ray_block_cache((x, y), + lambda: any(True for e in entities_hit if e.var_is_blocking_light) + ) diag_hits = all([ self.ray_block_cache( - cache_blocking, key, - lambda: all(False for e in entities.pos_dict[key] if not e.var_is_blocking_light)) + lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light) and bool(pos_dict[key])) for key in ((x, y-cy), (x-cx, y)) ]) if (cx != 0 and cy != 0) else False diff --git a/marl_factory_grid/utils/render.py b/marl_factory_grid/utils/render.py deleted file mode 100644 index efcaaa9..0000000 --- a/marl_factory_grid/utils/render.py +++ /dev/null @@ -1,16 +0,0 @@ -from dataclasses import dataclass -from typing import Any - -import numpy as np - - -@dataclass -class RenderEntity: - name: str - pos: np.array - value: float = 1 - value_operation: str = 'none' - state: str = None - id: int = 0 - aux: Any = None - real_name: str = 'none' diff --git a/marl_factory_grid/utils/renderer.py b/marl_factory_grid/utils/renderer.py index 38a8e22..db6a93f 100644 --- a/marl_factory_grid/utils/renderer.py +++ b/marl_factory_grid/utils/renderer.py @@ -9,7 +9,7 @@ import pygame from typing import Tuple, Union import time -from marl_factory_grid.utils.render import RenderEntity +from marl_factory_grid.utils.utility_classes import RenderEntity AGENT: str = 'agent' STATE_IDLE: str = 'idle' diff --git a/marl_factory_grid/utils/states.py b/marl_factory_grid/utils/states.py index 04e8dcd..83dbcf9 100644 --- a/marl_factory_grid/utils/states.py +++ b/marl_factory_grid/utils/states.py @@ -3,8 +3,6 @@ from typing import List, Dict, Tuple import numpy as np from marl_factory_grid.environment import constants as c -from marl_factory_grid.environment.entity.wall_floor import Floor -from marl_factory_grid.environment.groups.global_entities import Entities from marl_factory_grid.environment.rules import Rule from marl_factory_grid.utils.results import Result @@ -112,15 +110,10 @@ class Gamestate(object): results.extend(on_check_done_result) return results - # def get_all_tiles_with_collisions(self) -> List[Floor]: - # tiles = [self[c.FLOORS].by_pos(pos) for pos, e in self.entities.pos_dict.items() - # if sum([x.var_can_collide for x in e]) > 1] - # # tiles = [x for x in self[c.FLOOR] if len(x.guests_that_can_collide) > 1] - # return tiles def get_all_pos_with_collisions(self) -> List[Tuple[(int, int)]]: - positions = [pos for pos, e in self.entities.pos_dict.items() - if sum([x.var_can_collide for x in e]) > 1] + positions = [pos for pos, entity_list_for_position in self.entities.pos_dict.items() + if any([e.var_can_collide for e in entity_list_for_position])] return positions def check_move_validity(self, moving_entity, position): @@ -128,6 +121,14 @@ class Gamestate(object): # and not (guest.var_is_blocking_pos and self.is_occupied()): if moving_entity.pos != position and not any( entity.var_is_blocking_pos for entity in self.entities.pos_dict[position]) and not ( - moving_entity.var_is_blocking_pos and moving_entity.is_occupied()): + moving_entity.var_is_blocking_pos and self.entities.is_occupied(position)): return True - return False + else: + return False + + def check_pos_validity(self, position): + if not any(entity.var_is_blocking_pos for entity in self.entities.pos_dict[position]): + return True + else: + return False + diff --git a/marl_factory_grid/utils/tools.py b/marl_factory_grid/utils/tools.py index 7747c18..d2f9bd1 100644 --- a/marl_factory_grid/utils/tools.py +++ b/marl_factory_grid/utils/tools.py @@ -15,7 +15,7 @@ ENTITIES = 'Objects' OBSERVATIONS = 'Observations' RULES = 'Rule' ASSETS = 'Assets' -EXCLUDED = ['identifier', 'args', 'kwargs', 'Move', 'Floor', 'Agent', 'GlobalPositions', 'Walls', +EXCLUDED = ['identifier', 'args', 'kwargs', 'Move', 'Agent', 'GlobalPositions', 'Walls', 'TemplateRule', 'Entities', 'EnvObjects', 'Zones', ] diff --git a/marl_factory_grid/utils/utility_classes.py b/marl_factory_grid/utils/utility_classes.py index 8fee782..5574a81 100644 --- a/marl_factory_grid/utils/utility_classes.py +++ b/marl_factory_grid/utils/utility_classes.py @@ -1,4 +1,8 @@ +from dataclasses import dataclass +from typing import Any + import gymnasium as gym +import numpy as np class MarlFrameStack(gym.ObservationWrapper): @@ -10,3 +14,37 @@ class MarlFrameStack(gym.ObservationWrapper): if isinstance(self.env, gym.wrappers.FrameStack) and self.env.unwrapped.n_agents > 1: return observation[0:].swapaxes(0, 1) return observation + + +@dataclass +class RenderEntity: + name: str + pos: np.array + value: float = 1 + value_operation: str = 'none' + state: str = None + id: int = 0 + aux: Any = None + real_name: str = 'none' + + +@dataclass +class Floor: + + @property + def name(self): + return f"Floor({self.pos})" + + @property + def pos(self): + return self.x, self.y + + x: int + y: int + var_is_blocking_light: bool = False + + def __eq__(self, other): + return self.name == other.name + + def __hash__(self): + return hash(self.name)