Resolved some warnings and style issues

This commit is contained in:
Steffen Illium 2023-11-10 09:29:54 +01:00
parent a9462a8b6f
commit 6711a0976b
64 changed files with 331 additions and 361 deletions

5
.idea/.gitignore generated vendored Normal file
View File

@ -0,0 +1,5 @@
# Default ignored files
/shelf/
/workspace.xml
# Editor-based HTTP Client requests
/httpRequests/

View File

@ -94,7 +94,7 @@ All [Entites](marl_factory_grid/environment/entity/global_entities.py) are avail
#### Rules #### Rules
[Rules](marl_factory_grid/environment/entity/object.py) define how the environment behaves on micro-scale. [Rules](marl_factory_grid/environment/entity/object.py) define how the environment behaves on microscale.
Each of the hookes (`on_init`, `pre_step`, `on_step`, '`post_step`', `on_done`) Each of the hookes (`on_init`, `pre_step`, `on_step`, '`post_step`', `on_done`)
provide env-access to implement customn logic, calculate rewards, or gather information. provide env-access to implement customn logic, calculate rewards, or gather information.
@ -107,6 +107,7 @@ Make sure to bring your own assets for each Entity living in the Gridworld as th
PNG-files (transparent background) of square aspect-ratio should do the job, in general. PNG-files (transparent background) of square aspect-ratio should do the job, in general.
<img src="/marl_factory_grid/environment/assets/wall.png" width="5%"> <img src="/marl_factory_grid/environment/assets/wall.png" width="5%">
<!--suppress HtmlUnknownAttribute -->
<html &nbsp&nbsp&nbsp&nbsp html> <html &nbsp&nbsp&nbsp&nbsp html>
<img src="/marl_factory_grid/environment/assets/agent/agent.png" width="5%"> <img src="/marl_factory_grid/environment/assets/agent/agent.png" width="5%">

View File

@ -1 +1 @@
from .quickstart import init from .quickstart import init

View File

@ -1 +1,4 @@
import os, sys; sys.path.append(os.path.dirname(os.path.realpath(__file__))) import os
import sys
sys.path.append(os.path.dirname(os.path.realpath(__file__)))

View File

@ -1 +1 @@
from marl_factory_grid.algorithms.marl.memory import MARLActorCriticMemory from marl_factory_grid.algorithms.marl.memory import MARLActorCriticMemory

View File

@ -28,6 +28,7 @@ class Names:
BATCH_SIZE = 'bnatch_size' BATCH_SIZE = 'bnatch_size'
N_ACTIONS = 'n_actions' N_ACTIONS = 'n_actions'
nms = Names nms = Names
ListOrTensor = Union[List, torch.Tensor] ListOrTensor = Union[List, torch.Tensor]
@ -112,10 +113,9 @@ class BaseActorCritic:
next_obs, reward, done, info = env.step(action) next_obs, reward, done, info = env.step(action)
done = [done] * self.n_agents if isinstance(done, bool) else done done = [done] * self.n_agents if isinstance(done, bool) else done
last_hiddens = dict(hidden_actor =out[nms.HIDDEN_ACTOR], last_hiddens = dict(hidden_actor=out[nms.HIDDEN_ACTOR],
hidden_critic=out[nms.HIDDEN_CRITIC]) hidden_critic=out[nms.HIDDEN_CRITIC])
tm.add(observation=obs, action=action, reward=reward, done=done, tm.add(observation=obs, action=action, reward=reward, done=done,
logits=out.get(nms.LOGITS, None), values=out.get(nms.CRITIC, None), logits=out.get(nms.LOGITS, None), values=out.get(nms.CRITIC, None),
**last_hiddens) **last_hiddens)
@ -142,7 +142,9 @@ class BaseActorCritic:
print(f'reward at episode: {episode} = {rew_log}') print(f'reward at episode: {episode} = {rew_log}')
episode += 1 episode += 1
df_results.append([episode, rew_log, *reward]) df_results.append([episode, rew_log, *reward])
df_results = pd.DataFrame(df_results, columns=['steps', 'reward', *[f'agent#{i}' for i in range(self.n_agents)]]) df_results = pd.DataFrame(df_results,
columns=['steps', 'reward', *[f'agent#{i}' for i in range(self.n_agents)]]
)
if checkpointer is not None: if checkpointer is not None:
df_results.to_csv(checkpointer.path / 'results.csv', index=False) df_results.to_csv(checkpointer.path / 'results.csv', index=False)
return df_results return df_results
@ -157,24 +159,27 @@ class BaseActorCritic:
last_action, reward = [-1] * self.n_agents, [0.] * self.n_agents last_action, reward = [-1] * self.n_agents, [0.] * self.n_agents
done, rew_log, eps_rew = [False] * self.n_agents, 0, torch.zeros(self.n_agents) done, rew_log, eps_rew = [False] * self.n_agents, 0, torch.zeros(self.n_agents)
while not all(done): while not all(done):
if render: env.render() if render:
env.render()
out = self.forward(obs, last_action, **last_hiddens) out = self.forward(obs, last_action, **last_hiddens)
action = self.get_actions(out) action = self.get_actions(out)
next_obs, reward, done, info = env.step(action) next_obs, reward, done, info = env.step(action)
if isinstance(done, bool): done = [done] * obs.shape[0] if isinstance(done, bool):
done = [done] * obs.shape[0]
obs = next_obs obs = next_obs
last_action = action last_action = action
last_hiddens = dict(hidden_actor=out.get(nms.HIDDEN_ACTOR, None), last_hiddens = dict(hidden_actor=out.get(nms.HIDDEN_ACTOR, None),
hidden_critic=out.get(nms.HIDDEN_CRITIC, None) hidden_critic=out.get(nms.HIDDEN_CRITIC, None)
) )
eps_rew += torch.tensor(reward) eps_rew += torch.tensor(reward)
results.append(eps_rew.tolist() + [sum(eps_rew).item()] + [episode]) results.append(eps_rew.tolist() + [np.sum(eps_rew).item()] + [episode])
episode += 1 episode += 1
agent_columns = [f'agent#{i}' for i in range(self.cfg['environment']['n_agents'])] agent_columns = [f'agent#{i}' for i in range(self.cfg['environment']['n_agents'])]
results = pd.DataFrame(results, columns=agent_columns + ['sum', 'episode']) results = pd.DataFrame(results, columns=agent_columns + ['sum', 'episode'])
results = pd.melt(results, id_vars=['episode'], value_vars=agent_columns + ['sum'], value_name='reward', var_name='agent') results = pd.melt(results, id_vars=['episode'], value_vars=agent_columns + ['sum'],
value_name='reward', var_name='agent')
return results return results
@staticmethod @staticmethod

View File

@ -36,7 +36,7 @@ class LoopMAPPO(LoopSNAC):
rewards_ = torch.stack(rewards_, dim=1) rewards_ = torch.stack(rewards_, dim=1)
return rewards_ return rewards_
def mappo(self, batch, network, gamma, entropy_coef, vf_coef, clip_range, **kwargs): def mappo(self, batch, network, gamma, entropy_coef, vf_coef, clip_range, **__):
out = network(batch[nms.OBSERVATION], batch[nms.ACTION], batch[nms.HIDDEN_ACTOR], batch[nms.HIDDEN_CRITIC]) out = network(batch[nms.OBSERVATION], batch[nms.ACTION], batch[nms.HIDDEN_ACTOR], batch[nms.HIDDEN_CRITIC])
logits = out[nms.LOGITS][:, :-1] # last one only needed for v_{t+1} logits = out[nms.LOGITS][:, :-1] # last one only needed for v_{t+1}
@ -45,7 +45,7 @@ class LoopMAPPO(LoopSNAC):
# monte carlo returns # monte carlo returns
mc_returns = self.monte_carlo_returns(batch[nms.REWARD], batch[nms.DONE], gamma) mc_returns = self.monte_carlo_returns(batch[nms.REWARD], batch[nms.DONE], gamma)
mc_returns = (mc_returns - mc_returns.mean()) / (mc_returns.std() + 1e-8) #todo: norm across agent ok? mc_returns = (mc_returns - mc_returns.mean()) / (mc_returns.std() + 1e-8) # todo: norm across agent ok?
advantages = mc_returns - out[nms.CRITIC][:, :-1] advantages = mc_returns - out[nms.CRITIC][:, :-1]
# policy loss # policy loss

View File

@ -1,8 +1,7 @@
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import numpy as np
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn.utils import spectral_norm
class RecurrentAC(nn.Module): class RecurrentAC(nn.Module):
@ -88,8 +87,8 @@ class NormalizedLinear(nn.Linear):
self.trainable_magnitude = trainable_magnitude self.trainable_magnitude = trainable_magnitude
self.scale = nn.Parameter(torch.tensor([1.]), requires_grad=trainable_magnitude) self.scale = nn.Parameter(torch.tensor([1.]), requires_grad=trainable_magnitude)
def forward(self, input): def forward(self, in_array):
normalized_input = F.normalize(input, dim=-1, p=2, eps=1e-5) normalized_input = F.normalize(in_array, dim=-1, p=2, eps=1e-5)
normalized_weight = F.normalize(self.weight, dim=-1, p=2, eps=1e-5) normalized_weight = F.normalize(self.weight, dim=-1, p=2, eps=1e-5)
return F.linear(normalized_input, normalized_weight) * self.d_sqrt * self.scale return F.linear(normalized_input, normalized_weight) * self.d_sqrt * self.scale

View File

@ -16,7 +16,7 @@ class LoopSEAC(LoopIAC):
with torch.inference_mode(True): with torch.inference_mode(True):
true_action_logp = torch.stack([ true_action_logp = torch.stack([
torch.log_softmax(out[nms.LOGITS][ag_i, :-1], -1) torch.log_softmax(out[nms.LOGITS][ag_i, :-1], -1)
.gather(index=actions[ag_i, 1:, None], dim=-1) .gather(index=actions[ag_i, 1:, None], dim=-1)
for ag_i, out in enumerate(outputs) for ag_i, out in enumerate(outputs)
], 0).squeeze() ], 0).squeeze()
@ -38,7 +38,6 @@ class LoopSEAC(LoopIAC):
a2c_loss = (-iw*log_ap * advantages.detach()).mean(-1) a2c_loss = (-iw*log_ap * advantages.detach()).mean(-1)
value_loss = (iw*advantages.pow(2)).mean(-1) # n_agent value_loss = (iw*advantages.pow(2)).mean(-1) # n_agent
# weighted loss # weighted loss
@ -53,4 +52,4 @@ class LoopSEAC(LoopIAC):
self.optimizer[ag_i].zero_grad() self.optimizer[ag_i].zero_grad()
loss.backward() loss.backward()
torch.nn.utils.clip_grad_norm_(self.net[ag_i].parameters(), 0.5) torch.nn.utils.clip_grad_norm_(self.net[ag_i].parameters(), 0.5)
self.optimizer[ag_i].step() self.optimizer[ag_i].step()

View File

@ -30,4 +30,4 @@ class LoopSNAC(BaseActorCritic):
self._as_torch(actions).unsqueeze(1), self._as_torch(actions).unsqueeze(1),
hidden_actor, hidden_critic hidden_actor, hidden_critic
) )
return out return out

View File

@ -56,8 +56,8 @@ class TSPBaseAgent(ABC):
def _door_is_close(self, state): def _door_is_close(self, state):
try: try:
# return next(y for x in self.state.tile.neighboring_floor for y in x.guests if do.DOOR in y.name) return next(y for x in state.entities.neighboring_positions(self.state.pos)
return next(y for x in state.entities.neighboring_positions(self.state.pos) for y in state.entities.pos_dict[x] if do.DOOR in y.name) for y in state.entities.pos_dict[x] if do.DOOR in y.name)
except StopIteration: except StopIteration:
return None return None

View File

@ -14,8 +14,8 @@ class TSPTargetAgent(TSPBaseAgent):
def _handle_doors(self, state): def _handle_doors(self, state):
try: try:
# return next(y for x in self.state.tile.neighboring_floor for y in x.guests if do.DOOR in y.name) return next(y for x in state.entities.neighboring_positions(self.state.pos)
return next(y for x in state.entities.neighboring_positions(self.state.pos) for y in state.entities.pos_dict[x] if do.DOOR in y.name) for y in state.entities.pos_dict[x] if do.DOOR in y.name)
except StopIteration: except StopIteration:
return None return None

View File

@ -1,8 +1,9 @@
import torch
import numpy as np
import yaml
from pathlib import Path from pathlib import Path
import numpy as np
import torch
import yaml
def load_class(classname): def load_class(classname):
from importlib import import_module from importlib import import_module
@ -42,7 +43,6 @@ def get_class(arguments):
def get_arguments(arguments): def get_arguments(arguments):
from importlib import import_module
d = dict(arguments) d = dict(arguments)
if "classname" in d: if "classname" in d:
del d["classname"] del d["classname"]
@ -82,4 +82,4 @@ class Checkpointer(object):
for name, model in to_save: for name, model in to_save:
self.save_experiment(name, model) self.save_experiment(name, model)
self.__current_checkpoint += 1 self.__current_checkpoint += 1
self.__current_step += 1 self.__current_step += 1

View File

@ -1,4 +1,4 @@
eneral: General:
# Your Seed # Your Seed
env_seed: 69 env_seed: 69
# Individual or global rewards? # Individual or global rewards?
@ -86,4 +86,4 @@ Rules:
DoneAtDestinationReachAll: DoneAtDestinationReachAll:
# reward_at_done: 1 # reward_at_done: 1
DoneAtMaxStepsReached: DoneAtMaxStepsReached:
max_steps: 500 max_steps: 200

View File

@ -1,15 +1,14 @@
import abc import abc
from collections import defaultdict
import numpy as np import numpy as np
from .object import _Object from .object import Object
from .. import constants as c from .. import constants as c
from ...utils.results import ActionResult from ...utils.results import ActionResult
from ...utils.utility_classes import RenderEntity from ...utils.utility_classes import RenderEntity
class Entity(_Object, abc.ABC): class Entity(Object, abc.ABC):
"""Full Env Entity that lives on the environment Grid. Doors, Items, DirtPile etc...""" """Full Env Entity that lives on the environment Grid. Doors, Items, DirtPile etc..."""
@property @property
@ -96,8 +95,9 @@ class Entity(_Object, abc.ABC):
def __init__(self, pos, bind_to=None, **kwargs): def __init__(self, pos, bind_to=None, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self._view_directory = c.VALUE_NO_POS
self._status = None self._status = None
self.set_pos(pos) self._pos = pos
self._last_pos = pos self._last_pos = pos
if bind_to: if bind_to:
try: try:
@ -113,10 +113,6 @@ class Entity(_Object, abc.ABC):
def render(self): def render(self):
return RenderEntity(self.__class__.__name__.lower(), self.pos) return RenderEntity(self.__class__.__name__.lower(), self.pos)
@abc.abstractmethod
def render(self):
return RenderEntity(self.__class__.__name__.lower(), self.pos)
@property @property
def obs_tag(self): def obs_tag(self):
try: try:
@ -133,25 +129,3 @@ class Entity(_Object, abc.ABC):
self._collection.delete_env_object(self) self._collection.delete_env_object(self)
self._collection = other_collection self._collection = other_collection
return self._collection == other_collection return self._collection == other_collection
@classmethod
def from_coordinates(cls, positions: [(int, int)], *args, entity_kwargs=None, **kwargs, ):
collection = cls(*args, **kwargs)
collection.add_items(
[cls._entity(tuple(pos), **entity_kwargs if entity_kwargs is not None else {}) for pos in positions])
return collection
def notify_del_entity(self, entity):
try:
self.pos_dict[entity.pos].remove(entity)
except (ValueError, AttributeError):
pass
def by_pos(self, pos: (int, int)):
pos = tuple(pos)
try:
return self.state.entities.pos_dict[pos]
except StopIteration:
pass
except ValueError:
pass

View File

@ -5,7 +5,7 @@ from marl_factory_grid.environment import constants as c
import marl_factory_grid.utils.helpers as h import marl_factory_grid.utils.helpers as h
class _Object: class Object:
"""Generell Objects for Organisation and Maintanance such as Actions etc...""" """Generell Objects for Organisation and Maintanance such as Actions etc..."""
_u_idx = defaultdict(lambda: 0) _u_idx = defaultdict(lambda: 0)
@ -50,15 +50,15 @@ class _Object:
print(f'Following kwargs were passed, but ignored: {kwargs}') print(f'Following kwargs were passed, but ignored: {kwargs}')
def __repr__(self): def __repr__(self):
name = self.name name = self.name
if self.bound_entity: if self.bound_entity:
name = h.add_bound_name(name, self.bound_entity) name = h.add_bound_name(name, self.bound_entity)
try: try:
if self.var_has_position: if self.var_has_position:
name = h.add_pos_name(name, self) name = h.add_pos_name(name, self)
except (AttributeError): except AttributeError:
pass pass
return name return name
def __eq__(self, other) -> bool: def __eq__(self, other) -> bool:
return other == self.identifier return other == self.identifier
@ -67,8 +67,8 @@ class _Object:
return hash(self.identifier) return hash(self.identifier)
def _identify_and_count_up(self): def _identify_and_count_up(self):
idx = _Object._u_idx[self.__class__.__name__] idx = Object._u_idx[self.__class__.__name__]
_Object._u_idx[self.__class__.__name__] += 1 Object._u_idx[self.__class__.__name__] += 1
return idx return idx
def set_collection(self, collection): def set_collection(self, collection):
@ -98,79 +98,3 @@ class _Object:
def unbind(self): def unbind(self):
self._bound_entity = None self._bound_entity = None
# class EnvObject(_Object):
# """Objects that hold Information that are observable, but have no position on the environment grid. Inventories etc..."""
#
# _u_idx = defaultdict(lambda: 0)
#
# @property
# def obs_tag(self):
# try:
# return self._collection.name or self.name
# except AttributeError:
# return self.name
#
# @property
# def var_is_blocking_light(self):
# try:
# return self._collection.var_is_blocking_light or False
# except AttributeError:
# return False
#
# @property
# def var_can_be_bound(self):
# try:
# return self._collection.var_can_be_bound or False
# except AttributeError:
# return False
#
# @property
# def var_can_move(self):
# try:
# return self._collection.var_can_move or False
# except AttributeError:
# return False
#
# @property
# def var_is_blocking_pos(self):
# try:
# return self._collection.var_is_blocking_pos or False
# except AttributeError:
# return False
#
# @property
# def var_has_position(self):
# try:
# return self._collection.var_has_position or False
# except AttributeError:
# return False
#
# @property
# def var_can_collide(self):
# try:
# return self._collection.var_can_collide or False
# except AttributeError:
# return False
#
#
# @property
# def encoding(self):
# return c.VALUE_OCCUPIED_CELL
#
#
# def __init__(self, **kwargs):
# self._bound_entity = None
# super(EnvObject, self).__init__(**kwargs)
#
#
# def change_parent_collection(self, other_collection):
# other_collection.add_item(self)
# self._collection.delete_env_object(self)
# self._collection = other_collection
# return self._collection == other_collection
#
#
# def summarize_state(self):
# return dict(name=str(self.name))

View File

@ -1,6 +1,6 @@
import numpy as np import numpy as np
from marl_factory_grid.environment.entity.object import _Object from marl_factory_grid.environment.entity.object import Object
########################################################################## ##########################################################################
@ -8,7 +8,7 @@ from marl_factory_grid.environment.entity.object import _Object
########################################################################## ##########################################################################
class PlaceHolder(_Object): class PlaceHolder(Object):
def __init__(self, *args, fill_value=0, **kwargs): def __init__(self, *args, fill_value=0, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@ -27,7 +27,7 @@ class PlaceHolder(_Object):
return self.__class__.__name__ return self.__class__.__name__
class GlobalPosition(_Object): class GlobalPosition(Object):
@property @property
def encoding(self): def encoding(self):

View File

@ -56,15 +56,18 @@ class Factory(gym.Env):
self.level_filepath = Path(custom_level_path) self.level_filepath = Path(custom_level_path)
else: else:
self.level_filepath = Path(__file__).parent.parent / h.LEVELS_DIR / f'{self.conf.level_name}.txt' self.level_filepath = Path(__file__).parent.parent / h.LEVELS_DIR / f'{self.conf.level_name}.txt'
self._renderer = None # expensive - don't use; unless required !
parsed_entities = self.conf.load_entities() parsed_entities = self.conf.load_entities()
self.map = LevelParser(self.level_filepath, parsed_entities, self.conf.pomdp_r) self.map = LevelParser(self.level_filepath, parsed_entities, self.conf.pomdp_r)
# Init for later usage: # Init for later usage:
self.state: Gamestate # noinspection PyTypeChecker
self.map: LevelParser self.state: Gamestate = None
self.obs_builder: OBSBuilder # noinspection PyTypeChecker
self.obs_builder: OBSBuilder = None
# expensive - don't use; unless required !
self._renderer = None
# reset env to initial state, preparing env for new episode. # reset env to initial state, preparing env for new episode.
# returns tuple where the first dict contains initial observation for each agent in the env # returns tuple where the first dict contains initial observation for each agent in the env
@ -74,7 +77,7 @@ class Factory(gym.Env):
return self.state.entities[item] return self.state.entities[item]
def reset(self) -> (dict, dict): def reset(self) -> (dict, dict):
if hasattr(self, 'state'): if self.state is not None:
for entity_group in self.state.entities: for entity_group in self.state.entities:
try: try:
entity_group[0].reset_uid() entity_group[0].reset_uid()
@ -160,7 +163,7 @@ class Factory(gym.Env):
# Finalize # Finalize
reward, reward_info, done = self.summarize_step_results(tick_result, done_results) reward, reward_info, done = self.summarize_step_results(tick_result, done_results)
info = reward_info info = dict(reward_info)
info.update(step_reward=sum(reward), step=self.state.curr_step) info.update(step_reward=sum(reward), step=self.state.curr_step)

View File

@ -1,15 +1,15 @@
from typing import List, Tuple, Union, Dict from typing import List, Tuple, Union, Dict
from marl_factory_grid.environment.entity.entity import Entity from marl_factory_grid.environment.entity.entity import Entity
from marl_factory_grid.environment.groups.objects import _Objects from marl_factory_grid.environment.groups.objects import Objects
# noinspection PyProtectedMember # noinspection PyProtectedMember
from marl_factory_grid.environment.entity.object import _Object from marl_factory_grid.environment.entity.object import Object
import marl_factory_grid.environment.constants as c import marl_factory_grid.environment.constants as c
from marl_factory_grid.utils.results import Result from marl_factory_grid.utils.results import Result
class Collection(_Objects): class Collection(Objects):
_entity = _Object # entity? _entity = Object # entity?
symbol = None symbol = None
@property @property
@ -58,7 +58,7 @@ class Collection(_Objects):
def trigger_spawn(self, state, *entity_args, coords_or_quantity=None, ignore_blocking=False, **entity_kwargs): def trigger_spawn(self, state, *entity_args, coords_or_quantity=None, ignore_blocking=False, **entity_kwargs):
coords_or_quantity = coords_or_quantity if coords_or_quantity else self._coords_or_quantity coords_or_quantity = coords_or_quantity if coords_or_quantity else self._coords_or_quantity
if self.var_has_position: if self.var_has_position:
if isinstance(coords_or_quantity, int): if self.var_has_position and isinstance(coords_or_quantity, int):
if ignore_blocking or self._ignore_blocking: if ignore_blocking or self._ignore_blocking:
coords_or_quantity = state.entities.floorlist[:coords_or_quantity] coords_or_quantity = state.entities.floorlist[:coords_or_quantity]
else: else:
@ -87,8 +87,8 @@ class Collection(_Objects):
raise ValueError(f'{self._entity.__name__} has no position!') raise ValueError(f'{self._entity.__name__} has no position!')
return c.VALID return c.VALID
def despawn(self, items: List[_Object]): def despawn(self, items: List[Object]):
items = [items] if isinstance(items, _Object) else items items = [items] if isinstance(items, Object) else items
for item in items: for item in items:
del self[item] del self[item]

View File

@ -3,12 +3,12 @@ from operator import itemgetter
from random import shuffle from random import shuffle
from typing import Dict from typing import Dict
from marl_factory_grid.environment.groups.objects import _Objects from marl_factory_grid.environment.groups.objects import Objects
from marl_factory_grid.utils.helpers import POS_MASK from marl_factory_grid.utils.helpers import POS_MASK
class Entities(_Objects): class Entities(Objects):
_entity = _Objects _entity = Objects
@staticmethod @staticmethod
def neighboring_positions(pos): def neighboring_positions(pos):
@ -87,7 +87,7 @@ class Entities(_Objects):
def __delitem__(self, name): def __delitem__(self, name):
assert_str = 'This group of entity does not exist in this collection!' assert_str = 'This group of entity does not exist in this collection!'
assert any([key for key in name.keys() if key in self.keys()]), assert_str assert any([key for key in name.keys() if key in self.keys()]), assert_str
self[name]._observers.delete(self) self[name].del_observer(self)
for entity in self[name]: for entity in self[name]:
entity.del_observer(self) entity.del_observer(self)
return super(Entities, self).__delitem__(name) return super(Entities, self).__delitem__(name)

View File

@ -1,15 +1,15 @@
from collections import defaultdict from collections import defaultdict
from typing import List from typing import List, Iterator, Union
import numpy as np import numpy as np
from marl_factory_grid.environment.entity.object import _Object from marl_factory_grid.environment.entity.object import Object
import marl_factory_grid.environment.constants as c import marl_factory_grid.environment.constants as c
from marl_factory_grid.utils import helpers as h from marl_factory_grid.utils import helpers as h
class _Objects: class Objects:
_entity = _Object _entity = Object
@property @property
def var_can_be_bound(self): def var_can_be_bound(self):
@ -50,7 +50,7 @@ class _Objects:
def __len__(self): def __len__(self):
return len(self._data) return len(self._data)
def __iter__(self): def __iter__(self) -> Iterator[Union[Object, None]]:
return iter(self.values()) return iter(self.values())
def add_item(self, item: _entity): def add_item(self, item: _entity):
@ -130,13 +130,14 @@ class _Objects:
repr_dict = {key: val for key, val in self._data.items() if key not in [c.WALLS]} repr_dict = {key: val for key, val in self._data.items() if key not in [c.WALLS]}
return f'{self.__class__.__name__}[{repr_dict}]' return f'{self.__class__.__name__}[{repr_dict}]'
def notify_del_entity(self, entity: _Object): def notify_del_entity(self, entity: Object):
try: try:
# noinspection PyUnresolvedReferences
self.pos_dict[entity.pos].remove(entity) self.pos_dict[entity.pos].remove(entity)
except (AttributeError, ValueError, IndexError): except (AttributeError, ValueError, IndexError):
pass pass
def notify_add_entity(self, entity: _Object): def notify_add_entity(self, entity: Object):
try: try:
if self not in entity.observers: if self not in entity.observers:
entity.add_observer(self) entity.add_observer(self)

View File

@ -1,11 +1,11 @@
import abc import abc
from random import shuffle from random import shuffle
from typing import List, Collection, Union from typing import List, Collection
from marl_factory_grid.environment import rewards as r, constants as c
from marl_factory_grid.environment.entity.agent import Agent from marl_factory_grid.environment.entity.agent import Agent
from marl_factory_grid.utils import helpers as h from marl_factory_grid.utils import helpers as h
from marl_factory_grid.utils.results import TickResult, DoneResult from marl_factory_grid.utils.results import TickResult, DoneResult
from marl_factory_grid.environment import rewards as r, constants as c
class Rule(abc.ABC): class Rule(abc.ABC):
@ -118,8 +118,7 @@ class AssignGlobalPositions(Rule):
def on_init(self, state, lvl_map): def on_init(self, state, lvl_map):
from marl_factory_grid.environment.entity.util import GlobalPosition from marl_factory_grid.environment.entity.util import GlobalPosition
for agent in state[c.AGENT]: for agent in state[c.AGENT]:
gp = GlobalPosition(lvl_map.level_shape) gp = GlobalPosition(agent, lvl_map.level_shape)
gp.bind_to(agent)
state[c.GLOBALPOSITIONS].add_item(gp) state[c.GLOBALPOSITIONS].add_item(gp)
return [] return []

View File

@ -6,7 +6,9 @@ from marl_factory_grid.utils.results import TickResult, DoneResult
class TemplateRule(Rule): class TemplateRule(Rule):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(TemplateRule, self).__init__(*args, **kwargs) super(TemplateRule, self).__init__()
self.args = args
self.kwargs = kwargs
def on_init(self, state, lvl_map): def on_init(self, state, lvl_map):
pass pass

View File

@ -1,6 +1,5 @@
from typing import Union from typing import Union
import marl_factory_grid.modules.batteries.constants
from marl_factory_grid.environment.actions import Action from marl_factory_grid.environment.actions import Action
from marl_factory_grid.utils.results import ActionResult from marl_factory_grid.utils.results import ActionResult
@ -24,5 +23,6 @@ class BtryCharge(Action):
else: else:
valid = c.NOT_VALID valid = c.NOT_VALID
state.print(f'{entity.name} failed to charged batteries at {entity.pos}.') state.print(f'{entity.name} failed to charged batteries at {entity.pos}.')
return ActionResult(entity=entity, identifier=self._identifier, validity=valid, return ActionResult(entity=entity, identifier=self._identifier, validity=valid,
reward=marl_factory_grid.modules.batteries.constants.REWARD_CHARGE_VALID if valid else marl_factory_grid.modules.batteries.constants.Reward_CHARGE_FAIL) reward=b.REWARD_CHARGE_VALID if valid else b.Reward_CHARGE_FAIL)

View File

@ -1,11 +1,11 @@
from marl_factory_grid.environment import constants as c from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.entity.entity import Entity from marl_factory_grid.environment.entity.entity import Entity
from marl_factory_grid.environment.entity.object import _Object from marl_factory_grid.environment.entity.object import Object
from marl_factory_grid.modules.batteries import constants as b from marl_factory_grid.modules.batteries import constants as b
from marl_factory_grid.utils.utility_classes import RenderEntity from marl_factory_grid.utils.utility_classes import RenderEntity
class Battery(_Object): class Battery(Object):
@property @property
def var_can_be_bound(self): def var_can_be_bound(self):

View File

@ -1,11 +1,9 @@
from typing import List, Union from typing import List, Union
import marl_factory_grid.modules.batteries.constants
from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.utils.results import TickResult, DoneResult
from marl_factory_grid.environment import constants as c from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.modules.batteries import constants as b from marl_factory_grid.modules.batteries import constants as b
from marl_factory_grid.utils.results import TickResult, DoneResult
class BatteryDecharge(Rule): class BatteryDecharge(Rule):

View File

@ -1,5 +1,3 @@
from numpy import random
from marl_factory_grid.environment.entity.entity import Entity from marl_factory_grid.environment.entity.entity import Entity
from marl_factory_grid.utils.utility_classes import RenderEntity from marl_factory_grid.utils.utility_classes import RenderEntity
from marl_factory_grid.modules.clean_up import constants as d from marl_factory_grid.modules.clean_up import constants as d

View File

@ -1,9 +1,7 @@
from typing import Union, List, Tuple
from marl_factory_grid.environment import constants as c from marl_factory_grid.environment import constants as c
from marl_factory_grid.utils.results import Result
from marl_factory_grid.environment.groups.collection import Collection from marl_factory_grid.environment.groups.collection import Collection
from marl_factory_grid.modules.clean_up.entitites import DirtPile from marl_factory_grid.modules.clean_up.entitites import DirtPile
from marl_factory_grid.utils.results import Result
class DirtPiles(Collection): class DirtPiles(Collection):

View File

@ -49,7 +49,7 @@ class RespawnDirt(Rule):
def tick_step(self, state): def tick_step(self, state):
collection = state[d.DIRT] collection = state[d.DIRT]
if self._next_dirt_spawn < 0: if self._next_dirt_spawn < 0:
pass # No DirtPile Spawn result = [] # No DirtPile Spawn
elif not self._next_dirt_spawn: elif not self._next_dirt_spawn:
result = [collection.trigger_spawn(state, coords_or_quantity=self.respawn_n, amount=self.respawn_amount)] result = [collection.trigger_spawn(state, coords_or_quantity=self.respawn_n, amount=self.respawn_amount)]
self._next_dirt_spawn = self.respawn_freq self._next_dirt_spawn = self.respawn_freq

View File

@ -21,4 +21,4 @@ class DestAction(Action):
valid = c.NOT_VALID valid = c.NOT_VALID
state.print(f'{entity.name} just tried to do_wait_action do_wait_action at {entity.pos} but failed') state.print(f'{entity.name} just tried to do_wait_action do_wait_action at {entity.pos} but failed')
return ActionResult(entity=entity, identifier=self._identifier, validity=valid, return ActionResult(entity=entity, identifier=self._identifier, validity=valid,
reward=marl_factory_grid.modules.destinations.constants.REWARD_WAIT_VALID if valid else marl_factory_grid.modules.destinations.constants.REWARD_WAIT_FAIL) reward=d.REWARD_WAIT_VALID if valid else d.REWARD_WAIT_FAIL)

View File

@ -1,7 +1,5 @@
from marl_factory_grid.environment.groups.collection import Collection from marl_factory_grid.environment.groups.collection import Collection
from marl_factory_grid.modules.destinations.entitites import Destination from marl_factory_grid.modules.destinations.entitites import Destination
from marl_factory_grid.environment import constants as c
from marl_factory_grid.modules.destinations import constants as d
class Destinations(Collection): class Destinations(Collection):

View File

@ -1,5 +1,3 @@
from typing import Union
from marl_factory_grid.environment.groups.collection import Collection from marl_factory_grid.environment.groups.collection import Collection
from marl_factory_grid.modules.doors import constants as d from marl_factory_grid.modules.doors import constants as d
from marl_factory_grid.modules.doors.entitites import Door from marl_factory_grid.modules.doors.entitites import Door

View File

@ -1,2 +1,2 @@
USE_DOOR_VALID: float = -0.00 USE_DOOR_VALID: float = -0.00
USE_DOOR_FAIL: float = -0.01 USE_DOOR_FAIL: float = -0.01

View File

@ -1,8 +1,8 @@
import random import random
from typing import List, Union from typing import List
from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.environment import constants as c from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.utils.results import TickResult from marl_factory_grid.utils.results import TickResult
@ -14,8 +14,6 @@ class AgentSingleZonePlacementBeta(Rule):
super().__init__() super().__init__()
def on_init(self, state, lvl_map): def on_init(self, state, lvl_map):
zones = state[c.ZONES]
n_zones = state[c.ZONES]
agents = state[c.AGENT] agents = state[c.AGENT]
if len(self.coordinates) == len(agents): if len(self.coordinates) == len(agents):
coordinates = self.coordinates coordinates = self.coordinates
@ -31,4 +29,4 @@ class AgentSingleZonePlacementBeta(Rule):
return [] return []
def tick_post_step(self, state) -> List[TickResult]: def tick_post_step(self, state) -> List[TickResult]:
return [] return []

View File

@ -1,6 +1,3 @@
from typing import NamedTuple
SYMBOL_NO_ITEM = 0 SYMBOL_NO_ITEM = 0
SYMBOL_DROP_OFF = 1 SYMBOL_DROP_OFF = 1
# Item Env # Item Env

View File

@ -14,27 +14,14 @@ class Item(Entity):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@property
def auto_despawn(self):
return self._auto_despawn
@property @property
def encoding(self): def encoding(self):
# Edit this if you want items to be drawn in the ops differently # Edit this if you want items to be drawn in the ops differently
return 1 return 1
def set_auto_despawn(self, auto_despawn):
self._auto_despawn = auto_despawn
def summarize_state(self) -> dict:
super_summarization = super(Item, self).summarize_state()
super_summarization.update(dict(auto_despawn=self.auto_despawn))
return super_summarization
class DropOffLocation(Entity): class DropOffLocation(Entity):
def render(self): def render(self):
return RenderEntity(i.DROP_OFF, self.pos) return RenderEntity(i.DROP_OFF, self.pos)
@ -42,18 +29,16 @@ class DropOffLocation(Entity):
def encoding(self): def encoding(self):
return i.SYMBOL_DROP_OFF return i.SYMBOL_DROP_OFF
def __init__(self, *args, storage_size_until_full: int = 5, auto_item_despawn_interval: int = 5, **kwargs): def __init__(self, *args, storage_size_until_full: int = 5, **kwargs):
super(DropOffLocation, self).__init__(*args, **kwargs) super(DropOffLocation, self).__init__(*args, **kwargs)
self.auto_item_despawn_interval = auto_item_despawn_interval
self.storage = deque(maxlen=storage_size_until_full or None) self.storage = deque(maxlen=storage_size_until_full or None)
def place_item(self, item: Item): def place_item(self, item: Item):
if self.is_full: if self.is_full:
raise RuntimeWarning("There is currently no way to clear the storage or make it unfull.") raise RuntimeWarning("There is currently no way to clear the storage or make it unfull.")
return bc.NOT_VALID # in Zeile 81 verschieben? return bc.NOT_VALID
else: else:
self.storage.append(item) self.storage.append(item)
item.set_auto_despawn(self.auto_item_despawn_interval)
return c.VALID return c.VALID
@property @property

View File

@ -1,12 +1,9 @@
from random import shuffle
from marl_factory_grid.modules.items import constants as i
from marl_factory_grid.environment import constants as c from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.groups.collection import Collection
from marl_factory_grid.environment.groups.objects import _Objects
from marl_factory_grid.environment.groups.mixins import IsBoundMixin
from marl_factory_grid.environment.entity.agent import Agent from marl_factory_grid.environment.entity.agent import Agent
from marl_factory_grid.environment.groups.collection import Collection
from marl_factory_grid.environment.groups.mixins import IsBoundMixin
from marl_factory_grid.environment.groups.objects import Objects
from marl_factory_grid.modules.items import constants as i
from marl_factory_grid.modules.items.entitites import Item, DropOffLocation from marl_factory_grid.modules.items.entitites import Item, DropOffLocation
from marl_factory_grid.utils.results import Result from marl_factory_grid.utils.results import Result
@ -74,13 +71,12 @@ class Inventory(IsBoundMixin, Collection):
self._collection = collection self._collection = collection
class Inventories(_Objects): class Inventories(Objects):
_entity = Inventory _entity = Inventory
var_can_move = False var_can_move = False
var_has_position = False var_has_position = False
symbol = None symbol = None
@property @property
@ -116,7 +112,6 @@ class Inventories(_Objects):
return [val.summarize_states(**kwargs) for key, val in self.items()] return [val.summarize_states(**kwargs) for key, val in self.items()]
class DropOffLocations(Collection): class DropOffLocations(Collection):
_entity = DropOffLocation _entity = DropOffLocation

View File

@ -1,4 +1,4 @@
DROP_OFF_VALID: float = 0.1 DROP_OFF_VALID: float = 0.1
DROP_OFF_FAIL: float = -0.1 DROP_OFF_FAIL: float = -0.1
PICK_UP_FAIL: float = -0.1 PICK_UP_FAIL: float = -0.1
PICK_UP_VALID: float = 0.1 PICK_UP_VALID: float = 0.1

View File

@ -1,9 +1,10 @@
from typing import Union from typing import Union
import marl_factory_grid.modules.machines.constants
from marl_factory_grid.environment.actions import Action from marl_factory_grid.environment.actions import Action
from marl_factory_grid.utils.results import ActionResult from marl_factory_grid.utils.results import ActionResult
from marl_factory_grid.modules.machines import constants as m, rewards as r from marl_factory_grid.modules.machines import constants as m
from marl_factory_grid.environment import constants as c from marl_factory_grid.environment import constants as c
from marl_factory_grid.utils import helpers as h from marl_factory_grid.utils import helpers as h
@ -16,8 +17,10 @@ class MachineAction(Action):
def do(self, entity, state) -> Union[None, ActionResult]: def do(self, entity, state) -> Union[None, ActionResult]:
if machine := h.get_first(state[m.MACHINES].by_pos(entity.pos)): if machine := h.get_first(state[m.MACHINES].by_pos(entity.pos)):
if valid := machine.maintain(): if valid := machine.maintain():
return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=r.MAINTAIN_VALID) return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=marl_factory_grid.modules.machines.constants.MAINTAIN_VALID)
else: else:
return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=r.MAINTAIN_FAIL) return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=marl_factory_grid.modules.machines.constants.MAINTAIN_FAIL)
else: else:
return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=r.MAINTAIN_FAIL) return ActionResult(entity=entity, identifier=self._identifier,
validity=c.NOT_VALID, reward=marl_factory_grid.modules.machines.constants.MAINTAIN_FAIL
)

View File

@ -11,3 +11,7 @@ STATE_MAINTAIN = 'maintenance'
SYMBOL_WORK = 1 SYMBOL_WORK = 1
SYMBOL_IDLE = 0.6 SYMBOL_IDLE = 0.6
SYMBOL_MAINTAIN = 0.3 SYMBOL_MAINTAIN = 0.3
MAINTAIN_VALID: float = 0.5
MAINTAIN_FAIL: float = -0.1
FAIL_MISSING_MAINTENANCE: float = -0.5
NONE: float = 0

View File

@ -31,11 +31,10 @@ class Machine(Entity):
return c.NOT_VALID return c.NOT_VALID
def tick(self, state): def tick(self, state):
# if self.status == m.STATE_MAINTAIN and any([c.AGENT in x.name for x in self.tile.guests]): others = state.entities.pos_dict[self.pos]
if self.status == m.STATE_MAINTAIN and any([c.AGENT in x.name for x in state.entities.pos_dict[self.pos]]): if self.status == m.STATE_MAINTAIN and any([c.AGENT in x.name for x in others]):
return TickResult(identifier=self.name, validity=c.VALID, entity=self) return TickResult(identifier=self.name, validity=c.VALID, entity=self)
# elif self.status == m.STATE_MAINTAIN and not any([c.AGENT in x.name for x in self.tile.guests]): elif self.status == m.STATE_MAINTAIN and not any([c.AGENT in x.name for x in others]):
elif self.status == m.STATE_MAINTAIN and not any([c.AGENT in x.name for x in state.entities.pos_dict[self.pos]]):
self.status = m.STATE_WORK self.status = m.STATE_WORK
self.reset_counter() self.reset_counter()
return None return None

View File

@ -1,5 +1,3 @@
from typing import Union, List, Tuple
from marl_factory_grid.environment.groups.collection import Collection from marl_factory_grid.environment.groups.collection import Collection
from .entitites import Machine from .entitites import Machine

View File

@ -1,5 +0,0 @@
MAINTAIN_VALID: float = 0.5
MAINTAIN_FAIL: float = -0.1
FAIL_MISSING_MAINTENANCE: float = -0.5
NONE: float = 0

View File

@ -1,3 +1,4 @@
MAINTAINER = 'Maintainer' # TEMPLATE _identifier. Define your own! MAINTAINER = 'Maintainer' # TEMPLATE _identifier. Define your own!
MAINTAINERS = 'Maintainers' # TEMPLATE _identifier. Define your own! MAINTAINERS = 'Maintainers' # TEMPLATE _identifier. Define your own!
MAINTAINER_COLLISION_REWARD = -5

View File

@ -4,7 +4,6 @@ from marl_factory_grid.environment.groups.collection import Collection
from .entities import Maintainer from .entities import Maintainer
from ..machines import constants as mc from ..machines import constants as mc
from ..machines.actions import MachineAction from ..machines.actions import MachineAction
from ...utils.states import Gamestate
class Maintainers(Collection): class Maintainers(Collection):
@ -23,8 +22,6 @@ class Maintainers(Collection):
self.size = size self.size = size
self._spawnrule = spawnrule self._spawnrule = spawnrule
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args): def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args):
self.add_items([self._entity(mc.MACHINES, MachineAction(), pos) for pos in coords_or_quantity]) self.add_items([self._entity(mc.MACHINES, MachineAction(), pos) for pos in coords_or_quantity])

View File

@ -1 +0,0 @@
MAINTAINER_COLLISION_REWARD = -5

View File

@ -1,15 +1,16 @@
from typing import List from typing import List
import marl_factory_grid.modules.maintenance.constants
from marl_factory_grid.environment.rules import Rule from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.utils.results import TickResult, DoneResult from marl_factory_grid.utils.results import TickResult, DoneResult
from marl_factory_grid.environment import constants as c from marl_factory_grid.environment import constants as c
from . import rewards as r
from . import constants as M from . import constants as M
class MoveMaintainers(Rule): class MoveMaintainers(Rule):
def __init__(self, *args, **kwargs): def __init__(self):
super().__init__(*args, **kwargs) super().__init__()
def tick_step(self, state) -> List[TickResult]: def tick_step(self, state) -> List[TickResult]:
for maintainer in state[M.MAINTAINERS]: for maintainer in state[M.MAINTAINERS]:
@ -20,8 +21,8 @@ class MoveMaintainers(Rule):
class DoneAtMaintainerCollision(Rule): class DoneAtMaintainerCollision(Rule):
def __init__(self, *args, **kwargs): def __init__(self):
super().__init__(*args, **kwargs) super().__init__()
def on_check_done(self, state) -> List[DoneResult]: def on_check_done(self, state) -> List[DoneResult]:
agents = list(state[c.AGENT].values()) agents = list(state[c.AGENT].values())
@ -30,5 +31,5 @@ class DoneAtMaintainerCollision(Rule):
for agent in agents: for agent in agents:
if agent.pos in m_pos: if agent.pos in m_pos:
done_results.append(DoneResult(entity=agent, validity=c.VALID, identifier=self.name, done_results.append(DoneResult(entity=agent, validity=c.VALID, identifier=self.name,
reward=r.MAINTAINER_COLLISION_REWARD)) reward=marl_factory_grid.modules.maintenance.constants.MAINTAINER_COLLISION_REWARD))
return done_results return done_results

View File

@ -1,10 +1,10 @@
import random import random
from typing import List, Tuple from typing import List, Tuple
from marl_factory_grid.environment.entity.object import _Object from marl_factory_grid.environment.entity.object import Object
class Zone(_Object): class Zone(Object):
@property @property
def positions(self): def positions(self):

View File

@ -1,8 +1,8 @@
from marl_factory_grid.environment.groups.objects import _Objects from marl_factory_grid.environment.groups.objects import Objects
from marl_factory_grid.modules.zones import Zone from marl_factory_grid.modules.zones import Zone
class Zones(_Objects): class Zones(Objects):
symbol = None symbol = None
_entity = Zone _entity = Zone

View File

@ -58,7 +58,10 @@ class FactoryConfigParser(object):
return str(self.config) return str(self.config)
def __getitem__(self, item): def __getitem__(self, item):
return self.config[item] try:
return self.config[item]
except KeyError:
print(f'The mandatory {item} section could not be found in your .config gile. Check Spelling!')
def load_entities(self): def load_entities(self):
entity_classes = dict() entity_classes = dict()
@ -161,7 +164,6 @@ class FactoryConfigParser(object):
def _load_smth(self, config, class_obj): def _load_smth(self, config, class_obj):
rules = list() rules = list()
rules_names = list()
for rule in config: for rule in config:
e1 = e2 = e3 = None e1 = e2 = e3 = None
try: try:

View File

@ -61,8 +61,8 @@ class ObservationTranslator:
:param per_agent_named_obs_spaces: `Named observation space` one for each agent. Overloaded. :param per_agent_named_obs_spaces: `Named observation space` one for each agent. Overloaded.
type per_agent_named_obs_spaces: Dict[str, dict] type per_agent_named_obs_spaces: Dict[str, dict]
:param placeholder_fill_value: Currently not fully implemented!!! :param placeholder_fill_value: Currently, not fully implemented!!!
:type placeholder_fill_value: Union[int, str] = 'N') :type placeholder_fill_value: Union[int, str] = 'N'
""" """
if isinstance(placeholder_fill_value, str): if isinstance(placeholder_fill_value, str):

View File

@ -9,7 +9,7 @@ from marl_factory_grid.utils.helpers import IGNORED_DF_COLUMNS
import pandas as pd import pandas as pd
from marl_factory_grid.utils.plotting.compare_runs import plot_single_run from marl_factory_grid.utils.plotting.plot_single_runs import plot_single_run
class EnvMonitor(Wrapper): class EnvMonitor(Wrapper):
@ -22,7 +22,6 @@ class EnvMonitor(Wrapper):
self._monitor_df = pd.DataFrame() self._monitor_df = pd.DataFrame()
self._monitor_dict = dict() self._monitor_dict = dict()
def step(self, action): def step(self, action):
obs_type, obs, reward, done, info = self.env.step(action) obs_type, obs, reward, done, info = self.env.step(action)
self._read_info(info) self._read_info(info)

View File

@ -2,11 +2,9 @@ from os import PathLike
from pathlib import Path from pathlib import Path
from typing import Union, List from typing import Union, List
import yaml
from gymnasium import Wrapper
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from gymnasium import Wrapper
class EnvRecorder(Wrapper): class EnvRecorder(Wrapper):
@ -106,7 +104,7 @@ class EnvRecorder(Wrapper):
out_dict = {'episodes': self._recorder_out_list} out_dict = {'episodes': self._recorder_out_list}
out_dict.update( out_dict.update(
{'n_episodes': self._curr_episode, {'n_episodes': self._curr_episode,
'metadata':dict( 'metadata': dict(
level_name=self.env.params['General']['level_name'], level_name=self.env.params['General']['level_name'],
verbose=False, verbose=False,
n_agents=len(self.env.params['Agents']), n_agents=len(self.env.params['Agents']),

View File

@ -5,7 +5,7 @@ from typing import Dict, List
import numpy as np import numpy as np
from marl_factory_grid.environment import constants as c from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.entity.object import _Object from marl_factory_grid.environment.entity.object import Object
from marl_factory_grid.environment.groups.utils import Combined from marl_factory_grid.environment.groups.utils import Combined
from marl_factory_grid.utils.utility_classes import Floor from marl_factory_grid.utils.utility_classes import Floor
from marl_factory_grid.utils.ray_caster import RayCaster from marl_factory_grid.utils.ray_caster import RayCaster
@ -13,7 +13,6 @@ from marl_factory_grid.utils.states import Gamestate
from marl_factory_grid.utils import helpers as h from marl_factory_grid.utils import helpers as h
class OBSBuilder(object): class OBSBuilder(object):
default_obs = [c.WALLS, c.OTHERS] default_obs = [c.WALLS, c.OTHERS]
@ -128,7 +127,7 @@ class OBSBuilder(object):
f'{re.escape("[")}(.*){re.escape("]")}' f'{re.escape("[")}(.*){re.escape("]")}'
f'{re.escape("(")}{re.escape(agent.name)}{re.escape(")")}') f'{re.escape("(")}{re.escape(agent.name)}{re.escape(")")}')
name = next((key for key, val in self.all_obs.items() name = next((key for key, val in self.all_obs.items()
if pattern.search(str(val)) and isinstance(val, _Object)), None) if pattern.search(str(val)) and isinstance(val, Object)), None)
e = self.all_obs[name] e = self.all_obs[name]
except KeyError: except KeyError:
try: try:
@ -181,11 +180,11 @@ class OBSBuilder(object):
return obs, self.obs_layers[agent.name] return obs, self.obs_layers[agent.name]
def _sort_and_name_observation_conf(self, agent): def _sort_and_name_observation_conf(self, agent):
''' """
Builds the useable observation scheme per agent from conf.yaml. Builds the useable observation scheme per agent from conf.yaml.
:param agent: :param agent:
:return: :return:
''' """
# Fixme: no asymetric shapes possible. # Fixme: no asymetric shapes possible.
self.ray_caster[agent.name] = RayCaster(agent, min(self.obs_shape)) self.ray_caster[agent.name] = RayCaster(agent, min(self.obs_shape))
obs_layers = [] obs_layers = []

View File

@ -7,50 +7,11 @@ from typing import Union, List
import pandas as pd import pandas as pd
from marl_factory_grid.utils.helpers import IGNORED_DF_COLUMNS from marl_factory_grid.utils.helpers import IGNORED_DF_COLUMNS
from marl_factory_grid.utils.plotting.plotting import prepare_plot from marl_factory_grid.utils.plotting.plotting_utils import prepare_plot
MODEL_MAP = None MODEL_MAP = None
def plot_single_run(run_path: Union[str, PathLike], use_tex: bool = False, column_keys=None):
run_path = Path(run_path)
df_list = list()
if run_path.is_dir():
monitor_file = next(run_path.glob('*monitor*.pick'))
elif run_path.exists() and run_path.is_file():
monitor_file = run_path
else:
raise ValueError
with monitor_file.open('rb') as f:
monitor_df = pickle.load(f)
monitor_df = monitor_df.fillna(0)
df_list.append(monitor_df)
df = pd.concat(df_list, ignore_index=True)
df = df.fillna(0).rename(columns={'episode': 'Episode'}).sort_values(['Episode'])
if column_keys is not None:
columns = [col for col in column_keys if col in df.columns]
else:
columns = [col for col in df.columns if col not in IGNORED_DF_COLUMNS]
roll_n = 50
non_overlapp_window = df.groupby(['Episode']).rolling(roll_n, min_periods=1).mean()
df_melted = df[columns + ['Episode']].reset_index().melt(
id_vars=['Episode'], value_vars=columns, var_name="Measurement", value_name="Score"
)
if df_melted['Episode'].max() > 800:
skip_n = round(df_melted['Episode'].max() * 0.02)
df_melted = df_melted[df_melted['Episode'] % skip_n == 0]
prepare_plot(run_path.parent / f'{run_path.parent.name}_monitor_lineplot.png', df_melted, use_tex=use_tex)
print('Plotting done.')
def compare_seed_runs(run_path: Union[str, PathLike], use_tex: bool = False): def compare_seed_runs(run_path: Union[str, PathLike], use_tex: bool = False):
run_path = Path(run_path) run_path = Path(run_path)
df_list = list() df_list = list()

View File

@ -0,0 +1,48 @@
import pickle
from os import PathLike
from pathlib import Path
from typing import Union
import pandas as pd
from marl_factory_grid.utils.helpers import IGNORED_DF_COLUMNS
from marl_factory_grid.utils.plotting.plotting_utils import prepare_plot
def plot_single_run(run_path: Union[str, PathLike], use_tex: bool = False, column_keys=None,
file_key: str ='monitor', file_ext: str ='pkl'):
run_path = Path(run_path)
df_list = list()
if run_path.is_dir():
monitor_file = next(run_path.glob(f'*{file_key}*.{file_ext}'))
elif run_path.exists() and run_path.is_file():
monitor_file = run_path
else:
raise ValueError
with monitor_file.open('rb') as f:
monitor_df = pickle.load(f)
monitor_df = monitor_df.fillna(0)
df_list.append(monitor_df)
df = pd.concat(df_list, ignore_index=True)
df = df.fillna(0).rename(columns={'episode': 'Episode'}).sort_values(['Episode'])
if column_keys is not None:
columns = [col for col in column_keys if col in df.columns]
else:
columns = [col for col in df.columns if col not in IGNORED_DF_COLUMNS]
# roll_n = 50
# non_overlapp_window = df.groupby(['Episode']).rolling(roll_n, min_periods=1).mean()
df_melted = df[columns + ['Episode']].reset_index().melt(
id_vars=['Episode'], value_vars=columns, var_name="Measurement", value_name="Score"
)
if df_melted['Episode'].max() > 800:
skip_n = round(df_melted['Episode'].max() * 0.02)
df_melted = df_melted[df_melted['Episode'] % skip_n == 0]
prepare_plot(run_path.parent / f'{run_path.parent.name}_monitor_lineplot.png', df_melted, use_tex=use_tex)
print('Plotting done.')

View File

@ -60,7 +60,7 @@ def prepare_center_double_column_legend(df, hue, style, hue_order):
print('Struggling to plot Figure using LaTeX - going back to normal.') print('Struggling to plot Figure using LaTeX - going back to normal.')
plt.close('all') plt.close('all')
sns.set(rc={'text.usetex': False}, style='whitegrid') sns.set(rc={'text.usetex': False}, style='whitegrid')
fig = plt.figure(figsize=(10, 11)) _ = plt.figure(figsize=(10, 11))
lineplot = sns.lineplot(data=df, x='Episode', y='Score', hue=hue, style=style, lineplot = sns.lineplot(data=df, x='Episode', y='Score', hue=hue, style=style,
ci=95, palette=PALETTE, hue_order=hue_order, legend=False) ci=95, palette=PALETTE, hue_order=hue_order, legend=False)
# plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0) # plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0)

View File

@ -19,7 +19,7 @@ class RayCaster:
return f'{self.__class__.__name__}({self.agent.name})' return f'{self.__class__.__name__}({self.agent.name})'
def build_ray_targets(self): def build_ray_targets(self):
north = np.array([0, -1])*self.pomdp_r north = np.array([0, -1]) * self.pomdp_r
thetas = [np.deg2rad(deg) for deg in np.linspace(-self.degs // 2, self.degs // 2, self.n_rays)[::-1]] thetas = [np.deg2rad(deg) for deg in np.linspace(-self.degs // 2, self.degs // 2, self.n_rays)[::-1]]
rot_M = [ rot_M = [
[[math.cos(theta), -math.sin(theta)], [[math.cos(theta), -math.sin(theta)],
@ -53,9 +53,9 @@ class RayCaster:
diag_hits = all([ diag_hits = all([
self.ray_block_cache( self.ray_block_cache(
key, key,
lambda: any(True for e in pos_dict[key] if e.var_is_blocking_light)) # lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light)
# lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light)) lambda: any(True for e in pos_dict[key] if e.var_is_blocking_light))
for key in ((x, y-cy), (x-cx, y)) for key in ((x, y - cy), (x - cx, y))
]) if (cx != 0 and cy != 0) else False ]) if (cx != 0 and cy != 0) else False
visible += entities_hit if not diag_hits else [] visible += entities_hit if not diag_hits else []
@ -77,8 +77,8 @@ class RayCaster:
agent = self.agent agent = self.agent
x_coords = range(agent.x - self.pomdp_r, agent.x + self.pomdp_r + 1) x_coords = range(agent.x - self.pomdp_r, agent.x + self.pomdp_r + 1)
y_coords = range(agent.y - self.pomdp_r, agent.y + self.pomdp_r + 1) y_coords = range(agent.y - self.pomdp_r, agent.y + self.pomdp_r + 1)
outline = list(product(x_coords, [agent.y - self.pomdp_r, agent.y + self.pomdp_r])) \ outline = list(product(x_coords, [agent.y - self.pomdp_r, agent.y + self.pomdp_r]))
+ list(product([agent.x - self.pomdp_r, agent.x + self.pomdp_r], y_coords)) outline += list(product([agent.x - self.pomdp_r, agent.x + self.pomdp_r], y_coords))
return outline return outline
@staticmethod @staticmethod

View File

@ -1,9 +1,12 @@
from typing import Union from typing import Union
from dataclasses import dataclass from dataclasses import dataclass
from marl_factory_grid.environment.entity.object import Object
TYPE_VALUE = 'value' TYPE_VALUE = 'value'
TYPE_REWARD = 'reward' TYPE_REWARD = 'reward'
types = [TYPE_VALUE, TYPE_REWARD] TYPES = [TYPE_VALUE, TYPE_REWARD]
@dataclass @dataclass
class InfoObject: class InfoObject:
@ -18,12 +21,13 @@ class Result:
validity: bool validity: bool
reward: Union[float, None] = None reward: Union[float, None] = None
value: Union[float, None] = None value: Union[float, None] = None
entity: None = None entity: Object = None
def get_infos(self): def get_infos(self):
n = self.entity.name if self.entity is not None else "Global" n = self.entity.name if self.entity is not None else "Global"
return [InfoObject(identifier=f'{n}_{self.identifier}_{t}', # Return multiple Info Dicts
val_type=t, value=self.__getattribute__(t)) for t in types return [InfoObject(identifier=f'{n}_{self.identifier}',
val_type=t, value=self.__getattribute__(t)) for t in TYPES
if self.__getattribute__(t) is not None] if self.__getattribute__(t) is not None]
def __repr__(self): def __repr__(self):
@ -31,7 +35,7 @@ class Result:
reward = f" | Reward: {self.reward}" if self.reward is not None else "" reward = f" | Reward: {self.reward}" if self.reward is not None else ""
value = f" | Value: {self.value}" if self.value is not None else "" value = f" | Value: {self.value}" if self.value is not None else ""
entity = f" | by: {self.entity.name}" if self.entity is not None else "" entity = f" | by: {self.entity.name}" if self.entity is not None else ""
return f'{self.__class__.__name__}({self.identifier.capitalize()} {valid}valid{reward}{value})' return f'{self.__class__.__name__}({self.identifier.capitalize()} {valid}valid{reward}{value}{entity})'
@dataclass @dataclass

View File

@ -1,11 +1,12 @@
from itertools import islice from itertools import islice
from typing import List, Dict, Tuple from typing import List, Tuple
import numpy as np import numpy as np
from marl_factory_grid.environment import constants as c from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.entity.entity import Entity
from marl_factory_grid.environment.rules import Rule from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.utils.results import Result from marl_factory_grid.utils.results import Result, DoneResult
class StepRules: class StepRules:
@ -83,13 +84,51 @@ class Gamestate(object):
return f'{self.__class__.__name__}({len(self.entities)} Entitites @ Step {self.curr_step})' return f'{self.__class__.__name__}({len(self.entities)} Entitites @ Step {self.curr_step})'
@property @property
def random_free_position(self): def random_free_position(self) -> (int, int):
"""
Returns a single **free** position (x, y), which is **free** for spawning or walking.
No Entity at this position posses *var_is_blocking_pos* or *var_can_collide*.
:return: Single **free** position.
"""
return self.get_n_random_free_positions(1)[0] return self.get_n_random_free_positions(1)[0]
def get_n_random_free_positions(self, n): def get_n_random_free_positions(self, n) -> list[tuple[int, int]]:
"""
Returns a list of *n* **free** positions [(x, y), ... ], which are **free** for spawning or walking.
No Entity at this position posses *var_is_blocking_pos* or *var_can_collide*.
:return: List of n **free** position.
"""
return list(islice(self.entities.free_positions_generator, n)) return list(islice(self.entities.free_positions_generator, n))
def tick(self, actions) -> List[Result]: @property
def random_position(self) -> (int, int):
"""
Returns a single available position (x, y), ignores all entity attributes.
:return: Single random position.
"""
return self.get_n_random_positions(1)[0]
def get_n_random_positions(self, n) -> list[tuple[int, int]]:
"""
Returns a list of *n* available positions [(x, y), ... ], ignores all entity attributes.
:return: List of n random positions.
"""
return list(islice(self.entities.floorlist, n))
def tick(self, actions) -> list[Result]:
"""
Performs a single **Gamestate Tick**by calling the inner rule hooks in sequential order.
- tick_pre_step_all: Things to do before the agents do their actions. Statechange, Moving, Spawning etc...
- agent tick: Agents do their actions.
- tick_step_all: Things to do after the agents did their actions. Statechange, Moving, Spawning etc...
- tick_post_step_all: Things to do at the very end of each step. Counting, Reward calculations etc...
:return: List of *Result*-objects.
"""
results = list() results = list()
self.curr_step += 1 self.curr_step += 1
@ -112,11 +151,23 @@ class Gamestate(object):
return results return results
def print(self, string): def print(self, string) -> None:
"""
When *verbose* is active, print stuff.
:param string: *String* to print.
:type string: str
:return: Nothing
"""
if self.verbose: if self.verbose:
print(string) print(string)
def check_done(self): def check_done(self) -> List[DoneResult]:
"""
Iterate all **Rules** that override tehe *on_ckeck_done* hook.
:return: List of Results
"""
results = list() results = list()
for rule in self.rules: for rule in self.rules:
if on_check_done_result := rule.on_check_done(self): if on_check_done_result := rule.on_check_done(self):
@ -124,20 +175,44 @@ class Gamestate(object):
return results return results
def get_all_pos_with_collisions(self) -> List[Tuple[(int, int)]]: def get_all_pos_with_collisions(self) -> List[Tuple[(int, int)]]:
positions = [pos for pos, entities in self.entities.pos_dict.items() if len(entities) >= 2 and (len([e for e in entities if e.var_can_collide]) >= 2)] """
Returns a list positions [(x, y), ... ] on which collisions occur. This does not include agents,
that were unable to move because their target direction was blocked, also a form of collision.
:return: List of positions.
"""
positions = [pos for pos, entities in self.entities.pos_dict.items() if
len(entities) >= 2 and (len([e for e in entities if e.var_can_collide]) >= 2)
]
return positions return positions
def check_move_validity(self, moving_entity, position): def check_move_validity(self, moving_entity: Entity, target_position: (int, int)) -> bool:
if moving_entity.pos != position and not any( """
entity.var_is_blocking_pos for entity in self.entities.pos_dict[position]) and not ( Whether it is safe to move to the target positions and moving entity does not introduce a blocking attribute,
moving_entity.var_is_blocking_pos and self.entities.is_occupied(position)): when position is allready occupied.
return True
else:
return False
def check_pos_validity(self, position): :param moving_entity: Entity
if not any(entity.var_is_blocking_pos for entity in self.entities.pos_dict[position]): :param target_position: pos
return True :return: Safe to move to
else: """
return False
is_not_blocked = self.check_pos_validity(target_position)
will_not_block_others = moving_entity.var_is_blocking_pos and self.entities.is_occupied(target_position)
if moving_entity.pos != target_position and is_not_blocked and not will_not_block_others:
return c.VALID
else:
return c.NOT_VALID
def check_pos_validity(self, pos: (int, int)) -> bool:
"""
Check if *pos* is a valid position to move or spawn to.
:param pos: position to check
:return: Wheter pos is a valid target.
"""
if not any(e.var_is_blocking_pos for e in self.entities.pos_dict[pos]) and pos in self.entities.floorlist:
return c.VALID
else:
return c.NOT_VALID

View File

@ -28,7 +28,9 @@ class ConfigExplainer:
def explain_module(self, class_to_explain): def explain_module(self, class_to_explain):
parameters = inspect.signature(class_to_explain).parameters parameters = inspect.signature(class_to_explain).parameters
explained = {class_to_explain.__name__: {key: val.default for key, val in parameters.items() if key not in EXCLUDED}} explained = {class_to_explain.__name__:
{key: val.default for key, val in parameters.items() if key not in EXCLUDED}
}
return explained return explained
def _load_and_compare(self, compare_class, paths): def _load_and_compare(self, compare_class, paths):

View File

@ -6,18 +6,21 @@ from marl_factory_grid.environment.factory import Factory
from marl_factory_grid.utils.logging.envmonitor import EnvMonitor from marl_factory_grid.utils.logging.envmonitor import EnvMonitor
from marl_factory_grid.utils.logging.recorder import EnvRecorder from marl_factory_grid.utils.logging.recorder import EnvRecorder
from marl_factory_grid.utils.plotting.plot_single_runs import plot_single_run
from marl_factory_grid.utils.tools import ConfigExplainer from marl_factory_grid.utils.tools import ConfigExplainer
if __name__ == '__main__': if __name__ == '__main__':
# Render at each step? # Render at each step?
render = True render = False
# Reveal all possible Modules (Entities, Rules, Agents[Actions, Observations], etc.) # Reveal all possible Modules (Entities, Rules, Agents[Actions, Observations], etc.)
explain_config = False explain_config = False
# Collect statistics? # Collect statistics?
monitor = False monitor = True
# Record as Protobuf? # Record as Protobuf?
record = False record = False
# Plot Results?
plotting = True
run_path = Path('study_out') run_path = Path('study_out')
@ -38,7 +41,7 @@ if __name__ == '__main__':
factory = EnvRecorder(factory) factory = EnvRecorder(factory)
# RL learn Loop # RL learn Loop
for episode in trange(500): for episode in trange(10):
_ = factory.reset() _ = factory.reset()
done = False done = False
if render: if render:
@ -54,7 +57,10 @@ if __name__ == '__main__':
break break
if monitor: if monitor:
factory.save_run(run_path / 'test.pkl') factory.save_run(run_path / 'test_monitor.pkl')
if record: if record:
factory.save_records(run_path / 'test.pb') factory.save_records(run_path / 'test.pb')
if plotting:
plot_single_run(run_path)
print('Done!!! Goodbye....') print('Done!!! Goodbye....')

View File

@ -56,6 +56,7 @@ if __name__ == '__main__':
for model_idx, model in enumerate(models)] for model_idx, model in enumerate(models)]
else: else:
actions = models[0].predict(env_state, deterministic=determin)[0] actions = models[0].predict(env_state, deterministic=determin)[0]
# noinspection PyTupleAssignmentBalance
env_state, step_r, done_bool, info_obj = env.step(actions) env_state, step_r, done_bool, info_obj = env.step(actions)
rew += step_r rew += step_r

View File

@ -5,7 +5,6 @@ from pathlib import Path
if __name__ == '__main__': if __name__ == '__main__':
conf_path = Path('wg0') conf_path = Path('wg0')
wg0_conf = configparser.ConfigParser() wg0_conf = configparser.ConfigParser()
wg0_conf.read(conf_path/'wg0.conf') wg0_conf.read(conf_path/'wg0.conf')
@ -17,7 +16,6 @@ if __name__ == '__main__':
# Delete any old conf.json for the current peer # Delete any old conf.json for the current peer
(conf_path / f'{client_name}.json').unlink(missing_ok=True) (conf_path / f'{client_name}.json').unlink(missing_ok=True)
peer = wg0_conf[client_name] peer = wg0_conf[client_name]
date_time = datetime.now().strftime('%Y-%m-%dT%H:%M:%S.%f000Z') date_time = datetime.now().strftime('%Y-%m-%dT%H:%M:%S.%f000Z')