diff --git a/marl_factory_grid/algorithms/marl/base_ac.py b/marl_factory_grid/algorithms/marl/base_ac.py index ef195b7..8b64262 100644 --- a/marl_factory_grid/algorithms/marl/base_ac.py +++ b/marl_factory_grid/algorithms/marl/base_ac.py @@ -174,7 +174,7 @@ class BaseActorCritic: hidden_critic=out.get(nms.HIDDEN_CRITIC, None) ) eps_rew += torch.tensor(reward) - results.append(eps_rew.tolist() + [np.sum(eps_rew).item()] + [episode]) + results.append(eps_rew.tolist() + [sum(eps_rew).item()] + [episode]) episode += 1 agent_columns = [f'agent#{i}' for i in range(self.cfg['environment']['n_agents'])] results = pd.DataFrame(results, columns=agent_columns + ['sum', 'episode']) diff --git a/marl_factory_grid/environment/actions.py b/marl_factory_grid/environment/actions.py index 606832c..94cbef1 100644 --- a/marl_factory_grid/environment/actions.py +++ b/marl_factory_grid/environment/actions.py @@ -18,6 +18,7 @@ class Action(abc.ABC): @abc.abstractmethod def do(self, entity, state) -> Union[None, ActionResult]: + print() return def __repr__(self): diff --git a/marl_factory_grid/environment/entity/object.py b/marl_factory_grid/environment/entity/object.py index e8c69da..1b5190d 100644 --- a/marl_factory_grid/environment/entity/object.py +++ b/marl_factory_grid/environment/entity/object.py @@ -41,7 +41,7 @@ class Object: def __init__(self, str_ident: Union[str, None] = None, **kwargs): self._bound_entity = None - self._observers = [] + self._observers = set() self._str_ident = str_ident self.u_int = self._identify_and_count_up() self._collection = None @@ -75,7 +75,7 @@ class Object: self._collection = collection def add_observer(self, observer): - self.observers.append(observer) + self.observers.add(observer) observer.notify_add_entity(self) def del_observer(self, observer): diff --git a/marl_factory_grid/environment/factory.py b/marl_factory_grid/environment/factory.py index 6662581..f37ce3a 100644 --- a/marl_factory_grid/environment/factory.py +++ b/marl_factory_grid/environment/factory.py @@ -69,23 +69,6 @@ class Factory(gym.Env): # expensive - don't use; unless required ! self._renderer = None - # reset env to initial state, preparing env for new episode. - # returns tuple where the first dict contains initial observation for each agent in the env - self.reset() - - def __getitem__(self, item): - return self.state.entities[item] - - def reset(self) -> (dict, dict): - if self.state is not None: - for entity_group in self.state.entities: - try: - entity_group[0].reset_uid() - except (AttributeError, TypeError): - pass - - self.state = None - # Init entities entities = self.map.do_init() @@ -101,7 +84,6 @@ class Factory(gym.Env): self.state = Gamestate(entities, parsed_agents_conf, env_rules, env_tests, self.map.level_shape, self.conf.env_seed, self.conf.verbose) - # All is set up, trigger entity init with variable pos # All is set up, trigger additional init (after agent entity spawn etc) self.state.rules.do_all_init(self.state, self.map) @@ -110,6 +92,17 @@ class Factory(gym.Env): # Build initial observations for all agents # noinspection PyAttributeOutsideInit self.obs_builder = OBSBuilder(self.map.level_shape, self.state, self.map.pomdp_r) + + def __getitem__(self, item): + return self.state.entities[item] + + def reset(self) -> (dict, dict): + self.state.entities.reset() + + # All is set up, trigger entity spawn with variable pos + self.state.rules.do_all_reset(self.state) + + # Build initial observations for all agents return self.obs_builder.refresh_and_build_for_all(self.state) def manual_step_init(self) -> List[Result]: diff --git a/marl_factory_grid/environment/groups/collection.py b/marl_factory_grid/environment/groups/collection.py index c0f0f6b..c7c9f4c 100644 --- a/marl_factory_grid/environment/groups/collection.py +++ b/marl_factory_grid/environment/groups/collection.py @@ -2,7 +2,6 @@ from typing import List, Tuple, Union, Dict from marl_factory_grid.environment.entity.entity import Entity from marl_factory_grid.environment.groups.objects import Objects -# noinspection PyProtectedMember from marl_factory_grid.environment.entity.object import Object import marl_factory_grid.environment.constants as c from marl_factory_grid.utils.results import Result diff --git a/marl_factory_grid/environment/groups/global_entities.py b/marl_factory_grid/environment/groups/global_entities.py index 37779f9..601ce4d 100644 --- a/marl_factory_grid/environment/groups/global_entities.py +++ b/marl_factory_grid/environment/groups/global_entities.py @@ -31,9 +31,12 @@ class Entities(Objects): def __init__(self, floor_positions): self._floor_positions = floor_positions - self.pos_dict = defaultdict(list) + self.pos_dict = None super().__init__() + def __repr__(self): + return f'{self.__class__.__name__}{[x for x in self]}' + def guests_that_can_collide(self, pos): return [x for val in self.pos_dict[pos] for x in val if x.var_can_collide] @@ -108,3 +111,12 @@ class Entities(Objects): def is_occupied(self, pos): return len([x for x in self.pos_dict[pos] if x.var_can_collide or x.var_is_blocking_pos]) >= 1 + + def reset(self): + self._observers = set(self) + self.pos_dict = defaultdict(list) + for entity_group in self: + entity_group.reset() + + if hasattr(entity_group, "var_has_position") and entity_group.var_has_position: + entity_group.add_observer(self) diff --git a/marl_factory_grid/environment/groups/objects.py b/marl_factory_grid/environment/groups/objects.py index 9229787..23e91ab 100644 --- a/marl_factory_grid/environment/groups/objects.py +++ b/marl_factory_grid/environment/groups/objects.py @@ -44,7 +44,7 @@ class Objects: def __init__(self, *args, **kwargs): self._data = defaultdict(lambda: None) - self._observers = [self] + self._observers = set(self) self.pos_dict = defaultdict(list) def __len__(self): @@ -59,6 +59,8 @@ class Objects: assert self._data[item.name] is None, f'{item.name} allready exists!!!' self._data.update({item.name: item}) item.set_collection(self) + if hasattr(self, "var_has_position") and self.var_has_position: + item.add_observer(self) for observer in self.observers: observer.notify_add_entity(item) return self @@ -82,10 +84,9 @@ class Objects: # noinspection PyUnresolvedReferences def add_observer(self, observer): - self.observers.append(observer) + self.observers.add(observer) for entity in self: - if observer not in entity.observers: - entity.add_observer(observer) + entity.add_observer(observer) def add_items(self, items: List[_entity]): for item in items: @@ -127,8 +128,7 @@ class Objects: raise TypeError def __repr__(self): - repr_dict = {key: val for key, val in self._data.items() if key not in [c.WALLS]} - return f'{self.__class__.__name__}[{repr_dict}]' + return f'{self.__class__.__name__}[{len(self)}]' def notify_del_entity(self, entity: Object): try: @@ -163,3 +163,9 @@ class Objects: return h.get_first_index(self, filter_by=lambda x: x.belongs_to_entity(entity)) except (StopIteration, AttributeError): return None + + def reset(self): + self._data = defaultdict(lambda: None) + self._observers = set(self) + self.pos_dict = defaultdict(list) + diff --git a/marl_factory_grid/environment/groups/walls.py b/marl_factory_grid/environment/groups/walls.py index 776bbca..c03f724 100644 --- a/marl_factory_grid/environment/groups/walls.py +++ b/marl_factory_grid/environment/groups/walls.py @@ -23,3 +23,7 @@ class Walls(Collection): return super().by_pos(pos)[0] except IndexError: return None + + def reset(self): + pass + diff --git a/marl_factory_grid/environment/rules.py b/marl_factory_grid/environment/rules.py index f5b6836..5884cac 100644 --- a/marl_factory_grid/environment/rules.py +++ b/marl_factory_grid/environment/rules.py @@ -23,7 +23,7 @@ class Rule(abc.ABC): def on_init(self, state, lvl_map): return [] - def on_reset(self): + def on_reset(self, state) -> List[TickResult]: return [] def tick_pre_step(self, state) -> List[TickResult]: @@ -55,7 +55,7 @@ class SpawnEntity(Rule): self.collection = collection self.ignore_blocking = ignore_blocking - def on_init(self, state, lvl_map) -> [TickResult]: + def on_reset(self, state) -> [TickResult]: results = self.collection.trigger_spawn(state, ignore_blocking=self.ignore_blocking) pos_str = f' on: {[x.pos for x in self.collection]}' if self.collection.var_has_position else '' state.print(f'Initial {self.collection.__class__.__name__} were spawned{pos_str}') @@ -68,8 +68,7 @@ class SpawnAgents(Rule): super().__init__() pass - def on_init(self, state, lvl_map): - # agents = Agents(lvl_map.size) + def on_reset(self, state): agents = state[c.AGENT] empty_positions = state.entities.empty_positions[:len(state.agents_conf)] for agent_name, agent_conf in state.agents_conf.items(): @@ -101,9 +100,6 @@ class DoneAtMaxStepsReached(Rule): super().__init__() self.max_steps = max_steps - def on_init(self, state, lvl_map): - pass - def on_check_done(self, state): if self.max_steps <= state.curr_step: return [DoneResult(validity=c.VALID, identifier=self.name)] @@ -115,7 +111,7 @@ class AssignGlobalPositions(Rule): def __init__(self): super().__init__() - def on_init(self, state, lvl_map): + def on_reset(self, state, lvl_map): from marl_factory_grid.environment.entity.util import GlobalPosition for agent in state[c.AGENT]: gp = GlobalPosition(agent, lvl_map.level_shape) diff --git a/marl_factory_grid/modules/batteries/rules.py b/marl_factory_grid/modules/batteries/rules.py index 8a4725b..7314c93 100644 --- a/marl_factory_grid/modules/batteries/rules.py +++ b/marl_factory_grid/modules/batteries/rules.py @@ -127,30 +127,3 @@ class DoneAtBatteryDischarge(BatteryDecharge): return [DoneResult(self.name, validity=c.VALID, reward=self.reward_discharge_done)] else: return [DoneResult(self.name, validity=c.NOT_VALID)] - - -class SpawnChargePods(Rule): - - def __init__(self, n_pods: int, charge_rate: float = 0.4, multi_charge: bool = False): - """ - Spawn Chargepods in accordance to the given parameters. - - :type n_pods: int - :param n_pods: How many charge pods are there? - :type charge_rate: float - :param charge_rate: How much juice does each use of the charge action top up? - :type multi_charge: bool - :param multi_charge: Whether multiple agents are able to charge at the same time. - """ - super().__init__() - self.multi_charge = multi_charge - self.charge_rate = charge_rate - self.n_pods = n_pods - - def on_init(self, state, lvl_map): - pod_collection = state[b.CHARGE_PODS] - empty_positions = state.entities.empty_positions - pods = pod_collection.from_coordinates(empty_positions, entity_kwargs=dict( - multi_charge=self.multi_charge, charge_rate=self.charge_rate) - ) - pod_collection.add_items(pods) diff --git a/marl_factory_grid/modules/clean_up/groups.py b/marl_factory_grid/modules/clean_up/groups.py index 7ae3247..83a9f02 100644 --- a/marl_factory_grid/modules/clean_up/groups.py +++ b/marl_factory_grid/modules/clean_up/groups.py @@ -34,7 +34,12 @@ class DirtPiles(Collection): self.coords_or_quantity = coords_or_quantity self.initial_amount = initial_amount - def trigger_spawn(self, state, coords_or_quantity=0, amount=0) -> [Result]: + def trigger_spawn(self, state, coords_or_quantity=0, amount=0, ignore_blocking=False) -> [Result]: + if ignore_blocking: + print("##########################################") + print("Blocking should not be ignored for this Entity") + print("Exiting....") + exit() coords_or_quantity = coords_or_quantity if coords_or_quantity else self.coords_or_quantity n_new = int(abs(coords_or_quantity + (state.rng.uniform(-self.n_var, self.n_var)))) n_new = state.get_n_random_free_positions(n_new) diff --git a/marl_factory_grid/modules/destinations/rules.py b/marl_factory_grid/modules/destinations/rules.py index 8e72141..ef004c3 100644 --- a/marl_factory_grid/modules/destinations/rules.py +++ b/marl_factory_grid/modules/destinations/rules.py @@ -106,7 +106,7 @@ class SpawnDestinationsPerAgent(Rule): super(Rule, self).__init__() self.per_agent_positions = {key: [ast.literal_eval(x) for x in val] for key, val in coords_or_quantity.items()} - def on_init(self, state, lvl_map): + def on_reset(self, state, lvl_map): for (agent_name, position_list) in self.per_agent_positions.items(): agent = h.get_first(state[c.AGENT], lambda x: agent_name in x.name) assert agent diff --git a/marl_factory_grid/modules/doors/actions.py b/marl_factory_grid/modules/doors/actions.py index c7d06ed..a8c7c14 100644 --- a/marl_factory_grid/modules/doors/actions.py +++ b/marl_factory_grid/modules/doors/actions.py @@ -15,7 +15,7 @@ class DoorUse(Action): # Check if agent really is standing on a door: e = state.entities.get_entities_near_pos(entity.pos) try: - # Only one door opens TODO introcude loop + # Only one door opens TODO introduce loop door = next(x for x in e if x.name.startswith(d.DOOR)) valid = door.use() state.print(f'{entity.name} just used a {door.name} at {door.pos}') diff --git a/marl_factory_grid/modules/doors/entitites.py b/marl_factory_grid/modules/doors/entitites.py index 1c33d7b..4a84628 100644 --- a/marl_factory_grid/modules/doors/entitites.py +++ b/marl_factory_grid/modules/doors/entitites.py @@ -117,3 +117,7 @@ class Door(Entity): def _reset_timer(self): self._time_to_close = self._auto_close_interval return True + + def reset(self): + self._close() + self._reset_timer() diff --git a/marl_factory_grid/modules/doors/groups.py b/marl_factory_grid/modules/doors/groups.py index 973d1ab..0e83881 100644 --- a/marl_factory_grid/modules/doors/groups.py +++ b/marl_factory_grid/modules/doors/groups.py @@ -23,3 +23,7 @@ class Doors(Collection): results.append(tick_result) # TODO: Should return a Result object, not a random dict. return results + + def reset(self): + for door in self: + door.reset() diff --git a/marl_factory_grid/modules/doors/rules.py b/marl_factory_grid/modules/doors/rules.py index 599d975..4b24470 100644 --- a/marl_factory_grid/modules/doors/rules.py +++ b/marl_factory_grid/modules/doors/rules.py @@ -40,6 +40,6 @@ class IndicateDoorAreaInObservation(Rule): # Could then be combined with the "Combine"-approach. super().__init__() - def on_init(self, state, lvl_map): + def on_reset(self, state, lvl_map): for door in state[d.DOORS]: state[d.DOORS].add_items([DoorIndicator(x) for x in state.entities.neighboring_positions(door.pos)]) diff --git a/marl_factory_grid/modules/factory/__init__.py b/marl_factory_grid/modules/factory/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/marl_factory_grid/modules/factory/rules.py b/marl_factory_grid/modules/factory/rules.py deleted file mode 100644 index e056135..0000000 --- a/marl_factory_grid/modules/factory/rules.py +++ /dev/null @@ -1,32 +0,0 @@ -import random -from typing import List - -from marl_factory_grid.environment import constants as c -from marl_factory_grid.environment.rules import Rule -from marl_factory_grid.utils.results import TickResult - - -class AgentSingleZonePlacementBeta(Rule): - - def __init__(self): - raise NotImplementedError() - # TODO!!!! Is this concept needed any more? - super().__init__() - - def on_init(self, state, lvl_map): - agents = state[c.AGENT] - if len(self.coordinates) == len(agents): - coordinates = self.coordinates - elif len(self.coordinates) > len(agents): - coordinates = random.choices(self.coordinates, k=len(agents)) - else: - raise ValueError - - for agent, pos in zip(agents, coordinates): - agent.move(pos, state) - - def tick_step(self, state): - return [] - - def tick_post_step(self, state) -> List[TickResult]: - return [] diff --git a/marl_factory_grid/modules/maintenance/entities.py b/marl_factory_grid/modules/maintenance/entities.py index 1a043c8..479e4c8 100644 --- a/marl_factory_grid/modules/maintenance/entities.py +++ b/marl_factory_grid/modules/maintenance/entities.py @@ -3,8 +3,6 @@ from random import shuffle import networkx as nx import numpy as np - -from ...algorithms.static.utils import points_to_graph from ...environment import constants as c from ...environment.actions import Action, ALL_BASEACTIONS from ...environment.entity.entity import Entity @@ -26,7 +24,6 @@ class Maintainer(Entity): self._next = [] self._last = [] self._last_serviced = 'None' - self._floortile_graph = None def tick(self, state): if found_objective := h.get_first(state[self.objective].by_pos(self.pos)): @@ -41,17 +38,18 @@ class Maintainer(Entity): return action.do(self, state) def get_move_action(self, state) -> Action: - if not self._floortile_graph: - state.print("Generating Floorgraph....") - self._floortile_graph = points_to_graph(state.entities.floorlist) - if self._path is None or not self._path: + if self._path is None or not len(self._path): if not self._next: self._next = list(state[self.objective].values()) + [Floor(*state.random_free_position)] shuffle(self._next) self._last = [] self._last.append(self._next.pop()) state.print("Calculating shortest path....") - self._path = self.calculate_route(self._last[-1]) + self._path = self.calculate_route(self._last[-1], state.floortile_graph) + if not self._path: + self._last.append(self._next.pop()) + state.print("Calculating shortest path.... Again....") + self._path = self.calculate_route(self._last[-1], state.floortile_graph) if door := self._closed_door_in_path(state): state.print(f"{self} found {door} that is closed. Attempt to open.") @@ -67,8 +65,8 @@ class Maintainer(Entity): raise EnvironmentError return action_obj - def calculate_route(self, entity): - route = nx.shortest_path(self._floortile_graph, self.pos, entity.pos) + def calculate_route(self, entity, floortile_graph): + route = nx.shortest_path(floortile_graph, self.pos, entity.pos) return route[1:] def _closed_door_in_path(self, state): diff --git a/marl_factory_grid/modules/maintenance/groups.py b/marl_factory_grid/modules/maintenance/groups.py index 5b09c9c..4ac416b 100644 --- a/marl_factory_grid/modules/maintenance/groups.py +++ b/marl_factory_grid/modules/maintenance/groups.py @@ -14,14 +14,8 @@ class Maintainers(Collection): var_is_blocking_light = False var_has_position = True - def __init__(self, size, *args, coords_or_quantity: int = None, - spawnrule: Union[None, Dict[str, dict]] = None, - **kwargs): - super(Collection, self).__init__(*args, **kwargs) - self._coords_or_quantity = coords_or_quantity - self.size = size - self._spawnrule = spawnrule - + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args): self.add_items([self._entity(mc.MACHINES, MachineAction(), pos) for pos in coords_or_quantity]) diff --git a/marl_factory_grid/modules/zones/rules.py b/marl_factory_grid/modules/zones/rules.py index f9b5c11..c31666c 100644 --- a/marl_factory_grid/modules/zones/rules.py +++ b/marl_factory_grid/modules/zones/rules.py @@ -11,19 +11,21 @@ class ZoneInit(Rule): def __init__(self): super().__init__() + self._zones = list() def on_init(self, state, lvl_map): - zones = [] z_idx = 1 while z_idx: zone_positions = lvl_map.get_coordinates_for_symbol(z_idx) if len(zone_positions): - zones.append(Zone(zone_positions)) + self._zones.append(Zone(zone_positions)) z_idx += 1 else: z_idx = 0 - state[z.ZONES].add_items(zones) + + def on_reset(self, state): + state[z.ZONES].add_items(self._zones) return [] @@ -32,7 +34,7 @@ class AgentSingleZonePlacement(Rule): def __init__(self): super().__init__() - def on_init(self, state, lvl_map): + def on_reset(self, state): n_agents = len(state[c.AGENT]) assert len(state[z.ZONES]) >= n_agents @@ -48,19 +50,16 @@ class AgentSingleZonePlacement(Rule): class IndividualDestinationZonePlacement(Rule): def __init__(self): + raise NotImplementedError("This is rpetty new, and needs to be debugged, after the zones") super().__init__() - def on_init(self, state, lvl_map): + def on_reset(self, state): for agent in state[c.AGENT]: - self.trigger_destination_spawn(agent, state) - pass - return [] - - def tick_step(self, state): + self.trigger_spawn(agent, state) return [] @staticmethod - def trigger_destination_spawn(agent, state): + def trigger_spawn(agent, state): agent_zones = state[z.ZONES].by_pos(agent.pos) other_zones = [x for x in state[z.ZONES] if x not in agent_zones] already_has_destination = True diff --git a/marl_factory_grid/utils/states.py b/marl_factory_grid/utils/states.py index d54db6a..b5f794c 100644 --- a/marl_factory_grid/utils/states.py +++ b/marl_factory_grid/utils/states.py @@ -3,6 +3,7 @@ from typing import List, Tuple import numpy as np +from marl_factory_grid.algorithms.static.utils import points_to_graph from marl_factory_grid.environment import constants as c from marl_factory_grid.environment.entity.entity import Entity from marl_factory_grid.environment.rules import Rule @@ -29,6 +30,12 @@ class StepRules: self.rules.append(item) return True + def do_all_reset(self, state): + for rule in self.rules: + if rule_reset_printline := rule.on_reset(state): + state.print(rule_reset_printline) + return c.VALID + def do_all_init(self, state, lvl_map): for rule in self.rules: if rule_init_printline := rule.on_init(state, lvl_map): @@ -59,6 +66,13 @@ class StepRules: class Gamestate(object): + @property + def floortile_graph(self): + if not self._floortile_graph: + self.print("Generating Floorgraph....") + self._floortile_graph = points_to_graph(self.entities.floorlist) + return self._floortile_graph + @property def moving_entites(self): return [y for x in self.entities for y in x if x.var_can_move] @@ -72,6 +86,7 @@ class Gamestate(object): self.verbose = verbose self.rng = np.random.default_rng(env_seed) self.rules = StepRules(*rules) + self._floortile_graph = None self.tests = StepTests(*tests) def __getitem__(self, item):