diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..b58b603 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,5 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ diff --git a/README.md b/README.md index 07d67bf..d0c0a19 100644 --- a/README.md +++ b/README.md @@ -94,7 +94,7 @@ All [Entites](marl_factory_grid/environment/entity/global_entities.py) are avail #### Rules -[Rules](marl_factory_grid/environment/entity/object.py) define how the environment behaves on micro-scale. +[Rules](marl_factory_grid/environment/entity/object.py) define how the environment behaves on microscale. Each of the hookes (`on_init`, `pre_step`, `on_step`, '`post_step`', `on_done`) provide env-access to implement customn logic, calculate rewards, or gather information. @@ -107,6 +107,7 @@ Make sure to bring your own assets for each Entity living in the Gridworld as th PNG-files (transparent background) of square aspect-ratio should do the job, in general. + diff --git a/marl_factory_grid/__init__.py b/marl_factory_grid/__init__.py index 49f5635..259e3cf 100644 --- a/marl_factory_grid/__init__.py +++ b/marl_factory_grid/__init__.py @@ -1 +1 @@ -from .quickstart import init \ No newline at end of file +from .quickstart import init diff --git a/marl_factory_grid/algorithms/__init__.py b/marl_factory_grid/algorithms/__init__.py index 0980070..cc2c489 100644 --- a/marl_factory_grid/algorithms/__init__.py +++ b/marl_factory_grid/algorithms/__init__.py @@ -1 +1,4 @@ -import os, sys; sys.path.append(os.path.dirname(os.path.realpath(__file__))) +import os +import sys + +sys.path.append(os.path.dirname(os.path.realpath(__file__))) diff --git a/marl_factory_grid/algorithms/marl/__init__.py b/marl_factory_grid/algorithms/marl/__init__.py index 984588c..a4c30ef 100644 --- a/marl_factory_grid/algorithms/marl/__init__.py +++ b/marl_factory_grid/algorithms/marl/__init__.py @@ -1 +1 @@ -from marl_factory_grid.algorithms.marl.memory import MARLActorCriticMemory \ No newline at end of file +from marl_factory_grid.algorithms.marl.memory import MARLActorCriticMemory diff --git a/marl_factory_grid/algorithms/marl/base_ac.py b/marl_factory_grid/algorithms/marl/base_ac.py index 3bb0318..ef195b7 100644 --- a/marl_factory_grid/algorithms/marl/base_ac.py +++ b/marl_factory_grid/algorithms/marl/base_ac.py @@ -28,6 +28,7 @@ class Names: BATCH_SIZE = 'bnatch_size' N_ACTIONS = 'n_actions' + nms = Names ListOrTensor = Union[List, torch.Tensor] @@ -112,10 +113,9 @@ class BaseActorCritic: next_obs, reward, done, info = env.step(action) done = [done] * self.n_agents if isinstance(done, bool) else done - last_hiddens = dict(hidden_actor =out[nms.HIDDEN_ACTOR], + last_hiddens = dict(hidden_actor=out[nms.HIDDEN_ACTOR], hidden_critic=out[nms.HIDDEN_CRITIC]) - tm.add(observation=obs, action=action, reward=reward, done=done, logits=out.get(nms.LOGITS, None), values=out.get(nms.CRITIC, None), **last_hiddens) @@ -142,7 +142,9 @@ class BaseActorCritic: print(f'reward at episode: {episode} = {rew_log}') episode += 1 df_results.append([episode, rew_log, *reward]) - df_results = pd.DataFrame(df_results, columns=['steps', 'reward', *[f'agent#{i}' for i in range(self.n_agents)]]) + df_results = pd.DataFrame(df_results, + columns=['steps', 'reward', *[f'agent#{i}' for i in range(self.n_agents)]] + ) if checkpointer is not None: df_results.to_csv(checkpointer.path / 'results.csv', index=False) return df_results @@ -157,24 +159,27 @@ class BaseActorCritic: last_action, reward = [-1] * self.n_agents, [0.] * self.n_agents done, rew_log, eps_rew = [False] * self.n_agents, 0, torch.zeros(self.n_agents) while not all(done): - if render: env.render() + if render: + env.render() out = self.forward(obs, last_action, **last_hiddens) action = self.get_actions(out) next_obs, reward, done, info = env.step(action) - if isinstance(done, bool): done = [done] * obs.shape[0] + if isinstance(done, bool): + done = [done] * obs.shape[0] obs = next_obs last_action = action last_hiddens = dict(hidden_actor=out.get(nms.HIDDEN_ACTOR, None), hidden_critic=out.get(nms.HIDDEN_CRITIC, None) ) eps_rew += torch.tensor(reward) - results.append(eps_rew.tolist() + [sum(eps_rew).item()] + [episode]) + results.append(eps_rew.tolist() + [np.sum(eps_rew).item()] + [episode]) episode += 1 agent_columns = [f'agent#{i}' for i in range(self.cfg['environment']['n_agents'])] results = pd.DataFrame(results, columns=agent_columns + ['sum', 'episode']) - results = pd.melt(results, id_vars=['episode'], value_vars=agent_columns + ['sum'], value_name='reward', var_name='agent') + results = pd.melt(results, id_vars=['episode'], value_vars=agent_columns + ['sum'], + value_name='reward', var_name='agent') return results @staticmethod diff --git a/marl_factory_grid/algorithms/marl/mappo.py b/marl_factory_grid/algorithms/marl/mappo.py index d22fa08..faf3b0d 100644 --- a/marl_factory_grid/algorithms/marl/mappo.py +++ b/marl_factory_grid/algorithms/marl/mappo.py @@ -36,7 +36,7 @@ class LoopMAPPO(LoopSNAC): rewards_ = torch.stack(rewards_, dim=1) return rewards_ - def mappo(self, batch, network, gamma, entropy_coef, vf_coef, clip_range, **kwargs): + def mappo(self, batch, network, gamma, entropy_coef, vf_coef, clip_range, **__): out = network(batch[nms.OBSERVATION], batch[nms.ACTION], batch[nms.HIDDEN_ACTOR], batch[nms.HIDDEN_CRITIC]) logits = out[nms.LOGITS][:, :-1] # last one only needed for v_{t+1} @@ -45,7 +45,7 @@ class LoopMAPPO(LoopSNAC): # monte carlo returns mc_returns = self.monte_carlo_returns(batch[nms.REWARD], batch[nms.DONE], gamma) - mc_returns = (mc_returns - mc_returns.mean()) / (mc_returns.std() + 1e-8) #todo: norm across agent ok? + mc_returns = (mc_returns - mc_returns.mean()) / (mc_returns.std() + 1e-8) # todo: norm across agent ok? advantages = mc_returns - out[nms.CRITIC][:, :-1] # policy loss diff --git a/marl_factory_grid/algorithms/marl/networks.py b/marl_factory_grid/algorithms/marl/networks.py index c4fdb72..796c03f 100644 --- a/marl_factory_grid/algorithms/marl/networks.py +++ b/marl_factory_grid/algorithms/marl/networks.py @@ -1,8 +1,7 @@ +import numpy as np import torch import torch.nn as nn -import numpy as np import torch.nn.functional as F -from torch.nn.utils import spectral_norm class RecurrentAC(nn.Module): @@ -88,8 +87,8 @@ class NormalizedLinear(nn.Linear): self.trainable_magnitude = trainable_magnitude self.scale = nn.Parameter(torch.tensor([1.]), requires_grad=trainable_magnitude) - def forward(self, input): - normalized_input = F.normalize(input, dim=-1, p=2, eps=1e-5) + def forward(self, in_array): + normalized_input = F.normalize(in_array, dim=-1, p=2, eps=1e-5) normalized_weight = F.normalize(self.weight, dim=-1, p=2, eps=1e-5) return F.linear(normalized_input, normalized_weight) * self.d_sqrt * self.scale diff --git a/marl_factory_grid/algorithms/marl/seac.py b/marl_factory_grid/algorithms/marl/seac.py index 9c458c7..07e8267 100644 --- a/marl_factory_grid/algorithms/marl/seac.py +++ b/marl_factory_grid/algorithms/marl/seac.py @@ -16,7 +16,7 @@ class LoopSEAC(LoopIAC): with torch.inference_mode(True): true_action_logp = torch.stack([ torch.log_softmax(out[nms.LOGITS][ag_i, :-1], -1) - .gather(index=actions[ag_i, 1:, None], dim=-1) + .gather(index=actions[ag_i, 1:, None], dim=-1) for ag_i, out in enumerate(outputs) ], 0).squeeze() @@ -38,7 +38,6 @@ class LoopSEAC(LoopIAC): a2c_loss = (-iw*log_ap * advantages.detach()).mean(-1) - value_loss = (iw*advantages.pow(2)).mean(-1) # n_agent # weighted loss @@ -53,4 +52,4 @@ class LoopSEAC(LoopIAC): self.optimizer[ag_i].zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(self.net[ag_i].parameters(), 0.5) - self.optimizer[ag_i].step() \ No newline at end of file + self.optimizer[ag_i].step() diff --git a/marl_factory_grid/algorithms/marl/snac.py b/marl_factory_grid/algorithms/marl/snac.py index b249754..11be902 100644 --- a/marl_factory_grid/algorithms/marl/snac.py +++ b/marl_factory_grid/algorithms/marl/snac.py @@ -30,4 +30,4 @@ class LoopSNAC(BaseActorCritic): self._as_torch(actions).unsqueeze(1), hidden_actor, hidden_critic ) - return out \ No newline at end of file + return out diff --git a/marl_factory_grid/algorithms/static/TSP_base_agent.py b/marl_factory_grid/algorithms/static/TSP_base_agent.py index bc48f7c..7d25f63 100644 --- a/marl_factory_grid/algorithms/static/TSP_base_agent.py +++ b/marl_factory_grid/algorithms/static/TSP_base_agent.py @@ -56,8 +56,8 @@ class TSPBaseAgent(ABC): def _door_is_close(self, state): try: - # return next(y for x in self.state.tile.neighboring_floor for y in x.guests if do.DOOR in y.name) - return next(y for x in state.entities.neighboring_positions(self.state.pos) for y in state.entities.pos_dict[x] if do.DOOR in y.name) + return next(y for x in state.entities.neighboring_positions(self.state.pos) + for y in state.entities.pos_dict[x] if do.DOOR in y.name) except StopIteration: return None diff --git a/marl_factory_grid/algorithms/static/TSP_target_agent.py b/marl_factory_grid/algorithms/static/TSP_target_agent.py index 5e0f989..b0d8b29 100644 --- a/marl_factory_grid/algorithms/static/TSP_target_agent.py +++ b/marl_factory_grid/algorithms/static/TSP_target_agent.py @@ -14,8 +14,8 @@ class TSPTargetAgent(TSPBaseAgent): def _handle_doors(self, state): try: - # return next(y for x in self.state.tile.neighboring_floor for y in x.guests if do.DOOR in y.name) - return next(y for x in state.entities.neighboring_positions(self.state.pos) for y in state.entities.pos_dict[x] if do.DOOR in y.name) + return next(y for x in state.entities.neighboring_positions(self.state.pos) + for y in state.entities.pos_dict[x] if do.DOOR in y.name) except StopIteration: return None diff --git a/marl_factory_grid/algorithms/utils.py b/marl_factory_grid/algorithms/utils.py index 59e78bd..8c60386 100644 --- a/marl_factory_grid/algorithms/utils.py +++ b/marl_factory_grid/algorithms/utils.py @@ -1,8 +1,9 @@ -import torch -import numpy as np -import yaml from pathlib import Path +import numpy as np +import torch +import yaml + def load_class(classname): from importlib import import_module @@ -42,7 +43,6 @@ def get_class(arguments): def get_arguments(arguments): - from importlib import import_module d = dict(arguments) if "classname" in d: del d["classname"] @@ -82,4 +82,4 @@ class Checkpointer(object): for name, model in to_save: self.save_experiment(name, model) self.__current_checkpoint += 1 - self.__current_step += 1 \ No newline at end of file + self.__current_step += 1 diff --git a/marl_factory_grid/configs/narrow_corridor.yaml b/marl_factory_grid/configs/narrow_corridor.yaml index 04f42ae..f53b972 100644 --- a/marl_factory_grid/configs/narrow_corridor.yaml +++ b/marl_factory_grid/configs/narrow_corridor.yaml @@ -1,4 +1,4 @@ -eneral: +General: # Your Seed env_seed: 69 # Individual or global rewards? @@ -86,4 +86,4 @@ Rules: DoneAtDestinationReachAll: # reward_at_done: 1 DoneAtMaxStepsReached: - max_steps: 500 + max_steps: 200 diff --git a/marl_factory_grid/environment/entity/entity.py b/marl_factory_grid/environment/entity/entity.py index 4abf2af..999787b 100644 --- a/marl_factory_grid/environment/entity/entity.py +++ b/marl_factory_grid/environment/entity/entity.py @@ -1,15 +1,14 @@ import abc -from collections import defaultdict import numpy as np -from .object import _Object +from .object import Object from .. import constants as c from ...utils.results import ActionResult from ...utils.utility_classes import RenderEntity -class Entity(_Object, abc.ABC): +class Entity(Object, abc.ABC): """Full Env Entity that lives on the environment Grid. Doors, Items, DirtPile etc...""" @property @@ -96,8 +95,9 @@ class Entity(_Object, abc.ABC): def __init__(self, pos, bind_to=None, **kwargs): super().__init__(**kwargs) + self._view_directory = c.VALUE_NO_POS self._status = None - self.set_pos(pos) + self._pos = pos self._last_pos = pos if bind_to: try: @@ -113,10 +113,6 @@ class Entity(_Object, abc.ABC): def render(self): return RenderEntity(self.__class__.__name__.lower(), self.pos) - @abc.abstractmethod - def render(self): - return RenderEntity(self.__class__.__name__.lower(), self.pos) - @property def obs_tag(self): try: @@ -133,25 +129,3 @@ class Entity(_Object, abc.ABC): self._collection.delete_env_object(self) self._collection = other_collection return self._collection == other_collection - - @classmethod - def from_coordinates(cls, positions: [(int, int)], *args, entity_kwargs=None, **kwargs, ): - collection = cls(*args, **kwargs) - collection.add_items( - [cls._entity(tuple(pos), **entity_kwargs if entity_kwargs is not None else {}) for pos in positions]) - return collection - - def notify_del_entity(self, entity): - try: - self.pos_dict[entity.pos].remove(entity) - except (ValueError, AttributeError): - pass - - def by_pos(self, pos: (int, int)): - pos = tuple(pos) - try: - return self.state.entities.pos_dict[pos] - except StopIteration: - pass - except ValueError: - pass diff --git a/marl_factory_grid/environment/entity/object.py b/marl_factory_grid/environment/entity/object.py index 768f8b5..e8c69da 100644 --- a/marl_factory_grid/environment/entity/object.py +++ b/marl_factory_grid/environment/entity/object.py @@ -5,7 +5,7 @@ from marl_factory_grid.environment import constants as c import marl_factory_grid.utils.helpers as h -class _Object: +class Object: """Generell Objects for Organisation and Maintanance such as Actions etc...""" _u_idx = defaultdict(lambda: 0) @@ -50,15 +50,15 @@ class _Object: print(f'Following kwargs were passed, but ignored: {kwargs}') def __repr__(self): - name = self.name - if self.bound_entity: - name = h.add_bound_name(name, self.bound_entity) - try: - if self.var_has_position: - name = h.add_pos_name(name, self) - except (AttributeError): - pass - return name + name = self.name + if self.bound_entity: + name = h.add_bound_name(name, self.bound_entity) + try: + if self.var_has_position: + name = h.add_pos_name(name, self) + except AttributeError: + pass + return name def __eq__(self, other) -> bool: return other == self.identifier @@ -67,8 +67,8 @@ class _Object: return hash(self.identifier) def _identify_and_count_up(self): - idx = _Object._u_idx[self.__class__.__name__] - _Object._u_idx[self.__class__.__name__] += 1 + idx = Object._u_idx[self.__class__.__name__] + Object._u_idx[self.__class__.__name__] += 1 return idx def set_collection(self, collection): @@ -98,79 +98,3 @@ class _Object: def unbind(self): self._bound_entity = None - - -# class EnvObject(_Object): -# """Objects that hold Information that are observable, but have no position on the environment grid. Inventories etc...""" -# - # _u_idx = defaultdict(lambda: 0) -# -# @property -# def obs_tag(self): -# try: -# return self._collection.name or self.name -# except AttributeError: -# return self.name -# -# @property -# def var_is_blocking_light(self): -# try: -# return self._collection.var_is_blocking_light or False -# except AttributeError: -# return False -# -# @property -# def var_can_be_bound(self): -# try: -# return self._collection.var_can_be_bound or False -# except AttributeError: -# return False -# -# @property -# def var_can_move(self): -# try: -# return self._collection.var_can_move or False -# except AttributeError: -# return False -# -# @property -# def var_is_blocking_pos(self): -# try: -# return self._collection.var_is_blocking_pos or False -# except AttributeError: -# return False -# -# @property -# def var_has_position(self): -# try: -# return self._collection.var_has_position or False -# except AttributeError: -# return False -# -# @property -# def var_can_collide(self): -# try: -# return self._collection.var_can_collide or False -# except AttributeError: -# return False -# -# -# @property -# def encoding(self): -# return c.VALUE_OCCUPIED_CELL -# -# -# def __init__(self, **kwargs): -# self._bound_entity = None -# super(EnvObject, self).__init__(**kwargs) -# -# -# def change_parent_collection(self, other_collection): -# other_collection.add_item(self) -# self._collection.delete_env_object(self) -# self._collection = other_collection -# return self._collection == other_collection -# -# -# def summarize_state(self): -# return dict(name=str(self.name)) diff --git a/marl_factory_grid/environment/entity/util.py b/marl_factory_grid/environment/entity/util.py index d43c53a..2a15c41 100644 --- a/marl_factory_grid/environment/entity/util.py +++ b/marl_factory_grid/environment/entity/util.py @@ -1,6 +1,6 @@ import numpy as np -from marl_factory_grid.environment.entity.object import _Object +from marl_factory_grid.environment.entity.object import Object ########################################################################## @@ -8,7 +8,7 @@ from marl_factory_grid.environment.entity.object import _Object ########################################################################## -class PlaceHolder(_Object): +class PlaceHolder(Object): def __init__(self, *args, fill_value=0, **kwargs): super().__init__(*args, **kwargs) @@ -27,7 +27,7 @@ class PlaceHolder(_Object): return self.__class__.__name__ -class GlobalPosition(_Object): +class GlobalPosition(Object): @property def encoding(self): diff --git a/marl_factory_grid/environment/factory.py b/marl_factory_grid/environment/factory.py index 3c5f7f6..97ce621 100644 --- a/marl_factory_grid/environment/factory.py +++ b/marl_factory_grid/environment/factory.py @@ -56,15 +56,18 @@ class Factory(gym.Env): self.level_filepath = Path(custom_level_path) else: self.level_filepath = Path(__file__).parent.parent / h.LEVELS_DIR / f'{self.conf.level_name}.txt' - self._renderer = None # expensive - don't use; unless required ! parsed_entities = self.conf.load_entities() self.map = LevelParser(self.level_filepath, parsed_entities, self.conf.pomdp_r) # Init for later usage: - self.state: Gamestate - self.map: LevelParser - self.obs_builder: OBSBuilder + # noinspection PyTypeChecker + self.state: Gamestate = None + # noinspection PyTypeChecker + self.obs_builder: OBSBuilder = None + + # expensive - don't use; unless required ! + self._renderer = None # reset env to initial state, preparing env for new episode. # returns tuple where the first dict contains initial observation for each agent in the env @@ -74,7 +77,7 @@ class Factory(gym.Env): return self.state.entities[item] def reset(self) -> (dict, dict): - if hasattr(self, 'state'): + if self.state is not None: for entity_group in self.state.entities: try: entity_group[0].reset_uid() @@ -160,7 +163,7 @@ class Factory(gym.Env): # Finalize reward, reward_info, done = self.summarize_step_results(tick_result, done_results) - info = reward_info + info = dict(reward_info) info.update(step_reward=sum(reward), step=self.state.curr_step) diff --git a/marl_factory_grid/environment/groups/collection.py b/marl_factory_grid/environment/groups/collection.py index 140c941..c0f0f6b 100644 --- a/marl_factory_grid/environment/groups/collection.py +++ b/marl_factory_grid/environment/groups/collection.py @@ -1,15 +1,15 @@ from typing import List, Tuple, Union, Dict from marl_factory_grid.environment.entity.entity import Entity -from marl_factory_grid.environment.groups.objects import _Objects +from marl_factory_grid.environment.groups.objects import Objects # noinspection PyProtectedMember -from marl_factory_grid.environment.entity.object import _Object +from marl_factory_grid.environment.entity.object import Object import marl_factory_grid.environment.constants as c from marl_factory_grid.utils.results import Result -class Collection(_Objects): - _entity = _Object # entity? +class Collection(Objects): + _entity = Object # entity? symbol = None @property @@ -58,7 +58,7 @@ class Collection(_Objects): def trigger_spawn(self, state, *entity_args, coords_or_quantity=None, ignore_blocking=False, **entity_kwargs): coords_or_quantity = coords_or_quantity if coords_or_quantity else self._coords_or_quantity if self.var_has_position: - if isinstance(coords_or_quantity, int): + if self.var_has_position and isinstance(coords_or_quantity, int): if ignore_blocking or self._ignore_blocking: coords_or_quantity = state.entities.floorlist[:coords_or_quantity] else: @@ -87,8 +87,8 @@ class Collection(_Objects): raise ValueError(f'{self._entity.__name__} has no position!') return c.VALID - def despawn(self, items: List[_Object]): - items = [items] if isinstance(items, _Object) else items + def despawn(self, items: List[Object]): + items = [items] if isinstance(items, Object) else items for item in items: del self[item] diff --git a/marl_factory_grid/environment/groups/global_entities.py b/marl_factory_grid/environment/groups/global_entities.py index 7a50de4..37779f9 100644 --- a/marl_factory_grid/environment/groups/global_entities.py +++ b/marl_factory_grid/environment/groups/global_entities.py @@ -3,12 +3,12 @@ from operator import itemgetter from random import shuffle from typing import Dict -from marl_factory_grid.environment.groups.objects import _Objects +from marl_factory_grid.environment.groups.objects import Objects from marl_factory_grid.utils.helpers import POS_MASK -class Entities(_Objects): - _entity = _Objects +class Entities(Objects): + _entity = Objects @staticmethod def neighboring_positions(pos): @@ -87,7 +87,7 @@ class Entities(_Objects): def __delitem__(self, name): assert_str = 'This group of entity does not exist in this collection!' assert any([key for key in name.keys() if key in self.keys()]), assert_str - self[name]._observers.delete(self) + self[name].del_observer(self) for entity in self[name]: entity.del_observer(self) return super(Entities, self).__delitem__(name) diff --git a/marl_factory_grid/environment/groups/objects.py b/marl_factory_grid/environment/groups/objects.py index d29cc2c..9229787 100644 --- a/marl_factory_grid/environment/groups/objects.py +++ b/marl_factory_grid/environment/groups/objects.py @@ -1,15 +1,15 @@ from collections import defaultdict -from typing import List +from typing import List, Iterator, Union import numpy as np -from marl_factory_grid.environment.entity.object import _Object +from marl_factory_grid.environment.entity.object import Object import marl_factory_grid.environment.constants as c from marl_factory_grid.utils import helpers as h -class _Objects: - _entity = _Object +class Objects: + _entity = Object @property def var_can_be_bound(self): @@ -50,7 +50,7 @@ class _Objects: def __len__(self): return len(self._data) - def __iter__(self): + def __iter__(self) -> Iterator[Union[Object, None]]: return iter(self.values()) def add_item(self, item: _entity): @@ -130,13 +130,14 @@ class _Objects: repr_dict = {key: val for key, val in self._data.items() if key not in [c.WALLS]} return f'{self.__class__.__name__}[{repr_dict}]' - def notify_del_entity(self, entity: _Object): + def notify_del_entity(self, entity: Object): try: + # noinspection PyUnresolvedReferences self.pos_dict[entity.pos].remove(entity) except (AttributeError, ValueError, IndexError): pass - def notify_add_entity(self, entity: _Object): + def notify_add_entity(self, entity: Object): try: if self not in entity.observers: entity.add_observer(self) diff --git a/marl_factory_grid/environment/rules.py b/marl_factory_grid/environment/rules.py index 55bb2bc..f5b6836 100644 --- a/marl_factory_grid/environment/rules.py +++ b/marl_factory_grid/environment/rules.py @@ -1,11 +1,11 @@ import abc from random import shuffle -from typing import List, Collection, Union +from typing import List, Collection +from marl_factory_grid.environment import rewards as r, constants as c from marl_factory_grid.environment.entity.agent import Agent from marl_factory_grid.utils import helpers as h from marl_factory_grid.utils.results import TickResult, DoneResult -from marl_factory_grid.environment import rewards as r, constants as c class Rule(abc.ABC): @@ -118,8 +118,7 @@ class AssignGlobalPositions(Rule): def on_init(self, state, lvl_map): from marl_factory_grid.environment.entity.util import GlobalPosition for agent in state[c.AGENT]: - gp = GlobalPosition(lvl_map.level_shape) - gp.bind_to(agent) + gp = GlobalPosition(agent, lvl_map.level_shape) state[c.GLOBALPOSITIONS].add_item(gp) return [] diff --git a/marl_factory_grid/modules/_template/rules.py b/marl_factory_grid/modules/_template/rules.py index 6ed2f2d..7696616 100644 --- a/marl_factory_grid/modules/_template/rules.py +++ b/marl_factory_grid/modules/_template/rules.py @@ -6,7 +6,9 @@ from marl_factory_grid.utils.results import TickResult, DoneResult class TemplateRule(Rule): def __init__(self, *args, **kwargs): - super(TemplateRule, self).__init__(*args, **kwargs) + super(TemplateRule, self).__init__() + self.args = args + self.kwargs = kwargs def on_init(self, state, lvl_map): pass diff --git a/marl_factory_grid/modules/batteries/actions.py b/marl_factory_grid/modules/batteries/actions.py index bd755a2..7d1c4a2 100644 --- a/marl_factory_grid/modules/batteries/actions.py +++ b/marl_factory_grid/modules/batteries/actions.py @@ -1,6 +1,5 @@ from typing import Union -import marl_factory_grid.modules.batteries.constants from marl_factory_grid.environment.actions import Action from marl_factory_grid.utils.results import ActionResult @@ -24,5 +23,6 @@ class BtryCharge(Action): else: valid = c.NOT_VALID state.print(f'{entity.name} failed to charged batteries at {entity.pos}.') + return ActionResult(entity=entity, identifier=self._identifier, validity=valid, - reward=marl_factory_grid.modules.batteries.constants.REWARD_CHARGE_VALID if valid else marl_factory_grid.modules.batteries.constants.Reward_CHARGE_FAIL) + reward=b.REWARD_CHARGE_VALID if valid else b.Reward_CHARGE_FAIL) diff --git a/marl_factory_grid/modules/batteries/entitites.py b/marl_factory_grid/modules/batteries/entitites.py index 751d57b..7675fe9 100644 --- a/marl_factory_grid/modules/batteries/entitites.py +++ b/marl_factory_grid/modules/batteries/entitites.py @@ -1,11 +1,11 @@ from marl_factory_grid.environment import constants as c from marl_factory_grid.environment.entity.entity import Entity -from marl_factory_grid.environment.entity.object import _Object +from marl_factory_grid.environment.entity.object import Object from marl_factory_grid.modules.batteries import constants as b from marl_factory_grid.utils.utility_classes import RenderEntity -class Battery(_Object): +class Battery(Object): @property def var_can_be_bound(self): diff --git a/marl_factory_grid/modules/batteries/rules.py b/marl_factory_grid/modules/batteries/rules.py index 84b7ef2..8a4725b 100644 --- a/marl_factory_grid/modules/batteries/rules.py +++ b/marl_factory_grid/modules/batteries/rules.py @@ -1,11 +1,9 @@ from typing import List, Union -import marl_factory_grid.modules.batteries.constants -from marl_factory_grid.environment.rules import Rule -from marl_factory_grid.utils.results import TickResult, DoneResult - from marl_factory_grid.environment import constants as c +from marl_factory_grid.environment.rules import Rule from marl_factory_grid.modules.batteries import constants as b +from marl_factory_grid.utils.results import TickResult, DoneResult class BatteryDecharge(Rule): diff --git a/marl_factory_grid/modules/clean_up/entitites.py b/marl_factory_grid/modules/clean_up/entitites.py index 19e703c..25c6eb1 100644 --- a/marl_factory_grid/modules/clean_up/entitites.py +++ b/marl_factory_grid/modules/clean_up/entitites.py @@ -1,5 +1,3 @@ -from numpy import random - from marl_factory_grid.environment.entity.entity import Entity from marl_factory_grid.utils.utility_classes import RenderEntity from marl_factory_grid.modules.clean_up import constants as d diff --git a/marl_factory_grid/modules/clean_up/groups.py b/marl_factory_grid/modules/clean_up/groups.py index 2029171..7ae3247 100644 --- a/marl_factory_grid/modules/clean_up/groups.py +++ b/marl_factory_grid/modules/clean_up/groups.py @@ -1,9 +1,7 @@ -from typing import Union, List, Tuple - from marl_factory_grid.environment import constants as c -from marl_factory_grid.utils.results import Result from marl_factory_grid.environment.groups.collection import Collection from marl_factory_grid.modules.clean_up.entitites import DirtPile +from marl_factory_grid.utils.results import Result class DirtPiles(Collection): diff --git a/marl_factory_grid/modules/clean_up/rules.py b/marl_factory_grid/modules/clean_up/rules.py index be2f9b9..b81ee41 100644 --- a/marl_factory_grid/modules/clean_up/rules.py +++ b/marl_factory_grid/modules/clean_up/rules.py @@ -49,7 +49,7 @@ class RespawnDirt(Rule): def tick_step(self, state): collection = state[d.DIRT] if self._next_dirt_spawn < 0: - pass # No DirtPile Spawn + result = [] # No DirtPile Spawn elif not self._next_dirt_spawn: result = [collection.trigger_spawn(state, coords_or_quantity=self.respawn_n, amount=self.respawn_amount)] self._next_dirt_spawn = self.respawn_freq diff --git a/marl_factory_grid/modules/destinations/actions.py b/marl_factory_grid/modules/destinations/actions.py index 13f7fe3..6367acd 100644 --- a/marl_factory_grid/modules/destinations/actions.py +++ b/marl_factory_grid/modules/destinations/actions.py @@ -21,4 +21,4 @@ class DestAction(Action): valid = c.NOT_VALID state.print(f'{entity.name} just tried to do_wait_action do_wait_action at {entity.pos} but failed') return ActionResult(entity=entity, identifier=self._identifier, validity=valid, - reward=marl_factory_grid.modules.destinations.constants.REWARD_WAIT_VALID if valid else marl_factory_grid.modules.destinations.constants.REWARD_WAIT_FAIL) + reward=d.REWARD_WAIT_VALID if valid else d.REWARD_WAIT_FAIL) diff --git a/marl_factory_grid/modules/destinations/groups.py b/marl_factory_grid/modules/destinations/groups.py index 5f0b654..f0b7f9e 100644 --- a/marl_factory_grid/modules/destinations/groups.py +++ b/marl_factory_grid/modules/destinations/groups.py @@ -1,7 +1,5 @@ from marl_factory_grid.environment.groups.collection import Collection from marl_factory_grid.modules.destinations.entitites import Destination -from marl_factory_grid.environment import constants as c -from marl_factory_grid.modules.destinations import constants as d class Destinations(Collection): diff --git a/marl_factory_grid/modules/doors/groups.py b/marl_factory_grid/modules/doors/groups.py index a27d598..973d1ab 100644 --- a/marl_factory_grid/modules/doors/groups.py +++ b/marl_factory_grid/modules/doors/groups.py @@ -1,5 +1,3 @@ -from typing import Union - from marl_factory_grid.environment.groups.collection import Collection from marl_factory_grid.modules.doors import constants as d from marl_factory_grid.modules.doors.entitites import Door diff --git a/marl_factory_grid/modules/doors/rewards.py b/marl_factory_grid/modules/doors/rewards.py index c87d123..b38c7c5 100644 --- a/marl_factory_grid/modules/doors/rewards.py +++ b/marl_factory_grid/modules/doors/rewards.py @@ -1,2 +1,2 @@ USE_DOOR_VALID: float = -0.00 -USE_DOOR_FAIL: float = -0.01 \ No newline at end of file +USE_DOOR_FAIL: float = -0.01 diff --git a/marl_factory_grid/modules/factory/rules.py b/marl_factory_grid/modules/factory/rules.py index d736f7a..e056135 100644 --- a/marl_factory_grid/modules/factory/rules.py +++ b/marl_factory_grid/modules/factory/rules.py @@ -1,8 +1,8 @@ import random -from typing import List, Union +from typing import List -from marl_factory_grid.environment.rules import Rule from marl_factory_grid.environment import constants as c +from marl_factory_grid.environment.rules import Rule from marl_factory_grid.utils.results import TickResult @@ -14,8 +14,6 @@ class AgentSingleZonePlacementBeta(Rule): super().__init__() def on_init(self, state, lvl_map): - zones = state[c.ZONES] - n_zones = state[c.ZONES] agents = state[c.AGENT] if len(self.coordinates) == len(agents): coordinates = self.coordinates @@ -31,4 +29,4 @@ class AgentSingleZonePlacementBeta(Rule): return [] def tick_post_step(self, state) -> List[TickResult]: - return [] \ No newline at end of file + return [] diff --git a/marl_factory_grid/modules/items/constants.py b/marl_factory_grid/modules/items/constants.py index 86b8b0c..5cb82c3 100644 --- a/marl_factory_grid/modules/items/constants.py +++ b/marl_factory_grid/modules/items/constants.py @@ -1,6 +1,3 @@ -from typing import NamedTuple - - SYMBOL_NO_ITEM = 0 SYMBOL_DROP_OFF = 1 # Item Env diff --git a/marl_factory_grid/modules/items/entitites.py b/marl_factory_grid/modules/items/entitites.py index ff34e23..8549134 100644 --- a/marl_factory_grid/modules/items/entitites.py +++ b/marl_factory_grid/modules/items/entitites.py @@ -14,27 +14,14 @@ class Item(Entity): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - @property - def auto_despawn(self): - return self._auto_despawn - @property def encoding(self): # Edit this if you want items to be drawn in the ops differently return 1 - def set_auto_despawn(self, auto_despawn): - self._auto_despawn = auto_despawn - - def summarize_state(self) -> dict: - super_summarization = super(Item, self).summarize_state() - super_summarization.update(dict(auto_despawn=self.auto_despawn)) - return super_summarization - class DropOffLocation(Entity): - def render(self): return RenderEntity(i.DROP_OFF, self.pos) @@ -42,18 +29,16 @@ class DropOffLocation(Entity): def encoding(self): return i.SYMBOL_DROP_OFF - def __init__(self, *args, storage_size_until_full: int = 5, auto_item_despawn_interval: int = 5, **kwargs): + def __init__(self, *args, storage_size_until_full: int = 5, **kwargs): super(DropOffLocation, self).__init__(*args, **kwargs) - self.auto_item_despawn_interval = auto_item_despawn_interval self.storage = deque(maxlen=storage_size_until_full or None) def place_item(self, item: Item): if self.is_full: raise RuntimeWarning("There is currently no way to clear the storage or make it unfull.") - return bc.NOT_VALID # in Zeile 81 verschieben? + return bc.NOT_VALID else: self.storage.append(item) - item.set_auto_despawn(self.auto_item_despawn_interval) return c.VALID @property diff --git a/marl_factory_grid/modules/items/groups.py b/marl_factory_grid/modules/items/groups.py index deb1812..be5ca49 100644 --- a/marl_factory_grid/modules/items/groups.py +++ b/marl_factory_grid/modules/items/groups.py @@ -1,12 +1,9 @@ -from random import shuffle - -from marl_factory_grid.modules.items import constants as i from marl_factory_grid.environment import constants as c - -from marl_factory_grid.environment.groups.collection import Collection -from marl_factory_grid.environment.groups.objects import _Objects -from marl_factory_grid.environment.groups.mixins import IsBoundMixin from marl_factory_grid.environment.entity.agent import Agent +from marl_factory_grid.environment.groups.collection import Collection +from marl_factory_grid.environment.groups.mixins import IsBoundMixin +from marl_factory_grid.environment.groups.objects import Objects +from marl_factory_grid.modules.items import constants as i from marl_factory_grid.modules.items.entitites import Item, DropOffLocation from marl_factory_grid.utils.results import Result @@ -74,13 +71,12 @@ class Inventory(IsBoundMixin, Collection): self._collection = collection -class Inventories(_Objects): +class Inventories(Objects): _entity = Inventory var_can_move = False var_has_position = False - symbol = None @property @@ -116,7 +112,6 @@ class Inventories(_Objects): return [val.summarize_states(**kwargs) for key, val in self.items()] - class DropOffLocations(Collection): _entity = DropOffLocation diff --git a/marl_factory_grid/modules/items/rewards.py b/marl_factory_grid/modules/items/rewards.py index 40adf46..bcd2918 100644 --- a/marl_factory_grid/modules/items/rewards.py +++ b/marl_factory_grid/modules/items/rewards.py @@ -1,4 +1,4 @@ DROP_OFF_VALID: float = 0.1 DROP_OFF_FAIL: float = -0.1 PICK_UP_FAIL: float = -0.1 -PICK_UP_VALID: float = 0.1 \ No newline at end of file +PICK_UP_VALID: float = 0.1 diff --git a/marl_factory_grid/modules/machines/actions.py b/marl_factory_grid/modules/machines/actions.py index 970f85f..dbb303f 100644 --- a/marl_factory_grid/modules/machines/actions.py +++ b/marl_factory_grid/modules/machines/actions.py @@ -1,9 +1,10 @@ from typing import Union +import marl_factory_grid.modules.machines.constants from marl_factory_grid.environment.actions import Action from marl_factory_grid.utils.results import ActionResult -from marl_factory_grid.modules.machines import constants as m, rewards as r +from marl_factory_grid.modules.machines import constants as m from marl_factory_grid.environment import constants as c from marl_factory_grid.utils import helpers as h @@ -16,8 +17,10 @@ class MachineAction(Action): def do(self, entity, state) -> Union[None, ActionResult]: if machine := h.get_first(state[m.MACHINES].by_pos(entity.pos)): if valid := machine.maintain(): - return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=r.MAINTAIN_VALID) + return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=marl_factory_grid.modules.machines.constants.MAINTAIN_VALID) else: - return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=r.MAINTAIN_FAIL) + return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=marl_factory_grid.modules.machines.constants.MAINTAIN_FAIL) else: - return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=r.MAINTAIN_FAIL) + return ActionResult(entity=entity, identifier=self._identifier, + validity=c.NOT_VALID, reward=marl_factory_grid.modules.machines.constants.MAINTAIN_FAIL + ) diff --git a/marl_factory_grid/modules/machines/constants.py b/marl_factory_grid/modules/machines/constants.py index 29ce3bc..3771cbb 100644 --- a/marl_factory_grid/modules/machines/constants.py +++ b/marl_factory_grid/modules/machines/constants.py @@ -11,3 +11,7 @@ STATE_MAINTAIN = 'maintenance' SYMBOL_WORK = 1 SYMBOL_IDLE = 0.6 SYMBOL_MAINTAIN = 0.3 +MAINTAIN_VALID: float = 0.5 +MAINTAIN_FAIL: float = -0.1 +FAIL_MISSING_MAINTENANCE: float = -0.5 +NONE: float = 0 diff --git a/marl_factory_grid/modules/machines/entitites.py b/marl_factory_grid/modules/machines/entitites.py index f5775e1..581adf6 100644 --- a/marl_factory_grid/modules/machines/entitites.py +++ b/marl_factory_grid/modules/machines/entitites.py @@ -31,11 +31,10 @@ class Machine(Entity): return c.NOT_VALID def tick(self, state): - # if self.status == m.STATE_MAINTAIN and any([c.AGENT in x.name for x in self.tile.guests]): - if self.status == m.STATE_MAINTAIN and any([c.AGENT in x.name for x in state.entities.pos_dict[self.pos]]): + others = state.entities.pos_dict[self.pos] + if self.status == m.STATE_MAINTAIN and any([c.AGENT in x.name for x in others]): return TickResult(identifier=self.name, validity=c.VALID, entity=self) - # elif self.status == m.STATE_MAINTAIN and not any([c.AGENT in x.name for x in self.tile.guests]): - elif self.status == m.STATE_MAINTAIN and not any([c.AGENT in x.name for x in state.entities.pos_dict[self.pos]]): + elif self.status == m.STATE_MAINTAIN and not any([c.AGENT in x.name for x in others]): self.status = m.STATE_WORK self.reset_counter() return None diff --git a/marl_factory_grid/modules/machines/groups.py b/marl_factory_grid/modules/machines/groups.py index 5f2d970..9d89d6c 100644 --- a/marl_factory_grid/modules/machines/groups.py +++ b/marl_factory_grid/modules/machines/groups.py @@ -1,5 +1,3 @@ -from typing import Union, List, Tuple - from marl_factory_grid.environment.groups.collection import Collection from .entitites import Machine diff --git a/marl_factory_grid/modules/machines/rewards.py b/marl_factory_grid/modules/machines/rewards.py deleted file mode 100644 index c868196..0000000 --- a/marl_factory_grid/modules/machines/rewards.py +++ /dev/null @@ -1,5 +0,0 @@ -MAINTAIN_VALID: float = 0.5 -MAINTAIN_FAIL: float = -0.1 -FAIL_MISSING_MAINTENANCE: float = -0.5 - -NONE: float = 0 diff --git a/marl_factory_grid/modules/maintenance/constants.py b/marl_factory_grid/modules/maintenance/constants.py index e0ab12c..3aed36c 100644 --- a/marl_factory_grid/modules/maintenance/constants.py +++ b/marl_factory_grid/modules/maintenance/constants.py @@ -1,3 +1,4 @@ MAINTAINER = 'Maintainer' # TEMPLATE _identifier. Define your own! MAINTAINERS = 'Maintainers' # TEMPLATE _identifier. Define your own! +MAINTAINER_COLLISION_REWARD = -5 diff --git a/marl_factory_grid/modules/maintenance/groups.py b/marl_factory_grid/modules/maintenance/groups.py index 79f7480..5b09c9c 100644 --- a/marl_factory_grid/modules/maintenance/groups.py +++ b/marl_factory_grid/modules/maintenance/groups.py @@ -4,7 +4,6 @@ from marl_factory_grid.environment.groups.collection import Collection from .entities import Maintainer from ..machines import constants as mc from ..machines.actions import MachineAction -from ...utils.states import Gamestate class Maintainers(Collection): @@ -23,8 +22,6 @@ class Maintainers(Collection): self.size = size self._spawnrule = spawnrule - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args): self.add_items([self._entity(mc.MACHINES, MachineAction(), pos) for pos in coords_or_quantity]) diff --git a/marl_factory_grid/modules/maintenance/rewards.py b/marl_factory_grid/modules/maintenance/rewards.py deleted file mode 100644 index 425ac3b..0000000 --- a/marl_factory_grid/modules/maintenance/rewards.py +++ /dev/null @@ -1 +0,0 @@ -MAINTAINER_COLLISION_REWARD = -5 \ No newline at end of file diff --git a/marl_factory_grid/modules/maintenance/rules.py b/marl_factory_grid/modules/maintenance/rules.py index fdefe42..92e6e75 100644 --- a/marl_factory_grid/modules/maintenance/rules.py +++ b/marl_factory_grid/modules/maintenance/rules.py @@ -1,15 +1,16 @@ from typing import List + +import marl_factory_grid.modules.maintenance.constants from marl_factory_grid.environment.rules import Rule from marl_factory_grid.utils.results import TickResult, DoneResult from marl_factory_grid.environment import constants as c -from . import rewards as r from . import constants as M class MoveMaintainers(Rule): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__(self): + super().__init__() def tick_step(self, state) -> List[TickResult]: for maintainer in state[M.MAINTAINERS]: @@ -20,8 +21,8 @@ class MoveMaintainers(Rule): class DoneAtMaintainerCollision(Rule): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__(self): + super().__init__() def on_check_done(self, state) -> List[DoneResult]: agents = list(state[c.AGENT].values()) @@ -30,5 +31,5 @@ class DoneAtMaintainerCollision(Rule): for agent in agents: if agent.pos in m_pos: done_results.append(DoneResult(entity=agent, validity=c.VALID, identifier=self.name, - reward=r.MAINTAINER_COLLISION_REWARD)) + reward=marl_factory_grid.modules.maintenance.constants.MAINTAINER_COLLISION_REWARD)) return done_results diff --git a/marl_factory_grid/modules/zones/entitites.py b/marl_factory_grid/modules/zones/entitites.py index cfd313f..4aa0f70 100644 --- a/marl_factory_grid/modules/zones/entitites.py +++ b/marl_factory_grid/modules/zones/entitites.py @@ -1,10 +1,10 @@ import random from typing import List, Tuple -from marl_factory_grid.environment.entity.object import _Object +from marl_factory_grid.environment.entity.object import Object -class Zone(_Object): +class Zone(Object): @property def positions(self): diff --git a/marl_factory_grid/modules/zones/groups.py b/marl_factory_grid/modules/zones/groups.py index 71eb329..f5494cd 100644 --- a/marl_factory_grid/modules/zones/groups.py +++ b/marl_factory_grid/modules/zones/groups.py @@ -1,8 +1,8 @@ -from marl_factory_grid.environment.groups.objects import _Objects +from marl_factory_grid.environment.groups.objects import Objects from marl_factory_grid.modules.zones import Zone -class Zones(_Objects): +class Zones(Objects): symbol = None _entity = Zone diff --git a/marl_factory_grid/utils/config_parser.py b/marl_factory_grid/utils/config_parser.py index 7cdc9e6..8215ed2 100644 --- a/marl_factory_grid/utils/config_parser.py +++ b/marl_factory_grid/utils/config_parser.py @@ -58,7 +58,10 @@ class FactoryConfigParser(object): return str(self.config) def __getitem__(self, item): - return self.config[item] + try: + return self.config[item] + except KeyError: + print(f'The mandatory {item} section could not be found in your .config gile. Check Spelling!') def load_entities(self): entity_classes = dict() @@ -161,7 +164,6 @@ class FactoryConfigParser(object): def _load_smth(self, config, class_obj): rules = list() - rules_names = list() for rule in config: e1 = e2 = e3 = None try: diff --git a/marl_factory_grid/utils/helpers.py b/marl_factory_grid/utils/helpers.py index ae68bf7..f5f6d00 100644 --- a/marl_factory_grid/utils/helpers.py +++ b/marl_factory_grid/utils/helpers.py @@ -61,8 +61,8 @@ class ObservationTranslator: :param per_agent_named_obs_spaces: `Named observation space` one for each agent. Overloaded. type per_agent_named_obs_spaces: Dict[str, dict] - :param placeholder_fill_value: Currently not fully implemented!!! - :type placeholder_fill_value: Union[int, str] = 'N') + :param placeholder_fill_value: Currently, not fully implemented!!! + :type placeholder_fill_value: Union[int, str] = 'N' """ if isinstance(placeholder_fill_value, str): diff --git a/marl_factory_grid/utils/logging/envmonitor.py b/marl_factory_grid/utils/logging/envmonitor.py index 67eac73..e2551c8 100644 --- a/marl_factory_grid/utils/logging/envmonitor.py +++ b/marl_factory_grid/utils/logging/envmonitor.py @@ -9,7 +9,7 @@ from marl_factory_grid.utils.helpers import IGNORED_DF_COLUMNS import pandas as pd -from marl_factory_grid.utils.plotting.compare_runs import plot_single_run +from marl_factory_grid.utils.plotting.plot_single_runs import plot_single_run class EnvMonitor(Wrapper): @@ -22,7 +22,6 @@ class EnvMonitor(Wrapper): self._monitor_df = pd.DataFrame() self._monitor_dict = dict() - def step(self, action): obs_type, obs, reward, done, info = self.env.step(action) self._read_info(info) diff --git a/marl_factory_grid/utils/logging/recorder.py b/marl_factory_grid/utils/logging/recorder.py index fac2e16..797866e 100644 --- a/marl_factory_grid/utils/logging/recorder.py +++ b/marl_factory_grid/utils/logging/recorder.py @@ -2,11 +2,9 @@ from os import PathLike from pathlib import Path from typing import Union, List -import yaml -from gymnasium import Wrapper - import numpy as np import pandas as pd +from gymnasium import Wrapper class EnvRecorder(Wrapper): @@ -106,7 +104,7 @@ class EnvRecorder(Wrapper): out_dict = {'episodes': self._recorder_out_list} out_dict.update( {'n_episodes': self._curr_episode, - 'metadata':dict( + 'metadata': dict( level_name=self.env.params['General']['level_name'], verbose=False, n_agents=len(self.env.params['Agents']), diff --git a/marl_factory_grid/utils/observation_builder.py b/marl_factory_grid/utils/observation_builder.py index df10ae9..55d6ec0 100644 --- a/marl_factory_grid/utils/observation_builder.py +++ b/marl_factory_grid/utils/observation_builder.py @@ -5,7 +5,7 @@ from typing import Dict, List import numpy as np from marl_factory_grid.environment import constants as c -from marl_factory_grid.environment.entity.object import _Object +from marl_factory_grid.environment.entity.object import Object from marl_factory_grid.environment.groups.utils import Combined from marl_factory_grid.utils.utility_classes import Floor from marl_factory_grid.utils.ray_caster import RayCaster @@ -13,7 +13,6 @@ from marl_factory_grid.utils.states import Gamestate from marl_factory_grid.utils import helpers as h - class OBSBuilder(object): default_obs = [c.WALLS, c.OTHERS] @@ -128,7 +127,7 @@ class OBSBuilder(object): f'{re.escape("[")}(.*){re.escape("]")}' f'{re.escape("(")}{re.escape(agent.name)}{re.escape(")")}') name = next((key for key, val in self.all_obs.items() - if pattern.search(str(val)) and isinstance(val, _Object)), None) + if pattern.search(str(val)) and isinstance(val, Object)), None) e = self.all_obs[name] except KeyError: try: @@ -181,11 +180,11 @@ class OBSBuilder(object): return obs, self.obs_layers[agent.name] def _sort_and_name_observation_conf(self, agent): - ''' + """ Builds the useable observation scheme per agent from conf.yaml. :param agent: :return: - ''' + """ # Fixme: no asymetric shapes possible. self.ray_caster[agent.name] = RayCaster(agent, min(self.obs_shape)) obs_layers = [] diff --git a/marl_factory_grid/utils/plotting/compare_runs.py b/marl_factory_grid/utils/plotting/plot_compare_runs.py similarity index 83% rename from marl_factory_grid/utils/plotting/compare_runs.py rename to marl_factory_grid/utils/plotting/plot_compare_runs.py index cb5c853..5115478 100644 --- a/marl_factory_grid/utils/plotting/compare_runs.py +++ b/marl_factory_grid/utils/plotting/plot_compare_runs.py @@ -7,50 +7,11 @@ from typing import Union, List import pandas as pd from marl_factory_grid.utils.helpers import IGNORED_DF_COLUMNS -from marl_factory_grid.utils.plotting.plotting import prepare_plot +from marl_factory_grid.utils.plotting.plotting_utils import prepare_plot MODEL_MAP = None -def plot_single_run(run_path: Union[str, PathLike], use_tex: bool = False, column_keys=None): - run_path = Path(run_path) - df_list = list() - if run_path.is_dir(): - monitor_file = next(run_path.glob('*monitor*.pick')) - elif run_path.exists() and run_path.is_file(): - monitor_file = run_path - else: - raise ValueError - - with monitor_file.open('rb') as f: - monitor_df = pickle.load(f) - - monitor_df = monitor_df.fillna(0) - df_list.append(monitor_df) - - df = pd.concat(df_list, ignore_index=True) - df = df.fillna(0).rename(columns={'episode': 'Episode'}).sort_values(['Episode']) - if column_keys is not None: - columns = [col for col in column_keys if col in df.columns] - else: - columns = [col for col in df.columns if col not in IGNORED_DF_COLUMNS] - - roll_n = 50 - - non_overlapp_window = df.groupby(['Episode']).rolling(roll_n, min_periods=1).mean() - - df_melted = df[columns + ['Episode']].reset_index().melt( - id_vars=['Episode'], value_vars=columns, var_name="Measurement", value_name="Score" - ) - - if df_melted['Episode'].max() > 800: - skip_n = round(df_melted['Episode'].max() * 0.02) - df_melted = df_melted[df_melted['Episode'] % skip_n == 0] - - prepare_plot(run_path.parent / f'{run_path.parent.name}_monitor_lineplot.png', df_melted, use_tex=use_tex) - print('Plotting done.') - - def compare_seed_runs(run_path: Union[str, PathLike], use_tex: bool = False): run_path = Path(run_path) df_list = list() diff --git a/marl_factory_grid/utils/plotting/plot_single_runs.py b/marl_factory_grid/utils/plotting/plot_single_runs.py new file mode 100644 index 0000000..7316d6a --- /dev/null +++ b/marl_factory_grid/utils/plotting/plot_single_runs.py @@ -0,0 +1,48 @@ +import pickle +from os import PathLike +from pathlib import Path +from typing import Union + +import pandas as pd + +from marl_factory_grid.utils.helpers import IGNORED_DF_COLUMNS +from marl_factory_grid.utils.plotting.plotting_utils import prepare_plot + + +def plot_single_run(run_path: Union[str, PathLike], use_tex: bool = False, column_keys=None, + file_key: str ='monitor', file_ext: str ='pkl'): + run_path = Path(run_path) + df_list = list() + if run_path.is_dir(): + monitor_file = next(run_path.glob(f'*{file_key}*.{file_ext}')) + elif run_path.exists() and run_path.is_file(): + monitor_file = run_path + else: + raise ValueError + + with monitor_file.open('rb') as f: + monitor_df = pickle.load(f) + + monitor_df = monitor_df.fillna(0) + df_list.append(monitor_df) + + df = pd.concat(df_list, ignore_index=True) + df = df.fillna(0).rename(columns={'episode': 'Episode'}).sort_values(['Episode']) + if column_keys is not None: + columns = [col for col in column_keys if col in df.columns] + else: + columns = [col for col in df.columns if col not in IGNORED_DF_COLUMNS] + + # roll_n = 50 + # non_overlapp_window = df.groupby(['Episode']).rolling(roll_n, min_periods=1).mean() + + df_melted = df[columns + ['Episode']].reset_index().melt( + id_vars=['Episode'], value_vars=columns, var_name="Measurement", value_name="Score" + ) + + if df_melted['Episode'].max() > 800: + skip_n = round(df_melted['Episode'].max() * 0.02) + df_melted = df_melted[df_melted['Episode'] % skip_n == 0] + + prepare_plot(run_path.parent / f'{run_path.parent.name}_monitor_lineplot.png', df_melted, use_tex=use_tex) + print('Plotting done.') diff --git a/marl_factory_grid/utils/plotting/plotting.py b/marl_factory_grid/utils/plotting/plotting_utils.py similarity index 98% rename from marl_factory_grid/utils/plotting/plotting.py rename to marl_factory_grid/utils/plotting/plotting_utils.py index 455f81a..17bb7ff 100644 --- a/marl_factory_grid/utils/plotting/plotting.py +++ b/marl_factory_grid/utils/plotting/plotting_utils.py @@ -60,7 +60,7 @@ def prepare_center_double_column_legend(df, hue, style, hue_order): print('Struggling to plot Figure using LaTeX - going back to normal.') plt.close('all') sns.set(rc={'text.usetex': False}, style='whitegrid') - fig = plt.figure(figsize=(10, 11)) + _ = plt.figure(figsize=(10, 11)) lineplot = sns.lineplot(data=df, x='Episode', y='Score', hue=hue, style=style, ci=95, palette=PALETTE, hue_order=hue_order, legend=False) # plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0) diff --git a/marl_factory_grid/utils/ray_caster.py b/marl_factory_grid/utils/ray_caster.py index ecbac6d..d89997e 100644 --- a/marl_factory_grid/utils/ray_caster.py +++ b/marl_factory_grid/utils/ray_caster.py @@ -19,7 +19,7 @@ class RayCaster: return f'{self.__class__.__name__}({self.agent.name})' def build_ray_targets(self): - north = np.array([0, -1])*self.pomdp_r + north = np.array([0, -1]) * self.pomdp_r thetas = [np.deg2rad(deg) for deg in np.linspace(-self.degs // 2, self.degs // 2, self.n_rays)[::-1]] rot_M = [ [[math.cos(theta), -math.sin(theta)], @@ -53,9 +53,9 @@ class RayCaster: diag_hits = all([ self.ray_block_cache( key, - lambda: any(True for e in pos_dict[key] if e.var_is_blocking_light)) - # lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light)) - for key in ((x, y-cy), (x-cx, y)) + # lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light) + lambda: any(True for e in pos_dict[key] if e.var_is_blocking_light)) + for key in ((x, y - cy), (x - cx, y)) ]) if (cx != 0 and cy != 0) else False visible += entities_hit if not diag_hits else [] @@ -77,8 +77,8 @@ class RayCaster: agent = self.agent x_coords = range(agent.x - self.pomdp_r, agent.x + self.pomdp_r + 1) y_coords = range(agent.y - self.pomdp_r, agent.y + self.pomdp_r + 1) - outline = list(product(x_coords, [agent.y - self.pomdp_r, agent.y + self.pomdp_r])) \ - + list(product([agent.x - self.pomdp_r, agent.x + self.pomdp_r], y_coords)) + outline = list(product(x_coords, [agent.y - self.pomdp_r, agent.y + self.pomdp_r])) + outline += list(product([agent.x - self.pomdp_r, agent.x + self.pomdp_r], y_coords)) return outline @staticmethod diff --git a/marl_factory_grid/utils/results.py b/marl_factory_grid/utils/results.py index 6abf11c..b4b07fc 100644 --- a/marl_factory_grid/utils/results.py +++ b/marl_factory_grid/utils/results.py @@ -1,9 +1,12 @@ from typing import Union from dataclasses import dataclass +from marl_factory_grid.environment.entity.object import Object + TYPE_VALUE = 'value' TYPE_REWARD = 'reward' -types = [TYPE_VALUE, TYPE_REWARD] +TYPES = [TYPE_VALUE, TYPE_REWARD] + @dataclass class InfoObject: @@ -18,12 +21,13 @@ class Result: validity: bool reward: Union[float, None] = None value: Union[float, None] = None - entity: None = None + entity: Object = None def get_infos(self): n = self.entity.name if self.entity is not None else "Global" - return [InfoObject(identifier=f'{n}_{self.identifier}_{t}', - val_type=t, value=self.__getattribute__(t)) for t in types + # Return multiple Info Dicts + return [InfoObject(identifier=f'{n}_{self.identifier}', + val_type=t, value=self.__getattribute__(t)) for t in TYPES if self.__getattribute__(t) is not None] def __repr__(self): @@ -31,7 +35,7 @@ class Result: reward = f" | Reward: {self.reward}" if self.reward is not None else "" value = f" | Value: {self.value}" if self.value is not None else "" entity = f" | by: {self.entity.name}" if self.entity is not None else "" - return f'{self.__class__.__name__}({self.identifier.capitalize()} {valid}valid{reward}{value})' + return f'{self.__class__.__name__}({self.identifier.capitalize()} {valid}valid{reward}{value}{entity})' @dataclass diff --git a/marl_factory_grid/utils/states.py b/marl_factory_grid/utils/states.py index fc07b95..f38f7f9 100644 --- a/marl_factory_grid/utils/states.py +++ b/marl_factory_grid/utils/states.py @@ -1,11 +1,12 @@ from itertools import islice -from typing import List, Dict, Tuple +from typing import List, Tuple import numpy as np from marl_factory_grid.environment import constants as c +from marl_factory_grid.environment.entity.entity import Entity from marl_factory_grid.environment.rules import Rule -from marl_factory_grid.utils.results import Result +from marl_factory_grid.utils.results import Result, DoneResult class StepRules: @@ -83,13 +84,51 @@ class Gamestate(object): return f'{self.__class__.__name__}({len(self.entities)} Entitites @ Step {self.curr_step})' @property - def random_free_position(self): + def random_free_position(self) -> (int, int): + """ + Returns a single **free** position (x, y), which is **free** for spawning or walking. + No Entity at this position posses *var_is_blocking_pos* or *var_can_collide*. + + :return: Single **free** position. + """ return self.get_n_random_free_positions(1)[0] - def get_n_random_free_positions(self, n): + def get_n_random_free_positions(self, n) -> list[tuple[int, int]]: + """ + Returns a list of *n* **free** positions [(x, y), ... ], which are **free** for spawning or walking. + No Entity at this position posses *var_is_blocking_pos* or *var_can_collide*. + + :return: List of n **free** position. + """ return list(islice(self.entities.free_positions_generator, n)) - def tick(self, actions) -> List[Result]: + @property + def random_position(self) -> (int, int): + """ + Returns a single available position (x, y), ignores all entity attributes. + + :return: Single random position. + """ + return self.get_n_random_positions(1)[0] + + def get_n_random_positions(self, n) -> list[tuple[int, int]]: + """ + Returns a list of *n* available positions [(x, y), ... ], ignores all entity attributes. + + :return: List of n random positions. + """ + return list(islice(self.entities.floorlist, n)) + + def tick(self, actions) -> list[Result]: + """ + Performs a single **Gamestate Tick**by calling the inner rule hooks in sequential order. + - tick_pre_step_all: Things to do before the agents do their actions. Statechange, Moving, Spawning etc... + - agent tick: Agents do their actions. + - tick_step_all: Things to do after the agents did their actions. Statechange, Moving, Spawning etc... + - tick_post_step_all: Things to do at the very end of each step. Counting, Reward calculations etc... + + :return: List of *Result*-objects. + """ results = list() self.curr_step += 1 @@ -112,11 +151,23 @@ class Gamestate(object): return results - def print(self, string): + def print(self, string) -> None: + """ + When *verbose* is active, print stuff. + + :param string: *String* to print. + :type string: str + :return: Nothing + """ if self.verbose: print(string) - def check_done(self): + def check_done(self) -> List[DoneResult]: + """ + Iterate all **Rules** that override tehe *on_ckeck_done* hook. + + :return: List of Results + """ results = list() for rule in self.rules: if on_check_done_result := rule.on_check_done(self): @@ -124,20 +175,44 @@ class Gamestate(object): return results def get_all_pos_with_collisions(self) -> List[Tuple[(int, int)]]: - positions = [pos for pos, entities in self.entities.pos_dict.items() if len(entities) >= 2 and (len([e for e in entities if e.var_can_collide]) >= 2)] + """ + Returns a list positions [(x, y), ... ] on which collisions occur. This does not include agents, + that were unable to move because their target direction was blocked, also a form of collision. + + :return: List of positions. + """ + positions = [pos for pos, entities in self.entities.pos_dict.items() if + len(entities) >= 2 and (len([e for e in entities if e.var_can_collide]) >= 2) + ] return positions - def check_move_validity(self, moving_entity, position): - if moving_entity.pos != position and not any( - entity.var_is_blocking_pos for entity in self.entities.pos_dict[position]) and not ( - moving_entity.var_is_blocking_pos and self.entities.is_occupied(position)): - return True - else: - return False + def check_move_validity(self, moving_entity: Entity, target_position: (int, int)) -> bool: + """ + Whether it is safe to move to the target positions and moving entity does not introduce a blocking attribute, + when position is allready occupied. - def check_pos_validity(self, position): - if not any(entity.var_is_blocking_pos for entity in self.entities.pos_dict[position]): - return True - else: - return False + :param moving_entity: Entity + :param target_position: pos + :return: Safe to move to + """ + is_not_blocked = self.check_pos_validity(target_position) + will_not_block_others = moving_entity.var_is_blocking_pos and self.entities.is_occupied(target_position) + + if moving_entity.pos != target_position and is_not_blocked and not will_not_block_others: + return c.VALID + else: + return c.NOT_VALID + + def check_pos_validity(self, pos: (int, int)) -> bool: + """ + Check if *pos* is a valid position to move or spawn to. + + :param pos: position to check + :return: Wheter pos is a valid target. + """ + + if not any(e.var_is_blocking_pos for e in self.entities.pos_dict[pos]) and pos in self.entities.floorlist: + return c.VALID + else: + return c.NOT_VALID diff --git a/marl_factory_grid/utils/tools.py b/marl_factory_grid/utils/tools.py index 63c9f69..73fa50d 100644 --- a/marl_factory_grid/utils/tools.py +++ b/marl_factory_grid/utils/tools.py @@ -28,7 +28,9 @@ class ConfigExplainer: def explain_module(self, class_to_explain): parameters = inspect.signature(class_to_explain).parameters - explained = {class_to_explain.__name__: {key: val.default for key, val in parameters.items() if key not in EXCLUDED}} + explained = {class_to_explain.__name__: + {key: val.default for key, val in parameters.items() if key not in EXCLUDED} + } return explained def _load_and_compare(self, compare_class, paths): diff --git a/random_testrun.py b/random_testrun.py index 9bebf17..ef8df08 100644 --- a/random_testrun.py +++ b/random_testrun.py @@ -6,18 +6,21 @@ from marl_factory_grid.environment.factory import Factory from marl_factory_grid.utils.logging.envmonitor import EnvMonitor from marl_factory_grid.utils.logging.recorder import EnvRecorder +from marl_factory_grid.utils.plotting.plot_single_runs import plot_single_run from marl_factory_grid.utils.tools import ConfigExplainer if __name__ == '__main__': # Render at each step? - render = True + render = False # Reveal all possible Modules (Entities, Rules, Agents[Actions, Observations], etc.) explain_config = False # Collect statistics? - monitor = False + monitor = True # Record as Protobuf? record = False + # Plot Results? + plotting = True run_path = Path('study_out') @@ -38,7 +41,7 @@ if __name__ == '__main__': factory = EnvRecorder(factory) # RL learn Loop - for episode in trange(500): + for episode in trange(10): _ = factory.reset() done = False if render: @@ -54,7 +57,10 @@ if __name__ == '__main__': break if monitor: - factory.save_run(run_path / 'test.pkl') + factory.save_run(run_path / 'test_monitor.pkl') if record: factory.save_records(run_path / 'test.pb') + if plotting: + plot_single_run(run_path) + print('Done!!! Goodbye....') diff --git a/reload_agent.py b/reload_agent.py index f0ed389..0fb8066 100644 --- a/reload_agent.py +++ b/reload_agent.py @@ -56,6 +56,7 @@ if __name__ == '__main__': for model_idx, model in enumerate(models)] else: actions = models[0].predict(env_state, deterministic=determin)[0] + # noinspection PyTupleAssignmentBalance env_state, step_r, done_bool, info_obj = env.step(actions) rew += step_r diff --git a/transform_wg_to_json_no_priv.py b/transform_wg_to_json_no_priv.py index d9bc8e1..1b7ef3e 100644 --- a/transform_wg_to_json_no_priv.py +++ b/transform_wg_to_json_no_priv.py @@ -5,7 +5,6 @@ from pathlib import Path if __name__ == '__main__': - conf_path = Path('wg0') wg0_conf = configparser.ConfigParser() wg0_conf.read(conf_path/'wg0.conf') @@ -17,7 +16,6 @@ if __name__ == '__main__': # Delete any old conf.json for the current peer (conf_path / f'{client_name}.json').unlink(missing_ok=True) - peer = wg0_conf[client_name] date_time = datetime.now().strftime('%Y-%m-%dT%H:%M:%S.%f000Z')