diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..b58b603
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,5 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Editor-based HTTP Client requests
+/httpRequests/
diff --git a/README.md b/README.md
index 07d67bf..d0c0a19 100644
--- a/README.md
+++ b/README.md
@@ -94,7 +94,7 @@ All [Entites](marl_factory_grid/environment/entity/global_entities.py) are avail
#### Rules
-[Rules](marl_factory_grid/environment/entity/object.py) define how the environment behaves on micro-scale.
+[Rules](marl_factory_grid/environment/entity/object.py) define how the environment behaves on microscale.
Each of the hookes (`on_init`, `pre_step`, `on_step`, '`post_step`', `on_done`)
provide env-access to implement customn logic, calculate rewards, or gather information.
@@ -107,6 +107,7 @@ Make sure to bring your own assets for each Entity living in the Gridworld as th
PNG-files (transparent background) of square aspect-ratio should do the job, in general.
+
diff --git a/marl_factory_grid/__init__.py b/marl_factory_grid/__init__.py
index 49f5635..259e3cf 100644
--- a/marl_factory_grid/__init__.py
+++ b/marl_factory_grid/__init__.py
@@ -1 +1 @@
-from .quickstart import init
\ No newline at end of file
+from .quickstart import init
diff --git a/marl_factory_grid/algorithms/__init__.py b/marl_factory_grid/algorithms/__init__.py
index 0980070..cc2c489 100644
--- a/marl_factory_grid/algorithms/__init__.py
+++ b/marl_factory_grid/algorithms/__init__.py
@@ -1 +1,4 @@
-import os, sys; sys.path.append(os.path.dirname(os.path.realpath(__file__)))
+import os
+import sys
+
+sys.path.append(os.path.dirname(os.path.realpath(__file__)))
diff --git a/marl_factory_grid/algorithms/marl/__init__.py b/marl_factory_grid/algorithms/marl/__init__.py
index 984588c..a4c30ef 100644
--- a/marl_factory_grid/algorithms/marl/__init__.py
+++ b/marl_factory_grid/algorithms/marl/__init__.py
@@ -1 +1 @@
-from marl_factory_grid.algorithms.marl.memory import MARLActorCriticMemory
\ No newline at end of file
+from marl_factory_grid.algorithms.marl.memory import MARLActorCriticMemory
diff --git a/marl_factory_grid/algorithms/marl/base_ac.py b/marl_factory_grid/algorithms/marl/base_ac.py
index 3bb0318..ef195b7 100644
--- a/marl_factory_grid/algorithms/marl/base_ac.py
+++ b/marl_factory_grid/algorithms/marl/base_ac.py
@@ -28,6 +28,7 @@ class Names:
BATCH_SIZE = 'bnatch_size'
N_ACTIONS = 'n_actions'
+
nms = Names
ListOrTensor = Union[List, torch.Tensor]
@@ -112,10 +113,9 @@ class BaseActorCritic:
next_obs, reward, done, info = env.step(action)
done = [done] * self.n_agents if isinstance(done, bool) else done
- last_hiddens = dict(hidden_actor =out[nms.HIDDEN_ACTOR],
+ last_hiddens = dict(hidden_actor=out[nms.HIDDEN_ACTOR],
hidden_critic=out[nms.HIDDEN_CRITIC])
-
tm.add(observation=obs, action=action, reward=reward, done=done,
logits=out.get(nms.LOGITS, None), values=out.get(nms.CRITIC, None),
**last_hiddens)
@@ -142,7 +142,9 @@ class BaseActorCritic:
print(f'reward at episode: {episode} = {rew_log}')
episode += 1
df_results.append([episode, rew_log, *reward])
- df_results = pd.DataFrame(df_results, columns=['steps', 'reward', *[f'agent#{i}' for i in range(self.n_agents)]])
+ df_results = pd.DataFrame(df_results,
+ columns=['steps', 'reward', *[f'agent#{i}' for i in range(self.n_agents)]]
+ )
if checkpointer is not None:
df_results.to_csv(checkpointer.path / 'results.csv', index=False)
return df_results
@@ -157,24 +159,27 @@ class BaseActorCritic:
last_action, reward = [-1] * self.n_agents, [0.] * self.n_agents
done, rew_log, eps_rew = [False] * self.n_agents, 0, torch.zeros(self.n_agents)
while not all(done):
- if render: env.render()
+ if render:
+ env.render()
out = self.forward(obs, last_action, **last_hiddens)
action = self.get_actions(out)
next_obs, reward, done, info = env.step(action)
- if isinstance(done, bool): done = [done] * obs.shape[0]
+ if isinstance(done, bool):
+ done = [done] * obs.shape[0]
obs = next_obs
last_action = action
last_hiddens = dict(hidden_actor=out.get(nms.HIDDEN_ACTOR, None),
hidden_critic=out.get(nms.HIDDEN_CRITIC, None)
)
eps_rew += torch.tensor(reward)
- results.append(eps_rew.tolist() + [sum(eps_rew).item()] + [episode])
+ results.append(eps_rew.tolist() + [np.sum(eps_rew).item()] + [episode])
episode += 1
agent_columns = [f'agent#{i}' for i in range(self.cfg['environment']['n_agents'])]
results = pd.DataFrame(results, columns=agent_columns + ['sum', 'episode'])
- results = pd.melt(results, id_vars=['episode'], value_vars=agent_columns + ['sum'], value_name='reward', var_name='agent')
+ results = pd.melt(results, id_vars=['episode'], value_vars=agent_columns + ['sum'],
+ value_name='reward', var_name='agent')
return results
@staticmethod
diff --git a/marl_factory_grid/algorithms/marl/mappo.py b/marl_factory_grid/algorithms/marl/mappo.py
index d22fa08..faf3b0d 100644
--- a/marl_factory_grid/algorithms/marl/mappo.py
+++ b/marl_factory_grid/algorithms/marl/mappo.py
@@ -36,7 +36,7 @@ class LoopMAPPO(LoopSNAC):
rewards_ = torch.stack(rewards_, dim=1)
return rewards_
- def mappo(self, batch, network, gamma, entropy_coef, vf_coef, clip_range, **kwargs):
+ def mappo(self, batch, network, gamma, entropy_coef, vf_coef, clip_range, **__):
out = network(batch[nms.OBSERVATION], batch[nms.ACTION], batch[nms.HIDDEN_ACTOR], batch[nms.HIDDEN_CRITIC])
logits = out[nms.LOGITS][:, :-1] # last one only needed for v_{t+1}
@@ -45,7 +45,7 @@ class LoopMAPPO(LoopSNAC):
# monte carlo returns
mc_returns = self.monte_carlo_returns(batch[nms.REWARD], batch[nms.DONE], gamma)
- mc_returns = (mc_returns - mc_returns.mean()) / (mc_returns.std() + 1e-8) #todo: norm across agent ok?
+ mc_returns = (mc_returns - mc_returns.mean()) / (mc_returns.std() + 1e-8) # todo: norm across agent ok?
advantages = mc_returns - out[nms.CRITIC][:, :-1]
# policy loss
diff --git a/marl_factory_grid/algorithms/marl/networks.py b/marl_factory_grid/algorithms/marl/networks.py
index c4fdb72..796c03f 100644
--- a/marl_factory_grid/algorithms/marl/networks.py
+++ b/marl_factory_grid/algorithms/marl/networks.py
@@ -1,8 +1,7 @@
+import numpy as np
import torch
import torch.nn as nn
-import numpy as np
import torch.nn.functional as F
-from torch.nn.utils import spectral_norm
class RecurrentAC(nn.Module):
@@ -88,8 +87,8 @@ class NormalizedLinear(nn.Linear):
self.trainable_magnitude = trainable_magnitude
self.scale = nn.Parameter(torch.tensor([1.]), requires_grad=trainable_magnitude)
- def forward(self, input):
- normalized_input = F.normalize(input, dim=-1, p=2, eps=1e-5)
+ def forward(self, in_array):
+ normalized_input = F.normalize(in_array, dim=-1, p=2, eps=1e-5)
normalized_weight = F.normalize(self.weight, dim=-1, p=2, eps=1e-5)
return F.linear(normalized_input, normalized_weight) * self.d_sqrt * self.scale
diff --git a/marl_factory_grid/algorithms/marl/seac.py b/marl_factory_grid/algorithms/marl/seac.py
index 9c458c7..07e8267 100644
--- a/marl_factory_grid/algorithms/marl/seac.py
+++ b/marl_factory_grid/algorithms/marl/seac.py
@@ -16,7 +16,7 @@ class LoopSEAC(LoopIAC):
with torch.inference_mode(True):
true_action_logp = torch.stack([
torch.log_softmax(out[nms.LOGITS][ag_i, :-1], -1)
- .gather(index=actions[ag_i, 1:, None], dim=-1)
+ .gather(index=actions[ag_i, 1:, None], dim=-1)
for ag_i, out in enumerate(outputs)
], 0).squeeze()
@@ -38,7 +38,6 @@ class LoopSEAC(LoopIAC):
a2c_loss = (-iw*log_ap * advantages.detach()).mean(-1)
-
value_loss = (iw*advantages.pow(2)).mean(-1) # n_agent
# weighted loss
@@ -53,4 +52,4 @@ class LoopSEAC(LoopIAC):
self.optimizer[ag_i].zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.net[ag_i].parameters(), 0.5)
- self.optimizer[ag_i].step()
\ No newline at end of file
+ self.optimizer[ag_i].step()
diff --git a/marl_factory_grid/algorithms/marl/snac.py b/marl_factory_grid/algorithms/marl/snac.py
index b249754..11be902 100644
--- a/marl_factory_grid/algorithms/marl/snac.py
+++ b/marl_factory_grid/algorithms/marl/snac.py
@@ -30,4 +30,4 @@ class LoopSNAC(BaseActorCritic):
self._as_torch(actions).unsqueeze(1),
hidden_actor, hidden_critic
)
- return out
\ No newline at end of file
+ return out
diff --git a/marl_factory_grid/algorithms/static/TSP_base_agent.py b/marl_factory_grid/algorithms/static/TSP_base_agent.py
index bc48f7c..7d25f63 100644
--- a/marl_factory_grid/algorithms/static/TSP_base_agent.py
+++ b/marl_factory_grid/algorithms/static/TSP_base_agent.py
@@ -56,8 +56,8 @@ class TSPBaseAgent(ABC):
def _door_is_close(self, state):
try:
- # return next(y for x in self.state.tile.neighboring_floor for y in x.guests if do.DOOR in y.name)
- return next(y for x in state.entities.neighboring_positions(self.state.pos) for y in state.entities.pos_dict[x] if do.DOOR in y.name)
+ return next(y for x in state.entities.neighboring_positions(self.state.pos)
+ for y in state.entities.pos_dict[x] if do.DOOR in y.name)
except StopIteration:
return None
diff --git a/marl_factory_grid/algorithms/static/TSP_target_agent.py b/marl_factory_grid/algorithms/static/TSP_target_agent.py
index 5e0f989..b0d8b29 100644
--- a/marl_factory_grid/algorithms/static/TSP_target_agent.py
+++ b/marl_factory_grid/algorithms/static/TSP_target_agent.py
@@ -14,8 +14,8 @@ class TSPTargetAgent(TSPBaseAgent):
def _handle_doors(self, state):
try:
- # return next(y for x in self.state.tile.neighboring_floor for y in x.guests if do.DOOR in y.name)
- return next(y for x in state.entities.neighboring_positions(self.state.pos) for y in state.entities.pos_dict[x] if do.DOOR in y.name)
+ return next(y for x in state.entities.neighboring_positions(self.state.pos)
+ for y in state.entities.pos_dict[x] if do.DOOR in y.name)
except StopIteration:
return None
diff --git a/marl_factory_grid/algorithms/utils.py b/marl_factory_grid/algorithms/utils.py
index 59e78bd..8c60386 100644
--- a/marl_factory_grid/algorithms/utils.py
+++ b/marl_factory_grid/algorithms/utils.py
@@ -1,8 +1,9 @@
-import torch
-import numpy as np
-import yaml
from pathlib import Path
+import numpy as np
+import torch
+import yaml
+
def load_class(classname):
from importlib import import_module
@@ -42,7 +43,6 @@ def get_class(arguments):
def get_arguments(arguments):
- from importlib import import_module
d = dict(arguments)
if "classname" in d:
del d["classname"]
@@ -82,4 +82,4 @@ class Checkpointer(object):
for name, model in to_save:
self.save_experiment(name, model)
self.__current_checkpoint += 1
- self.__current_step += 1
\ No newline at end of file
+ self.__current_step += 1
diff --git a/marl_factory_grid/configs/narrow_corridor.yaml b/marl_factory_grid/configs/narrow_corridor.yaml
index 04f42ae..f53b972 100644
--- a/marl_factory_grid/configs/narrow_corridor.yaml
+++ b/marl_factory_grid/configs/narrow_corridor.yaml
@@ -1,4 +1,4 @@
-eneral:
+General:
# Your Seed
env_seed: 69
# Individual or global rewards?
@@ -86,4 +86,4 @@ Rules:
DoneAtDestinationReachAll:
# reward_at_done: 1
DoneAtMaxStepsReached:
- max_steps: 500
+ max_steps: 200
diff --git a/marl_factory_grid/environment/entity/entity.py b/marl_factory_grid/environment/entity/entity.py
index 4abf2af..999787b 100644
--- a/marl_factory_grid/environment/entity/entity.py
+++ b/marl_factory_grid/environment/entity/entity.py
@@ -1,15 +1,14 @@
import abc
-from collections import defaultdict
import numpy as np
-from .object import _Object
+from .object import Object
from .. import constants as c
from ...utils.results import ActionResult
from ...utils.utility_classes import RenderEntity
-class Entity(_Object, abc.ABC):
+class Entity(Object, abc.ABC):
"""Full Env Entity that lives on the environment Grid. Doors, Items, DirtPile etc..."""
@property
@@ -96,8 +95,9 @@ class Entity(_Object, abc.ABC):
def __init__(self, pos, bind_to=None, **kwargs):
super().__init__(**kwargs)
+ self._view_directory = c.VALUE_NO_POS
self._status = None
- self.set_pos(pos)
+ self._pos = pos
self._last_pos = pos
if bind_to:
try:
@@ -113,10 +113,6 @@ class Entity(_Object, abc.ABC):
def render(self):
return RenderEntity(self.__class__.__name__.lower(), self.pos)
- @abc.abstractmethod
- def render(self):
- return RenderEntity(self.__class__.__name__.lower(), self.pos)
-
@property
def obs_tag(self):
try:
@@ -133,25 +129,3 @@ class Entity(_Object, abc.ABC):
self._collection.delete_env_object(self)
self._collection = other_collection
return self._collection == other_collection
-
- @classmethod
- def from_coordinates(cls, positions: [(int, int)], *args, entity_kwargs=None, **kwargs, ):
- collection = cls(*args, **kwargs)
- collection.add_items(
- [cls._entity(tuple(pos), **entity_kwargs if entity_kwargs is not None else {}) for pos in positions])
- return collection
-
- def notify_del_entity(self, entity):
- try:
- self.pos_dict[entity.pos].remove(entity)
- except (ValueError, AttributeError):
- pass
-
- def by_pos(self, pos: (int, int)):
- pos = tuple(pos)
- try:
- return self.state.entities.pos_dict[pos]
- except StopIteration:
- pass
- except ValueError:
- pass
diff --git a/marl_factory_grid/environment/entity/object.py b/marl_factory_grid/environment/entity/object.py
index 768f8b5..e8c69da 100644
--- a/marl_factory_grid/environment/entity/object.py
+++ b/marl_factory_grid/environment/entity/object.py
@@ -5,7 +5,7 @@ from marl_factory_grid.environment import constants as c
import marl_factory_grid.utils.helpers as h
-class _Object:
+class Object:
"""Generell Objects for Organisation and Maintanance such as Actions etc..."""
_u_idx = defaultdict(lambda: 0)
@@ -50,15 +50,15 @@ class _Object:
print(f'Following kwargs were passed, but ignored: {kwargs}')
def __repr__(self):
- name = self.name
- if self.bound_entity:
- name = h.add_bound_name(name, self.bound_entity)
- try:
- if self.var_has_position:
- name = h.add_pos_name(name, self)
- except (AttributeError):
- pass
- return name
+ name = self.name
+ if self.bound_entity:
+ name = h.add_bound_name(name, self.bound_entity)
+ try:
+ if self.var_has_position:
+ name = h.add_pos_name(name, self)
+ except AttributeError:
+ pass
+ return name
def __eq__(self, other) -> bool:
return other == self.identifier
@@ -67,8 +67,8 @@ class _Object:
return hash(self.identifier)
def _identify_and_count_up(self):
- idx = _Object._u_idx[self.__class__.__name__]
- _Object._u_idx[self.__class__.__name__] += 1
+ idx = Object._u_idx[self.__class__.__name__]
+ Object._u_idx[self.__class__.__name__] += 1
return idx
def set_collection(self, collection):
@@ -98,79 +98,3 @@ class _Object:
def unbind(self):
self._bound_entity = None
-
-
-# class EnvObject(_Object):
-# """Objects that hold Information that are observable, but have no position on the environment grid. Inventories etc..."""
-#
- # _u_idx = defaultdict(lambda: 0)
-#
-# @property
-# def obs_tag(self):
-# try:
-# return self._collection.name or self.name
-# except AttributeError:
-# return self.name
-#
-# @property
-# def var_is_blocking_light(self):
-# try:
-# return self._collection.var_is_blocking_light or False
-# except AttributeError:
-# return False
-#
-# @property
-# def var_can_be_bound(self):
-# try:
-# return self._collection.var_can_be_bound or False
-# except AttributeError:
-# return False
-#
-# @property
-# def var_can_move(self):
-# try:
-# return self._collection.var_can_move or False
-# except AttributeError:
-# return False
-#
-# @property
-# def var_is_blocking_pos(self):
-# try:
-# return self._collection.var_is_blocking_pos or False
-# except AttributeError:
-# return False
-#
-# @property
-# def var_has_position(self):
-# try:
-# return self._collection.var_has_position or False
-# except AttributeError:
-# return False
-#
-# @property
-# def var_can_collide(self):
-# try:
-# return self._collection.var_can_collide or False
-# except AttributeError:
-# return False
-#
-#
-# @property
-# def encoding(self):
-# return c.VALUE_OCCUPIED_CELL
-#
-#
-# def __init__(self, **kwargs):
-# self._bound_entity = None
-# super(EnvObject, self).__init__(**kwargs)
-#
-#
-# def change_parent_collection(self, other_collection):
-# other_collection.add_item(self)
-# self._collection.delete_env_object(self)
-# self._collection = other_collection
-# return self._collection == other_collection
-#
-#
-# def summarize_state(self):
-# return dict(name=str(self.name))
diff --git a/marl_factory_grid/environment/entity/util.py b/marl_factory_grid/environment/entity/util.py
index d43c53a..2a15c41 100644
--- a/marl_factory_grid/environment/entity/util.py
+++ b/marl_factory_grid/environment/entity/util.py
@@ -1,6 +1,6 @@
import numpy as np
-from marl_factory_grid.environment.entity.object import _Object
+from marl_factory_grid.environment.entity.object import Object
##########################################################################
@@ -8,7 +8,7 @@ from marl_factory_grid.environment.entity.object import _Object
##########################################################################
-class PlaceHolder(_Object):
+class PlaceHolder(Object):
def __init__(self, *args, fill_value=0, **kwargs):
super().__init__(*args, **kwargs)
@@ -27,7 +27,7 @@ class PlaceHolder(_Object):
return self.__class__.__name__
-class GlobalPosition(_Object):
+class GlobalPosition(Object):
@property
def encoding(self):
diff --git a/marl_factory_grid/environment/factory.py b/marl_factory_grid/environment/factory.py
index 3c5f7f6..97ce621 100644
--- a/marl_factory_grid/environment/factory.py
+++ b/marl_factory_grid/environment/factory.py
@@ -56,15 +56,18 @@ class Factory(gym.Env):
self.level_filepath = Path(custom_level_path)
else:
self.level_filepath = Path(__file__).parent.parent / h.LEVELS_DIR / f'{self.conf.level_name}.txt'
- self._renderer = None # expensive - don't use; unless required !
parsed_entities = self.conf.load_entities()
self.map = LevelParser(self.level_filepath, parsed_entities, self.conf.pomdp_r)
# Init for later usage:
- self.state: Gamestate
- self.map: LevelParser
- self.obs_builder: OBSBuilder
+ # noinspection PyTypeChecker
+ self.state: Gamestate = None
+ # noinspection PyTypeChecker
+ self.obs_builder: OBSBuilder = None
+
+ # expensive - don't use; unless required !
+ self._renderer = None
# reset env to initial state, preparing env for new episode.
# returns tuple where the first dict contains initial observation for each agent in the env
@@ -74,7 +77,7 @@ class Factory(gym.Env):
return self.state.entities[item]
def reset(self) -> (dict, dict):
- if hasattr(self, 'state'):
+ if self.state is not None:
for entity_group in self.state.entities:
try:
entity_group[0].reset_uid()
@@ -160,7 +163,7 @@ class Factory(gym.Env):
# Finalize
reward, reward_info, done = self.summarize_step_results(tick_result, done_results)
- info = reward_info
+ info = dict(reward_info)
info.update(step_reward=sum(reward), step=self.state.curr_step)
diff --git a/marl_factory_grid/environment/groups/collection.py b/marl_factory_grid/environment/groups/collection.py
index 140c941..c0f0f6b 100644
--- a/marl_factory_grid/environment/groups/collection.py
+++ b/marl_factory_grid/environment/groups/collection.py
@@ -1,15 +1,15 @@
from typing import List, Tuple, Union, Dict
from marl_factory_grid.environment.entity.entity import Entity
-from marl_factory_grid.environment.groups.objects import _Objects
+from marl_factory_grid.environment.groups.objects import Objects
# noinspection PyProtectedMember
-from marl_factory_grid.environment.entity.object import _Object
+from marl_factory_grid.environment.entity.object import Object
import marl_factory_grid.environment.constants as c
from marl_factory_grid.utils.results import Result
-class Collection(_Objects):
- _entity = _Object # entity?
+class Collection(Objects):
+ _entity = Object # entity?
symbol = None
@property
@@ -58,7 +58,7 @@ class Collection(_Objects):
def trigger_spawn(self, state, *entity_args, coords_or_quantity=None, ignore_blocking=False, **entity_kwargs):
coords_or_quantity = coords_or_quantity if coords_or_quantity else self._coords_or_quantity
if self.var_has_position:
- if isinstance(coords_or_quantity, int):
+ if self.var_has_position and isinstance(coords_or_quantity, int):
if ignore_blocking or self._ignore_blocking:
coords_or_quantity = state.entities.floorlist[:coords_or_quantity]
else:
@@ -87,8 +87,8 @@ class Collection(_Objects):
raise ValueError(f'{self._entity.__name__} has no position!')
return c.VALID
- def despawn(self, items: List[_Object]):
- items = [items] if isinstance(items, _Object) else items
+ def despawn(self, items: List[Object]):
+ items = [items] if isinstance(items, Object) else items
for item in items:
del self[item]
diff --git a/marl_factory_grid/environment/groups/global_entities.py b/marl_factory_grid/environment/groups/global_entities.py
index 7a50de4..37779f9 100644
--- a/marl_factory_grid/environment/groups/global_entities.py
+++ b/marl_factory_grid/environment/groups/global_entities.py
@@ -3,12 +3,12 @@ from operator import itemgetter
from random import shuffle
from typing import Dict
-from marl_factory_grid.environment.groups.objects import _Objects
+from marl_factory_grid.environment.groups.objects import Objects
from marl_factory_grid.utils.helpers import POS_MASK
-class Entities(_Objects):
- _entity = _Objects
+class Entities(Objects):
+ _entity = Objects
@staticmethod
def neighboring_positions(pos):
@@ -87,7 +87,7 @@ class Entities(_Objects):
def __delitem__(self, name):
assert_str = 'This group of entity does not exist in this collection!'
assert any([key for key in name.keys() if key in self.keys()]), assert_str
- self[name]._observers.delete(self)
+ self[name].del_observer(self)
for entity in self[name]:
entity.del_observer(self)
return super(Entities, self).__delitem__(name)
diff --git a/marl_factory_grid/environment/groups/objects.py b/marl_factory_grid/environment/groups/objects.py
index d29cc2c..9229787 100644
--- a/marl_factory_grid/environment/groups/objects.py
+++ b/marl_factory_grid/environment/groups/objects.py
@@ -1,15 +1,15 @@
from collections import defaultdict
-from typing import List
+from typing import List, Iterator, Union
import numpy as np
-from marl_factory_grid.environment.entity.object import _Object
+from marl_factory_grid.environment.entity.object import Object
import marl_factory_grid.environment.constants as c
from marl_factory_grid.utils import helpers as h
-class _Objects:
- _entity = _Object
+class Objects:
+ _entity = Object
@property
def var_can_be_bound(self):
@@ -50,7 +50,7 @@ class _Objects:
def __len__(self):
return len(self._data)
- def __iter__(self):
+ def __iter__(self) -> Iterator[Union[Object, None]]:
return iter(self.values())
def add_item(self, item: _entity):
@@ -130,13 +130,14 @@ class _Objects:
repr_dict = {key: val for key, val in self._data.items() if key not in [c.WALLS]}
return f'{self.__class__.__name__}[{repr_dict}]'
- def notify_del_entity(self, entity: _Object):
+ def notify_del_entity(self, entity: Object):
try:
+ # noinspection PyUnresolvedReferences
self.pos_dict[entity.pos].remove(entity)
except (AttributeError, ValueError, IndexError):
pass
- def notify_add_entity(self, entity: _Object):
+ def notify_add_entity(self, entity: Object):
try:
if self not in entity.observers:
entity.add_observer(self)
diff --git a/marl_factory_grid/environment/rules.py b/marl_factory_grid/environment/rules.py
index 55bb2bc..f5b6836 100644
--- a/marl_factory_grid/environment/rules.py
+++ b/marl_factory_grid/environment/rules.py
@@ -1,11 +1,11 @@
import abc
from random import shuffle
-from typing import List, Collection, Union
+from typing import List, Collection
+from marl_factory_grid.environment import rewards as r, constants as c
from marl_factory_grid.environment.entity.agent import Agent
from marl_factory_grid.utils import helpers as h
from marl_factory_grid.utils.results import TickResult, DoneResult
-from marl_factory_grid.environment import rewards as r, constants as c
class Rule(abc.ABC):
@@ -118,8 +118,7 @@ class AssignGlobalPositions(Rule):
def on_init(self, state, lvl_map):
from marl_factory_grid.environment.entity.util import GlobalPosition
for agent in state[c.AGENT]:
- gp = GlobalPosition(lvl_map.level_shape)
- gp.bind_to(agent)
+ gp = GlobalPosition(agent, lvl_map.level_shape)
state[c.GLOBALPOSITIONS].add_item(gp)
return []
diff --git a/marl_factory_grid/modules/_template/rules.py b/marl_factory_grid/modules/_template/rules.py
index 6ed2f2d..7696616 100644
--- a/marl_factory_grid/modules/_template/rules.py
+++ b/marl_factory_grid/modules/_template/rules.py
@@ -6,7 +6,9 @@ from marl_factory_grid.utils.results import TickResult, DoneResult
class TemplateRule(Rule):
def __init__(self, *args, **kwargs):
- super(TemplateRule, self).__init__(*args, **kwargs)
+ super(TemplateRule, self).__init__()
+ self.args = args
+ self.kwargs = kwargs
def on_init(self, state, lvl_map):
pass
diff --git a/marl_factory_grid/modules/batteries/actions.py b/marl_factory_grid/modules/batteries/actions.py
index bd755a2..7d1c4a2 100644
--- a/marl_factory_grid/modules/batteries/actions.py
+++ b/marl_factory_grid/modules/batteries/actions.py
@@ -1,6 +1,5 @@
from typing import Union
-import marl_factory_grid.modules.batteries.constants
from marl_factory_grid.environment.actions import Action
from marl_factory_grid.utils.results import ActionResult
@@ -24,5 +23,6 @@ class BtryCharge(Action):
else:
valid = c.NOT_VALID
state.print(f'{entity.name} failed to charged batteries at {entity.pos}.')
+
return ActionResult(entity=entity, identifier=self._identifier, validity=valid,
- reward=marl_factory_grid.modules.batteries.constants.REWARD_CHARGE_VALID if valid else marl_factory_grid.modules.batteries.constants.Reward_CHARGE_FAIL)
+ reward=b.REWARD_CHARGE_VALID if valid else b.Reward_CHARGE_FAIL)
diff --git a/marl_factory_grid/modules/batteries/entitites.py b/marl_factory_grid/modules/batteries/entitites.py
index 751d57b..7675fe9 100644
--- a/marl_factory_grid/modules/batteries/entitites.py
+++ b/marl_factory_grid/modules/batteries/entitites.py
@@ -1,11 +1,11 @@
from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.entity.entity import Entity
-from marl_factory_grid.environment.entity.object import _Object
+from marl_factory_grid.environment.entity.object import Object
from marl_factory_grid.modules.batteries import constants as b
from marl_factory_grid.utils.utility_classes import RenderEntity
-class Battery(_Object):
+class Battery(Object):
@property
def var_can_be_bound(self):
diff --git a/marl_factory_grid/modules/batteries/rules.py b/marl_factory_grid/modules/batteries/rules.py
index 84b7ef2..8a4725b 100644
--- a/marl_factory_grid/modules/batteries/rules.py
+++ b/marl_factory_grid/modules/batteries/rules.py
@@ -1,11 +1,9 @@
from typing import List, Union
-import marl_factory_grid.modules.batteries.constants
-from marl_factory_grid.environment.rules import Rule
-from marl_factory_grid.utils.results import TickResult, DoneResult
-
from marl_factory_grid.environment import constants as c
+from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.modules.batteries import constants as b
+from marl_factory_grid.utils.results import TickResult, DoneResult
class BatteryDecharge(Rule):
diff --git a/marl_factory_grid/modules/clean_up/entitites.py b/marl_factory_grid/modules/clean_up/entitites.py
index 19e703c..25c6eb1 100644
--- a/marl_factory_grid/modules/clean_up/entitites.py
+++ b/marl_factory_grid/modules/clean_up/entitites.py
@@ -1,5 +1,3 @@
-from numpy import random
-
from marl_factory_grid.environment.entity.entity import Entity
from marl_factory_grid.utils.utility_classes import RenderEntity
from marl_factory_grid.modules.clean_up import constants as d
diff --git a/marl_factory_grid/modules/clean_up/groups.py b/marl_factory_grid/modules/clean_up/groups.py
index 2029171..7ae3247 100644
--- a/marl_factory_grid/modules/clean_up/groups.py
+++ b/marl_factory_grid/modules/clean_up/groups.py
@@ -1,9 +1,7 @@
-from typing import Union, List, Tuple
-
from marl_factory_grid.environment import constants as c
-from marl_factory_grid.utils.results import Result
from marl_factory_grid.environment.groups.collection import Collection
from marl_factory_grid.modules.clean_up.entitites import DirtPile
+from marl_factory_grid.utils.results import Result
class DirtPiles(Collection):
diff --git a/marl_factory_grid/modules/clean_up/rules.py b/marl_factory_grid/modules/clean_up/rules.py
index be2f9b9..b81ee41 100644
--- a/marl_factory_grid/modules/clean_up/rules.py
+++ b/marl_factory_grid/modules/clean_up/rules.py
@@ -49,7 +49,7 @@ class RespawnDirt(Rule):
def tick_step(self, state):
collection = state[d.DIRT]
if self._next_dirt_spawn < 0:
- pass # No DirtPile Spawn
+ result = [] # No DirtPile Spawn
elif not self._next_dirt_spawn:
result = [collection.trigger_spawn(state, coords_or_quantity=self.respawn_n, amount=self.respawn_amount)]
self._next_dirt_spawn = self.respawn_freq
diff --git a/marl_factory_grid/modules/destinations/actions.py b/marl_factory_grid/modules/destinations/actions.py
index 13f7fe3..6367acd 100644
--- a/marl_factory_grid/modules/destinations/actions.py
+++ b/marl_factory_grid/modules/destinations/actions.py
@@ -21,4 +21,4 @@ class DestAction(Action):
valid = c.NOT_VALID
state.print(f'{entity.name} just tried to do_wait_action do_wait_action at {entity.pos} but failed')
return ActionResult(entity=entity, identifier=self._identifier, validity=valid,
- reward=marl_factory_grid.modules.destinations.constants.REWARD_WAIT_VALID if valid else marl_factory_grid.modules.destinations.constants.REWARD_WAIT_FAIL)
+ reward=d.REWARD_WAIT_VALID if valid else d.REWARD_WAIT_FAIL)
diff --git a/marl_factory_grid/modules/destinations/groups.py b/marl_factory_grid/modules/destinations/groups.py
index 5f0b654..f0b7f9e 100644
--- a/marl_factory_grid/modules/destinations/groups.py
+++ b/marl_factory_grid/modules/destinations/groups.py
@@ -1,7 +1,5 @@
from marl_factory_grid.environment.groups.collection import Collection
from marl_factory_grid.modules.destinations.entitites import Destination
-from marl_factory_grid.environment import constants as c
-from marl_factory_grid.modules.destinations import constants as d
class Destinations(Collection):
diff --git a/marl_factory_grid/modules/doors/groups.py b/marl_factory_grid/modules/doors/groups.py
index a27d598..973d1ab 100644
--- a/marl_factory_grid/modules/doors/groups.py
+++ b/marl_factory_grid/modules/doors/groups.py
@@ -1,5 +1,3 @@
-from typing import Union
-
from marl_factory_grid.environment.groups.collection import Collection
from marl_factory_grid.modules.doors import constants as d
from marl_factory_grid.modules.doors.entitites import Door
diff --git a/marl_factory_grid/modules/doors/rewards.py b/marl_factory_grid/modules/doors/rewards.py
index c87d123..b38c7c5 100644
--- a/marl_factory_grid/modules/doors/rewards.py
+++ b/marl_factory_grid/modules/doors/rewards.py
@@ -1,2 +1,2 @@
USE_DOOR_VALID: float = -0.00
-USE_DOOR_FAIL: float = -0.01
\ No newline at end of file
+USE_DOOR_FAIL: float = -0.01
diff --git a/marl_factory_grid/modules/factory/rules.py b/marl_factory_grid/modules/factory/rules.py
index d736f7a..e056135 100644
--- a/marl_factory_grid/modules/factory/rules.py
+++ b/marl_factory_grid/modules/factory/rules.py
@@ -1,8 +1,8 @@
import random
-from typing import List, Union
+from typing import List
-from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.environment import constants as c
+from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.utils.results import TickResult
@@ -14,8 +14,6 @@ class AgentSingleZonePlacementBeta(Rule):
super().__init__()
def on_init(self, state, lvl_map):
- zones = state[c.ZONES]
- n_zones = state[c.ZONES]
agents = state[c.AGENT]
if len(self.coordinates) == len(agents):
coordinates = self.coordinates
@@ -31,4 +29,4 @@ class AgentSingleZonePlacementBeta(Rule):
return []
def tick_post_step(self, state) -> List[TickResult]:
- return []
\ No newline at end of file
+ return []
diff --git a/marl_factory_grid/modules/items/constants.py b/marl_factory_grid/modules/items/constants.py
index 86b8b0c..5cb82c3 100644
--- a/marl_factory_grid/modules/items/constants.py
+++ b/marl_factory_grid/modules/items/constants.py
@@ -1,6 +1,3 @@
-from typing import NamedTuple
-
-
SYMBOL_NO_ITEM = 0
SYMBOL_DROP_OFF = 1
# Item Env
diff --git a/marl_factory_grid/modules/items/entitites.py b/marl_factory_grid/modules/items/entitites.py
index ff34e23..8549134 100644
--- a/marl_factory_grid/modules/items/entitites.py
+++ b/marl_factory_grid/modules/items/entitites.py
@@ -14,27 +14,14 @@ class Item(Entity):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
- @property
- def auto_despawn(self):
- return self._auto_despawn
-
@property
def encoding(self):
# Edit this if you want items to be drawn in the ops differently
return 1
- def set_auto_despawn(self, auto_despawn):
- self._auto_despawn = auto_despawn
-
- def summarize_state(self) -> dict:
- super_summarization = super(Item, self).summarize_state()
- super_summarization.update(dict(auto_despawn=self.auto_despawn))
- return super_summarization
-
class DropOffLocation(Entity):
-
def render(self):
return RenderEntity(i.DROP_OFF, self.pos)
@@ -42,18 +29,16 @@ class DropOffLocation(Entity):
def encoding(self):
return i.SYMBOL_DROP_OFF
- def __init__(self, *args, storage_size_until_full: int = 5, auto_item_despawn_interval: int = 5, **kwargs):
+ def __init__(self, *args, storage_size_until_full: int = 5, **kwargs):
super(DropOffLocation, self).__init__(*args, **kwargs)
- self.auto_item_despawn_interval = auto_item_despawn_interval
self.storage = deque(maxlen=storage_size_until_full or None)
def place_item(self, item: Item):
if self.is_full:
raise RuntimeWarning("There is currently no way to clear the storage or make it unfull.")
- return bc.NOT_VALID # in Zeile 81 verschieben?
+ return bc.NOT_VALID
else:
self.storage.append(item)
- item.set_auto_despawn(self.auto_item_despawn_interval)
return c.VALID
@property
diff --git a/marl_factory_grid/modules/items/groups.py b/marl_factory_grid/modules/items/groups.py
index deb1812..be5ca49 100644
--- a/marl_factory_grid/modules/items/groups.py
+++ b/marl_factory_grid/modules/items/groups.py
@@ -1,12 +1,9 @@
-from random import shuffle
-
-from marl_factory_grid.modules.items import constants as i
from marl_factory_grid.environment import constants as c
-
-from marl_factory_grid.environment.groups.collection import Collection
-from marl_factory_grid.environment.groups.objects import _Objects
-from marl_factory_grid.environment.groups.mixins import IsBoundMixin
from marl_factory_grid.environment.entity.agent import Agent
+from marl_factory_grid.environment.groups.collection import Collection
+from marl_factory_grid.environment.groups.mixins import IsBoundMixin
+from marl_factory_grid.environment.groups.objects import Objects
+from marl_factory_grid.modules.items import constants as i
from marl_factory_grid.modules.items.entitites import Item, DropOffLocation
from marl_factory_grid.utils.results import Result
@@ -74,13 +71,12 @@ class Inventory(IsBoundMixin, Collection):
self._collection = collection
-class Inventories(_Objects):
+class Inventories(Objects):
_entity = Inventory
var_can_move = False
var_has_position = False
-
symbol = None
@property
@@ -116,7 +112,6 @@ class Inventories(_Objects):
return [val.summarize_states(**kwargs) for key, val in self.items()]
-
class DropOffLocations(Collection):
_entity = DropOffLocation
diff --git a/marl_factory_grid/modules/items/rewards.py b/marl_factory_grid/modules/items/rewards.py
index 40adf46..bcd2918 100644
--- a/marl_factory_grid/modules/items/rewards.py
+++ b/marl_factory_grid/modules/items/rewards.py
@@ -1,4 +1,4 @@
DROP_OFF_VALID: float = 0.1
DROP_OFF_FAIL: float = -0.1
PICK_UP_FAIL: float = -0.1
-PICK_UP_VALID: float = 0.1
\ No newline at end of file
+PICK_UP_VALID: float = 0.1
diff --git a/marl_factory_grid/modules/machines/actions.py b/marl_factory_grid/modules/machines/actions.py
index 970f85f..dbb303f 100644
--- a/marl_factory_grid/modules/machines/actions.py
+++ b/marl_factory_grid/modules/machines/actions.py
@@ -1,9 +1,10 @@
from typing import Union
+import marl_factory_grid.modules.machines.constants
from marl_factory_grid.environment.actions import Action
from marl_factory_grid.utils.results import ActionResult
-from marl_factory_grid.modules.machines import constants as m, rewards as r
+from marl_factory_grid.modules.machines import constants as m
from marl_factory_grid.environment import constants as c
from marl_factory_grid.utils import helpers as h
@@ -16,8 +17,10 @@ class MachineAction(Action):
def do(self, entity, state) -> Union[None, ActionResult]:
if machine := h.get_first(state[m.MACHINES].by_pos(entity.pos)):
if valid := machine.maintain():
- return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=r.MAINTAIN_VALID)
+ return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=marl_factory_grid.modules.machines.constants.MAINTAIN_VALID)
else:
- return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=r.MAINTAIN_FAIL)
+ return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=marl_factory_grid.modules.machines.constants.MAINTAIN_FAIL)
else:
- return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=r.MAINTAIN_FAIL)
+ return ActionResult(entity=entity, identifier=self._identifier,
+ validity=c.NOT_VALID, reward=marl_factory_grid.modules.machines.constants.MAINTAIN_FAIL
+ )
diff --git a/marl_factory_grid/modules/machines/constants.py b/marl_factory_grid/modules/machines/constants.py
index 29ce3bc..3771cbb 100644
--- a/marl_factory_grid/modules/machines/constants.py
+++ b/marl_factory_grid/modules/machines/constants.py
@@ -11,3 +11,7 @@ STATE_MAINTAIN = 'maintenance'
SYMBOL_WORK = 1
SYMBOL_IDLE = 0.6
SYMBOL_MAINTAIN = 0.3
+MAINTAIN_VALID: float = 0.5
+MAINTAIN_FAIL: float = -0.1
+FAIL_MISSING_MAINTENANCE: float = -0.5
+NONE: float = 0
diff --git a/marl_factory_grid/modules/machines/entitites.py b/marl_factory_grid/modules/machines/entitites.py
index f5775e1..581adf6 100644
--- a/marl_factory_grid/modules/machines/entitites.py
+++ b/marl_factory_grid/modules/machines/entitites.py
@@ -31,11 +31,10 @@ class Machine(Entity):
return c.NOT_VALID
def tick(self, state):
- # if self.status == m.STATE_MAINTAIN and any([c.AGENT in x.name for x in self.tile.guests]):
- if self.status == m.STATE_MAINTAIN and any([c.AGENT in x.name for x in state.entities.pos_dict[self.pos]]):
+ others = state.entities.pos_dict[self.pos]
+ if self.status == m.STATE_MAINTAIN and any([c.AGENT in x.name for x in others]):
return TickResult(identifier=self.name, validity=c.VALID, entity=self)
- # elif self.status == m.STATE_MAINTAIN and not any([c.AGENT in x.name for x in self.tile.guests]):
- elif self.status == m.STATE_MAINTAIN and not any([c.AGENT in x.name for x in state.entities.pos_dict[self.pos]]):
+ elif self.status == m.STATE_MAINTAIN and not any([c.AGENT in x.name for x in others]):
self.status = m.STATE_WORK
self.reset_counter()
return None
diff --git a/marl_factory_grid/modules/machines/groups.py b/marl_factory_grid/modules/machines/groups.py
index 5f2d970..9d89d6c 100644
--- a/marl_factory_grid/modules/machines/groups.py
+++ b/marl_factory_grid/modules/machines/groups.py
@@ -1,5 +1,3 @@
-from typing import Union, List, Tuple
-
from marl_factory_grid.environment.groups.collection import Collection
from .entitites import Machine
diff --git a/marl_factory_grid/modules/machines/rewards.py b/marl_factory_grid/modules/machines/rewards.py
deleted file mode 100644
index c868196..0000000
--- a/marl_factory_grid/modules/machines/rewards.py
+++ /dev/null
@@ -1,5 +0,0 @@
-MAINTAIN_VALID: float = 0.5
-MAINTAIN_FAIL: float = -0.1
-FAIL_MISSING_MAINTENANCE: float = -0.5
-
-NONE: float = 0
diff --git a/marl_factory_grid/modules/maintenance/constants.py b/marl_factory_grid/modules/maintenance/constants.py
index e0ab12c..3aed36c 100644
--- a/marl_factory_grid/modules/maintenance/constants.py
+++ b/marl_factory_grid/modules/maintenance/constants.py
@@ -1,3 +1,4 @@
MAINTAINER = 'Maintainer' # TEMPLATE _identifier. Define your own!
MAINTAINERS = 'Maintainers' # TEMPLATE _identifier. Define your own!
+MAINTAINER_COLLISION_REWARD = -5
diff --git a/marl_factory_grid/modules/maintenance/groups.py b/marl_factory_grid/modules/maintenance/groups.py
index 79f7480..5b09c9c 100644
--- a/marl_factory_grid/modules/maintenance/groups.py
+++ b/marl_factory_grid/modules/maintenance/groups.py
@@ -4,7 +4,6 @@ from marl_factory_grid.environment.groups.collection import Collection
from .entities import Maintainer
from ..machines import constants as mc
from ..machines.actions import MachineAction
-from ...utils.states import Gamestate
class Maintainers(Collection):
@@ -23,8 +22,6 @@ class Maintainers(Collection):
self.size = size
self._spawnrule = spawnrule
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args):
self.add_items([self._entity(mc.MACHINES, MachineAction(), pos) for pos in coords_or_quantity])
diff --git a/marl_factory_grid/modules/maintenance/rewards.py b/marl_factory_grid/modules/maintenance/rewards.py
deleted file mode 100644
index 425ac3b..0000000
--- a/marl_factory_grid/modules/maintenance/rewards.py
+++ /dev/null
@@ -1 +0,0 @@
-MAINTAINER_COLLISION_REWARD = -5
\ No newline at end of file
diff --git a/marl_factory_grid/modules/maintenance/rules.py b/marl_factory_grid/modules/maintenance/rules.py
index fdefe42..92e6e75 100644
--- a/marl_factory_grid/modules/maintenance/rules.py
+++ b/marl_factory_grid/modules/maintenance/rules.py
@@ -1,15 +1,16 @@
from typing import List
+
+import marl_factory_grid.modules.maintenance.constants
from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.utils.results import TickResult, DoneResult
from marl_factory_grid.environment import constants as c
-from . import rewards as r
from . import constants as M
class MoveMaintainers(Rule):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
+ def __init__(self):
+ super().__init__()
def tick_step(self, state) -> List[TickResult]:
for maintainer in state[M.MAINTAINERS]:
@@ -20,8 +21,8 @@ class MoveMaintainers(Rule):
class DoneAtMaintainerCollision(Rule):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
+ def __init__(self):
+ super().__init__()
def on_check_done(self, state) -> List[DoneResult]:
agents = list(state[c.AGENT].values())
@@ -30,5 +31,5 @@ class DoneAtMaintainerCollision(Rule):
for agent in agents:
if agent.pos in m_pos:
done_results.append(DoneResult(entity=agent, validity=c.VALID, identifier=self.name,
- reward=r.MAINTAINER_COLLISION_REWARD))
+ reward=marl_factory_grid.modules.maintenance.constants.MAINTAINER_COLLISION_REWARD))
return done_results
diff --git a/marl_factory_grid/modules/zones/entitites.py b/marl_factory_grid/modules/zones/entitites.py
index cfd313f..4aa0f70 100644
--- a/marl_factory_grid/modules/zones/entitites.py
+++ b/marl_factory_grid/modules/zones/entitites.py
@@ -1,10 +1,10 @@
import random
from typing import List, Tuple
-from marl_factory_grid.environment.entity.object import _Object
+from marl_factory_grid.environment.entity.object import Object
-class Zone(_Object):
+class Zone(Object):
@property
def positions(self):
diff --git a/marl_factory_grid/modules/zones/groups.py b/marl_factory_grid/modules/zones/groups.py
index 71eb329..f5494cd 100644
--- a/marl_factory_grid/modules/zones/groups.py
+++ b/marl_factory_grid/modules/zones/groups.py
@@ -1,8 +1,8 @@
-from marl_factory_grid.environment.groups.objects import _Objects
+from marl_factory_grid.environment.groups.objects import Objects
from marl_factory_grid.modules.zones import Zone
-class Zones(_Objects):
+class Zones(Objects):
symbol = None
_entity = Zone
diff --git a/marl_factory_grid/utils/config_parser.py b/marl_factory_grid/utils/config_parser.py
index 7cdc9e6..8215ed2 100644
--- a/marl_factory_grid/utils/config_parser.py
+++ b/marl_factory_grid/utils/config_parser.py
@@ -58,7 +58,10 @@ class FactoryConfigParser(object):
return str(self.config)
def __getitem__(self, item):
- return self.config[item]
+ try:
+ return self.config[item]
+ except KeyError:
+ print(f'The mandatory {item} section could not be found in your .config gile. Check Spelling!')
def load_entities(self):
entity_classes = dict()
@@ -161,7 +164,6 @@ class FactoryConfigParser(object):
def _load_smth(self, config, class_obj):
rules = list()
- rules_names = list()
for rule in config:
e1 = e2 = e3 = None
try:
diff --git a/marl_factory_grid/utils/helpers.py b/marl_factory_grid/utils/helpers.py
index ae68bf7..f5f6d00 100644
--- a/marl_factory_grid/utils/helpers.py
+++ b/marl_factory_grid/utils/helpers.py
@@ -61,8 +61,8 @@ class ObservationTranslator:
:param per_agent_named_obs_spaces: `Named observation space` one for each agent. Overloaded.
type per_agent_named_obs_spaces: Dict[str, dict]
- :param placeholder_fill_value: Currently not fully implemented!!!
- :type placeholder_fill_value: Union[int, str] = 'N')
+ :param placeholder_fill_value: Currently, not fully implemented!!!
+ :type placeholder_fill_value: Union[int, str] = 'N'
"""
if isinstance(placeholder_fill_value, str):
diff --git a/marl_factory_grid/utils/logging/envmonitor.py b/marl_factory_grid/utils/logging/envmonitor.py
index 67eac73..e2551c8 100644
--- a/marl_factory_grid/utils/logging/envmonitor.py
+++ b/marl_factory_grid/utils/logging/envmonitor.py
@@ -9,7 +9,7 @@ from marl_factory_grid.utils.helpers import IGNORED_DF_COLUMNS
import pandas as pd
-from marl_factory_grid.utils.plotting.compare_runs import plot_single_run
+from marl_factory_grid.utils.plotting.plot_single_runs import plot_single_run
class EnvMonitor(Wrapper):
@@ -22,7 +22,6 @@ class EnvMonitor(Wrapper):
self._monitor_df = pd.DataFrame()
self._monitor_dict = dict()
-
def step(self, action):
obs_type, obs, reward, done, info = self.env.step(action)
self._read_info(info)
diff --git a/marl_factory_grid/utils/logging/recorder.py b/marl_factory_grid/utils/logging/recorder.py
index fac2e16..797866e 100644
--- a/marl_factory_grid/utils/logging/recorder.py
+++ b/marl_factory_grid/utils/logging/recorder.py
@@ -2,11 +2,9 @@ from os import PathLike
from pathlib import Path
from typing import Union, List
-import yaml
-from gymnasium import Wrapper
-
import numpy as np
import pandas as pd
+from gymnasium import Wrapper
class EnvRecorder(Wrapper):
@@ -106,7 +104,7 @@ class EnvRecorder(Wrapper):
out_dict = {'episodes': self._recorder_out_list}
out_dict.update(
{'n_episodes': self._curr_episode,
- 'metadata':dict(
+ 'metadata': dict(
level_name=self.env.params['General']['level_name'],
verbose=False,
n_agents=len(self.env.params['Agents']),
diff --git a/marl_factory_grid/utils/observation_builder.py b/marl_factory_grid/utils/observation_builder.py
index df10ae9..55d6ec0 100644
--- a/marl_factory_grid/utils/observation_builder.py
+++ b/marl_factory_grid/utils/observation_builder.py
@@ -5,7 +5,7 @@ from typing import Dict, List
import numpy as np
from marl_factory_grid.environment import constants as c
-from marl_factory_grid.environment.entity.object import _Object
+from marl_factory_grid.environment.entity.object import Object
from marl_factory_grid.environment.groups.utils import Combined
from marl_factory_grid.utils.utility_classes import Floor
from marl_factory_grid.utils.ray_caster import RayCaster
@@ -13,7 +13,6 @@ from marl_factory_grid.utils.states import Gamestate
from marl_factory_grid.utils import helpers as h
-
class OBSBuilder(object):
default_obs = [c.WALLS, c.OTHERS]
@@ -128,7 +127,7 @@ class OBSBuilder(object):
f'{re.escape("[")}(.*){re.escape("]")}'
f'{re.escape("(")}{re.escape(agent.name)}{re.escape(")")}')
name = next((key for key, val in self.all_obs.items()
- if pattern.search(str(val)) and isinstance(val, _Object)), None)
+ if pattern.search(str(val)) and isinstance(val, Object)), None)
e = self.all_obs[name]
except KeyError:
try:
@@ -181,11 +180,11 @@ class OBSBuilder(object):
return obs, self.obs_layers[agent.name]
def _sort_and_name_observation_conf(self, agent):
- '''
+ """
Builds the useable observation scheme per agent from conf.yaml.
:param agent:
:return:
- '''
+ """
# Fixme: no asymetric shapes possible.
self.ray_caster[agent.name] = RayCaster(agent, min(self.obs_shape))
obs_layers = []
diff --git a/marl_factory_grid/utils/plotting/compare_runs.py b/marl_factory_grid/utils/plotting/plot_compare_runs.py
similarity index 83%
rename from marl_factory_grid/utils/plotting/compare_runs.py
rename to marl_factory_grid/utils/plotting/plot_compare_runs.py
index cb5c853..5115478 100644
--- a/marl_factory_grid/utils/plotting/compare_runs.py
+++ b/marl_factory_grid/utils/plotting/plot_compare_runs.py
@@ -7,50 +7,11 @@ from typing import Union, List
import pandas as pd
from marl_factory_grid.utils.helpers import IGNORED_DF_COLUMNS
-from marl_factory_grid.utils.plotting.plotting import prepare_plot
+from marl_factory_grid.utils.plotting.plotting_utils import prepare_plot
MODEL_MAP = None
-def plot_single_run(run_path: Union[str, PathLike], use_tex: bool = False, column_keys=None):
- run_path = Path(run_path)
- df_list = list()
- if run_path.is_dir():
- monitor_file = next(run_path.glob('*monitor*.pick'))
- elif run_path.exists() and run_path.is_file():
- monitor_file = run_path
- else:
- raise ValueError
-
- with monitor_file.open('rb') as f:
- monitor_df = pickle.load(f)
-
- monitor_df = monitor_df.fillna(0)
- df_list.append(monitor_df)
-
- df = pd.concat(df_list, ignore_index=True)
- df = df.fillna(0).rename(columns={'episode': 'Episode'}).sort_values(['Episode'])
- if column_keys is not None:
- columns = [col for col in column_keys if col in df.columns]
- else:
- columns = [col for col in df.columns if col not in IGNORED_DF_COLUMNS]
-
- roll_n = 50
-
- non_overlapp_window = df.groupby(['Episode']).rolling(roll_n, min_periods=1).mean()
-
- df_melted = df[columns + ['Episode']].reset_index().melt(
- id_vars=['Episode'], value_vars=columns, var_name="Measurement", value_name="Score"
- )
-
- if df_melted['Episode'].max() > 800:
- skip_n = round(df_melted['Episode'].max() * 0.02)
- df_melted = df_melted[df_melted['Episode'] % skip_n == 0]
-
- prepare_plot(run_path.parent / f'{run_path.parent.name}_monitor_lineplot.png', df_melted, use_tex=use_tex)
- print('Plotting done.')
-
-
def compare_seed_runs(run_path: Union[str, PathLike], use_tex: bool = False):
run_path = Path(run_path)
df_list = list()
diff --git a/marl_factory_grid/utils/plotting/plot_single_runs.py b/marl_factory_grid/utils/plotting/plot_single_runs.py
new file mode 100644
index 0000000..7316d6a
--- /dev/null
+++ b/marl_factory_grid/utils/plotting/plot_single_runs.py
@@ -0,0 +1,48 @@
+import pickle
+from os import PathLike
+from pathlib import Path
+from typing import Union
+
+import pandas as pd
+
+from marl_factory_grid.utils.helpers import IGNORED_DF_COLUMNS
+from marl_factory_grid.utils.plotting.plotting_utils import prepare_plot
+
+
+def plot_single_run(run_path: Union[str, PathLike], use_tex: bool = False, column_keys=None,
+ file_key: str ='monitor', file_ext: str ='pkl'):
+ run_path = Path(run_path)
+ df_list = list()
+ if run_path.is_dir():
+ monitor_file = next(run_path.glob(f'*{file_key}*.{file_ext}'))
+ elif run_path.exists() and run_path.is_file():
+ monitor_file = run_path
+ else:
+ raise ValueError
+
+ with monitor_file.open('rb') as f:
+ monitor_df = pickle.load(f)
+
+ monitor_df = monitor_df.fillna(0)
+ df_list.append(monitor_df)
+
+ df = pd.concat(df_list, ignore_index=True)
+ df = df.fillna(0).rename(columns={'episode': 'Episode'}).sort_values(['Episode'])
+ if column_keys is not None:
+ columns = [col for col in column_keys if col in df.columns]
+ else:
+ columns = [col for col in df.columns if col not in IGNORED_DF_COLUMNS]
+
+ # roll_n = 50
+ # non_overlapp_window = df.groupby(['Episode']).rolling(roll_n, min_periods=1).mean()
+
+ df_melted = df[columns + ['Episode']].reset_index().melt(
+ id_vars=['Episode'], value_vars=columns, var_name="Measurement", value_name="Score"
+ )
+
+ if df_melted['Episode'].max() > 800:
+ skip_n = round(df_melted['Episode'].max() * 0.02)
+ df_melted = df_melted[df_melted['Episode'] % skip_n == 0]
+
+ prepare_plot(run_path.parent / f'{run_path.parent.name}_monitor_lineplot.png', df_melted, use_tex=use_tex)
+ print('Plotting done.')
diff --git a/marl_factory_grid/utils/plotting/plotting.py b/marl_factory_grid/utils/plotting/plotting_utils.py
similarity index 98%
rename from marl_factory_grid/utils/plotting/plotting.py
rename to marl_factory_grid/utils/plotting/plotting_utils.py
index 455f81a..17bb7ff 100644
--- a/marl_factory_grid/utils/plotting/plotting.py
+++ b/marl_factory_grid/utils/plotting/plotting_utils.py
@@ -60,7 +60,7 @@ def prepare_center_double_column_legend(df, hue, style, hue_order):
print('Struggling to plot Figure using LaTeX - going back to normal.')
plt.close('all')
sns.set(rc={'text.usetex': False}, style='whitegrid')
- fig = plt.figure(figsize=(10, 11))
+ _ = plt.figure(figsize=(10, 11))
lineplot = sns.lineplot(data=df, x='Episode', y='Score', hue=hue, style=style,
ci=95, palette=PALETTE, hue_order=hue_order, legend=False)
# plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0)
diff --git a/marl_factory_grid/utils/ray_caster.py b/marl_factory_grid/utils/ray_caster.py
index ecbac6d..d89997e 100644
--- a/marl_factory_grid/utils/ray_caster.py
+++ b/marl_factory_grid/utils/ray_caster.py
@@ -19,7 +19,7 @@ class RayCaster:
return f'{self.__class__.__name__}({self.agent.name})'
def build_ray_targets(self):
- north = np.array([0, -1])*self.pomdp_r
+ north = np.array([0, -1]) * self.pomdp_r
thetas = [np.deg2rad(deg) for deg in np.linspace(-self.degs // 2, self.degs // 2, self.n_rays)[::-1]]
rot_M = [
[[math.cos(theta), -math.sin(theta)],
@@ -53,9 +53,9 @@ class RayCaster:
diag_hits = all([
self.ray_block_cache(
key,
- lambda: any(True for e in pos_dict[key] if e.var_is_blocking_light))
- # lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light))
- for key in ((x, y-cy), (x-cx, y))
+ # lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light)
+ lambda: any(True for e in pos_dict[key] if e.var_is_blocking_light))
+ for key in ((x, y - cy), (x - cx, y))
]) if (cx != 0 and cy != 0) else False
visible += entities_hit if not diag_hits else []
@@ -77,8 +77,8 @@ class RayCaster:
agent = self.agent
x_coords = range(agent.x - self.pomdp_r, agent.x + self.pomdp_r + 1)
y_coords = range(agent.y - self.pomdp_r, agent.y + self.pomdp_r + 1)
- outline = list(product(x_coords, [agent.y - self.pomdp_r, agent.y + self.pomdp_r])) \
- + list(product([agent.x - self.pomdp_r, agent.x + self.pomdp_r], y_coords))
+ outline = list(product(x_coords, [agent.y - self.pomdp_r, agent.y + self.pomdp_r]))
+ outline += list(product([agent.x - self.pomdp_r, agent.x + self.pomdp_r], y_coords))
return outline
@staticmethod
diff --git a/marl_factory_grid/utils/results.py b/marl_factory_grid/utils/results.py
index 6abf11c..b4b07fc 100644
--- a/marl_factory_grid/utils/results.py
+++ b/marl_factory_grid/utils/results.py
@@ -1,9 +1,12 @@
from typing import Union
from dataclasses import dataclass
+from marl_factory_grid.environment.entity.object import Object
+
TYPE_VALUE = 'value'
TYPE_REWARD = 'reward'
-types = [TYPE_VALUE, TYPE_REWARD]
+TYPES = [TYPE_VALUE, TYPE_REWARD]
+
@dataclass
class InfoObject:
@@ -18,12 +21,13 @@ class Result:
validity: bool
reward: Union[float, None] = None
value: Union[float, None] = None
- entity: None = None
+ entity: Object = None
def get_infos(self):
n = self.entity.name if self.entity is not None else "Global"
- return [InfoObject(identifier=f'{n}_{self.identifier}_{t}',
- val_type=t, value=self.__getattribute__(t)) for t in types
+ # Return multiple Info Dicts
+ return [InfoObject(identifier=f'{n}_{self.identifier}',
+ val_type=t, value=self.__getattribute__(t)) for t in TYPES
if self.__getattribute__(t) is not None]
def __repr__(self):
@@ -31,7 +35,7 @@ class Result:
reward = f" | Reward: {self.reward}" if self.reward is not None else ""
value = f" | Value: {self.value}" if self.value is not None else ""
entity = f" | by: {self.entity.name}" if self.entity is not None else ""
- return f'{self.__class__.__name__}({self.identifier.capitalize()} {valid}valid{reward}{value})'
+ return f'{self.__class__.__name__}({self.identifier.capitalize()} {valid}valid{reward}{value}{entity})'
@dataclass
diff --git a/marl_factory_grid/utils/states.py b/marl_factory_grid/utils/states.py
index fc07b95..f38f7f9 100644
--- a/marl_factory_grid/utils/states.py
+++ b/marl_factory_grid/utils/states.py
@@ -1,11 +1,12 @@
from itertools import islice
-from typing import List, Dict, Tuple
+from typing import List, Tuple
import numpy as np
from marl_factory_grid.environment import constants as c
+from marl_factory_grid.environment.entity.entity import Entity
from marl_factory_grid.environment.rules import Rule
-from marl_factory_grid.utils.results import Result
+from marl_factory_grid.utils.results import Result, DoneResult
class StepRules:
@@ -83,13 +84,51 @@ class Gamestate(object):
return f'{self.__class__.__name__}({len(self.entities)} Entitites @ Step {self.curr_step})'
@property
- def random_free_position(self):
+ def random_free_position(self) -> (int, int):
+ """
+ Returns a single **free** position (x, y), which is **free** for spawning or walking.
+ No Entity at this position posses *var_is_blocking_pos* or *var_can_collide*.
+
+ :return: Single **free** position.
+ """
return self.get_n_random_free_positions(1)[0]
- def get_n_random_free_positions(self, n):
+ def get_n_random_free_positions(self, n) -> list[tuple[int, int]]:
+ """
+ Returns a list of *n* **free** positions [(x, y), ... ], which are **free** for spawning or walking.
+ No Entity at this position posses *var_is_blocking_pos* or *var_can_collide*.
+
+ :return: List of n **free** position.
+ """
return list(islice(self.entities.free_positions_generator, n))
- def tick(self, actions) -> List[Result]:
+ @property
+ def random_position(self) -> (int, int):
+ """
+ Returns a single available position (x, y), ignores all entity attributes.
+
+ :return: Single random position.
+ """
+ return self.get_n_random_positions(1)[0]
+
+ def get_n_random_positions(self, n) -> list[tuple[int, int]]:
+ """
+ Returns a list of *n* available positions [(x, y), ... ], ignores all entity attributes.
+
+ :return: List of n random positions.
+ """
+ return list(islice(self.entities.floorlist, n))
+
+ def tick(self, actions) -> list[Result]:
+ """
+ Performs a single **Gamestate Tick**by calling the inner rule hooks in sequential order.
+ - tick_pre_step_all: Things to do before the agents do their actions. Statechange, Moving, Spawning etc...
+ - agent tick: Agents do their actions.
+ - tick_step_all: Things to do after the agents did their actions. Statechange, Moving, Spawning etc...
+ - tick_post_step_all: Things to do at the very end of each step. Counting, Reward calculations etc...
+
+ :return: List of *Result*-objects.
+ """
results = list()
self.curr_step += 1
@@ -112,11 +151,23 @@ class Gamestate(object):
return results
- def print(self, string):
+ def print(self, string) -> None:
+ """
+ When *verbose* is active, print stuff.
+
+ :param string: *String* to print.
+ :type string: str
+ :return: Nothing
+ """
if self.verbose:
print(string)
- def check_done(self):
+ def check_done(self) -> List[DoneResult]:
+ """
+ Iterate all **Rules** that override tehe *on_ckeck_done* hook.
+
+ :return: List of Results
+ """
results = list()
for rule in self.rules:
if on_check_done_result := rule.on_check_done(self):
@@ -124,20 +175,44 @@ class Gamestate(object):
return results
def get_all_pos_with_collisions(self) -> List[Tuple[(int, int)]]:
- positions = [pos for pos, entities in self.entities.pos_dict.items() if len(entities) >= 2 and (len([e for e in entities if e.var_can_collide]) >= 2)]
+ """
+ Returns a list positions [(x, y), ... ] on which collisions occur. This does not include agents,
+ that were unable to move because their target direction was blocked, also a form of collision.
+
+ :return: List of positions.
+ """
+ positions = [pos for pos, entities in self.entities.pos_dict.items() if
+ len(entities) >= 2 and (len([e for e in entities if e.var_can_collide]) >= 2)
+ ]
return positions
- def check_move_validity(self, moving_entity, position):
- if moving_entity.pos != position and not any(
- entity.var_is_blocking_pos for entity in self.entities.pos_dict[position]) and not (
- moving_entity.var_is_blocking_pos and self.entities.is_occupied(position)):
- return True
- else:
- return False
+ def check_move_validity(self, moving_entity: Entity, target_position: (int, int)) -> bool:
+ """
+ Whether it is safe to move to the target positions and moving entity does not introduce a blocking attribute,
+ when position is allready occupied.
- def check_pos_validity(self, position):
- if not any(entity.var_is_blocking_pos for entity in self.entities.pos_dict[position]):
- return True
- else:
- return False
+ :param moving_entity: Entity
+ :param target_position: pos
+ :return: Safe to move to
+ """
+ is_not_blocked = self.check_pos_validity(target_position)
+ will_not_block_others = moving_entity.var_is_blocking_pos and self.entities.is_occupied(target_position)
+
+ if moving_entity.pos != target_position and is_not_blocked and not will_not_block_others:
+ return c.VALID
+ else:
+ return c.NOT_VALID
+
+ def check_pos_validity(self, pos: (int, int)) -> bool:
+ """
+ Check if *pos* is a valid position to move or spawn to.
+
+ :param pos: position to check
+ :return: Wheter pos is a valid target.
+ """
+
+ if not any(e.var_is_blocking_pos for e in self.entities.pos_dict[pos]) and pos in self.entities.floorlist:
+ return c.VALID
+ else:
+ return c.NOT_VALID
diff --git a/marl_factory_grid/utils/tools.py b/marl_factory_grid/utils/tools.py
index 63c9f69..73fa50d 100644
--- a/marl_factory_grid/utils/tools.py
+++ b/marl_factory_grid/utils/tools.py
@@ -28,7 +28,9 @@ class ConfigExplainer:
def explain_module(self, class_to_explain):
parameters = inspect.signature(class_to_explain).parameters
- explained = {class_to_explain.__name__: {key: val.default for key, val in parameters.items() if key not in EXCLUDED}}
+ explained = {class_to_explain.__name__:
+ {key: val.default for key, val in parameters.items() if key not in EXCLUDED}
+ }
return explained
def _load_and_compare(self, compare_class, paths):
diff --git a/random_testrun.py b/random_testrun.py
index 9bebf17..ef8df08 100644
--- a/random_testrun.py
+++ b/random_testrun.py
@@ -6,18 +6,21 @@ from marl_factory_grid.environment.factory import Factory
from marl_factory_grid.utils.logging.envmonitor import EnvMonitor
from marl_factory_grid.utils.logging.recorder import EnvRecorder
+from marl_factory_grid.utils.plotting.plot_single_runs import plot_single_run
from marl_factory_grid.utils.tools import ConfigExplainer
if __name__ == '__main__':
# Render at each step?
- render = True
+ render = False
# Reveal all possible Modules (Entities, Rules, Agents[Actions, Observations], etc.)
explain_config = False
# Collect statistics?
- monitor = False
+ monitor = True
# Record as Protobuf?
record = False
+ # Plot Results?
+ plotting = True
run_path = Path('study_out')
@@ -38,7 +41,7 @@ if __name__ == '__main__':
factory = EnvRecorder(factory)
# RL learn Loop
- for episode in trange(500):
+ for episode in trange(10):
_ = factory.reset()
done = False
if render:
@@ -54,7 +57,10 @@ if __name__ == '__main__':
break
if monitor:
- factory.save_run(run_path / 'test.pkl')
+ factory.save_run(run_path / 'test_monitor.pkl')
if record:
factory.save_records(run_path / 'test.pb')
+ if plotting:
+ plot_single_run(run_path)
+
print('Done!!! Goodbye....')
diff --git a/reload_agent.py b/reload_agent.py
index f0ed389..0fb8066 100644
--- a/reload_agent.py
+++ b/reload_agent.py
@@ -56,6 +56,7 @@ if __name__ == '__main__':
for model_idx, model in enumerate(models)]
else:
actions = models[0].predict(env_state, deterministic=determin)[0]
+ # noinspection PyTupleAssignmentBalance
env_state, step_r, done_bool, info_obj = env.step(actions)
rew += step_r
diff --git a/transform_wg_to_json_no_priv.py b/transform_wg_to_json_no_priv.py
index d9bc8e1..1b7ef3e 100644
--- a/transform_wg_to_json_no_priv.py
+++ b/transform_wg_to_json_no_priv.py
@@ -5,7 +5,6 @@ from pathlib import Path
if __name__ == '__main__':
-
conf_path = Path('wg0')
wg0_conf = configparser.ConfigParser()
wg0_conf.read(conf_path/'wg0.conf')
@@ -17,7 +16,6 @@ if __name__ == '__main__':
# Delete any old conf.json for the current peer
(conf_path / f'{client_name}.json').unlink(missing_ok=True)
-
peer = wg0_conf[client_name]
date_time = datetime.now().strftime('%Y-%m-%dT%H:%M:%S.%f000Z')