diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..b58b603 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,5 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ diff --git a/README.md b/README.md index a1d2740..d0c0a19 100644 --- a/README.md +++ b/README.md @@ -56,7 +56,7 @@ Just define what your environment needs in a *yaml*-configfile like: - Items Rules: Defaults: {} - Collision: + WatchCollisions: done_at_collisions: !!bool True ItemRespawn: spawn_freq: 5 @@ -94,7 +94,7 @@ All [Entites](marl_factory_grid/environment/entity/global_entities.py) are avail #### Rules -[Rules](marl_factory_grid/environment/entity/object.py) define how the environment behaves on micro-scale. +[Rules](marl_factory_grid/environment/entity/object.py) define how the environment behaves on microscale. Each of the hookes (`on_init`, `pre_step`, `on_step`, '`post_step`', `on_done`) provide env-access to implement customn logic, calculate rewards, or gather information. @@ -107,6 +107,7 @@ Make sure to bring your own assets for each Entity living in the Gridworld as th PNG-files (transparent background) of square aspect-ratio should do the job, in general. + diff --git a/marl_factory_grid/__init__.py b/marl_factory_grid/__init__.py index b2bbfa3..259e3cf 100644 --- a/marl_factory_grid/__init__.py +++ b/marl_factory_grid/__init__.py @@ -1,6 +1 @@ -from .environment import * -from .modules import * -from .utils import * - from .quickstart import init - diff --git a/marl_factory_grid/algorithms/__init__.py b/marl_factory_grid/algorithms/__init__.py index 0980070..cc2c489 100644 --- a/marl_factory_grid/algorithms/__init__.py +++ b/marl_factory_grid/algorithms/__init__.py @@ -1 +1,4 @@ -import os, sys; sys.path.append(os.path.dirname(os.path.realpath(__file__))) +import os +import sys + +sys.path.append(os.path.dirname(os.path.realpath(__file__))) diff --git a/marl_factory_grid/algorithms/marl/__init__.py b/marl_factory_grid/algorithms/marl/__init__.py index 984588c..a4c30ef 100644 --- a/marl_factory_grid/algorithms/marl/__init__.py +++ b/marl_factory_grid/algorithms/marl/__init__.py @@ -1 +1 @@ -from marl_factory_grid.algorithms.marl.memory import MARLActorCriticMemory \ No newline at end of file +from marl_factory_grid.algorithms.marl.memory import MARLActorCriticMemory diff --git a/marl_factory_grid/algorithms/marl/base_ac.py b/marl_factory_grid/algorithms/marl/base_ac.py index 3bb0318..ef195b7 100644 --- a/marl_factory_grid/algorithms/marl/base_ac.py +++ b/marl_factory_grid/algorithms/marl/base_ac.py @@ -28,6 +28,7 @@ class Names: BATCH_SIZE = 'bnatch_size' N_ACTIONS = 'n_actions' + nms = Names ListOrTensor = Union[List, torch.Tensor] @@ -112,10 +113,9 @@ class BaseActorCritic: next_obs, reward, done, info = env.step(action) done = [done] * self.n_agents if isinstance(done, bool) else done - last_hiddens = dict(hidden_actor =out[nms.HIDDEN_ACTOR], + last_hiddens = dict(hidden_actor=out[nms.HIDDEN_ACTOR], hidden_critic=out[nms.HIDDEN_CRITIC]) - tm.add(observation=obs, action=action, reward=reward, done=done, logits=out.get(nms.LOGITS, None), values=out.get(nms.CRITIC, None), **last_hiddens) @@ -142,7 +142,9 @@ class BaseActorCritic: print(f'reward at episode: {episode} = {rew_log}') episode += 1 df_results.append([episode, rew_log, *reward]) - df_results = pd.DataFrame(df_results, columns=['steps', 'reward', *[f'agent#{i}' for i in range(self.n_agents)]]) + df_results = pd.DataFrame(df_results, + columns=['steps', 'reward', *[f'agent#{i}' for i in range(self.n_agents)]] + ) if checkpointer is not None: df_results.to_csv(checkpointer.path / 'results.csv', index=False) return df_results @@ -157,24 +159,27 @@ class BaseActorCritic: last_action, reward = [-1] * self.n_agents, [0.] * self.n_agents done, rew_log, eps_rew = [False] * self.n_agents, 0, torch.zeros(self.n_agents) while not all(done): - if render: env.render() + if render: + env.render() out = self.forward(obs, last_action, **last_hiddens) action = self.get_actions(out) next_obs, reward, done, info = env.step(action) - if isinstance(done, bool): done = [done] * obs.shape[0] + if isinstance(done, bool): + done = [done] * obs.shape[0] obs = next_obs last_action = action last_hiddens = dict(hidden_actor=out.get(nms.HIDDEN_ACTOR, None), hidden_critic=out.get(nms.HIDDEN_CRITIC, None) ) eps_rew += torch.tensor(reward) - results.append(eps_rew.tolist() + [sum(eps_rew).item()] + [episode]) + results.append(eps_rew.tolist() + [np.sum(eps_rew).item()] + [episode]) episode += 1 agent_columns = [f'agent#{i}' for i in range(self.cfg['environment']['n_agents'])] results = pd.DataFrame(results, columns=agent_columns + ['sum', 'episode']) - results = pd.melt(results, id_vars=['episode'], value_vars=agent_columns + ['sum'], value_name='reward', var_name='agent') + results = pd.melt(results, id_vars=['episode'], value_vars=agent_columns + ['sum'], + value_name='reward', var_name='agent') return results @staticmethod diff --git a/marl_factory_grid/algorithms/marl/mappo.py b/marl_factory_grid/algorithms/marl/mappo.py index d22fa08..faf3b0d 100644 --- a/marl_factory_grid/algorithms/marl/mappo.py +++ b/marl_factory_grid/algorithms/marl/mappo.py @@ -36,7 +36,7 @@ class LoopMAPPO(LoopSNAC): rewards_ = torch.stack(rewards_, dim=1) return rewards_ - def mappo(self, batch, network, gamma, entropy_coef, vf_coef, clip_range, **kwargs): + def mappo(self, batch, network, gamma, entropy_coef, vf_coef, clip_range, **__): out = network(batch[nms.OBSERVATION], batch[nms.ACTION], batch[nms.HIDDEN_ACTOR], batch[nms.HIDDEN_CRITIC]) logits = out[nms.LOGITS][:, :-1] # last one only needed for v_{t+1} @@ -45,7 +45,7 @@ class LoopMAPPO(LoopSNAC): # monte carlo returns mc_returns = self.monte_carlo_returns(batch[nms.REWARD], batch[nms.DONE], gamma) - mc_returns = (mc_returns - mc_returns.mean()) / (mc_returns.std() + 1e-8) #todo: norm across agent ok? + mc_returns = (mc_returns - mc_returns.mean()) / (mc_returns.std() + 1e-8) # todo: norm across agent ok? advantages = mc_returns - out[nms.CRITIC][:, :-1] # policy loss diff --git a/marl_factory_grid/algorithms/marl/networks.py b/marl_factory_grid/algorithms/marl/networks.py index c4fdb72..796c03f 100644 --- a/marl_factory_grid/algorithms/marl/networks.py +++ b/marl_factory_grid/algorithms/marl/networks.py @@ -1,8 +1,7 @@ +import numpy as np import torch import torch.nn as nn -import numpy as np import torch.nn.functional as F -from torch.nn.utils import spectral_norm class RecurrentAC(nn.Module): @@ -88,8 +87,8 @@ class NormalizedLinear(nn.Linear): self.trainable_magnitude = trainable_magnitude self.scale = nn.Parameter(torch.tensor([1.]), requires_grad=trainable_magnitude) - def forward(self, input): - normalized_input = F.normalize(input, dim=-1, p=2, eps=1e-5) + def forward(self, in_array): + normalized_input = F.normalize(in_array, dim=-1, p=2, eps=1e-5) normalized_weight = F.normalize(self.weight, dim=-1, p=2, eps=1e-5) return F.linear(normalized_input, normalized_weight) * self.d_sqrt * self.scale diff --git a/marl_factory_grid/algorithms/marl/seac.py b/marl_factory_grid/algorithms/marl/seac.py index 9c458c7..07e8267 100644 --- a/marl_factory_grid/algorithms/marl/seac.py +++ b/marl_factory_grid/algorithms/marl/seac.py @@ -16,7 +16,7 @@ class LoopSEAC(LoopIAC): with torch.inference_mode(True): true_action_logp = torch.stack([ torch.log_softmax(out[nms.LOGITS][ag_i, :-1], -1) - .gather(index=actions[ag_i, 1:, None], dim=-1) + .gather(index=actions[ag_i, 1:, None], dim=-1) for ag_i, out in enumerate(outputs) ], 0).squeeze() @@ -38,7 +38,6 @@ class LoopSEAC(LoopIAC): a2c_loss = (-iw*log_ap * advantages.detach()).mean(-1) - value_loss = (iw*advantages.pow(2)).mean(-1) # n_agent # weighted loss @@ -53,4 +52,4 @@ class LoopSEAC(LoopIAC): self.optimizer[ag_i].zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(self.net[ag_i].parameters(), 0.5) - self.optimizer[ag_i].step() \ No newline at end of file + self.optimizer[ag_i].step() diff --git a/marl_factory_grid/algorithms/marl/snac.py b/marl_factory_grid/algorithms/marl/snac.py index b249754..11be902 100644 --- a/marl_factory_grid/algorithms/marl/snac.py +++ b/marl_factory_grid/algorithms/marl/snac.py @@ -30,4 +30,4 @@ class LoopSNAC(BaseActorCritic): self._as_torch(actions).unsqueeze(1), hidden_actor, hidden_critic ) - return out \ No newline at end of file + return out diff --git a/marl_factory_grid/algorithms/static/TSP_base_agent.py b/marl_factory_grid/algorithms/static/TSP_base_agent.py index bc48f7c..7d25f63 100644 --- a/marl_factory_grid/algorithms/static/TSP_base_agent.py +++ b/marl_factory_grid/algorithms/static/TSP_base_agent.py @@ -56,8 +56,8 @@ class TSPBaseAgent(ABC): def _door_is_close(self, state): try: - # return next(y for x in self.state.tile.neighboring_floor for y in x.guests if do.DOOR in y.name) - return next(y for x in state.entities.neighboring_positions(self.state.pos) for y in state.entities.pos_dict[x] if do.DOOR in y.name) + return next(y for x in state.entities.neighboring_positions(self.state.pos) + for y in state.entities.pos_dict[x] if do.DOOR in y.name) except StopIteration: return None diff --git a/marl_factory_grid/algorithms/static/TSP_target_agent.py b/marl_factory_grid/algorithms/static/TSP_target_agent.py index 0c5de3a..b0d8b29 100644 --- a/marl_factory_grid/algorithms/static/TSP_target_agent.py +++ b/marl_factory_grid/algorithms/static/TSP_target_agent.py @@ -14,8 +14,8 @@ class TSPTargetAgent(TSPBaseAgent): def _handle_doors(self, state): try: - # return next(y for x in self.state.tile.neighboring_floor for y in x.guests if do.DOOR in y.name) - return next(y for x in state.entities.neighboring_positions(self.state.pos) for y in state.entities.pos_dict[x] if do.DOOR in y.name) + return next(y for x in state.entities.neighboring_positions(self.state.pos) + for y in state.entities.pos_dict[x] if do.DOOR in y.name) except StopIteration: return None @@ -30,4 +30,3 @@ class TSPTargetAgent(TSPBaseAgent): except (StopIteration, UnboundLocalError): print('Will not happen') return action_obj - diff --git a/marl_factory_grid/algorithms/static/utils.py b/marl_factory_grid/algorithms/static/utils.py index d5119db..60cba30 100644 --- a/marl_factory_grid/algorithms/static/utils.py +++ b/marl_factory_grid/algorithms/static/utils.py @@ -26,12 +26,16 @@ def points_to_graph(coordiniates, allow_euclidean_connections=True, allow_manhat assert allow_euclidean_connections or allow_manhattan_connections possible_connections = itertools.combinations(coordiniates, 2) graph = nx.Graph() - for a, b in possible_connections: - diff = np.linalg.norm(np.asarray(a)-np.asarray(b)) - if allow_manhattan_connections and allow_euclidean_connections and diff <= np.sqrt(2): - graph.add_edge(a, b) - elif not allow_manhattan_connections and allow_euclidean_connections and diff == np.sqrt(2): - graph.add_edge(a, b) - elif allow_manhattan_connections and not allow_euclidean_connections and diff == 1: - graph.add_edge(a, b) + if allow_manhattan_connections and allow_euclidean_connections: + graph.add_edges_from( + (a, b) for a, b in possible_connections if np.linalg.norm(np.asarray(a) - np.asarray(b)) <= np.sqrt(2) + ) + elif not allow_manhattan_connections and allow_euclidean_connections: + graph.add_edges_from( + (a, b) for a, b in possible_connections if np.linalg.norm(np.asarray(a) - np.asarray(b)) == np.sqrt(2) + ) + elif allow_manhattan_connections and not allow_euclidean_connections: + graph.add_edges_from( + (a, b) for a, b in possible_connections if np.linalg.norm(np.asarray(a) - np.asarray(b)) == 1 + ) return graph diff --git a/marl_factory_grid/algorithms/utils.py b/marl_factory_grid/algorithms/utils.py index 59e78bd..8c60386 100644 --- a/marl_factory_grid/algorithms/utils.py +++ b/marl_factory_grid/algorithms/utils.py @@ -1,8 +1,9 @@ -import torch -import numpy as np -import yaml from pathlib import Path +import numpy as np +import torch +import yaml + def load_class(classname): from importlib import import_module @@ -42,7 +43,6 @@ def get_class(arguments): def get_arguments(arguments): - from importlib import import_module d = dict(arguments) if "classname" in d: del d["classname"] @@ -82,4 +82,4 @@ class Checkpointer(object): for name, model in to_save: self.save_experiment(name, model) self.__current_checkpoint += 1 - self.__current_step += 1 \ No newline at end of file + self.__current_step += 1 diff --git a/marl_factory_grid/configs/default_config.yaml b/marl_factory_grid/configs/default_config.yaml index 44a0977..d3015c9 100644 --- a/marl_factory_grid/configs/default_config.yaml +++ b/marl_factory_grid/configs/default_config.yaml @@ -22,26 +22,41 @@ Agents: - Inventory - DropOffLocations - Maintainers + # This is special for agents, as each one is differten and can act as an adversary e.g. + Positions: + - (16, 7) + - (16, 6) + - (16, 3) + - (16, 4) + - (16, 5) Entities: Batteries: initial_charge: 0.8 per_action_costs: 0.02 - ChargePods: {} - Destinations: {} + ChargePods: + coords_or_quantity: 2 + Destinations: + coords_or_quantity: 1 + spawn_mode: GROUPED DirtPiles: + coords_or_quantity: 10 + initial_amount: 2 clean_amount: 1 dirt_spawn_r_var: 0.1 - initial_amount: 2 - initial_dirt_ratio: 0.05 max_global_amount: 20 max_local_amount: 5 - Doors: {} - DropOffLocations: {} + Doors: + DropOffLocations: + coords_or_quantity: 1 + max_dropoff_storage_size: 0 GlobalPositions: {} Inventories: {} - Items: {} - Machines: {} - Maintainers: {} + Items: + coords_or_quantity: 5 + Machines: + coords_or_quantity: 2 + Maintainers: + coords_or_quantity: 1 Zones: {} General: @@ -49,32 +64,31 @@ General: individual_rewards: true level_name: large pomdp_r: 3 - verbose: false + verbose: True + tests: false Rules: - SpawnAgents: {} - DoneAtBatteryDischarge: {} - Collision: - done_at_collisions: false - AssignGlobalPositions: {} - DoneAtDestinationReachAny: {} - DestinationReachReward: {} - SpawnDestinations: - n_dests: 1 - spawn_mode: GROUPED - DoneOnAllDirtCleaned: {} - SpawnDirt: - spawn_freq: 15 + # Environment Dynamics EntitiesSmearDirtOnMove: smear_ratio: 0.2 DoorAutoClose: close_frequency: 10 - ItemRules: - max_dropoff_storage_size: 0 - n_items: 5 - n_locations: 5 - spawn_frequency: 15 - MaxStepsReached: + MoveMaintainers: + + # Respawn Stuff + RespawnDirt: + respawn_freq: 15 + RespawnItems: + respawn_freq: 15 + + # Utilities + WatchCollisions: + done_at_collisions: false + + # Done Conditions + DoneAtDestinationReachAny: + DoneOnAllDirtCleaned: + DoneAtBatteryDischarge: + DoneAtMaintainerCollision: + DoneAtMaxStepsReached: max_steps: 500 -# AgentSingleZonePlacement: -# n_zones: 4 diff --git a/marl_factory_grid/configs/narrow_corridor.yaml b/marl_factory_grid/configs/narrow_corridor.yaml index 0006513..f53b972 100644 --- a/marl_factory_grid/configs/narrow_corridor.yaml +++ b/marl_factory_grid/configs/narrow_corridor.yaml @@ -1,15 +1,41 @@ +General: + # Your Seed + env_seed: 69 + # Individual or global rewards? + individual_rewards: true + # The level.txt file to load + level_name: narrow_corridor + # View Radius; 0 = full observatbility + pomdp_r: 0 + # print all messages and events + verbose: true + Agents: + # Agents are identified by their name Wolfgang: + # The available actions for this particular agent Actions: + # Able to do nothing - Noop + # Able to move in all 8 directions - Move8 + # Stuff the agent can observe (per 2d slice) + # use "Combined" if you want to merge multiple slices into one Observations: + # He sees walls - Walls + # he sees other agent, "karl-Heinz" in this setting would be fine, too - Other + # He can see Destinations, that are assigned to him (hence the singular) - Destination + # Avaiable Spawn Positions as list Positions: - (2, 1) - (2, 5) + # It is okay to collide with other agents, so that + # they end up on the same position + is_blocking_pos: true + # See Above.... Karl-Heinz: Actions: - Noop @@ -21,26 +47,43 @@ Agents: Positions: - (2, 1) - (2, 5) + is_blocking_pos: true + +# Other noteworthy Entitites Entities: - Destinations: {} - -General: - env_seed: 69 - individual_rewards: true - level_name: narrow_corridor - pomdp_r: 0 - verbose: true + # The destiantions or positional targets to reach + Destinations: + # Let them spawn on closed doors and agent positions + ignore_blocking: true + # We need a special spawn rule... + spawnrule: + # ...which assigns the destinations per agent + SpawnDestinationsPerAgent: + # we use this parameter + coords_or_quantity: + # to enable and assign special positions per agent + Wolfgang: + - (2, 1) + - (2, 5) + Karl-Heinz: + - (2, 1) + - (2, 5) + # Whether you want to provide a numeric Position observation. + # GlobalPositions: + # normalized: false +# Define the env. dynamics Rules: - SpawnAgents: {} - Collision: + # Utilities + # This rule Checks for Collision, also it assigns the (negative) reward + WatchCollisions: + reward: -0.1 + reward_at_done: -1 done_at_collisions: false - FixedDestinationSpawn: - per_agent_positions: - Wolfgang: - - (2, 1) - - (2, 5) - Karl-Heinz: - - (2, 1) - - (2, 5) - DestinationReachAll: {} + # Done Conditions + # Load any of the rules, to check for done conditions. + # DoneAtDestinationReachAny: + DoneAtDestinationReachAll: + # reward_at_done: 1 + DoneAtMaxStepsReached: + max_steps: 200 diff --git a/marl_factory_grid/environment/actions.py b/marl_factory_grid/environment/actions.py index 4edfe24..606832c 100644 --- a/marl_factory_grid/environment/actions.py +++ b/marl_factory_grid/environment/actions.py @@ -48,9 +48,9 @@ class Move(Action, abc.ABC): reward = r.MOVEMENTS_VALID if move_validity else r.MOVEMENTS_FAIL return ActionResult(entity=entity, identifier=self._identifier, validity=move_validity, reward=reward) else: # There is no place to go, propably collision - # This is currently handeld by the Collision rule, so that it can be switched on and off by conf.yml + # This is currently handeld by the WatchCollisions rule, so that it can be switched on and off by conf.yml # return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=r.COLLISION) - return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=0) + return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID) def _calc_new_pos(self, pos): x_diff, y_diff = MOVEMAP[self._identifier] diff --git a/marl_factory_grid/environment/constants.py b/marl_factory_grid/environment/constants.py index 1fdf639..6ddb19a 100644 --- a/marl_factory_grid/environment/constants.py +++ b/marl_factory_grid/environment/constants.py @@ -10,6 +10,7 @@ AGENT = 'Agent' # Identifier of Agent-objects an OTHERS = 'Other' COMBINED = 'Combined' GLOBALPOSITIONS = 'GlobalPositions' # Identifier of the global position slice +SPAWN_ENTITY_RULE = 'SpawnEntity' # Attributes IS_BLOCKING_LIGHT = 'var_is_blocking_light' @@ -29,7 +30,7 @@ VALUE_NO_POS = (-9999, -9999) # Invalid Position value used in the e ACTION = 'action' # Identifier of Action-objects and groups (groups). -COLLISION = 'Collision' # Identifier to use in the context of collitions. +COLLISION = 'Collisions' # Identifier to use in the context of collitions. # LAST_POS = 'LAST_POS' # Identifiert for retrieving an enitites last pos. VALIDITY = 'VALIDITY' # Identifiert for retrieving the Validity of Action, Tick, etc. ... @@ -54,3 +55,5 @@ NOOP = 'Noop' # Result Identifier MOVEMENTS_VALID = 'motion_valid' MOVEMENTS_FAIL = 'motion_not_valid' +DEFAULT_PATH = 'environment' +MODULE_PATH = 'modules' diff --git a/marl_factory_grid/environment/entity/agent.py b/marl_factory_grid/environment/entity/agent.py index 285c8d2..0920604 100644 --- a/marl_factory_grid/environment/entity/agent.py +++ b/marl_factory_grid/environment/entity/agent.py @@ -12,14 +12,6 @@ from marl_factory_grid.environment import constants as c class Agent(Entity): - @property - def var_is_blocking_light(self): - return False - - @property - def var_can_move(self): - return True - @property def var_is_paralyzed(self): return len(self._paralyzed) @@ -28,14 +20,6 @@ class Agent(Entity): def paralyze_reasons(self): return [x for x in self._paralyzed] - @property - def var_is_blocking_pos(self): - return False - - @property - def var_has_position(self): - return True - @property def obs_tag(self): return self.name @@ -48,10 +32,6 @@ class Agent(Entity): def observations(self): return self._observations - @property - def var_can_collide(self): - return True - def step_result(self): pass @@ -60,16 +40,21 @@ class Agent(Entity): return self._collection @property - def state(self): - return self._state or ActionResult(entity=self, identifier=c.NOOP, validity=c.VALID, reward=0) + def var_is_blocking_pos(self): + return self._is_blocking_pos - def __init__(self, actions: List[Action], observations: List[str], *args, **kwargs): + @property + def state(self): + return self._state or ActionResult(entity=self, identifier=c.NOOP, validity=c.VALID) + + def __init__(self, actions: List[Action], observations: List[str], *args, is_blocking_pos=False, **kwargs): super(Agent, self).__init__(*args, **kwargs) self._paralyzed = set() self.step_result = dict() self._actions = actions self._observations = observations self._state: Union[Result, None] = None + self._is_blocking_pos = is_blocking_pos # noinspection PyAttributeOutsideInit def clear_temp_state(self): diff --git a/marl_factory_grid/environment/entity/entity.py b/marl_factory_grid/environment/entity/entity.py index 637827f..999787b 100644 --- a/marl_factory_grid/environment/entity/entity.py +++ b/marl_factory_grid/environment/entity/entity.py @@ -1,20 +1,19 @@ import abc -from collections import defaultdict import numpy as np -from .object import _Object +from .object import Object from .. import constants as c from ...utils.results import ActionResult from ...utils.utility_classes import RenderEntity -class Entity(_Object, abc.ABC): +class Entity(Object, abc.ABC): """Full Env Entity that lives on the environment Grid. Doors, Items, DirtPile etc...""" @property def state(self): - return self._status or ActionResult(entity=self, identifier=c.NOOP, validity=c.VALID, reward=0) + return self._status or ActionResult(entity=self, identifier=c.NOOP, validity=c.VALID) @property def var_has_position(self): @@ -60,6 +59,10 @@ class Entity(_Object, abc.ABC): def pos(self): return self._pos + def set_pos(self, pos): + assert isinstance(pos, tuple) and len(pos) == 2 + self._pos = pos + @property def last_pos(self): try: @@ -84,7 +87,7 @@ class Entity(_Object, abc.ABC): for observer in self.observers: observer.notify_del_entity(self) self._view_directory = curr_pos[0] - next_pos[0], curr_pos[1] - next_pos[1] - self._pos = next_pos + self.set_pos(next_pos) for observer in self.observers: observer.notify_add_entity(self) return valid @@ -92,6 +95,7 @@ class Entity(_Object, abc.ABC): def __init__(self, pos, bind_to=None, **kwargs): super().__init__(**kwargs) + self._view_directory = c.VALUE_NO_POS self._status = None self._pos = pos self._last_pos = pos @@ -109,9 +113,6 @@ class Entity(_Object, abc.ABC): def render(self): return RenderEntity(self.__class__.__name__.lower(), self.pos) - def __repr__(self): - return super(Entity, self).__repr__() + f'(@{self.pos})' - @property def obs_tag(self): try: @@ -128,25 +129,3 @@ class Entity(_Object, abc.ABC): self._collection.delete_env_object(self) self._collection = other_collection return self._collection == other_collection - - @classmethod - def from_coordinates(cls, positions: [(int, int)], *args, entity_kwargs=None, **kwargs, ): - collection = cls(*args, **kwargs) - collection.add_items( - [cls._entity(tuple(pos), **entity_kwargs if entity_kwargs is not None else {}) for pos in positions]) - return collection - - def notify_del_entity(self, entity): - try: - self.pos_dict[entity.pos].remove(entity) - except (ValueError, AttributeError): - pass - - def by_pos(self, pos: (int, int)): - pos = tuple(pos) - try: - return self.state.entities.pos_dict[pos] - except StopIteration: - pass - except ValueError: - print() diff --git a/marl_factory_grid/environment/entity/mixin.py b/marl_factory_grid/environment/entity/mixin.py deleted file mode 100644 index bab6343..0000000 --- a/marl_factory_grid/environment/entity/mixin.py +++ /dev/null @@ -1,24 +0,0 @@ - - -# noinspection PyAttributeOutsideInit -class BoundEntityMixin: - - @property - def bound_entity(self): - return self._bound_entity - - @property - def name(self): - if self.bound_entity: - return f'{self.__class__.__name__}({self.bound_entity.name})' - else: - pass - - def belongs_to_entity(self, entity): - return entity == self.bound_entity - - def bind_to(self, entity): - self._bound_entity = entity - - def unbind(self): - self._bound_entity = None diff --git a/marl_factory_grid/environment/entity/object.py b/marl_factory_grid/environment/entity/object.py index 8810baf..e8c69da 100644 --- a/marl_factory_grid/environment/entity/object.py +++ b/marl_factory_grid/environment/entity/object.py @@ -5,7 +5,7 @@ from marl_factory_grid.environment import constants as c import marl_factory_grid.utils.helpers as h -class _Object: +class Object: """Generell Objects for Organisation and Maintanance such as Actions etc...""" _u_idx = defaultdict(lambda: 0) @@ -13,10 +13,6 @@ class _Object: def __bool__(self): return True - @property - def var_has_position(self): - return False - @property def var_can_be_bound(self): try: @@ -30,22 +26,14 @@ class _Object: @property def name(self): - if self._str_ident is not None: - name = f'{self.__class__.__name__}[{self._str_ident}]' - else: - name = f'{self.__class__.__name__}#{self.u_int}' - if self.bound_entity: - name = h.add_bound_name(name, self.bound_entity) - if self.var_has_position: - name = h.add_pos_name(name, self) - return name + return f'{self.__class__.__name__}[{self.identifier}]' @property def identifier(self): if self._str_ident is not None: return self._str_ident else: - return self.name + return self.u_int def reset_uid(self): self._u_idx = defaultdict(lambda: 0) @@ -62,7 +50,15 @@ class _Object: print(f'Following kwargs were passed, but ignored: {kwargs}') def __repr__(self): - return f'{self.name}' + name = self.name + if self.bound_entity: + name = h.add_bound_name(name, self.bound_entity) + try: + if self.var_has_position: + name = h.add_pos_name(name, self) + except AttributeError: + pass + return name def __eq__(self, other) -> bool: return other == self.identifier @@ -71,8 +67,8 @@ class _Object: return hash(self.identifier) def _identify_and_count_up(self): - idx = _Object._u_idx[self.__class__.__name__] - _Object._u_idx[self.__class__.__name__] += 1 + idx = Object._u_idx[self.__class__.__name__] + Object._u_idx[self.__class__.__name__] += 1 return idx def set_collection(self, collection): @@ -88,7 +84,7 @@ class _Object: def summarize_state(self): return dict() - def bind(self, entity): + def bind_to(self, entity): # noinspection PyAttributeOutsideInit self._bound_entity = entity return c.VALID @@ -100,84 +96,5 @@ class _Object: def bound_entity(self): return self._bound_entity - def bind_to(self, entity): - self._bound_entity = entity - def unbind(self): self._bound_entity = None - - -# class EnvObject(_Object): -# """Objects that hold Information that are observable, but have no position on the environment grid. Inventories etc...""" -# - # _u_idx = defaultdict(lambda: 0) -# -# @property -# def obs_tag(self): -# try: -# return self._collection.name or self.name -# except AttributeError: -# return self.name -# -# @property -# def var_is_blocking_light(self): -# try: -# return self._collection.var_is_blocking_light or False -# except AttributeError: -# return False -# -# @property -# def var_can_be_bound(self): -# try: -# return self._collection.var_can_be_bound or False -# except AttributeError: -# return False -# -# @property -# def var_can_move(self): -# try: -# return self._collection.var_can_move or False -# except AttributeError: -# return False -# -# @property -# def var_is_blocking_pos(self): -# try: -# return self._collection.var_is_blocking_pos or False -# except AttributeError: -# return False -# -# @property -# def var_has_position(self): -# try: -# return self._collection.var_has_position or False -# except AttributeError: -# return False -# -# @property -# def var_can_collide(self): -# try: -# return self._collection.var_can_collide or False -# except AttributeError: -# return False -# -# -# @property -# def encoding(self): -# return c.VALUE_OCCUPIED_CELL -# -# -# def __init__(self, **kwargs): -# self._bound_entity = None -# super(EnvObject, self).__init__(**kwargs) -# -# -# def change_parent_collection(self, other_collection): -# other_collection.add_item(self) -# self._collection.delete_env_object(self) -# self._collection = other_collection -# return self._collection == other_collection -# -# -# def summarize_state(self): -# return dict(name=str(self.name)) diff --git a/marl_factory_grid/environment/entity/util.py b/marl_factory_grid/environment/entity/util.py index 1a5cbe3..2a15c41 100644 --- a/marl_factory_grid/environment/entity/util.py +++ b/marl_factory_grid/environment/entity/util.py @@ -1,6 +1,6 @@ import numpy as np -from marl_factory_grid.environment.entity.object import _Object +from marl_factory_grid.environment.entity.object import Object ########################################################################## @@ -8,7 +8,7 @@ from marl_factory_grid.environment.entity.object import _Object ########################################################################## -class PlaceHolder(_Object): +class PlaceHolder(Object): def __init__(self, *args, fill_value=0, **kwargs): super().__init__(*args, **kwargs) @@ -24,10 +24,10 @@ class PlaceHolder(_Object): @property def name(self): - return "PlaceHolder" + return self.__class__.__name__ -class GlobalPosition(_Object): +class GlobalPosition(Object): @property def encoding(self): @@ -36,7 +36,8 @@ class GlobalPosition(_Object): else: return self.bound_entity.pos - def __init__(self, level_shape, *args, normalized: bool = True, **kwargs): + def __init__(self, agent, level_shape, *args, normalized: bool = True, **kwargs): super(GlobalPosition, self).__init__(*args, **kwargs) + self.bind_to(agent) self._normalized = normalized self._shape = level_shape diff --git a/marl_factory_grid/environment/entity/wall.py b/marl_factory_grid/environment/entity/wall.py index 3f0fb7c..83044cd 100644 --- a/marl_factory_grid/environment/entity/wall.py +++ b/marl_factory_grid/environment/entity/wall.py @@ -5,13 +5,8 @@ from marl_factory_grid.utils.utility_classes import RenderEntity class Wall(Entity): - @property - def var_has_position(self): - return True - - @property - def var_can_collide(self): - return True + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) @property def encoding(self): @@ -19,11 +14,3 @@ class Wall(Entity): def render(self): return RenderEntity(c.WALL, self.pos) - - @property - def var_is_blocking_pos(self): - return True - - @property - def var_is_blocking_light(self): - return True diff --git a/marl_factory_grid/environment/factory.py b/marl_factory_grid/environment/factory.py index 651444e..6662581 100644 --- a/marl_factory_grid/environment/factory.py +++ b/marl_factory_grid/environment/factory.py @@ -56,15 +56,18 @@ class Factory(gym.Env): self.level_filepath = Path(custom_level_path) else: self.level_filepath = Path(__file__).parent.parent / h.LEVELS_DIR / f'{self.conf.level_name}.txt' - self._renderer = None # expensive - don't use; unless required ! parsed_entities = self.conf.load_entities() self.map = LevelParser(self.level_filepath, parsed_entities, self.conf.pomdp_r) # Init for later usage: - self.state: Gamestate - self.map: LevelParser - self.obs_builder: OBSBuilder + # noinspection PyTypeChecker + self.state: Gamestate = None + # noinspection PyTypeChecker + self.obs_builder: OBSBuilder = None + + # expensive - don't use; unless required ! + self._renderer = None # reset env to initial state, preparing env for new episode. # returns tuple where the first dict contains initial observation for each agent in the env @@ -74,7 +77,7 @@ class Factory(gym.Env): return self.state.entities[item] def reset(self) -> (dict, dict): - if hasattr(self, 'state'): + if self.state is not None: for entity_group in self.state.entities: try: entity_group[0].reset_uid() @@ -87,12 +90,16 @@ class Factory(gym.Env): entities = self.map.do_init() # Init rules - rules = self.conf.load_env_rules() + env_rules = self.conf.load_env_rules() + entity_rules = self.conf.load_entity_spawn_rules(entities) + env_rules.extend(entity_rules) + env_tests = self.conf.load_env_tests() if self.conf.tests else [] # Parse the agent conf parsed_agents_conf = self.conf.parse_agents_conf() - self.state = Gamestate(entities, parsed_agents_conf, rules, env_tests, self.conf.env_seed, self.conf.verbose) + self.state = Gamestate(entities, parsed_agents_conf, env_rules, env_tests, self.map.level_shape, + self.conf.env_seed, self.conf.verbose) # All is set up, trigger entity init with variable pos # All is set up, trigger additional init (after agent entity spawn etc) @@ -160,7 +167,7 @@ class Factory(gym.Env): # Finalize reward, reward_info, done = self.summarize_step_results(tick_result, done_results) - info = reward_info + info = dict(reward_info) info.update(step_reward=sum(reward), step=self.state.curr_step) diff --git a/marl_factory_grid/environment/groups/agents.py b/marl_factory_grid/environment/groups/agents.py index f4a6ac6..d549384 100644 --- a/marl_factory_grid/environment/groups/agents.py +++ b/marl_factory_grid/environment/groups/agents.py @@ -1,10 +1,15 @@ from marl_factory_grid.environment.entity.agent import Agent from marl_factory_grid.environment.groups.collection import Collection +from marl_factory_grid.environment.rules import SpawnAgents class Agents(Collection): _entity = Agent + @property + def spawn_rule(self): + return {SpawnAgents.__name__: {}} + @property def var_is_blocking_light(self): return False diff --git a/marl_factory_grid/environment/groups/collection.py b/marl_factory_grid/environment/groups/collection.py index 640c3b4..c0f0f6b 100644 --- a/marl_factory_grid/environment/groups/collection.py +++ b/marl_factory_grid/environment/groups/collection.py @@ -1,18 +1,25 @@ -from typing import List, Tuple, Union +from typing import List, Tuple, Union, Dict from marl_factory_grid.environment.entity.entity import Entity -from marl_factory_grid.environment.groups.objects import _Objects -from marl_factory_grid.environment.entity.object import _Object +from marl_factory_grid.environment.groups.objects import Objects +# noinspection PyProtectedMember +from marl_factory_grid.environment.entity.object import Object import marl_factory_grid.environment.constants as c +from marl_factory_grid.utils.results import Result -class Collection(_Objects): - _entity = _Object # entity? +class Collection(Objects): + _entity = Object # entity? + symbol = None @property def var_is_blocking_light(self): return False + @property + def var_is_blocking_pos(self): + return False + @property def var_can_collide(self): return False @@ -23,33 +30,65 @@ class Collection(_Objects): @property def var_has_position(self): - return False - - # @property - # def var_has_bound(self): - # return False # batteries, globalpos, inventories true - - @property - def var_can_be_bound(self): - return False + return True @property def encodings(self): return [x.encoding for x in self] - def __init__(self, size, *args, **kwargs): - super(Collection, self).__init__(*args, **kwargs) - self.size = size - - def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args): # woihn mit den args - if isinstance(coords_or_quantity, int): - self.add_items([self._entity() for _ in range(coords_or_quantity)]) + @property + def spawn_rule(self): + """Prevent SpawnRule creation if Objects are spawned by map, Doors e.g.""" + if self.symbol: + return None + elif self._spawnrule: + return self._spawnrule else: - self.add_items([self._entity(pos) for pos in coords_or_quantity]) + return {c.SPAWN_ENTITY_RULE: dict(collection=self, coords_or_quantity=self._coords_or_quantity)} + + def __init__(self, size, *args, coords_or_quantity: int = None, ignore_blocking=False, + spawnrule: Union[None, Dict[str, dict]] = None, + **kwargs): + super(Collection, self).__init__(*args, **kwargs) + self._coords_or_quantity = coords_or_quantity + self.size = size + self._spawnrule = spawnrule + self._ignore_blocking = ignore_blocking + + def trigger_spawn(self, state, *entity_args, coords_or_quantity=None, ignore_blocking=False, **entity_kwargs): + coords_or_quantity = coords_or_quantity if coords_or_quantity else self._coords_or_quantity + if self.var_has_position: + if self.var_has_position and isinstance(coords_or_quantity, int): + if ignore_blocking or self._ignore_blocking: + coords_or_quantity = state.entities.floorlist[:coords_or_quantity] + else: + coords_or_quantity = state.get_n_random_free_positions(coords_or_quantity) + self.spawn(coords_or_quantity, *entity_args, **entity_kwargs) + state.print(f'{len(coords_or_quantity)} new {self.name} have been spawned at {coords_or_quantity}') + return Result(identifier=f'{self.name}_spawn', validity=c.VALID, value=len(coords_or_quantity)) + else: + if isinstance(coords_or_quantity, int): + self.spawn(coords_or_quantity, *entity_args, **entity_kwargs) + state.print(f'{coords_or_quantity} new {self.name} have been spawned randomly.') + return Result(identifier=f'{self.name}_spawn', validity=c.VALID, value=coords_or_quantity) + else: + raise ValueError(f'{self._entity.__name__} has no position!') + + def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args, **entity_kwargs): + if self.var_has_position: + if isinstance(coords_or_quantity, int): + raise ValueError(f'{self._entity.__name__} should have a position!') + else: + self.add_items([self._entity(pos, *entity_args, **entity_kwargs) for pos in coords_or_quantity]) + else: + if isinstance(coords_or_quantity, int): + self.add_items([self._entity(*entity_args, **entity_kwargs) for _ in range(coords_or_quantity)]) + else: + raise ValueError(f'{self._entity.__name__} has no position!') return c.VALID - def despawn(self, items: List[_Object]): - items = [items] if isinstance(items, _Object) else items + def despawn(self, items: List[Object]): + items = [items] if isinstance(items, Object) else items for item in items: del self[item] @@ -115,7 +154,7 @@ class Collection(_Objects): except StopIteration: pass except ValueError: - print() + pass @property def positions(self): diff --git a/marl_factory_grid/environment/groups/global_entities.py b/marl_factory_grid/environment/groups/global_entities.py index 8bfc9fe..37779f9 100644 --- a/marl_factory_grid/environment/groups/global_entities.py +++ b/marl_factory_grid/environment/groups/global_entities.py @@ -1,21 +1,21 @@ from collections import defaultdict from operator import itemgetter -from random import shuffle, random +from random import shuffle from typing import Dict -from marl_factory_grid.environment.groups.objects import _Objects +from marl_factory_grid.environment.groups.objects import Objects from marl_factory_grid.utils.helpers import POS_MASK -class Entities(_Objects): - _entity = _Objects +class Entities(Objects): + _entity = Objects @staticmethod def neighboring_positions(pos): - return (POS_MASK + pos).reshape(-1, 2) + return [tuple(x) for x in (POS_MASK + pos).reshape(-1, 2)] def get_entities_near_pos(self, pos): - return [y for x in itemgetter(*(tuple(x) for x in self.neighboring_positions(pos)))(self.pos_dict) for y in x] + return [y for x in itemgetter(*self.neighboring_positions(pos))(self.pos_dict) for y in x] def render(self): return [y for x in self for y in x.render() if x is not None] @@ -35,8 +35,9 @@ class Entities(_Objects): super().__init__() def guests_that_can_collide(self, pos): - return[x for val in self.pos_dict[pos] for x in val if x.var_can_collide] + return [x for val in self.pos_dict[pos] for x in val if x.var_can_collide] + @property def empty_positions(self): empty_positions = [key for key in self.floorlist if not self.pos_dict[key]] shuffle(empty_positions) @@ -48,11 +49,23 @@ class Entities(_Objects): shuffle(empty_positions) return empty_positions - def is_blocked(self): - return[key for key, val in self.pos_dict.items() if any([x.var_is_blocking_pos for x in val])] + @property + def blocked_positions(self): + blocked_positions = [key for key, val in self.pos_dict.items() if any([x.var_is_blocking_pos for x in val])] + shuffle(blocked_positions) + return blocked_positions - def is_not_blocked(self): - return[key for key, val in self.pos_dict.items() if not all([x.var_is_blocking_pos for x in val])] + @property + def free_positions_generator(self): + generator = ( + key for key in self.floorlist if all(not x.var_can_collide and not x.var_is_blocking_pos + for x in self.pos_dict[key]) + ) + return generator + + @property + def free_positions_list(self): + return [x for x in self.free_positions_generator] def iter_entities(self): return iter((x for sublist in self.values() for x in sublist)) @@ -74,7 +87,7 @@ class Entities(_Objects): def __delitem__(self, name): assert_str = 'This group of entity does not exist in this collection!' assert any([key for key in name.keys() if key in self.keys()]), assert_str - self[name]._observers.delete(self) + self[name].del_observer(self) for entity in self[name]: entity.del_observer(self) return super(Entities, self).__delitem__(name) @@ -92,3 +105,6 @@ class Entities(_Objects): @property def positions(self): return [k for k, v in self.pos_dict.items() for _ in v] + + def is_occupied(self, pos): + return len([x for x in self.pos_dict[pos] if x.var_can_collide or x.var_is_blocking_pos]) >= 1 diff --git a/marl_factory_grid/environment/groups/mixins.py b/marl_factory_grid/environment/groups/mixins.py index 48333ca..acfac7e 100644 --- a/marl_factory_grid/environment/groups/mixins.py +++ b/marl_factory_grid/environment/groups/mixins.py @@ -4,10 +4,6 @@ from marl_factory_grid.environment import constants as c # noinspection PyUnresolvedReferences,PyTypeChecker class IsBoundMixin: - @property - def name(self): - return f'{self.__class__.__name__}({self._bound_entity.name})' - def __repr__(self): return f'{self.__class__.__name__}#{self._bound_entity.name}({self._data})' diff --git a/marl_factory_grid/environment/groups/objects.py b/marl_factory_grid/environment/groups/objects.py index d3f32af..9229787 100644 --- a/marl_factory_grid/environment/groups/objects.py +++ b/marl_factory_grid/environment/groups/objects.py @@ -1,14 +1,19 @@ from collections import defaultdict -from typing import List +from typing import List, Iterator, Union import numpy as np -from marl_factory_grid.environment.entity.object import _Object +from marl_factory_grid.environment.entity.object import Object import marl_factory_grid.environment.constants as c +from marl_factory_grid.utils import helpers as h -class _Objects: - _entity = _Object +class Objects: + _entity = Object + + @property + def var_can_be_bound(self): + return False @property def observers(self): @@ -45,7 +50,7 @@ class _Objects: def __len__(self): return len(self._data) - def __iter__(self): + def __iter__(self) -> Iterator[Union[Object, None]]: return iter(self.values()) def add_item(self, item: _entity): @@ -125,13 +130,14 @@ class _Objects: repr_dict = {key: val for key, val in self._data.items() if key not in [c.WALLS]} return f'{self.__class__.__name__}[{repr_dict}]' - def notify_del_entity(self, entity: _Object): + def notify_del_entity(self, entity: Object): try: + # noinspection PyUnresolvedReferences self.pos_dict[entity.pos].remove(entity) except (AttributeError, ValueError, IndexError): pass - def notify_add_entity(self, entity: _Object): + def notify_add_entity(self, entity: Object): try: if self not in entity.observers: entity.add_observer(self) @@ -148,12 +154,12 @@ class _Objects: def by_entity(self, entity): try: - return next((x for x in self if x.belongs_to_entity(entity))) + return h.get_first(self, filter_by=lambda x: x.belongs_to_entity(entity)) except (StopIteration, AttributeError): return None def idx_by_entity(self, entity): try: - return next((idx for idx, x in enumerate(self) if x.belongs_to_entity(entity))) + return h.get_first_index(self, filter_by=lambda x: x.belongs_to_entity(entity)) except (StopIteration, AttributeError): return None diff --git a/marl_factory_grid/environment/groups/utils.py b/marl_factory_grid/environment/groups/utils.py index 5619041..d272152 100644 --- a/marl_factory_grid/environment/groups/utils.py +++ b/marl_factory_grid/environment/groups/utils.py @@ -1,7 +1,10 @@ from typing import List, Union +from marl_factory_grid.environment import constants as c from marl_factory_grid.environment.entity.util import GlobalPosition from marl_factory_grid.environment.groups.collection import Collection +from marl_factory_grid.utils.results import Result +from marl_factory_grid.utils.states import Gamestate class Combined(Collection): @@ -36,17 +39,17 @@ class GlobalPositions(Collection): _entity = GlobalPosition - @property - def var_is_blocking_light(self): - return False - - @property - def var_can_collide(self): - return False - - @property - def var_can_be_bound(self): - return True + var_is_blocking_light = False + var_can_be_bound = True + var_can_collide = False + var_has_position = False def __init__(self, *args, **kwargs): super(GlobalPositions, self).__init__(*args, **kwargs) + + def spawn(self, agents, level_shape, *args, **kwargs): + self.add_items([self._entity(agent, level_shape, *args, **kwargs) for agent in agents]) + return [Result(identifier=f'{self.name}_spawn', validity=c.VALID, value=len(self))] + + def trigger_spawn(self, state: Gamestate, *args, **kwargs) -> [Result]: + return self.spawn(state[c.AGENT], state.lvl_shape, *args, **kwargs) diff --git a/marl_factory_grid/environment/groups/walls.py b/marl_factory_grid/environment/groups/walls.py index 2d85362..776bbca 100644 --- a/marl_factory_grid/environment/groups/walls.py +++ b/marl_factory_grid/environment/groups/walls.py @@ -7,9 +7,12 @@ class Walls(Collection): _entity = Wall symbol = c.SYMBOL_WALL - @property - def var_has_position(self): - return True + var_can_collide = True + var_is_blocking_light = True + var_can_move = False + var_has_position = True + var_can_be_bound = False + var_is_blocking_pos = True def __init__(self, *args, **kwargs): super(Walls, self).__init__(*args, **kwargs) diff --git a/marl_factory_grid/environment/rewards.py b/marl_factory_grid/environment/rewards.py index b3ebe8c..aa0acbd 100644 --- a/marl_factory_grid/environment/rewards.py +++ b/marl_factory_grid/environment/rewards.py @@ -2,3 +2,4 @@ MOVEMENTS_VALID: float = -0.001 MOVEMENTS_FAIL: float = -0.05 NOOP: float = -0.01 COLLISION: float = -0.5 +COLLISION_DONE: float = -1 diff --git a/marl_factory_grid/environment/rules.py b/marl_factory_grid/environment/rules.py index f9678b0..f5b6836 100644 --- a/marl_factory_grid/environment/rules.py +++ b/marl_factory_grid/environment/rules.py @@ -1,11 +1,11 @@ import abc from random import shuffle -from typing import List +from typing import List, Collection +from marl_factory_grid.environment import rewards as r, constants as c from marl_factory_grid.environment.entity.agent import Agent from marl_factory_grid.utils import helpers as h from marl_factory_grid.utils.results import TickResult, DoneResult -from marl_factory_grid.environment import rewards as r, constants as c class Rule(abc.ABC): @@ -39,6 +39,29 @@ class Rule(abc.ABC): return [] +class SpawnEntity(Rule): + + @property + def _collection(self) -> Collection: + return Collection() + + @property + def name(self): + return f'{self.__class__.__name__}({self.collection.name})' + + def __init__(self, collection, coords_or_quantity, ignore_blocking=False): + super().__init__() + self.coords_or_quantity = coords_or_quantity + self.collection = collection + self.ignore_blocking = ignore_blocking + + def on_init(self, state, lvl_map) -> [TickResult]: + results = self.collection.trigger_spawn(state, ignore_blocking=self.ignore_blocking) + pos_str = f' on: {[x.pos for x in self.collection]}' if self.collection.var_has_position else '' + state.print(f'Initial {self.collection.__class__.__name__} were spawned{pos_str}') + return results + + class SpawnAgents(Rule): def __init__(self): @@ -46,14 +69,14 @@ class SpawnAgents(Rule): pass def on_init(self, state, lvl_map): - agent_conf = state.agents_conf # agents = Agents(lvl_map.size) agents = state[c.AGENT] - empty_positions = state.entities.empty_positions()[:len(agent_conf)] - for agent_name in agent_conf: - actions = agent_conf[agent_name]['actions'].copy() - observations = agent_conf[agent_name]['observations'].copy() - positions = agent_conf[agent_name]['positions'].copy() + empty_positions = state.entities.empty_positions[:len(state.agents_conf)] + for agent_name, agent_conf in state.agents_conf.items(): + actions = agent_conf['actions'].copy() + observations = agent_conf['observations'].copy() + positions = agent_conf['positions'].copy() + other = agent_conf['other'].copy() if positions: shuffle(positions) while True: @@ -61,18 +84,18 @@ class SpawnAgents(Rule): pos = positions.pop() except IndexError: raise ValueError(f'It was not possible to spawn an Agent on the available position: ' - f'\n{agent_name[agent_name]["positions"].copy()}') - if agents.by_pos(pos) and state.check_pos_validity(pos): + f'\n{agent_conf["positions"].copy()}') + if bool(agents.by_pos(pos)) or not state.check_pos_validity(pos): continue else: - agents.add_item(Agent(actions, observations, pos, str_ident=agent_name)) + agents.add_item(Agent(actions, observations, pos, str_ident=agent_name, **other)) break else: - agents.add_item(Agent(actions, observations, empty_positions.pop(), str_ident=agent_name)) + agents.add_item(Agent(actions, observations, empty_positions.pop(), str_ident=agent_name, **other)) pass -class MaxStepsReached(Rule): +class DoneAtMaxStepsReached(Rule): def __init__(self, max_steps: int = 500): super().__init__() @@ -83,8 +106,8 @@ class MaxStepsReached(Rule): def on_check_done(self, state): if self.max_steps <= state.curr_step: - return [DoneResult(validity=c.VALID, identifier=self.name, reward=0)] - return [DoneResult(validity=c.NOT_VALID, identifier=self.name, reward=0)] + return [DoneResult(validity=c.VALID, identifier=self.name)] + return [DoneResult(validity=c.NOT_VALID, identifier=self.name)] class AssignGlobalPositions(Rule): @@ -95,16 +118,17 @@ class AssignGlobalPositions(Rule): def on_init(self, state, lvl_map): from marl_factory_grid.environment.entity.util import GlobalPosition for agent in state[c.AGENT]: - gp = GlobalPosition(lvl_map.level_shape) - gp.bind_to(agent) + gp = GlobalPosition(agent, lvl_map.level_shape) state[c.GLOBALPOSITIONS].add_item(gp) return [] -class Collision(Rule): +class WatchCollisions(Rule): - def __init__(self, done_at_collisions: bool = False): + def __init__(self, reward=r.COLLISION, done_at_collisions: bool = False, reward_at_done=r.COLLISION_DONE): super().__init__() + self.reward_at_done = reward_at_done + self.reward = reward self.done_at_collisions = done_at_collisions self.curr_done = False @@ -117,12 +141,12 @@ class Collision(Rule): if len(guests) >= 2: for i, guest in enumerate(guests): try: - guest.set_state(TickResult(identifier=c.COLLISION, reward=r.COLLISION, + guest.set_state(TickResult(identifier=c.COLLISION, reward=self.reward, validity=c.NOT_VALID, entity=self)) except AttributeError: pass results.append(TickResult(entity=guest, identifier=c.COLLISION, - reward=r.COLLISION, validity=c.VALID)) + reward=self.reward, validity=c.VALID)) self.curr_done = True if self.done_at_collisions else False return results @@ -131,5 +155,5 @@ class Collision(Rule): inter_entity_collision_detected = self.curr_done move_failed = any(h.is_move(x.state.identifier) and not x.state.validity for x in state[c.AGENT]) if inter_entity_collision_detected or move_failed: - return [DoneResult(validity=c.VALID, identifier=c.COLLISION, reward=r.COLLISION)] - return [DoneResult(validity=c.NOT_VALID, identifier=self.name, reward=0)] + return [DoneResult(validity=c.VALID, identifier=c.COLLISION, reward=self.reward_at_done)] + return [] diff --git a/marl_factory_grid/modules/_template/rules.py b/marl_factory_grid/modules/_template/rules.py index 6ed2f2d..7696616 100644 --- a/marl_factory_grid/modules/_template/rules.py +++ b/marl_factory_grid/modules/_template/rules.py @@ -6,7 +6,9 @@ from marl_factory_grid.utils.results import TickResult, DoneResult class TemplateRule(Rule): def __init__(self, *args, **kwargs): - super(TemplateRule, self).__init__(*args, **kwargs) + super(TemplateRule, self).__init__() + self.args = args + self.kwargs = kwargs def on_init(self, state, lvl_map): pass diff --git a/marl_factory_grid/modules/batteries/__init__.py b/marl_factory_grid/modules/batteries/__init__.py index 0218021..80671fd 100644 --- a/marl_factory_grid/modules/batteries/__init__.py +++ b/marl_factory_grid/modules/batteries/__init__.py @@ -1,4 +1,4 @@ from .actions import BtryCharge -from .entitites import Pod, Battery +from .entitites import ChargePod, Battery from .groups import ChargePods, Batteries from .rules import DoneAtBatteryDischarge, BatteryDecharge diff --git a/marl_factory_grid/modules/batteries/actions.py b/marl_factory_grid/modules/batteries/actions.py index 343bbcc..7d1c4a2 100644 --- a/marl_factory_grid/modules/batteries/actions.py +++ b/marl_factory_grid/modules/batteries/actions.py @@ -1,11 +1,11 @@ from typing import Union -import marl_factory_grid.modules.batteries.constants from marl_factory_grid.environment.actions import Action from marl_factory_grid.utils.results import ActionResult from marl_factory_grid.modules.batteries import constants as b from marl_factory_grid.environment import constants as c +from marl_factory_grid.utils import helpers as h class BtryCharge(Action): @@ -14,8 +14,8 @@ class BtryCharge(Action): super().__init__(b.ACTION_CHARGE) def do(self, entity, state) -> Union[None, ActionResult]: - if charge_pod := state[b.CHARGE_PODS].by_pos(entity.pos): - valid = charge_pod.charge_battery(state[b.BATTERIES].by_entity(entity)) + if charge_pod := h.get_first(state[b.CHARGE_PODS].by_pos(entity.pos)): + valid = h.get_first(charge_pod.charge_battery(state[b.BATTERIES].by_entity(entity))) if valid: state.print(f'{entity.name} just charged batteries at {charge_pod.name}.') else: @@ -23,5 +23,6 @@ class BtryCharge(Action): else: valid = c.NOT_VALID state.print(f'{entity.name} failed to charged batteries at {entity.pos}.') + return ActionResult(entity=entity, identifier=self._identifier, validity=valid, - reward=marl_factory_grid.modules.batteries.constants.REWARD_CHARGE_VALID if valid else marl_factory_grid.modules.batteries.constants.Reward_CHARGE_FAIL) + reward=b.REWARD_CHARGE_VALID if valid else b.Reward_CHARGE_FAIL) diff --git a/marl_factory_grid/modules/batteries/chargepods.png b/marl_factory_grid/modules/batteries/chargepods.png new file mode 100644 index 0000000..7221daa Binary files /dev/null and b/marl_factory_grid/modules/batteries/chargepods.png differ diff --git a/marl_factory_grid/modules/batteries/entitites.py b/marl_factory_grid/modules/batteries/entitites.py index b51f2dd..7675fe9 100644 --- a/marl_factory_grid/modules/batteries/entitites.py +++ b/marl_factory_grid/modules/batteries/entitites.py @@ -1,11 +1,11 @@ from marl_factory_grid.environment import constants as c from marl_factory_grid.environment.entity.entity import Entity -from marl_factory_grid.environment.entity.object import _Object +from marl_factory_grid.environment.entity.object import Object from marl_factory_grid.modules.batteries import constants as b from marl_factory_grid.utils.utility_classes import RenderEntity -class Battery(_Object): +class Battery(Object): @property def var_can_be_bound(self): @@ -50,7 +50,7 @@ class Battery(_Object): return summary -class Pod(Entity): +class ChargePod(Entity): @property def encoding(self): @@ -58,7 +58,7 @@ class Pod(Entity): def __init__(self, *args, charge_rate: float = 0.4, multi_charge: bool = False, **kwargs): - super(Pod, self).__init__(*args, **kwargs) + super(ChargePod, self).__init__(*args, **kwargs) self.charge_rate = charge_rate self.multi_charge = multi_charge diff --git a/marl_factory_grid/modules/batteries/groups.py b/marl_factory_grid/modules/batteries/groups.py index 8d9e060..7db43bd 100644 --- a/marl_factory_grid/modules/batteries/groups.py +++ b/marl_factory_grid/modules/batteries/groups.py @@ -1,52 +1,36 @@ from typing import Union, List, Tuple +from marl_factory_grid.environment import constants as c from marl_factory_grid.environment.groups.collection import Collection -from marl_factory_grid.modules.batteries.entitites import Pod, Battery +from marl_factory_grid.modules.batteries.entitites import ChargePod, Battery +from marl_factory_grid.utils.results import Result class Batteries(Collection): _entity = Battery - @property - def var_is_blocking_light(self): - return False - - @property - def var_can_collide(self): - return False - - @property - def var_can_move(self): - return False - - @property - def var_has_position(self): - return False - - @property - def var_can_be_bound(self): - return True + var_has_position = False + var_can_be_bound = True @property def obs_tag(self): return self.__class__.__name__ - def __init__(self, *args, **kwargs): - super(Batteries, self).__init__(*args, **kwargs) + def __init__(self, size, initial_charge_level: float=1.0, *args, **kwargs): + super(Batteries, self).__init__(size, *args, **kwargs) + self.initial_charge_level = initial_charge_level - def spawn(self, agents, initial_charge_level): - batteries = [self._entity(initial_charge_level, agent) for _, agent in enumerate(agents)] + def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], agents, *entity_args, **entity_kwargs): + batteries = [self._entity(self.initial_charge_level, agent) for _, agent in enumerate(agents)] self.add_items(batteries) - # def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args): hat keine pos - # agents = entity_args[0] - # initial_charge_level = entity_args[1] - # batteries = [self._entity(initial_charge_level, agent) for _, agent in enumerate(agents)] - # self.add_items(batteries) + def trigger_spawn(self, state, *entity_args, coords_or_quantity=None, **entity_kwargs): + self.spawn(0, state[c.AGENT]) + return Result(identifier=f'{self.name}_spawn', validity=c.VALID, value=len(self)) class ChargePods(Collection): - _entity = Pod + _entity = ChargePod def __init__(self, *args, **kwargs): super(ChargePods, self).__init__(*args, **kwargs) diff --git a/marl_factory_grid/modules/batteries/rules.py b/marl_factory_grid/modules/batteries/rules.py index e060629..8a4725b 100644 --- a/marl_factory_grid/modules/batteries/rules.py +++ b/marl_factory_grid/modules/batteries/rules.py @@ -1,11 +1,9 @@ from typing import List, Union -import marl_factory_grid.modules.batteries.constants -from marl_factory_grid.environment.rules import Rule -from marl_factory_grid.utils.results import TickResult, DoneResult - from marl_factory_grid.environment import constants as c +from marl_factory_grid.environment.rules import Rule from marl_factory_grid.modules.batteries import constants as b +from marl_factory_grid.utils.results import TickResult, DoneResult class BatteryDecharge(Rule): @@ -49,10 +47,6 @@ class BatteryDecharge(Rule): self.per_action_costs = per_action_costs self.initial_charge = initial_charge - def on_init(self, state, lvl_map): # on reset? - assert len(state[c.AGENT]), "There are no agents, did you already spawn them?" - state[b.BATTERIES].spawn(state[c.AGENT], self.initial_charge) - def tick_step(self, state) -> List[TickResult]: # Decharge batteries = state[b.BATTERIES] @@ -66,7 +60,7 @@ class BatteryDecharge(Rule): batteries.by_entity(agent).decharge(energy_consumption) - results.append(TickResult(self.name, reward=0, entity=agent, validity=c.VALID)) + results.append(TickResult(self.name, entity=agent, validity=c.VALID)) return results @@ -82,13 +76,13 @@ class BatteryDecharge(Rule): if self.paralyze_agents_on_discharge: btry.bound_entity.paralyze(self.name) results.append( - TickResult("Paralyzed", entity=btry.bound_entity, reward=0, validity=c.VALID) + TickResult("Paralyzed", entity=btry.bound_entity, validity=c.VALID) ) state.print(f'{btry.bound_entity.name} has just been paralyzed!') if btry.bound_entity.var_is_paralyzed and not btry.is_discharged: btry.bound_entity.de_paralyze(self.name) results.append( - TickResult("De-Paralyzed", entity=btry.bound_entity, reward=0, validity=c.VALID) + TickResult("De-Paralyzed", entity=btry.bound_entity, validity=c.VALID) ) state.print(f'{btry.bound_entity.name} has just been de-paralyzed!') return results @@ -132,7 +126,7 @@ class DoneAtBatteryDischarge(BatteryDecharge): if any_discharged or all_discharged: return [DoneResult(self.name, validity=c.VALID, reward=self.reward_discharge_done)] else: - return [DoneResult(self.name, validity=c.NOT_VALID, reward=0)] + return [DoneResult(self.name, validity=c.NOT_VALID)] class SpawnChargePods(Rule): @@ -155,7 +149,7 @@ class SpawnChargePods(Rule): def on_init(self, state, lvl_map): pod_collection = state[b.CHARGE_PODS] - empty_positions = state.entities.empty_positions() + empty_positions = state.entities.empty_positions pods = pod_collection.from_coordinates(empty_positions, entity_kwargs=dict( multi_charge=self.multi_charge, charge_rate=self.charge_rate) ) diff --git a/marl_factory_grid/modules/clean_up/__init__.py b/marl_factory_grid/modules/clean_up/__init__.py index 31cb841..ec4d1e7 100644 --- a/marl_factory_grid/modules/clean_up/__init__.py +++ b/marl_factory_grid/modules/clean_up/__init__.py @@ -1,4 +1,4 @@ from .actions import CleanUp from .entitites import DirtPile from .groups import DirtPiles -from .rules import SpawnDirt, EntitiesSmearDirtOnMove, DoneOnAllDirtCleaned +from .rules import EntitiesSmearDirtOnMove, DoneOnAllDirtCleaned diff --git a/marl_factory_grid/modules/clean_up/entitites.py b/marl_factory_grid/modules/clean_up/entitites.py index 8ac8a0c..25c6eb1 100644 --- a/marl_factory_grid/modules/clean_up/entitites.py +++ b/marl_factory_grid/modules/clean_up/entitites.py @@ -1,5 +1,3 @@ -from numpy import random - from marl_factory_grid.environment.entity.entity import Entity from marl_factory_grid.utils.utility_classes import RenderEntity from marl_factory_grid.modules.clean_up import constants as d @@ -7,22 +5,6 @@ from marl_factory_grid.modules.clean_up import constants as d class DirtPile(Entity): - @property - def var_can_collide(self): - return False - - @property - def var_can_move(self): - return False - - @property - def var_is_blocking_light(self): - return False - - @property - def var_has_position(self): - return True - @property def amount(self): return self._amount diff --git a/marl_factory_grid/modules/clean_up/groups.py b/marl_factory_grid/modules/clean_up/groups.py index 63e5898..7ae3247 100644 --- a/marl_factory_grid/modules/clean_up/groups.py +++ b/marl_factory_grid/modules/clean_up/groups.py @@ -1,76 +1,61 @@ -from typing import Union, List, Tuple - from marl_factory_grid.environment import constants as c -from marl_factory_grid.utils.results import Result from marl_factory_grid.environment.groups.collection import Collection from marl_factory_grid.modules.clean_up.entitites import DirtPile +from marl_factory_grid.utils.results import Result class DirtPiles(Collection): _entity = DirtPile - @property - def var_is_blocking_light(self): - return False + var_is_blocking_light = False + var_can_collide = False + var_can_move = False + var_has_position = True @property - def var_can_collide(self): - return False - - @property - def var_can_move(self): - return False - - @property - def var_has_position(self): - return True - - @property - def amount(self): + def global_amount(self): return sum([dirt.amount for dirt in self]) def __init__(self, *args, max_local_amount=5, clean_amount=1, - max_global_amount: int = 20, **kwargs): + max_global_amount: int = 20, + coords_or_quantity=10, + initial_amount=2, + amount_var=0.2, + n_var=0.2, + **kwargs): super(DirtPiles, self).__init__(*args, **kwargs) + self.amount_var = amount_var + self.n_var = n_var self.clean_amount = clean_amount self.max_global_amount = max_global_amount self.max_local_amount = max_local_amount + self.coords_or_quantity = coords_or_quantity + self.initial_amount = initial_amount - def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args): - amount_s = entity_args[0] + def trigger_spawn(self, state, coords_or_quantity=0, amount=0) -> [Result]: + coords_or_quantity = coords_or_quantity if coords_or_quantity else self.coords_or_quantity + n_new = int(abs(coords_or_quantity + (state.rng.uniform(-self.n_var, self.n_var)))) + n_new = state.get_n_random_free_positions(n_new) + + amounts = [amount if amount else (self.initial_amount + state.rng.uniform(-self.amount_var, self.amount_var)) + for _ in range(coords_or_quantity)] spawn_counter = 0 - for idx, pos in enumerate(coords_or_quantity): - if not self.amount > self.max_global_amount: - amount = amount_s[idx] if isinstance(amount_s, list) else amount_s + for idx, (pos, a) in enumerate(zip(n_new, amounts)): + if not self.global_amount > self.max_global_amount: if dirt := self.by_pos(pos): dirt = next(dirt.iter()) - new_value = dirt.amount + amount + new_value = dirt.amount + a dirt.set_new_amount(new_value) else: - dirt = DirtPile(pos, amount=amount) - self.add_item(dirt) + super().spawn([pos], amount=a) spawn_counter += 1 else: - return Result(identifier=f'{self.name}_spawn', validity=c.NOT_VALID, reward=0, - value=spawn_counter) - return Result(identifier=f'{self.name}_spawn', validity=c.VALID, reward=0, value=spawn_counter) + return Result(identifier=f'{self.name}_spawn', validity=c.NOT_VALID, value=spawn_counter) - def trigger_dirt_spawn(self, n, amount, state, n_var=0.2, amount_var=0.2) -> Result: - free_for_dirt = [x for x in state.entities.floorlist if len(state.entities.pos_dict[x]) == 0 or ( - len(state.entities.pos_dict[x]) >= 1 and isinstance(next(y for y in x), DirtPile))] - # free_for_dirt = [x for x in state[c.FLOOR] - # if len(x.guests) == 0 or ( - # len(x.guests) == 1 and - # isinstance(next(y for y in x.guests), DirtPile))] - state.rng.shuffle(free_for_dirt) - - new_spawn = int(abs(n + (state.rng.uniform(-n_var, n_var)))) - new_amount_s = [abs(amount + (amount*state.rng.uniform(-amount_var, amount_var))) for _ in range(new_spawn)] - n_dirty_positions = free_for_dirt[:new_spawn] - return self.spawn(n_dirty_positions, new_amount_s) + return Result(identifier=f'{self.name}_spawn', validity=c.VALID, value=spawn_counter) def __repr__(self): s = super(DirtPiles, self).__repr__() - return f'{s[:-1]}, {self.amount})' + return f'{s[:-1]}, {self.global_amount}]' diff --git a/marl_factory_grid/modules/clean_up/rules.py b/marl_factory_grid/modules/clean_up/rules.py index 3f58cdb..b81ee41 100644 --- a/marl_factory_grid/modules/clean_up/rules.py +++ b/marl_factory_grid/modules/clean_up/rules.py @@ -22,58 +22,37 @@ class DoneOnAllDirtCleaned(Rule): def on_check_done(self, state) -> [DoneResult]: if len(state[d.DIRT]) == 0 and state.curr_step: return [DoneResult(validity=c.VALID, identifier=self.name, reward=self.reward)] - return [DoneResult(validity=c.NOT_VALID, identifier=self.name, reward=0)] + return [DoneResult(validity=c.NOT_VALID, identifier=self.name)] -class SpawnDirt(Rule): +class RespawnDirt(Rule): - def __init__(self, initial_n: int = 5, initial_amount: float = 1.3, - respawn_n: int = 3, respawn_amount: float = 0.8, - n_var: float = 0.2, amount_var: float = 0.2, spawn_freq: int = 15): + def __init__(self, respawn_freq: int = 15, respawn_n: int = 5, respawn_amount: float = 1.0): """ Defines the spawn pattern of intial and additional 'Dirt'-entitites. First chooses positions, then trys to spawn dirt until 'respawn_n' or the maximal global amount is reached. If there is allready some, it is topped up to min(max_local_amount, amount). - :type spawn_freq: int - :parameter spawn_freq: In which frequency should this Rule try to spawn new 'Dirt'? + :type respawn_freq: int + :parameter respawn_freq: In which frequency should this Rule try to spawn new 'Dirt'? :type respawn_n: int :parameter respawn_n: How many respawn positions are considered. - :type initial_n: int - :parameter initial_n: How much initial positions are considered. - :type amount_var: float - :parameter amount_var: Variance of amount to spawn. - :type n_var: float - :parameter n_var: Variance of n to spawn. :type respawn_amount: float :parameter respawn_amount: Defines how much dirt 'amount' is placed every 'spawn_freq' ticks. - :type initial_amount: float - :parameter initial_amount: Defines how much dirt 'amount' is initially placed. - """ super().__init__() - self.amount_var = amount_var - self.n_var = n_var - self.respawn_amount = respawn_amount self.respawn_n = respawn_n - self.initial_amount = initial_amount - self.initial_n = initial_n - self.spawn_freq = spawn_freq - self._next_dirt_spawn = spawn_freq - - def on_init(self, state, lvl_map) -> str: - result = state[d.DIRT].trigger_dirt_spawn(self.initial_n, self.initial_amount, state, - n_var=self.n_var, amount_var=self.amount_var) - state.print(f'Initial Dirt was spawned on: {[x.pos for x in state[d.DIRT]]}') - return result + self.respawn_amount = respawn_amount + self.respawn_freq = respawn_freq + self._next_dirt_spawn = respawn_freq def tick_step(self, state): + collection = state[d.DIRT] if self._next_dirt_spawn < 0: - pass # No DirtPile Spawn + result = [] # No DirtPile Spawn elif not self._next_dirt_spawn: - result = [state[d.DIRT].trigger_dirt_spawn(self.respawn_n, self.respawn_amount, state, - n_var=self.n_var, amount_var=self.amount_var)] - self._next_dirt_spawn = self.spawn_freq + result = [collection.trigger_spawn(state, coords_or_quantity=self.respawn_n, amount=self.respawn_amount)] + self._next_dirt_spawn = self.respawn_freq else: self._next_dirt_spawn -= 1 result = [] @@ -99,8 +78,8 @@ class EntitiesSmearDirtOnMove(Rule): for entity in state.moving_entites: if is_move(entity.state.identifier) and entity.state.validity == c.VALID: if old_pos_dirt := state[d.DIRT].by_pos(entity.last_pos): + old_pos_dirt = next(iter(old_pos_dirt)) if smeared_dirt := round(old_pos_dirt.amount * self.smear_ratio, 2): if state[d.DIRT].spawn(entity.pos, amount=smeared_dirt): - results.append(TickResult(identifier=self.name, entity=entity, - reward=0, validity=c.VALID)) + results.append(TickResult(identifier=self.name, entity=entity, validity=c.VALID)) return results diff --git a/marl_factory_grid/modules/destinations/__init__.py b/marl_factory_grid/modules/destinations/__init__.py index 83e5988..4614dd7 100644 --- a/marl_factory_grid/modules/destinations/__init__.py +++ b/marl_factory_grid/modules/destinations/__init__.py @@ -1,4 +1,7 @@ from .actions import DestAction from .entitites import Destination from .groups import Destinations -from .rules import DoneAtDestinationReachAll, SpawnDestinations +from .rules import (DoneAtDestinationReachAll, + DoneAtDestinationReachAny, + SpawnDestinationsPerAgent, + DestinationReachReward) diff --git a/marl_factory_grid/modules/destinations/actions.py b/marl_factory_grid/modules/destinations/actions.py index 13f7fe3..6367acd 100644 --- a/marl_factory_grid/modules/destinations/actions.py +++ b/marl_factory_grid/modules/destinations/actions.py @@ -21,4 +21,4 @@ class DestAction(Action): valid = c.NOT_VALID state.print(f'{entity.name} just tried to do_wait_action do_wait_action at {entity.pos} but failed') return ActionResult(entity=entity, identifier=self._identifier, validity=valid, - reward=marl_factory_grid.modules.destinations.constants.REWARD_WAIT_VALID if valid else marl_factory_grid.modules.destinations.constants.REWARD_WAIT_FAIL) + reward=d.REWARD_WAIT_VALID if valid else d.REWARD_WAIT_FAIL) diff --git a/marl_factory_grid/modules/destinations/entitites.py b/marl_factory_grid/modules/destinations/entitites.py index 7b866b7..d75f9e0 100644 --- a/marl_factory_grid/modules/destinations/entitites.py +++ b/marl_factory_grid/modules/destinations/entitites.py @@ -9,30 +9,6 @@ from marl_factory_grid.utils.utility_classes import RenderEntity class Destination(Entity): - @property - def var_can_move(self): - return False - - @property - def var_can_collide(self): - return False - - @property - def var_has_position(self): - return True - - @property - def var_is_blocking_pos(self): - return False - - @property - def var_is_blocking_light(self): - return False - - @property - def var_can_be_bound(self): - return True - def was_reached(self): return self._was_reached diff --git a/marl_factory_grid/modules/destinations/groups.py b/marl_factory_grid/modules/destinations/groups.py index 5f91bb4..f0b7f9e 100644 --- a/marl_factory_grid/modules/destinations/groups.py +++ b/marl_factory_grid/modules/destinations/groups.py @@ -1,43 +1,18 @@ from marl_factory_grid.environment.groups.collection import Collection from marl_factory_grid.modules.destinations.entitites import Destination -from marl_factory_grid.environment import constants as c -from marl_factory_grid.modules.destinations import constants as d class Destinations(Collection): _entity = Destination - @property - def var_is_blocking_light(self): - return False - - @property - def var_can_collide(self): - return False - - @property - def var_can_move(self): - return False - - @property - def var_has_position(self): - return True + var_is_blocking_light = False + var_can_collide = False + var_can_move = False + var_has_position = True + var_can_be_bound = True def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def __repr__(self): return super(Destinations, self).__repr__() - - @staticmethod - def trigger_destination_spawn(n_dests, state): - coordinates = state.entities.floorlist[:n_dests] - if destinations := [Destination(pos) for pos in coordinates]: - state[d.DESTINATION].add_items(destinations) - state.print(f'{n_dests} new destinations have been spawned') - return c.VALID - else: - state.print('No Destiantions are spawning, limit is reached.') - return c.NOT_VALID - - diff --git a/marl_factory_grid/modules/destinations/rules.py b/marl_factory_grid/modules/destinations/rules.py index afb8575..8e72141 100644 --- a/marl_factory_grid/modules/destinations/rules.py +++ b/marl_factory_grid/modules/destinations/rules.py @@ -2,8 +2,8 @@ import ast from random import shuffle from typing import List, Dict, Tuple -import marl_factory_grid.modules.destinations.constants from marl_factory_grid.environment.rules import Rule +from marl_factory_grid.utils import helpers as h from marl_factory_grid.utils.results import TickResult, DoneResult from marl_factory_grid.environment import constants as c @@ -54,7 +54,7 @@ class DoneAtDestinationReachAll(DestinationReachReward): """ This rule triggers and sets the done flag if ALL Destinations have been reached. - :type reward_at_done: object + :type reward_at_done: float :param reward_at_done: Specifies the reward, agent get, whenn all destinations are reached. :type dest_reach_reward: float :param dest_reach_reward: Specify the reward, agents get when reaching a single destination. @@ -65,7 +65,7 @@ class DoneAtDestinationReachAll(DestinationReachReward): def on_check_done(self, state) -> List[DoneResult]: if all(x.was_reached() for x in state[d.DESTINATION]): return [DoneResult(self.name, validity=c.VALID, reward=self.reward)] - return [DoneResult(self.name, validity=c.NOT_VALID, reward=0)] + return [DoneResult(self.name, validity=c.NOT_VALID)] class DoneAtDestinationReachAny(DestinationReachReward): @@ -75,7 +75,7 @@ class DoneAtDestinationReachAny(DestinationReachReward): This rule triggers and sets the done flag if ANY Destinations has been reached. !!! IMPORTANT: 'reward_at_done' is shared between the agents; 'dest_reach_reward' is bound to a specific one. - :type reward_at_done: object + :type reward_at_done: float :param reward_at_done: Specifies the reward, all agent get, when any destinations has been reached. Default {d.REWARD_DEST_DONE} :type dest_reach_reward: float @@ -87,67 +87,29 @@ class DoneAtDestinationReachAny(DestinationReachReward): def on_check_done(self, state) -> List[DoneResult]: if any(x.was_reached() for x in state[d.DESTINATION]): - return [DoneResult(self.name, validity=c.VALID, reward=marl_factory_grid.modules.destinations.constants.REWARD_DEST_REACHED)] + return [DoneResult(self.name, validity=c.VALID, reward=d.REWARD_DEST_REACHED)] return [] -class SpawnDestinations(Rule): - - def __init__(self, n_dests: int = 1, spawn_mode: str = d.MODE_GROUPED): - f""" - Defines how destinations are initially spawned and respawned in addition. - !!! This rule introduces no kind of reward or Env.-Done condition! - - :type n_dests: int - :param n_dests: How many destiantions should be maintained (and initally spawnewd) on the map? - :type spawn_mode: str - :param spawn_mode: One of {d.SPAWN_MODES}. {d.MODE_GROUPED}: Always wait for all Dstiantions do be gone, - then respawn after the given time. {d.MODE_SINGLE}: Just spawn every destination, - that has been reached, after the given time - - """ - super(SpawnDestinations, self).__init__() - self.n_dests = n_dests - self.spawn_mode = spawn_mode - - def on_init(self, state, lvl_map): - # noinspection PyAttributeOutsideInit - state[d.DESTINATION].trigger_destination_spawn(self.n_dests, state) - pass - - def tick_pre_step(self, state) -> List[TickResult]: - pass - - def tick_step(self, state) -> List[TickResult]: - if n_dest_spawn := max(0, self.n_dests - len(state[d.DESTINATION])): - if self.spawn_mode == d.MODE_GROUPED and n_dest_spawn == self.n_dests: - validity = state[d.DESTINATION].trigger_destination_spawn(n_dest_spawn, state) - return [TickResult(self.name, validity=validity, entity=None, value=n_dest_spawn)] - elif self.spawn_mode == d.MODE_SINGLE and n_dest_spawn: - validity = state[d.DESTINATION].trigger_destination_spawn(n_dest_spawn, state) - return [TickResult(self.name, validity=validity, entity=None, value=n_dest_spawn)] - else: - pass - - class SpawnDestinationsPerAgent(Rule): - def __init__(self, per_agent_positions: Dict[str, List[Tuple[int, int]]]): + def __init__(self, coords_or_quantity: Dict[str, List[Tuple[int, int]]]): """ Special rule, that spawn distinations, that are bound to a single agent a fixed set of positions. Usefull for introducing specialists, etc. .. !!! This rule does not introduce any reward or done condition. - :type per_agent_positions: Dict[str, List[Tuple[int, int]] - :param per_agent_positions: Please provide a dictionary with agent names as keys; and a list of possible + :type coords_or_quantity: Dict[str, List[Tuple[int, int]] + :param coords_or_quantity: Please provide a dictionary with agent names as keys; and a list of possible destiantion coords as value. Example: {Wolfgang: [(0, 0), (1, 1), ...]} """ super(Rule, self).__init__() - self.per_agent_positions = {key: [ast.literal_eval(x) for x in val] for key, val in per_agent_positions.items()} + self.per_agent_positions = {key: [ast.literal_eval(x) for x in val] for key, val in coords_or_quantity.items()} def on_init(self, state, lvl_map): for (agent_name, position_list) in self.per_agent_positions.items(): - agent = next(x for x in state[c.AGENT] if agent_name in x.name) # Fixme: Ugly AF + agent = h.get_first(state[c.AGENT], lambda x: agent_name in x.name) + assert agent position_list = position_list.copy() shuffle(position_list) while True: @@ -155,7 +117,7 @@ class SpawnDestinationsPerAgent(Rule): pos = position_list.pop() except IndexError: print(f"Could not spawn Destinations at: {self.per_agent_positions[agent_name]}") - print(f'Check your agent palcement: {state[c.AGENT]} ... Exit ...') + print(f'Check your agent placement: {state[c.AGENT]} ... Exit ...') exit(9999) if (not pos == agent.pos) and (not state[d.DESTINATION].by_pos(pos)): destination = Destination(pos, bind_to=agent) diff --git a/marl_factory_grid/modules/doors/entitites.py b/marl_factory_grid/modules/doors/entitites.py index 669f74e..1c33d7b 100644 --- a/marl_factory_grid/modules/doors/entitites.py +++ b/marl_factory_grid/modules/doors/entitites.py @@ -1,4 +1,5 @@ from marl_factory_grid.environment.entity.entity import Entity +from marl_factory_grid.utils import Result from marl_factory_grid.utils.utility_classes import RenderEntity from marl_factory_grid.environment import constants as c @@ -41,21 +42,6 @@ class Door(Entity): def str_state(self): return 'open' if self.is_open else 'closed' - def __init__(self, *args, closed_on_init=True, auto_close_interval=10, **kwargs): - self._status = d.STATE_CLOSED - super(Door, self).__init__(*args, **kwargs) - self.auto_close_interval = auto_close_interval - self.time_to_close = 0 - if not closed_on_init: - self._open() - else: - self._close() - - def summarize_state(self): - state_dict = super().summarize_state() - state_dict.update(state=str(self.str_state), time_to_close=int(self.time_to_close)) - return state_dict - @property def is_closed(self): return self._status == d.STATE_CLOSED @@ -68,6 +54,25 @@ class Door(Entity): def status(self): return self._status + @property + def time_to_close(self): + return self._time_to_close + + def __init__(self, *args, closed_on_init=True, auto_close_interval=10, **kwargs): + self._status = d.STATE_CLOSED + super(Door, self).__init__(*args, **kwargs) + self._auto_close_interval = auto_close_interval + self._time_to_close = 0 + if not closed_on_init: + self._open() + else: + self._close() + + def summarize_state(self): + state_dict = super().summarize_state() + state_dict.update(state=str(self.str_state), time_to_close=self.time_to_close) + return state_dict + def render(self): name, state = 'door_open' if self.is_open else 'door_closed', 'blank' return RenderEntity(name, self.pos, 1, 'none', state, self.u_int + 1) @@ -80,18 +85,35 @@ class Door(Entity): return c.VALID def tick(self, state): - if self.is_open and len(state.entities.pos_dict[self.pos]) == 2 and self.time_to_close: - self.time_to_close -= 1 - return c.NOT_VALID - elif self.is_open and not self.time_to_close and len(state.entities.pos_dict[self.pos]) == 2: - self.use() - return c.VALID + # Check if no entity is standing in the door + if len(state.entities.pos_dict[self.pos]) <= 2: + if self.is_open and self.time_to_close: + self._decrement_timer() + return Result(f"{d.DOOR}_tick", c.VALID, entity=self) + elif self.is_open and not self.time_to_close: + self.use() + return Result(f"{d.DOOR}_closed", c.VALID, entity=self) + else: + # No one is in door, but it is closed... Nothing to do.... + return None else: - return c.NOT_VALID + # Entity is standing in the door, reset timer + self._reset_timer() + return Result(f"{d.DOOR}_reset", c.VALID, entity=self) def _open(self): self._status = d.STATE_OPEN - self.time_to_close = self.auto_close_interval + self._reset_timer() + return True def _close(self): self._status = d.STATE_CLOSED + return True + + def _decrement_timer(self): + self._time_to_close -= 1 + return True + + def _reset_timer(self): + self._time_to_close = self._auto_close_interval + return True diff --git a/marl_factory_grid/modules/doors/groups.py b/marl_factory_grid/modules/doors/groups.py index 687846e..973d1ab 100644 --- a/marl_factory_grid/modules/doors/groups.py +++ b/marl_factory_grid/modules/doors/groups.py @@ -1,5 +1,3 @@ -from typing import Union - from marl_factory_grid.environment.groups.collection import Collection from marl_factory_grid.modules.doors import constants as d from marl_factory_grid.modules.doors.entitites import Door @@ -18,8 +16,10 @@ class Doors(Collection): super(Doors, self).__init__(*args, can_collide=True, **kwargs) def tick_doors(self, state): - result_dict = dict() + results = list() for door in self: - did_tick = door.tick(state) - result_dict.update({door.name: did_tick}) - return result_dict + tick_result = door.tick(state) + if tick_result is not None: + results.append(tick_result) + # TODO: Should return a Result object, not a random dict. + return results diff --git a/marl_factory_grid/modules/doors/rewards.py b/marl_factory_grid/modules/doors/rewards.py index c87d123..b38c7c5 100644 --- a/marl_factory_grid/modules/doors/rewards.py +++ b/marl_factory_grid/modules/doors/rewards.py @@ -1,2 +1,2 @@ USE_DOOR_VALID: float = -0.00 -USE_DOOR_FAIL: float = -0.01 \ No newline at end of file +USE_DOOR_FAIL: float = -0.01 diff --git a/marl_factory_grid/modules/doors/rules.py b/marl_factory_grid/modules/doors/rules.py index da312cd..599d975 100644 --- a/marl_factory_grid/modules/doors/rules.py +++ b/marl_factory_grid/modules/doors/rules.py @@ -19,10 +19,10 @@ class DoorAutoClose(Rule): def tick_step(self, state): if doors := state[d.DOORS]: - doors_tick_result = doors.tick_doors(state) - doors_that_ticked = [key for key, val in doors_tick_result.items() if val] - state.print(f'{doors_that_ticked} were auto-closed' - if doors_that_ticked else 'No Doors were auto-closed') + doors_tick_results = doors.tick_doors(state) + doors_that_closed = [x.entity.name for x in doors_tick_results if 'closed' in x.identifier] + door_str = doors_that_closed if doors_that_closed else "No Doors" + state.print(f'{door_str} were auto-closed') return [TickResult(self.name, validity=c.VALID, value=1)] state.print('There are no doors, but you loaded the corresponding Module') return [] diff --git a/marl_factory_grid/modules/factory/rules.py b/marl_factory_grid/modules/factory/rules.py index d736f7a..e056135 100644 --- a/marl_factory_grid/modules/factory/rules.py +++ b/marl_factory_grid/modules/factory/rules.py @@ -1,8 +1,8 @@ import random -from typing import List, Union +from typing import List -from marl_factory_grid.environment.rules import Rule from marl_factory_grid.environment import constants as c +from marl_factory_grid.environment.rules import Rule from marl_factory_grid.utils.results import TickResult @@ -14,8 +14,6 @@ class AgentSingleZonePlacementBeta(Rule): super().__init__() def on_init(self, state, lvl_map): - zones = state[c.ZONES] - n_zones = state[c.ZONES] agents = state[c.AGENT] if len(self.coordinates) == len(agents): coordinates = self.coordinates @@ -31,4 +29,4 @@ class AgentSingleZonePlacementBeta(Rule): return [] def tick_post_step(self, state) -> List[TickResult]: - return [] \ No newline at end of file + return [] diff --git a/marl_factory_grid/modules/items/__init__.py b/marl_factory_grid/modules/items/__init__.py index 157c385..cb9b69b 100644 --- a/marl_factory_grid/modules/items/__init__.py +++ b/marl_factory_grid/modules/items/__init__.py @@ -1,4 +1,3 @@ from .actions import ItemAction from .entitites import Item, DropOffLocation from .groups import DropOffLocations, Items, Inventory, Inventories -from .rules import ItemRules diff --git a/marl_factory_grid/modules/items/actions.py b/marl_factory_grid/modules/items/actions.py index f9e4f6f..ef6aa99 100644 --- a/marl_factory_grid/modules/items/actions.py +++ b/marl_factory_grid/modules/items/actions.py @@ -29,7 +29,7 @@ class ItemAction(Action): elif items := state[i.ITEM].by_pos(entity.pos): item = items[0] item.change_parent_collection(inventory) - item.set_pos_to(c.VALUE_NO_POS) + item.set_pos(c.VALUE_NO_POS) state.print(f'{entity.name} just picked up an item at {entity.pos}') return ActionResult(entity=entity, identifier=self._identifier, validity=c.VALID, reward=r.PICK_UP_VALID) diff --git a/marl_factory_grid/modules/items/constants.py b/marl_factory_grid/modules/items/constants.py index 86b8b0c..5cb82c3 100644 --- a/marl_factory_grid/modules/items/constants.py +++ b/marl_factory_grid/modules/items/constants.py @@ -1,6 +1,3 @@ -from typing import NamedTuple - - SYMBOL_NO_ITEM = 0 SYMBOL_DROP_OFF = 1 # Item Env diff --git a/marl_factory_grid/modules/items/entitites.py b/marl_factory_grid/modules/items/entitites.py index b710282..8549134 100644 --- a/marl_factory_grid/modules/items/entitites.py +++ b/marl_factory_grid/modules/items/entitites.py @@ -8,56 +8,20 @@ from marl_factory_grid.modules.items import constants as i class Item(Entity): - @property - def var_can_collide(self): - return False - def render(self): return RenderEntity(i.ITEM, self.pos) if self.pos != c.VALUE_NO_POS else None def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self._auto_despawn = -1 - - @property - def auto_despawn(self): - return self._auto_despawn @property def encoding(self): # Edit this if you want items to be drawn in the ops differently return 1 - def set_auto_despawn(self, auto_despawn): - self._auto_despawn = auto_despawn - - def set_pos_to(self, no_pos): - self._pos = no_pos - - def summarize_state(self) -> dict: - super_summarization = super(Item, self).summarize_state() - super_summarization.update(dict(auto_despawn=self.auto_despawn)) - return super_summarization - class DropOffLocation(Entity): - @property - def var_can_collide(self): - return False - - @property - def var_can_move(self): - return False - - @property - def var_is_blocking_light(self): - return False - - @property - def var_has_position(self): - return True - def render(self): return RenderEntity(i.DROP_OFF, self.pos) @@ -65,18 +29,16 @@ class DropOffLocation(Entity): def encoding(self): return i.SYMBOL_DROP_OFF - def __init__(self, *args, storage_size_until_full: int = 5, auto_item_despawn_interval: int = 5, **kwargs): + def __init__(self, *args, storage_size_until_full: int = 5, **kwargs): super(DropOffLocation, self).__init__(*args, **kwargs) - self.auto_item_despawn_interval = auto_item_despawn_interval self.storage = deque(maxlen=storage_size_until_full or None) def place_item(self, item: Item): if self.is_full: raise RuntimeWarning("There is currently no way to clear the storage or make it unfull.") - return bc.NOT_VALID # in Zeile 81 verschieben? + return bc.NOT_VALID else: self.storage.append(item) - item.set_auto_despawn(self.auto_item_despawn_interval) return c.VALID @property diff --git a/marl_factory_grid/modules/items/groups.py b/marl_factory_grid/modules/items/groups.py index 707f743..be5ca49 100644 --- a/marl_factory_grid/modules/items/groups.py +++ b/marl_factory_grid/modules/items/groups.py @@ -1,13 +1,11 @@ -from random import shuffle - -from marl_factory_grid.modules.items import constants as i from marl_factory_grid.environment import constants as c - -from marl_factory_grid.environment.groups.collection import Collection -from marl_factory_grid.environment.groups.objects import _Objects -from marl_factory_grid.environment.groups.mixins import IsBoundMixin from marl_factory_grid.environment.entity.agent import Agent +from marl_factory_grid.environment.groups.collection import Collection +from marl_factory_grid.environment.groups.mixins import IsBoundMixin +from marl_factory_grid.environment.groups.objects import Objects +from marl_factory_grid.modules.items import constants as i from marl_factory_grid.modules.items.entitites import Item, DropOffLocation +from marl_factory_grid.utils.results import Result class Items(Collection): @@ -15,7 +13,7 @@ class Items(Collection): @property def var_has_position(self): - return False + return True @property def is_blocking_light(self): @@ -28,18 +26,18 @@ class Items(Collection): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - @staticmethod - def trigger_item_spawn(state, n_items, spawn_frequency): - if item_to_spawns := max(0, (n_items - len(state[i.ITEM]))): - position_list = [x for x in state.entities.floorlist] - shuffle(position_list) - position_list = state.entities.floorlist[:item_to_spawns] - state[i.ITEM].spawn(position_list) - state.print(f'{item_to_spawns} new items have been spawned; next spawn in {spawn_frequency}') - return len(position_list) + def trigger_spawn(self, state, *entity_args, coords_or_quantity=None, **entity_kwargs) -> [Result]: + coords_or_quantity = coords_or_quantity if coords_or_quantity else self._coords_or_quantity + assert coords_or_quantity + + if item_to_spawns := max(0, (coords_or_quantity - len(self))): + return super().trigger_spawn(state, + *entity_args, + coords_or_quantity=item_to_spawns, + **entity_kwargs) else: state.print('No Items are spawning, limit is reached.') - return 0 + return Result(identifier=f'{self.name}_spawn', validity=c.NOT_VALID, value=coords_or_quantity) class Inventory(IsBoundMixin, Collection): @@ -73,12 +71,17 @@ class Inventory(IsBoundMixin, Collection): self._collection = collection -class Inventories(_Objects): +class Inventories(Objects): _entity = Inventory + var_can_move = False + var_has_position = False + + symbol = None + @property - def var_can_move(self): - return False + def spawn_rule(self): + return {c.SPAWN_ENTITY_RULE: dict(collection=self, coords_or_quantity=None)} def __init__(self, size: int, *args, **kwargs): super(Inventories, self).__init__(*args, **kwargs) @@ -86,10 +89,12 @@ class Inventories(_Objects): self._obs = None self._lazy_eval_transforms = [] - def spawn(self, agents): - inventories = [self._entity(agent, self.size, ) - for _, agent in enumerate(agents)] - self.add_items(inventories) + def spawn(self, agents, *args, **kwargs): + self.add_items([self._entity(agent, self.size, *args, **kwargs) for _, agent in enumerate(agents)]) + return [Result(identifier=f'{self.name}_spawn', validity=c.VALID, value=len(self))] + + def trigger_spawn(self, state, *args, **kwargs) -> [Result]: + return self.spawn(state[c.AGENT], *args, **kwargs) def idx_by_entity(self, entity): try: @@ -106,10 +111,6 @@ class Inventories(_Objects): def summarize_states(self, **kwargs): return [val.summarize_states(**kwargs) for key, val in self.items()] - @staticmethod - def trigger_inventory_spawn(state): - state[i.INVENTORY].spawn(state[c.AGENT]) - class DropOffLocations(Collection): _entity = DropOffLocation @@ -135,7 +136,7 @@ class DropOffLocations(Collection): @staticmethod def trigger_drop_off_location_spawn(state, n_locations): - empty_positions = state.entities.empty_positions()[:n_locations] + empty_positions = state.entities.empty_positions[:n_locations] do_entites = state[i.DROP_OFF] drop_offs = [DropOffLocation(pos) for pos in empty_positions] do_entites.add_items(drop_offs) diff --git a/marl_factory_grid/modules/items/rewards.py b/marl_factory_grid/modules/items/rewards.py index 40adf46..bcd2918 100644 --- a/marl_factory_grid/modules/items/rewards.py +++ b/marl_factory_grid/modules/items/rewards.py @@ -1,4 +1,4 @@ DROP_OFF_VALID: float = 0.1 DROP_OFF_FAIL: float = -0.1 PICK_UP_FAIL: float = -0.1 -PICK_UP_VALID: float = 0.1 \ No newline at end of file +PICK_UP_VALID: float = 0.1 diff --git a/marl_factory_grid/modules/items/rules.py b/marl_factory_grid/modules/items/rules.py index 9f8a0cc..a655956 100644 --- a/marl_factory_grid/modules/items/rules.py +++ b/marl_factory_grid/modules/items/rules.py @@ -6,52 +6,28 @@ from marl_factory_grid.utils.results import TickResult from marl_factory_grid.modules.items import constants as i -class ItemRules(Rule): +class RespawnItems(Rule): - def __init__(self, n_items: int = 5, spawn_frequency: int = 15, - n_locations: int = 5, max_dropoff_storage_size: int = 0): + def __init__(self, n_items: int = 5, respawn_freq: int = 15, n_locations: int = 5): super().__init__() - self.spawn_frequency = spawn_frequency - self._next_item_spawn = spawn_frequency + self.spawn_frequency = respawn_freq + self._next_item_spawn = respawn_freq self.n_items = n_items - self.max_dropoff_storage_size = max_dropoff_storage_size self.n_locations = n_locations - def on_init(self, state, lvl_map): - state[i.DROP_OFF].trigger_drop_off_location_spawn(state, self.n_locations) - self._next_item_spawn = self.spawn_frequency - state[i.INVENTORY].trigger_inventory_spawn(state) - state[i.ITEM].trigger_item_spawn(state, self.n_items, self.spawn_frequency) - def tick_step(self, state): - for item in list(state[i.ITEM].values()): - if item.auto_despawn >= 1: - item.set_auto_despawn(item.auto_despawn - 1) - elif not item.auto_despawn: - state[i.ITEM].delete_env_object(item) - else: - pass - if not self._next_item_spawn: - state[i.ITEM].trigger_item_spawn(state, self.n_items, self.spawn_frequency) + state[i.ITEM].trigger_spawn(state, self.n_items, self.spawn_frequency) else: self._next_item_spawn = max(0, self._next_item_spawn - 1) return [] def tick_post_step(self, state) -> List[TickResult]: - for item in list(state[i.ITEM].values()): - if item.auto_despawn >= 1: - item.set_auto_despawn(item.auto_despawn-1) - elif not item.auto_despawn: - state[i.ITEM].delete_env_object(item) - else: - pass - if not self._next_item_spawn: - if spawned_items := state[i.ITEM].trigger_item_spawn(state, self.n_items, self.spawn_frequency): - return [TickResult(self.name, validity=c.VALID, value=spawned_items, entity=None)] + if spawned_items := state[i.ITEM].trigger_spawn(state, self.n_items, self.spawn_frequency): + return [TickResult(self.name, validity=c.VALID, value=spawned_items.value)] else: - return [TickResult(self.name, validity=c.NOT_VALID, value=0, entity=None)] + return [TickResult(self.name, validity=c.NOT_VALID, value=0)] else: self._next_item_spawn = max(0, self._next_item_spawn-1) return [] diff --git a/marl_factory_grid/modules/machines/__init__.py b/marl_factory_grid/modules/machines/__init__.py index 36ba51d..233efbb 100644 --- a/marl_factory_grid/modules/machines/__init__.py +++ b/marl_factory_grid/modules/machines/__init__.py @@ -1,3 +1,2 @@ from .entitites import Machine from .groups import Machines -from .rules import MachineRule diff --git a/marl_factory_grid/modules/machines/actions.py b/marl_factory_grid/modules/machines/actions.py index 8f4eaaa..dbb303f 100644 --- a/marl_factory_grid/modules/machines/actions.py +++ b/marl_factory_grid/modules/machines/actions.py @@ -1,10 +1,12 @@ from typing import Union +import marl_factory_grid.modules.machines.constants from marl_factory_grid.environment.actions import Action from marl_factory_grid.utils.results import ActionResult -from marl_factory_grid.modules.machines import constants as m, rewards as r +from marl_factory_grid.modules.machines import constants as m from marl_factory_grid.environment import constants as c +from marl_factory_grid.utils import helpers as h class MachineAction(Action): @@ -13,13 +15,12 @@ class MachineAction(Action): super().__init__(m.MACHINE_ACTION) def do(self, entity, state) -> Union[None, ActionResult]: - if machine := state[m.MACHINES].by_pos(entity.pos): + if machine := h.get_first(state[m.MACHINES].by_pos(entity.pos)): if valid := machine.maintain(): - return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=r.MAINTAIN_VALID) + return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=marl_factory_grid.modules.machines.constants.MAINTAIN_VALID) else: - return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=r.MAINTAIN_FAIL) + return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=marl_factory_grid.modules.machines.constants.MAINTAIN_FAIL) else: - return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=r.MAINTAIN_FAIL) - - - + return ActionResult(entity=entity, identifier=self._identifier, + validity=c.NOT_VALID, reward=marl_factory_grid.modules.machines.constants.MAINTAIN_FAIL + ) diff --git a/marl_factory_grid/modules/machines/constants.py b/marl_factory_grid/modules/machines/constants.py index 29ce3bc..3771cbb 100644 --- a/marl_factory_grid/modules/machines/constants.py +++ b/marl_factory_grid/modules/machines/constants.py @@ -11,3 +11,7 @@ STATE_MAINTAIN = 'maintenance' SYMBOL_WORK = 1 SYMBOL_IDLE = 0.6 SYMBOL_MAINTAIN = 0.3 +MAINTAIN_VALID: float = 0.5 +MAINTAIN_FAIL: float = -0.1 +FAIL_MISSING_MAINTENANCE: float = -0.5 +NONE: float = 0 diff --git a/marl_factory_grid/modules/machines/entitites.py b/marl_factory_grid/modules/machines/entitites.py index 36a87cc..581adf6 100644 --- a/marl_factory_grid/modules/machines/entitites.py +++ b/marl_factory_grid/modules/machines/entitites.py @@ -8,22 +8,6 @@ from . import constants as m class Machine(Entity): - @property - def var_can_collide(self): - return False - - @property - def var_can_move(self): - return False - - @property - def var_is_blocking_light(self): - return False - - @property - def var_has_position(self): - return True - @property def encoding(self): return self._encodings[self.status] @@ -46,12 +30,11 @@ class Machine(Entity): else: return c.NOT_VALID - def tick(self): - # if self.status == m.STATE_MAINTAIN and any([c.AGENT in x.name for x in self.tile.guests]): - if self.status == m.STATE_MAINTAIN and any([c.AGENT in x.name for x in self.state.entities.pos_dict[self.pos]]): - return TickResult(identifier=self.name, validity=c.VALID, reward=0, entity=self) - # elif self.status == m.STATE_MAINTAIN and not any([c.AGENT in x.name for x in self.tile.guests]): - elif self.status == m.STATE_MAINTAIN and not any([c.AGENT in x.name for x in self.state.entities.pos_dict[self.pos]]): + def tick(self, state): + others = state.entities.pos_dict[self.pos] + if self.status == m.STATE_MAINTAIN and any([c.AGENT in x.name for x in others]): + return TickResult(identifier=self.name, validity=c.VALID, entity=self) + elif self.status == m.STATE_MAINTAIN and not any([c.AGENT in x.name for x in others]): self.status = m.STATE_WORK self.reset_counter() return None diff --git a/marl_factory_grid/modules/machines/groups.py b/marl_factory_grid/modules/machines/groups.py index 5f2d970..9d89d6c 100644 --- a/marl_factory_grid/modules/machines/groups.py +++ b/marl_factory_grid/modules/machines/groups.py @@ -1,5 +1,3 @@ -from typing import Union, List, Tuple - from marl_factory_grid.environment.groups.collection import Collection from .entitites import Machine diff --git a/marl_factory_grid/modules/machines/rewards.py b/marl_factory_grid/modules/machines/rewards.py deleted file mode 100644 index c868196..0000000 --- a/marl_factory_grid/modules/machines/rewards.py +++ /dev/null @@ -1,5 +0,0 @@ -MAINTAIN_VALID: float = 0.5 -MAINTAIN_FAIL: float = -0.1 -FAIL_MISSING_MAINTENANCE: float = -0.5 - -NONE: float = 0 diff --git a/marl_factory_grid/modules/machines/rules.py b/marl_factory_grid/modules/machines/rules.py index 84e3410..e69de29 100644 --- a/marl_factory_grid/modules/machines/rules.py +++ b/marl_factory_grid/modules/machines/rules.py @@ -1,28 +0,0 @@ -from typing import List -from marl_factory_grid.environment.rules import Rule -from marl_factory_grid.utils.results import TickResult, DoneResult -from marl_factory_grid.environment import constants as c -from marl_factory_grid.modules.machines import constants as m -from marl_factory_grid.modules.machines.entitites import Machine - - -class MachineRule(Rule): - - def __init__(self, n_machines: int = 2): - super(MachineRule, self).__init__() - self.n_machines = n_machines - - def on_init(self, state, lvl_map): - state[m.MACHINES].spawn(state.entities.empty_positions()) - - def tick_pre_step(self, state) -> List[TickResult]: - pass - - def tick_step(self, state) -> List[TickResult]: - pass - - def tick_post_step(self, state) -> List[TickResult]: - pass - - def on_check_done(self, state) -> List[DoneResult]: - pass diff --git a/marl_factory_grid/modules/maintenance/constants.py b/marl_factory_grid/modules/maintenance/constants.py index e0ab12c..3aed36c 100644 --- a/marl_factory_grid/modules/maintenance/constants.py +++ b/marl_factory_grid/modules/maintenance/constants.py @@ -1,3 +1,4 @@ MAINTAINER = 'Maintainer' # TEMPLATE _identifier. Define your own! MAINTAINERS = 'Maintainers' # TEMPLATE _identifier. Define your own! +MAINTAINER_COLLISION_REWARD = -5 diff --git a/marl_factory_grid/modules/maintenance/entities.py b/marl_factory_grid/modules/maintenance/entities.py index e084b0c..1a043c8 100644 --- a/marl_factory_grid/modules/maintenance/entities.py +++ b/marl_factory_grid/modules/maintenance/entities.py @@ -1,48 +1,35 @@ +from random import shuffle + import networkx as nx import numpy as np + from ...algorithms.static.utils import points_to_graph from ...environment import constants as c from ...environment.actions import Action, ALL_BASEACTIONS from ...environment.entity.entity import Entity from ..doors import constants as do from ..maintenance import constants as mi -from ...utils.helpers import MOVEMAP -from ...utils.utility_classes import RenderEntity -from ...utils.states import Gamestate +from ...utils import helpers as h +from ...utils.utility_classes import RenderEntity, Floor +from ..doors import DoorUse class Maintainer(Entity): - @property - def var_can_collide(self): - return True - - @property - def var_can_move(self): - return False - - @property - def var_is_blocking_light(self): - return False - - @property - def var_has_position(self): - return True - - def __init__(self, state: Gamestate, objective: str, action: Action, *args, **kwargs): + def __init__(self, objective: str, action: Action, *args, **kwargs): super().__init__(*args, **kwargs) self.action = action - self.actions = [x() for x in ALL_BASEACTIONS] + self.actions = [x() for x in ALL_BASEACTIONS] + [DoorUse()] self.objective = objective self._path = None self._next = [] self._last = [] self._last_serviced = 'None' - self._floortile_graph = points_to_graph(state.entities.floorlist) + self._floortile_graph = None def tick(self, state): - if found_objective := state[self.objective].by_pos(self.pos): + if found_objective := h.get_first(state[self.objective].by_pos(self.pos)): if found_objective.name != self._last_serviced: self.action.do(self, state) self._last_serviced = found_objective.name @@ -54,24 +41,27 @@ class Maintainer(Entity): return action.do(self, state) def get_move_action(self, state) -> Action: + if not self._floortile_graph: + state.print("Generating Floorgraph....") + self._floortile_graph = points_to_graph(state.entities.floorlist) if self._path is None or not self._path: if not self._next: - self._next = list(state[self.objective].values()) + self._next = list(state[self.objective].values()) + [Floor(*state.random_free_position)] + shuffle(self._next) self._last = [] self._last.append(self._next.pop()) + state.print("Calculating shortest path....") self._path = self.calculate_route(self._last[-1]) - if door := self._door_is_close(state): - if door.is_closed: - # Translate the action_object to an integer to have the same output as any other model - action = do.ACTION_DOOR_USE - else: - action = self._predict_move(state) + if door := self._closed_door_in_path(state): + state.print(f"{self} found {door} that is closed. Attempt to open.") + # Translate the action_object to an integer to have the same output as any other model + action = do.ACTION_DOOR_USE else: action = self._predict_move(state) # Translate the action_object to an integer to have the same output as any other model try: - action_obj = next(x for x in self.actions if x.name == action) + action_obj = h.get_first(self.actions, lambda x: x.name == action) except (StopIteration, UnboundLocalError): print('Will not happen') raise EnvironmentError @@ -81,11 +71,10 @@ class Maintainer(Entity): route = nx.shortest_path(self._floortile_graph, self.pos, entity.pos) return route[1:] - def _door_is_close(self, state): - state.print("Found a door that is close.") - try: - return next(y for x in state.entities.neighboring_positions(self.state.pos) for y in state.entities.pos_dict[x] if do.DOOR in y.name) - except StopIteration: + def _closed_door_in_path(self, state): + if self._path: + return h.get_first(state[do.DOORS].by_pos(self._path[0]), lambda x: x.is_closed) + else: return None def _predict_move(self, state): @@ -96,7 +85,7 @@ class Maintainer(Entity): next_pos = self._path.pop(0) diff = np.subtract(next_pos, self.pos) # Retrieve action based on the pos dif (like in: What do I have to do to get there?) - action = next(action for action, pos_diff in MOVEMAP.items() if np.all(diff == pos_diff)) + action = next(action for action, pos_diff in h.MOVEMAP.items() if np.all(diff == pos_diff)) return action def render(self): diff --git a/marl_factory_grid/modules/maintenance/groups.py b/marl_factory_grid/modules/maintenance/groups.py index 2df70cb..5b09c9c 100644 --- a/marl_factory_grid/modules/maintenance/groups.py +++ b/marl_factory_grid/modules/maintenance/groups.py @@ -1,34 +1,27 @@ -from typing import Union, List, Tuple +from typing import Union, List, Tuple, Dict from marl_factory_grid.environment.groups.collection import Collection from .entities import Maintainer from ..machines import constants as mc from ..machines.actions import MachineAction -from ...utils.states import Gamestate class Maintainers(Collection): _entity = Maintainer - @property - def var_can_collide(self): - return True + var_can_collide = True + var_can_move = True + var_is_blocking_light = False + var_has_position = True - @property - def var_can_move(self): - return True + def __init__(self, size, *args, coords_or_quantity: int = None, + spawnrule: Union[None, Dict[str, dict]] = None, + **kwargs): + super(Collection, self).__init__(*args, **kwargs) + self._coords_or_quantity = coords_or_quantity + self.size = size + self._spawnrule = spawnrule - @property - def var_is_blocking_light(self): - return False - - @property - def var_has_position(self): - return True - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args): - state = entity_args[0] - self.add_items([self._entity(state, mc.MACHINES, MachineAction(), pos) for pos in coords_or_quantity]) + self.add_items([self._entity(mc.MACHINES, MachineAction(), pos) for pos in coords_or_quantity]) diff --git a/marl_factory_grid/modules/maintenance/rewards.py b/marl_factory_grid/modules/maintenance/rewards.py deleted file mode 100644 index 425ac3b..0000000 --- a/marl_factory_grid/modules/maintenance/rewards.py +++ /dev/null @@ -1 +0,0 @@ -MAINTAINER_COLLISION_REWARD = -5 \ No newline at end of file diff --git a/marl_factory_grid/modules/maintenance/rules.py b/marl_factory_grid/modules/maintenance/rules.py index 820183e..92e6e75 100644 --- a/marl_factory_grid/modules/maintenance/rules.py +++ b/marl_factory_grid/modules/maintenance/rules.py @@ -1,32 +1,28 @@ from typing import List + +import marl_factory_grid.modules.maintenance.constants from marl_factory_grid.environment.rules import Rule from marl_factory_grid.utils.results import TickResult, DoneResult from marl_factory_grid.environment import constants as c -from . import rewards as r from . import constants as M -from marl_factory_grid.utils.states import Gamestate -class MaintenanceRule(Rule): +class MoveMaintainers(Rule): - def __init__(self, n_maintainer: int = 1, *args, **kwargs): - super(MaintenanceRule, self).__init__(*args, **kwargs) - self.n_maintainer = n_maintainer - - def on_init(self, state: Gamestate, lvl_map): - state[M.MAINTAINERS].spawn(state.entities.empty_positions[:self.n_maintainer], state) - pass - - def tick_pre_step(self, state) -> List[TickResult]: - pass + def __init__(self): + super().__init__() def tick_step(self, state) -> List[TickResult]: for maintainer in state[M.MAINTAINERS]: maintainer.tick(state) + # Todo: Return a Result Object. return [] - def tick_post_step(self, state) -> List[TickResult]: - pass + +class DoneAtMaintainerCollision(Rule): + + def __init__(self): + super().__init__() def on_check_done(self, state) -> List[DoneResult]: agents = list(state[c.AGENT].values()) @@ -35,5 +31,5 @@ class MaintenanceRule(Rule): for agent in agents: if agent.pos in m_pos: done_results.append(DoneResult(entity=agent, validity=c.VALID, identifier=self.name, - reward=r.MAINTAINER_COLLISION_REWARD)) + reward=marl_factory_grid.modules.maintenance.constants.MAINTAINER_COLLISION_REWARD)) return done_results diff --git a/marl_factory_grid/modules/zones/entitites.py b/marl_factory_grid/modules/zones/entitites.py index cfd313f..4aa0f70 100644 --- a/marl_factory_grid/modules/zones/entitites.py +++ b/marl_factory_grid/modules/zones/entitites.py @@ -1,10 +1,10 @@ import random from typing import List, Tuple -from marl_factory_grid.environment.entity.object import _Object +from marl_factory_grid.environment.entity.object import Object -class Zone(_Object): +class Zone(Object): @property def positions(self): diff --git a/marl_factory_grid/modules/zones/groups.py b/marl_factory_grid/modules/zones/groups.py index 71eb329..f5494cd 100644 --- a/marl_factory_grid/modules/zones/groups.py +++ b/marl_factory_grid/modules/zones/groups.py @@ -1,8 +1,8 @@ -from marl_factory_grid.environment.groups.objects import _Objects +from marl_factory_grid.environment.groups.objects import Objects from marl_factory_grid.modules.zones import Zone -class Zones(_Objects): +class Zones(Objects): symbol = None _entity = Zone diff --git a/marl_factory_grid/modules/zones/rules.py b/marl_factory_grid/modules/zones/rules.py index 2969186..f9b5c11 100644 --- a/marl_factory_grid/modules/zones/rules.py +++ b/marl_factory_grid/modules/zones/rules.py @@ -1,8 +1,8 @@ from random import choices, choice from . import constants as z, Zone +from .. import Destination from ..destinations import constants as d -from ... import Destination from ...environment.rules import Rule from ...environment import constants as c diff --git a/marl_factory_grid/utils/__init__.py b/marl_factory_grid/utils/__init__.py index e69de29..23848e0 100644 --- a/marl_factory_grid/utils/__init__.py +++ b/marl_factory_grid/utils/__init__.py @@ -0,0 +1,3 @@ +from . import helpers as h +from . import helpers +from .results import Result, DoneResult, ActionResult, TickResult diff --git a/marl_factory_grid/utils/config_parser.py b/marl_factory_grid/utils/config_parser.py index c9223f8..5cad113 100644 --- a/marl_factory_grid/utils/config_parser.py +++ b/marl_factory_grid/utils/config_parser.py @@ -1,4 +1,5 @@ import ast + from os import PathLike from pathlib import Path from typing import Union, List @@ -9,18 +10,17 @@ from marl_factory_grid.environment import constants as c from marl_factory_grid.environment.rules import Rule from marl_factory_grid.environment.tests import Test from marl_factory_grid.utils.helpers import locate_and_import_class - -DEFAULT_PATH = 'environment' -MODULE_PATH = 'modules' +from marl_factory_grid.environment.constants import DEFAULT_PATH, MODULE_PATH +from marl_factory_grid.environment import constants as c class FactoryConfigParser(object): default_entites = [] - default_rules = ['MaxStepsReached', 'Collision'] + default_rules = ['DoneAtMaxStepsReached', 'WatchCollision'] default_actions = [c.MOVE8, c.NOOP] default_observations = [c.WALLS, c.AGENT] - def __init__(self, config_path, custom_modules_path: Union[None, PathLike] = None): + def __init__(self, config_path, custom_modules_path: Union[PathLike] = None): self.config_path = Path(config_path) self.custom_modules_path = Path(custom_modules_path) if custom_modules_path is not None else custom_modules_path self.config = yaml.safe_load(self.config_path.open()) @@ -44,6 +44,10 @@ class FactoryConfigParser(object): def rules(self): return self.config['Rules'] + @property + def tests(self): + return self.config.get('Tests', []) + @property def agents(self): return self.config['Agents'] @@ -56,10 +60,12 @@ class FactoryConfigParser(object): return str(self.config) def __getitem__(self, item): - return self.config[item] + try: + return self.config[item] + except KeyError: + print(f'The mandatory {item} section could not be found in your .config gile. Check Spelling!') def load_entities(self): - # entites = Entities() entity_classes = dict() entities = [] if c.DEFAULTS in self.entities: @@ -67,28 +73,40 @@ class FactoryConfigParser(object): entities.extend(x for x in self.entities if x != c.DEFAULTS) for entity in entities: + e1 = e2 = e3 = None try: folder_path = Path(__file__).parent.parent / DEFAULT_PATH entity_class = locate_and_import_class(entity, folder_path) - except AttributeError as e1: + except AttributeError as e: + e1 = e try: - folder_path = Path(__file__).parent.parent / MODULE_PATH - entity_class = locate_and_import_class(entity, folder_path) - except AttributeError as e2: - try: - folder_path = self.custom_modules_path - entity_class = locate_and_import_class(entity, folder_path) - except AttributeError as e3: - ents = [y for x in [e1.argss[1], e2.argss[1], e3.argss[1]] for y in x] - print('### Error ### Error ### Error ### Error ### Error ###') - print() - print(f'Class "{entity}" was not found in "{folder_path.name}"') - print('Possible Entitys are:', str(ents)) - print() - print('Goodbye') - print() - exit() - # raise AttributeError(e1.argss[0], e2.argss[0], e3.argss[0], 'Possible Entitys are:', str(ents)) + module_path = Path(__file__).parent.parent / MODULE_PATH + entity_class = locate_and_import_class(entity, module_path) + except AttributeError as e: + e2 = e + if self.custom_modules_path: + try: + entity_class = locate_and_import_class(entity, self.custom_modules_path) + except AttributeError as e: + e3 = e + pass + if (e1 and e2) or e3: + ents = [y for x in [e1, e2, e3] if x is not None for y in x.args[1]] + print('##############################################################') + print('### Error ### Error ### Error ### Error ### Error ###') + print('##############################################################') + print(f'Class "{entity}" was not found in "{module_path.name}"') + print(f'Class "{entity}" was not found in "{folder_path.name}"') + print('##############################################################') + if self.custom_modules_path: + print(f'Class "{entity}" was not found in "{self.custom_modules_path}"') + print('Possible Entitys are:', str(ents)) + print('##############################################################') + print('Goodbye') + print('##############################################################') + print('### Error ### Error ### Error ### Error ### Error ###') + print('##############################################################') + exit(-99999) entity_kwargs = self.entities.get(entity, {}) entity_symbol = entity_class.symbol if hasattr(entity_class, 'symbol') else None @@ -126,7 +144,12 @@ class FactoryConfigParser(object): observations.extend(self.default_observations) observations.extend(x for x in self.agents[name]['Observations'] if x != c.DEFAULTS) positions = [ast.literal_eval(x) for x in self.agents[name].get('Positions', [])] - parsed_agents_conf[name] = dict(actions=parsed_actions, observations=observations, positions=positions) + other_kwargs = {k: v for k, v in self.agents[name].items() if k not in + ['Actions', 'Observations', 'Positions']} + parsed_agents_conf[name] = dict( + actions=parsed_actions, observations=observations, positions=positions, other=other_kwargs + ) + return parsed_agents_conf def load_env_rules(self) -> List[Rule]: @@ -137,28 +160,69 @@ class FactoryConfigParser(object): rules.append({rule: {}}) return self._load_smth(rules, Rule) - pass - def load_env_tests(self) -> List[Test]: + def load_env_tests(self) -> List[Rule]: return self._load_smth(self.tests, None) # Test - pass def _load_smth(self, config, class_obj): rules = list() - rules_names = list() - - for rule in rules_names: + for rule in config: + e1 = e2 = e3 = None try: folder_path = (Path(__file__).parent.parent / DEFAULT_PATH) rule_class = locate_and_import_class(rule, folder_path) - except AttributeError: + except AttributeError as e: + e1 = e try: - folder_path = (Path(__file__).parent.parent / MODULE_PATH) - rule_class = locate_and_import_class(rule, folder_path) - except AttributeError: - rule_class = locate_and_import_class(rule, self.custom_modules_path) - # Fixme This check does not work! - # assert isinstance(rule_class, class_obj), f'{rule_class.__name__} is no valid "class_obj.__name__".' - rule_kwargs = config.get(rule, {}) - rules.append(rule_class(**rule_kwargs)) + module_path = (Path(__file__).parent.parent / MODULE_PATH) + rule_class = locate_and_import_class(rule, module_path) + except AttributeError as e: + e2 = e + if self.custom_modules_path: + try: + rule_class = locate_and_import_class(rule, self.custom_modules_path) + except AttributeError as e: + e3 = e + pass + if (e1 and e2) or e3: + ents = [y for x in [e1, e2, e3] if x is not None for y in x.args[1]] + print('### Error ### Error ### Error ### Error ### Error ###') + print('') + print(f'Class "{rule}" was not found in "{module_path.name}"') + print(f'Class "{rule}" was not found in "{folder_path.name}"') + if self.custom_modules_path: + print(f'Class "{rule}" was not found in "{self.custom_modules_path}"') + print('Possible Entitys are:', str(ents)) + print('') + print('Goodbye') + print('') + exit(-99999) + + if issubclass(rule_class, class_obj): + rule_kwargs = config.get(rule, {}) + rules.append(rule_class(**(rule_kwargs or {}))) + return rules + + def load_entity_spawn_rules(self, entities) -> List[Rule]: + rules = list() + rules_dicts = list() + for e in entities: + try: + if spawn_rule := e.spawn_rule: + rules_dicts.append(spawn_rule) + except AttributeError: + pass + + for rule_dict in rules_dicts: + for rule_name, rule_kwargs in rule_dict.items(): + try: + folder_path = (Path(__file__).parent.parent / DEFAULT_PATH) + rule_class = locate_and_import_class(rule_name, folder_path) + except AttributeError: + try: + folder_path = (Path(__file__).parent.parent / MODULE_PATH) + rule_class = locate_and_import_class(rule_name, folder_path) + except AttributeError: + rule_class = locate_and_import_class(rule_name, self.custom_modules_path) + rules.append(rule_class(**rule_kwargs)) return rules diff --git a/marl_factory_grid/utils/helpers.py b/marl_factory_grid/utils/helpers.py index e2f3c9a..f5f6d00 100644 --- a/marl_factory_grid/utils/helpers.py +++ b/marl_factory_grid/utils/helpers.py @@ -2,7 +2,7 @@ import importlib from collections import defaultdict from pathlib import PurePath, Path -from typing import Union, Dict, List +from typing import Union, Dict, List, Iterable, Callable import numpy as np from numpy.typing import ArrayLike @@ -61,8 +61,8 @@ class ObservationTranslator: :param per_agent_named_obs_spaces: `Named observation space` one for each agent. Overloaded. type per_agent_named_obs_spaces: Dict[str, dict] - :param placeholder_fill_value: Currently not fully implemented!!! - :type placeholder_fill_value: Union[int, str] = 'N') + :param placeholder_fill_value: Currently, not fully implemented!!! + :type placeholder_fill_value: Union[int, str] = 'N' """ if isinstance(placeholder_fill_value, str): @@ -222,7 +222,7 @@ def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''): mod = importlib.import_module('.'.join(module_parts)) all_found_modules.extend([x for x in dir(mod) if (not(x.startswith('__') or len(x) <= 2) and x.istitle()) and x not in ['Entity', 'NamedTuple', 'List', 'Rule', 'Union', - 'TickResult', 'ActionResult', 'Action', 'Agent', 'BoundEntityMixin', + 'TickResult', 'ActionResult', 'Action', 'Agent', 'RenderEntity', 'TemplateRule', 'Objects', 'PositionMixin', 'IsBoundMixin', 'EnvObject', 'EnvObjects', 'Dict', 'Any' ]]) @@ -240,7 +240,13 @@ def add_bound_name(name_str, bound_e): def add_pos_name(name_str, bound_e): if bound_e.var_has_position: - return f'{name_str}({bound_e.pos})' + return f'{name_str}@{bound_e.pos}' return name_str +def get_first(iterable: Iterable, filter_by: Callable[[any], bool] = lambda _: True): + return next((x for x in iterable if filter_by(x)), None) + + +def get_first_index(iterable: Iterable, filter_by: Callable[[any], bool] = lambda _: True): + return next((idx for idx, x in enumerate(iterable) if filter_by(x)), None) diff --git a/marl_factory_grid/utils/level_parser.py b/marl_factory_grid/utils/level_parser.py index fc8b948..24a05df 100644 --- a/marl_factory_grid/utils/level_parser.py +++ b/marl_factory_grid/utils/level_parser.py @@ -47,6 +47,7 @@ class LevelParser(object): # All other for es_name in self.e_p_dict: e_class, e_kwargs = self.e_p_dict[es_name]['class'], self.e_p_dict[es_name]['kwargs'] + e_kwargs = e_kwargs if e_kwargs else {} if hasattr(e_class, 'symbol') and e_class.symbol is not None: symbols = e_class.symbol diff --git a/marl_factory_grid/utils/logging/envmonitor.py b/marl_factory_grid/utils/logging/envmonitor.py index 67eac73..e2551c8 100644 --- a/marl_factory_grid/utils/logging/envmonitor.py +++ b/marl_factory_grid/utils/logging/envmonitor.py @@ -9,7 +9,7 @@ from marl_factory_grid.utils.helpers import IGNORED_DF_COLUMNS import pandas as pd -from marl_factory_grid.utils.plotting.compare_runs import plot_single_run +from marl_factory_grid.utils.plotting.plot_single_runs import plot_single_run class EnvMonitor(Wrapper): @@ -22,7 +22,6 @@ class EnvMonitor(Wrapper): self._monitor_df = pd.DataFrame() self._monitor_dict = dict() - def step(self, action): obs_type, obs, reward, done, info = self.env.step(action) self._read_info(info) diff --git a/marl_factory_grid/utils/logging/recorder.py b/marl_factory_grid/utils/logging/recorder.py index fac2e16..797866e 100644 --- a/marl_factory_grid/utils/logging/recorder.py +++ b/marl_factory_grid/utils/logging/recorder.py @@ -2,11 +2,9 @@ from os import PathLike from pathlib import Path from typing import Union, List -import yaml -from gymnasium import Wrapper - import numpy as np import pandas as pd +from gymnasium import Wrapper class EnvRecorder(Wrapper): @@ -106,7 +104,7 @@ class EnvRecorder(Wrapper): out_dict = {'episodes': self._recorder_out_list} out_dict.update( {'n_episodes': self._curr_episode, - 'metadata':dict( + 'metadata': dict( level_name=self.env.params['General']['level_name'], verbose=False, n_agents=len(self.env.params['Agents']), diff --git a/marl_factory_grid/utils/observation_builder.py b/marl_factory_grid/utils/observation_builder.py index 9fd1d26..55d6ec0 100644 --- a/marl_factory_grid/utils/observation_builder.py +++ b/marl_factory_grid/utils/observation_builder.py @@ -1,17 +1,16 @@ -import math import re from collections import defaultdict -from itertools import product from typing import Dict, List import numpy as np -from numba import njit from marl_factory_grid.environment import constants as c +from marl_factory_grid.environment.entity.object import Object from marl_factory_grid.environment.groups.utils import Combined -import marl_factory_grid.utils.helpers as h -from marl_factory_grid.utils.states import Gamestate from marl_factory_grid.utils.utility_classes import Floor +from marl_factory_grid.utils.ray_caster import RayCaster +from marl_factory_grid.utils.states import Gamestate +from marl_factory_grid.utils import helpers as h class OBSBuilder(object): @@ -77,11 +76,13 @@ class OBSBuilder(object): def place_entity_in_observation(self, obs_array, agent, e): x, y = (e.x - agent.x) + self.pomdp_r, (e.y - agent.y) + self.pomdp_r - try: - obs_array[x, y] += e.encoding - except IndexError: - # Seemded to be visible but is out of range - pass + if not min([y, x]) < 0: + try: + obs_array[x, y] += e.encoding + except IndexError: + # Seemded to be visible but is out of range + pass + pass def build_for_agent(self, agent, state) -> (List[str], np.ndarray): assert self._curr_env_step == state.curr_step, ( @@ -121,18 +122,24 @@ class OBSBuilder(object): e = self.all_obs[l_name] except KeyError: try: - # Look for bound entity names! - pattern = re.compile(f'{re.escape(l_name)}(.*){re.escape(agent.name)}') - name = next((x for x in self.all_obs if pattern.search(x)), None) + # Look for bound entity REPRs! + pattern = re.compile(f'{re.escape(l_name)}' + f'{re.escape("[")}(.*){re.escape("]")}' + f'{re.escape("(")}{re.escape(agent.name)}{re.escape(")")}') + name = next((key for key, val in self.all_obs.items() + if pattern.search(str(val)) and isinstance(val, Object)), None) e = self.all_obs[name] except KeyError: try: e = next(v for k, v in self.all_obs.items() if l_name in k and agent.name in k) except StopIteration: - raise KeyError( - f'Check for spelling errors! \n ' - f'No combination of "{l_name} and {agent.name}" could not be found in:\n ' - f'{list(dict(self.all_obs).keys())}') + print(f'# Check for spelling errors!') + print(f'# No combination of "{l_name}" and "{agent.name}" could not be found in:') + print(f'# {list(dict(self.all_obs).keys())}') + print('#') + print('# exiting...') + print('#') + exit(-99999) try: positional = e.var_has_position @@ -161,31 +168,30 @@ class OBSBuilder(object): try: light_map = np.zeros(self.obs_shape) visible_floor = self.ray_caster[agent.name].visible_entities(self._floortiles, reset_cache=False) - if self.pomdp_r: - for f in set(visible_floor): - self.place_entity_in_observation(light_map, agent, f) - else: - for f in set(visible_floor): - light_map[f.x, f.y] += f.encoding + + for f in set(visible_floor): + self.place_entity_in_observation(light_map, agent, f) + # else: + # for f in set(visible_floor): + # light_map[f.x, f.y] += f.encoding self.curr_lightmaps[agent.name] = light_map except (KeyError, ValueError): - print() pass return obs, self.obs_layers[agent.name] def _sort_and_name_observation_conf(self, agent): - ''' + """ Builds the useable observation scheme per agent from conf.yaml. :param agent: :return: - ''' + """ # Fixme: no asymetric shapes possible. self.ray_caster[agent.name] = RayCaster(agent, min(self.obs_shape)) obs_layers = [] for obs_str in agent.observations: if isinstance(obs_str, dict): - obs_str, vals = next(obs_str.items().__iter__()) + obs_str, vals = h.get_first(obs_str.items()) else: vals = None if obs_str == c.SELF: @@ -214,129 +220,3 @@ class OBSBuilder(object): obs_layers.append(obs_str) self.obs_layers[agent.name] = obs_layers self.curr_lightmaps[agent.name] = np.zeros(self.obs_shape) - - -class RayCaster: - def __init__(self, agent, pomdp_r, degs=360): - self.agent = agent - self.pomdp_r = pomdp_r - self.n_rays = (self.pomdp_r + 1) * 8 - self.degs = degs - self.ray_targets = self.build_ray_targets() - self.obs_shape_cube = np.array([self.pomdp_r, self.pomdp_r]) - self._cache_dict = {} - - def __repr__(self): - return f'{self.__class__.__name__}({self.agent.name})' - - def build_ray_targets(self): - north = np.array([0, -1]) * self.pomdp_r - thetas = [np.deg2rad(deg) for deg in np.linspace(-self.degs // 2, self.degs // 2, self.n_rays)[::-1]] - rot_M = [ - [[math.cos(theta), -math.sin(theta)], - [math.sin(theta), math.cos(theta)]] for theta in thetas - ] - rot_M = np.stack(rot_M, 0) - rot_M = np.unique(np.round(rot_M @ north), axis=0) - return rot_M.astype(int) - - def ray_block_cache(self, key, callback): - if key not in self._cache_dict: - self._cache_dict[key] = callback() - return self._cache_dict[key] - - def visible_entities(self, pos_dict, reset_cache=True): - visible = list() - if reset_cache: - self._cache_dict = {} - - for ray in self.get_rays(): - rx, ry = ray[0] - for x, y in ray: - cx, cy = x - rx, y - ry - - entities_hit = pos_dict[(x, y)] - hits = self.ray_block_cache((x, y), - lambda: any(True for e in entities_hit if e.var_is_blocking_light) - ) - - diag_hits = all([ - self.ray_block_cache( - key, - lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light) and bool( - pos_dict[key])) - for key in ((x, y - cy), (x - cx, y)) - ]) if (cx != 0 and cy != 0) else False - - visible += entities_hit if not diag_hits else [] - if hits or diag_hits: - break - rx, ry = x, y - return visible - - def get_rays(self): - a_pos = self.agent.pos - outline = self.ray_targets + a_pos - return self.bresenham_loop(a_pos, outline) - - # todo do this once and cache the points! - def get_fov_outline(self) -> np.ndarray: - return self.ray_targets + self.agent.pos - - def get_square_outline(self): - agent = self.agent - x_coords = range(agent.x - self.pomdp_r, agent.x + self.pomdp_r + 1) - y_coords = range(agent.y - self.pomdp_r, agent.y + self.pomdp_r + 1) - outline = list(product(x_coords, [agent.y - self.pomdp_r, agent.y + self.pomdp_r])) \ - + list(product([agent.x - self.pomdp_r, agent.x + self.pomdp_r], y_coords)) - return outline - - @staticmethod - @njit - def bresenham_loop(a_pos, points): - results = [] - for end in points: - x1, y1 = a_pos - x2, y2 = end - dx = x2 - x1 - dy = y2 - y1 - - # Determine how steep the line is - is_steep = abs(dy) > abs(dx) - - # Rotate line - if is_steep: - x1, y1 = y1, x1 - x2, y2 = y2, x2 - - # Swap start and end points if necessary and store swap state - swapped = False - if x1 > x2: - x1, x2 = x2, x1 - y1, y2 = y2, y1 - swapped = True - - # Recalculate differentials - dx = x2 - x1 - dy = y2 - y1 - - # Calculate error - error = int(dx / 2.0) - ystep = 1 if y1 < y2 else -1 - - # Iterate over bounding box generating points between start and end - y = y1 - points = [] - for x in range(int(x1), int(x2) + 1): - coord = [y, x] if is_steep else [x, y] - points.append(coord) - error -= abs(dy) - if error < 0: - y += ystep - error += dx - - # Reverse the list if the coordinates were swapped - if swapped: - points.reverse() - results.append(points) - return results diff --git a/marl_factory_grid/utils/plotting/compare_runs.py b/marl_factory_grid/utils/plotting/plot_compare_runs.py similarity index 83% rename from marl_factory_grid/utils/plotting/compare_runs.py rename to marl_factory_grid/utils/plotting/plot_compare_runs.py index cb5c853..5115478 100644 --- a/marl_factory_grid/utils/plotting/compare_runs.py +++ b/marl_factory_grid/utils/plotting/plot_compare_runs.py @@ -7,50 +7,11 @@ from typing import Union, List import pandas as pd from marl_factory_grid.utils.helpers import IGNORED_DF_COLUMNS -from marl_factory_grid.utils.plotting.plotting import prepare_plot +from marl_factory_grid.utils.plotting.plotting_utils import prepare_plot MODEL_MAP = None -def plot_single_run(run_path: Union[str, PathLike], use_tex: bool = False, column_keys=None): - run_path = Path(run_path) - df_list = list() - if run_path.is_dir(): - monitor_file = next(run_path.glob('*monitor*.pick')) - elif run_path.exists() and run_path.is_file(): - monitor_file = run_path - else: - raise ValueError - - with monitor_file.open('rb') as f: - monitor_df = pickle.load(f) - - monitor_df = monitor_df.fillna(0) - df_list.append(monitor_df) - - df = pd.concat(df_list, ignore_index=True) - df = df.fillna(0).rename(columns={'episode': 'Episode'}).sort_values(['Episode']) - if column_keys is not None: - columns = [col for col in column_keys if col in df.columns] - else: - columns = [col for col in df.columns if col not in IGNORED_DF_COLUMNS] - - roll_n = 50 - - non_overlapp_window = df.groupby(['Episode']).rolling(roll_n, min_periods=1).mean() - - df_melted = df[columns + ['Episode']].reset_index().melt( - id_vars=['Episode'], value_vars=columns, var_name="Measurement", value_name="Score" - ) - - if df_melted['Episode'].max() > 800: - skip_n = round(df_melted['Episode'].max() * 0.02) - df_melted = df_melted[df_melted['Episode'] % skip_n == 0] - - prepare_plot(run_path.parent / f'{run_path.parent.name}_monitor_lineplot.png', df_melted, use_tex=use_tex) - print('Plotting done.') - - def compare_seed_runs(run_path: Union[str, PathLike], use_tex: bool = False): run_path = Path(run_path) df_list = list() diff --git a/marl_factory_grid/utils/plotting/plot_single_runs.py b/marl_factory_grid/utils/plotting/plot_single_runs.py new file mode 100644 index 0000000..7316d6a --- /dev/null +++ b/marl_factory_grid/utils/plotting/plot_single_runs.py @@ -0,0 +1,48 @@ +import pickle +from os import PathLike +from pathlib import Path +from typing import Union + +import pandas as pd + +from marl_factory_grid.utils.helpers import IGNORED_DF_COLUMNS +from marl_factory_grid.utils.plotting.plotting_utils import prepare_plot + + +def plot_single_run(run_path: Union[str, PathLike], use_tex: bool = False, column_keys=None, + file_key: str ='monitor', file_ext: str ='pkl'): + run_path = Path(run_path) + df_list = list() + if run_path.is_dir(): + monitor_file = next(run_path.glob(f'*{file_key}*.{file_ext}')) + elif run_path.exists() and run_path.is_file(): + monitor_file = run_path + else: + raise ValueError + + with monitor_file.open('rb') as f: + monitor_df = pickle.load(f) + + monitor_df = monitor_df.fillna(0) + df_list.append(monitor_df) + + df = pd.concat(df_list, ignore_index=True) + df = df.fillna(0).rename(columns={'episode': 'Episode'}).sort_values(['Episode']) + if column_keys is not None: + columns = [col for col in column_keys if col in df.columns] + else: + columns = [col for col in df.columns if col not in IGNORED_DF_COLUMNS] + + # roll_n = 50 + # non_overlapp_window = df.groupby(['Episode']).rolling(roll_n, min_periods=1).mean() + + df_melted = df[columns + ['Episode']].reset_index().melt( + id_vars=['Episode'], value_vars=columns, var_name="Measurement", value_name="Score" + ) + + if df_melted['Episode'].max() > 800: + skip_n = round(df_melted['Episode'].max() * 0.02) + df_melted = df_melted[df_melted['Episode'] % skip_n == 0] + + prepare_plot(run_path.parent / f'{run_path.parent.name}_monitor_lineplot.png', df_melted, use_tex=use_tex) + print('Plotting done.') diff --git a/marl_factory_grid/utils/plotting/plotting.py b/marl_factory_grid/utils/plotting/plotting_utils.py similarity index 98% rename from marl_factory_grid/utils/plotting/plotting.py rename to marl_factory_grid/utils/plotting/plotting_utils.py index 455f81a..17bb7ff 100644 --- a/marl_factory_grid/utils/plotting/plotting.py +++ b/marl_factory_grid/utils/plotting/plotting_utils.py @@ -60,7 +60,7 @@ def prepare_center_double_column_legend(df, hue, style, hue_order): print('Struggling to plot Figure using LaTeX - going back to normal.') plt.close('all') sns.set(rc={'text.usetex': False}, style='whitegrid') - fig = plt.figure(figsize=(10, 11)) + _ = plt.figure(figsize=(10, 11)) lineplot = sns.lineplot(data=df, x='Episode', y='Score', hue=hue, style=style, ci=95, palette=PALETTE, hue_order=hue_order, legend=False) # plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0) diff --git a/marl_factory_grid/utils/ray_caster.py b/marl_factory_grid/utils/ray_caster.py index cf17bd1..d89997e 100644 --- a/marl_factory_grid/utils/ray_caster.py +++ b/marl_factory_grid/utils/ray_caster.py @@ -19,7 +19,7 @@ class RayCaster: return f'{self.__class__.__name__}({self.agent.name})' def build_ray_targets(self): - north = np.array([0, -1])*self.pomdp_r + north = np.array([0, -1]) * self.pomdp_r thetas = [np.deg2rad(deg) for deg in np.linspace(-self.degs // 2, self.degs // 2, self.n_rays)[::-1]] rot_M = [ [[math.cos(theta), -math.sin(theta)], @@ -39,8 +39,9 @@ class RayCaster: if reset_cache: self._cache_dict = dict() - for ray in self.get_rays(): + for ray in self.get_rays(): # Do not check, just trust. rx, ry = ray[0] + # self.ray_block_cache(ray[0], lambda: False) We do not do that, because of doors etc... for x, y in ray: cx, cy = x - rx, y - ry @@ -52,8 +53,9 @@ class RayCaster: diag_hits = all([ self.ray_block_cache( key, - lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light)) - for key in ((x, y-cy), (x-cx, y)) + # lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light) + lambda: any(True for e in pos_dict[key] if e.var_is_blocking_light)) + for key in ((x, y - cy), (x - cx, y)) ]) if (cx != 0 and cy != 0) else False visible += entities_hit if not diag_hits else [] @@ -75,8 +77,8 @@ class RayCaster: agent = self.agent x_coords = range(agent.x - self.pomdp_r, agent.x + self.pomdp_r + 1) y_coords = range(agent.y - self.pomdp_r, agent.y + self.pomdp_r + 1) - outline = list(product(x_coords, [agent.y - self.pomdp_r, agent.y + self.pomdp_r])) \ - + list(product([agent.x - self.pomdp_r, agent.x + self.pomdp_r], y_coords)) + outline = list(product(x_coords, [agent.y - self.pomdp_r, agent.y + self.pomdp_r])) + outline += list(product([agent.x - self.pomdp_r, agent.x + self.pomdp_r], y_coords)) return outline @staticmethod diff --git a/marl_factory_grid/utils/renderer.py b/marl_factory_grid/utils/renderer.py index db6a93f..1976974 100644 --- a/marl_factory_grid/utils/renderer.py +++ b/marl_factory_grid/utils/renderer.py @@ -31,7 +31,7 @@ class Renderer: def __init__(self, lvl_shape: Tuple[int, int] = (16, 16), lvl_padded_shape: Union[Tuple[int, int], None] = None, - cell_size: int = 40, fps: int = 7, + cell_size: int = 40, fps: int = 7, factor: float = 0.9, grid_lines: bool = True, view_radius: int = 2): # TODO: Customn_assets paths self.grid_h, self.grid_w = lvl_shape @@ -45,7 +45,7 @@ class Renderer: self.screen = pygame.display.set_mode(self.screen_size) self.clock = pygame.time.Clock() assets = list(self.ASSETS.rglob('*.png')) - self.assets = {path.stem: self.load_asset(str(path), 1) for path in assets} + self.assets = {path.stem: self.load_asset(str(path), factor) for path in assets} self.fill_bg() now = time.time() @@ -110,22 +110,22 @@ class Renderer: pygame.quit() sys.exit() self.fill_bg() - blits = deque() - for entity in [x for x in entities]: - bp = self.blit_params(entity) - blits.append(bp) - if entity.name.lower() == AGENT: - if self.view_radius > 0: - vis_rects = self.visibility_rects(bp, entity.aux) - blits.extendleft(vis_rects) - if entity.state != BLANK: - agent_state_blits = self.blit_params( - RenderEntity(entity.state, (entity.pos[0] + 0.12, entity.pos[1]), 0.48, SCALE) - ) - textsurface = self.font.render(str(entity.id), False, (0, 0, 0)) - text_blit = dict(source=textsurface, dest=(bp['dest'].center[0]-.07*self.cell_size, - bp['dest'].center[1])) - blits += [agent_state_blits, text_blit] + # First all others + blits = deque(self.blit_params(x) for x in entities if not x.name.lower() == AGENT) + # Then Agents, so that agents are rendered on top. + for agent in (x for x in entities if x.name.lower() == AGENT): + agent_blit = self.blit_params(agent) + if self.view_radius > 0: + vis_rects = self.visibility_rects(agent_blit, agent.aux) + blits.extendleft(vis_rects) + if agent.state != BLANK: + state_blit = self.blit_params( + RenderEntity(agent.state, (agent.pos[0] + 0.12, agent.pos[1]), 0.48, SCALE) + ) + textsurface = self.font.render(str(agent.id), False, (0, 0, 0)) + text_blit = dict(source=textsurface, dest=(agent_blit['dest'].center[0]-.07*self.cell_size, + agent_blit['dest'].center[1])) + blits += [agent_blit, state_blit, text_blit] for blit in blits: self.screen.blit(**blit) diff --git a/marl_factory_grid/utils/results.py b/marl_factory_grid/utils/results.py index 9f0fa38..b4b07fc 100644 --- a/marl_factory_grid/utils/results.py +++ b/marl_factory_grid/utils/results.py @@ -1,9 +1,12 @@ from typing import Union from dataclasses import dataclass +from marl_factory_grid.environment.entity.object import Object + TYPE_VALUE = 'value' TYPE_REWARD = 'reward' -types = [TYPE_VALUE, TYPE_REWARD] +TYPES = [TYPE_VALUE, TYPE_REWARD] + @dataclass class InfoObject: @@ -18,17 +21,21 @@ class Result: validity: bool reward: Union[float, None] = None value: Union[float, None] = None - entity: None = None + entity: Object = None def get_infos(self): n = self.entity.name if self.entity is not None else "Global" - return [InfoObject(identifier=f'{n}_{self.identifier}_{t}', - val_type=t, value=self.__getattribute__(t)) for t in types + # Return multiple Info Dicts + return [InfoObject(identifier=f'{n}_{self.identifier}', + val_type=t, value=self.__getattribute__(t)) for t in TYPES if self.__getattribute__(t) is not None] def __repr__(self): valid = "not " if not self.validity else "" - return f'{self.__class__.__name__}({self.identifier.capitalize()} {valid}valid: {self.reward})' + reward = f" | Reward: {self.reward}" if self.reward is not None else "" + value = f" | Value: {self.value}" if self.value is not None else "" + entity = f" | by: {self.entity.name}" if self.entity is not None else "" + return f'{self.__class__.__name__}({self.identifier.capitalize()} {valid}valid{reward}{value}{entity})' @dataclass diff --git a/marl_factory_grid/utils/states.py b/marl_factory_grid/utils/states.py index 1461826..d54db6a 100644 --- a/marl_factory_grid/utils/states.py +++ b/marl_factory_grid/utils/states.py @@ -1,9 +1,12 @@ -from typing import List, Dict, Tuple +from itertools import islice +from typing import List, Tuple import numpy as np from marl_factory_grid.environment import constants as c +from marl_factory_grid.environment.entity.entity import Entity from marl_factory_grid.environment.rules import Rule +from marl_factory_grid.utils.results import Result, DoneResult from marl_factory_grid.environment.tests import Test from marl_factory_grid.utils.results import Result @@ -60,7 +63,8 @@ class Gamestate(object): def moving_entites(self): return [y for x in self.entities for y in x if x.var_can_move] - def __init__(self, entities, agents_conf, rules: [Rule], tests: [Test], env_seed=69, verbose=False): + def __init__(self, entities, agents_conf, rules: List[Rule], tests: [Test], lvl_shape, env_seed=69, verbose=False): + self.lvl_shape = lvl_shape self.entities = entities self.curr_step = 0 self.curr_actions = None @@ -82,7 +86,52 @@ class Gamestate(object): def __repr__(self): return f'{self.__class__.__name__}({len(self.entities)} Entitites @ Step {self.curr_step})' - def tick(self, actions) -> List[Result]: + @property + def random_free_position(self) -> (int, int): + """ + Returns a single **free** position (x, y), which is **free** for spawning or walking. + No Entity at this position posses *var_is_blocking_pos* or *var_can_collide*. + + :return: Single **free** position. + """ + return self.get_n_random_free_positions(1)[0] + + def get_n_random_free_positions(self, n) -> list[tuple[int, int]]: + """ + Returns a list of *n* **free** positions [(x, y), ... ], which are **free** for spawning or walking. + No Entity at this position posses *var_is_blocking_pos* or *var_can_collide*. + + :return: List of n **free** position. + """ + return list(islice(self.entities.free_positions_generator, n)) + + @property + def random_position(self) -> (int, int): + """ + Returns a single available position (x, y), ignores all entity attributes. + + :return: Single random position. + """ + return self.get_n_random_positions(1)[0] + + def get_n_random_positions(self, n) -> list[tuple[int, int]]: + """ + Returns a list of *n* available positions [(x, y), ... ], ignores all entity attributes. + + :return: List of n random positions. + """ + return list(islice(self.entities.floorlist, n)) + + def tick(self, actions) -> list[Result]: + """ + Performs a single **Gamestate Tick**by calling the inner rule hooks in sequential order. + - tick_pre_step_all: Things to do before the agents do their actions. Statechange, Moving, Spawning etc... + - agent tick: Agents do their actions. + - tick_step_all: Things to do after the agents did their actions. Statechange, Moving, Spawning etc... + - tick_post_step_all: Things to do at the very end of each step. Counting, Reward calculations etc... + + :return: List of *Result*-objects. + """ results = list() test_results = list() self.curr_step += 1 @@ -112,11 +161,23 @@ class Gamestate(object): return results - def print(self, string): + def print(self, string) -> None: + """ + When *verbose* is active, print stuff. + + :param string: *String* to print. + :type string: str + :return: Nothing + """ if self.verbose: print(string) - def check_done(self): + def check_done(self) -> List[DoneResult]: + """ + Iterate all **Rules** that override tehe *on_ckeck_done* hook. + + :return: List of Results + """ results = list() for rule in self.rules: if on_check_done_result := rule.on_check_done(self): @@ -124,24 +185,47 @@ class Gamestate(object): return results def get_all_pos_with_collisions(self) -> List[Tuple[(int, int)]]: - positions = [pos for pos, entity_list_for_position in self.entities.pos_dict.items() - if any([e.var_can_collide for e in entity_list_for_position])] + """ + Returns a list positions [(x, y), ... ] on which collisions occur. This does not include agents, + that were unable to move because their target direction was blocked, also a form of collision. + + :return: List of positions. + """ + positions = [pos for pos, entities in self.entities.pos_dict.items() if + len(entities) >= 2 and (len([e for e in entities if e.var_can_collide]) >= 2) + ] return positions - def check_move_validity(self, moving_entity, position): - if moving_entity.pos != position and not any( - entity.var_is_blocking_pos for entity in self.entities.pos_dict[position]) and not ( - moving_entity.var_is_blocking_pos and self.entities.is_occupied(position)): - return True - else: - return False + def check_move_validity(self, moving_entity: Entity, target_position: (int, int)) -> bool: + """ + Whether it is safe to move to the target positions and moving entity does not introduce a blocking attribute, + when position is allready occupied. - def check_pos_validity(self, position): - if not any(entity.var_is_blocking_pos for entity in self.entities.pos_dict[position]): - return True - else: - return False + :param moving_entity: Entity + :param target_position: pos + :return: Safe to move to + """ + is_not_blocked = self.check_pos_validity(target_position) + will_not_block_others = moving_entity.var_is_blocking_pos and self.entities.is_occupied(target_position) + + if moving_entity.pos != target_position and is_not_blocked and not will_not_block_others: + return c.VALID + else: + return c.NOT_VALID + + def check_pos_validity(self, pos: (int, int)) -> bool: + """ + Check if *pos* is a valid position to move or spawn to. + + :param pos: position to check + :return: Wheter pos is a valid target. + """ + + if not any(e.var_is_blocking_pos for e in self.entities.pos_dict[pos]) and pos in self.entities.floorlist: + return c.VALID + else: + return c.NOT_VALID class StepTests: def __init__(self, *args): diff --git a/marl_factory_grid/utils/tools.py b/marl_factory_grid/utils/tools.py index d2f9bd1..73fa50d 100644 --- a/marl_factory_grid/utils/tools.py +++ b/marl_factory_grid/utils/tools.py @@ -28,7 +28,9 @@ class ConfigExplainer: def explain_module(self, class_to_explain): parameters = inspect.signature(class_to_explain).parameters - explained = {class_to_explain.__name__: {key: val.default for key, val in parameters.items() if key not in EXCLUDED}} + explained = {class_to_explain.__name__: + {key: val.default for key, val in parameters.items() if key not in EXCLUDED} + } return explained def _load_and_compare(self, compare_class, paths): @@ -135,4 +137,3 @@ if __name__ == '__main__': ce.get_observations() ce.get_assets() all_conf = ce.get_all() - print() diff --git a/marl_factory_grid/utils/utility_classes.py b/marl_factory_grid/utils/utility_classes.py index 4844133..4d1cfe1 100644 --- a/marl_factory_grid/utils/utility_classes.py +++ b/marl_factory_grid/utils/utility_classes.py @@ -52,3 +52,6 @@ class Floor: def __hash__(self): return hash(self.name) + + def __repr__(self): + return f"Floor{self.pos}" diff --git a/random_testrun.py b/random_testrun.py index 9bebf17..ef8df08 100644 --- a/random_testrun.py +++ b/random_testrun.py @@ -6,18 +6,21 @@ from marl_factory_grid.environment.factory import Factory from marl_factory_grid.utils.logging.envmonitor import EnvMonitor from marl_factory_grid.utils.logging.recorder import EnvRecorder +from marl_factory_grid.utils.plotting.plot_single_runs import plot_single_run from marl_factory_grid.utils.tools import ConfigExplainer if __name__ == '__main__': # Render at each step? - render = True + render = False # Reveal all possible Modules (Entities, Rules, Agents[Actions, Observations], etc.) explain_config = False # Collect statistics? - monitor = False + monitor = True # Record as Protobuf? record = False + # Plot Results? + plotting = True run_path = Path('study_out') @@ -38,7 +41,7 @@ if __name__ == '__main__': factory = EnvRecorder(factory) # RL learn Loop - for episode in trange(500): + for episode in trange(10): _ = factory.reset() done = False if render: @@ -54,7 +57,10 @@ if __name__ == '__main__': break if monitor: - factory.save_run(run_path / 'test.pkl') + factory.save_run(run_path / 'test_monitor.pkl') if record: factory.save_records(run_path / 'test.pb') + if plotting: + plot_single_run(run_path) + print('Done!!! Goodbye....') diff --git a/reload_agent.py b/reload_agent.py index 8c16069..0fb8066 100644 --- a/reload_agent.py +++ b/reload_agent.py @@ -6,6 +6,7 @@ import yaml from marl_factory_grid.environment.factory import Factory from marl_factory_grid.utils.logging.envmonitor import EnvMonitor from marl_factory_grid.utils.logging.recorder import EnvRecorder +from marl_factory_grid.utils import helpers as h from marl_factory_grid.modules.doors import constants as d @@ -55,13 +56,14 @@ if __name__ == '__main__': for model_idx, model in enumerate(models)] else: actions = models[0].predict(env_state, deterministic=determin)[0] + # noinspection PyTupleAssignmentBalance env_state, step_r, done_bool, info_obj = env.step(actions) rew += step_r if render: env.render() try: - door = next(x for x in env.unwrapped.unwrapped[d.DOORS] if x.is_open) + door = h.get_first([x for x in env.unwrapped.unwrapped[d.DOORS] if x.is_open]) print('openDoor found') except StopIteration: pass diff --git a/studies/normalization_study.py b/studies/normalization_study.py index 37e10c4..7c72982 100644 --- a/studies/normalization_study.py +++ b/studies/normalization_study.py @@ -1,8 +1,8 @@ from algorithms.utils import Checkpointer from pathlib import Path from algorithms.utils import load_yaml_file, add_env_props, instantiate_class, load_class -#from algorithms.marl import LoopSNAC, LoopIAC, LoopSEAC +# from algorithms.marl import LoopSNAC, LoopIAC, LoopSEAC for i in range(0, 5): diff --git a/transform_wg_to_json_no_priv.py b/transform_wg_to_json_no_priv.py new file mode 100644 index 0000000..1b7ef3e --- /dev/null +++ b/transform_wg_to_json_no_priv.py @@ -0,0 +1,41 @@ +import configparser +import json +from datetime import datetime +from pathlib import Path + +if __name__ == '__main__': + + conf_path = Path('wg0') + wg0_conf = configparser.ConfigParser() + wg0_conf.read(conf_path/'wg0.conf') + interface = wg0_conf['Interface'] + # Iterate all pears + for client_name in wg0_conf.sections(): + if client_name == 'Interface': + continue + # Delete any old conf.json for the current peer + (conf_path / f'{client_name}.json').unlink(missing_ok=True) + + peer = wg0_conf[client_name] + + date_time = datetime.now().strftime('%Y-%m-%dT%H:%M:%S.%f000Z') + + jdict = dict( + id=client_name, + private_key=peer['PublicKey'], + public_key=peer['PublicKey'], + # preshared_key=wg0_conf[client_name_wg0]['PresharedKey'], + name=client_name, + email=f"sysadmin@mobile.ifi.lmu.de", + allocated_ips=[interface['Address'].replace('/24', '')], + allowed_ips=['10.4.0.0/24', '10.153.199.0/24'], + extra_allowed_ips=[], + use_server_dns=True, + enabled=True, + created_at=date_time, + updated_at=date_time + ) + + with (conf_path / f'{client_name}.json').open('w+') as f: + json.dump(jdict, f, indent='\t', separators=(',', ': ')) + print(client_name, ' written...')