mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-22 14:56:43 +02:00
Resolved some warnings and style issues
This commit is contained in:
parent
a9462a8b6f
commit
6711a0976b
5
.idea/.gitignore
generated
vendored
Normal file
5
.idea/.gitignore
generated
vendored
Normal file
@ -0,0 +1,5 @@
|
||||
# Default ignored files
|
||||
/shelf/
|
||||
/workspace.xml
|
||||
# Editor-based HTTP Client requests
|
||||
/httpRequests/
|
@ -94,7 +94,7 @@ All [Entites](marl_factory_grid/environment/entity/global_entities.py) are avail
|
||||
|
||||
|
||||
#### Rules
|
||||
[Rules](marl_factory_grid/environment/entity/object.py) define how the environment behaves on micro-scale.
|
||||
[Rules](marl_factory_grid/environment/entity/object.py) define how the environment behaves on microscale.
|
||||
Each of the hookes (`on_init`, `pre_step`, `on_step`, '`post_step`', `on_done`)
|
||||
provide env-access to implement customn logic, calculate rewards, or gather information.
|
||||
|
||||
@ -107,6 +107,7 @@ Make sure to bring your own assets for each Entity living in the Gridworld as th
|
||||
PNG-files (transparent background) of square aspect-ratio should do the job, in general.
|
||||
|
||||
<img src="/marl_factory_grid/environment/assets/wall.png" width="5%">
|
||||
<!--suppress HtmlUnknownAttribute -->
|
||||
<html      html>
|
||||
<img src="/marl_factory_grid/environment/assets/agent/agent.png" width="5%">
|
||||
|
||||
|
@ -1 +1,4 @@
|
||||
import os, sys; sys.path.append(os.path.dirname(os.path.realpath(__file__)))
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.realpath(__file__)))
|
||||
|
@ -28,6 +28,7 @@ class Names:
|
||||
BATCH_SIZE = 'bnatch_size'
|
||||
N_ACTIONS = 'n_actions'
|
||||
|
||||
|
||||
nms = Names
|
||||
ListOrTensor = Union[List, torch.Tensor]
|
||||
|
||||
@ -112,10 +113,9 @@ class BaseActorCritic:
|
||||
next_obs, reward, done, info = env.step(action)
|
||||
done = [done] * self.n_agents if isinstance(done, bool) else done
|
||||
|
||||
last_hiddens = dict(hidden_actor =out[nms.HIDDEN_ACTOR],
|
||||
last_hiddens = dict(hidden_actor=out[nms.HIDDEN_ACTOR],
|
||||
hidden_critic=out[nms.HIDDEN_CRITIC])
|
||||
|
||||
|
||||
tm.add(observation=obs, action=action, reward=reward, done=done,
|
||||
logits=out.get(nms.LOGITS, None), values=out.get(nms.CRITIC, None),
|
||||
**last_hiddens)
|
||||
@ -142,7 +142,9 @@ class BaseActorCritic:
|
||||
print(f'reward at episode: {episode} = {rew_log}')
|
||||
episode += 1
|
||||
df_results.append([episode, rew_log, *reward])
|
||||
df_results = pd.DataFrame(df_results, columns=['steps', 'reward', *[f'agent#{i}' for i in range(self.n_agents)]])
|
||||
df_results = pd.DataFrame(df_results,
|
||||
columns=['steps', 'reward', *[f'agent#{i}' for i in range(self.n_agents)]]
|
||||
)
|
||||
if checkpointer is not None:
|
||||
df_results.to_csv(checkpointer.path / 'results.csv', index=False)
|
||||
return df_results
|
||||
@ -157,24 +159,27 @@ class BaseActorCritic:
|
||||
last_action, reward = [-1] * self.n_agents, [0.] * self.n_agents
|
||||
done, rew_log, eps_rew = [False] * self.n_agents, 0, torch.zeros(self.n_agents)
|
||||
while not all(done):
|
||||
if render: env.render()
|
||||
if render:
|
||||
env.render()
|
||||
|
||||
out = self.forward(obs, last_action, **last_hiddens)
|
||||
action = self.get_actions(out)
|
||||
next_obs, reward, done, info = env.step(action)
|
||||
|
||||
if isinstance(done, bool): done = [done] * obs.shape[0]
|
||||
if isinstance(done, bool):
|
||||
done = [done] * obs.shape[0]
|
||||
obs = next_obs
|
||||
last_action = action
|
||||
last_hiddens = dict(hidden_actor=out.get(nms.HIDDEN_ACTOR, None),
|
||||
hidden_critic=out.get(nms.HIDDEN_CRITIC, None)
|
||||
)
|
||||
eps_rew += torch.tensor(reward)
|
||||
results.append(eps_rew.tolist() + [sum(eps_rew).item()] + [episode])
|
||||
results.append(eps_rew.tolist() + [np.sum(eps_rew).item()] + [episode])
|
||||
episode += 1
|
||||
agent_columns = [f'agent#{i}' for i in range(self.cfg['environment']['n_agents'])]
|
||||
results = pd.DataFrame(results, columns=agent_columns + ['sum', 'episode'])
|
||||
results = pd.melt(results, id_vars=['episode'], value_vars=agent_columns + ['sum'], value_name='reward', var_name='agent')
|
||||
results = pd.melt(results, id_vars=['episode'], value_vars=agent_columns + ['sum'],
|
||||
value_name='reward', var_name='agent')
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
|
@ -36,7 +36,7 @@ class LoopMAPPO(LoopSNAC):
|
||||
rewards_ = torch.stack(rewards_, dim=1)
|
||||
return rewards_
|
||||
|
||||
def mappo(self, batch, network, gamma, entropy_coef, vf_coef, clip_range, **kwargs):
|
||||
def mappo(self, batch, network, gamma, entropy_coef, vf_coef, clip_range, **__):
|
||||
out = network(batch[nms.OBSERVATION], batch[nms.ACTION], batch[nms.HIDDEN_ACTOR], batch[nms.HIDDEN_CRITIC])
|
||||
logits = out[nms.LOGITS][:, :-1] # last one only needed for v_{t+1}
|
||||
|
||||
@ -45,7 +45,7 @@ class LoopMAPPO(LoopSNAC):
|
||||
|
||||
# monte carlo returns
|
||||
mc_returns = self.monte_carlo_returns(batch[nms.REWARD], batch[nms.DONE], gamma)
|
||||
mc_returns = (mc_returns - mc_returns.mean()) / (mc_returns.std() + 1e-8) #todo: norm across agent ok?
|
||||
mc_returns = (mc_returns - mc_returns.mean()) / (mc_returns.std() + 1e-8) # todo: norm across agent ok?
|
||||
advantages = mc_returns - out[nms.CRITIC][:, :-1]
|
||||
|
||||
# policy loss
|
||||
|
@ -1,8 +1,7 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import spectral_norm
|
||||
|
||||
|
||||
class RecurrentAC(nn.Module):
|
||||
@ -88,8 +87,8 @@ class NormalizedLinear(nn.Linear):
|
||||
self.trainable_magnitude = trainable_magnitude
|
||||
self.scale = nn.Parameter(torch.tensor([1.]), requires_grad=trainable_magnitude)
|
||||
|
||||
def forward(self, input):
|
||||
normalized_input = F.normalize(input, dim=-1, p=2, eps=1e-5)
|
||||
def forward(self, in_array):
|
||||
normalized_input = F.normalize(in_array, dim=-1, p=2, eps=1e-5)
|
||||
normalized_weight = F.normalize(self.weight, dim=-1, p=2, eps=1e-5)
|
||||
return F.linear(normalized_input, normalized_weight) * self.d_sqrt * self.scale
|
||||
|
||||
|
@ -38,7 +38,6 @@ class LoopSEAC(LoopIAC):
|
||||
|
||||
a2c_loss = (-iw*log_ap * advantages.detach()).mean(-1)
|
||||
|
||||
|
||||
value_loss = (iw*advantages.pow(2)).mean(-1) # n_agent
|
||||
|
||||
# weighted loss
|
||||
|
@ -56,8 +56,8 @@ class TSPBaseAgent(ABC):
|
||||
|
||||
def _door_is_close(self, state):
|
||||
try:
|
||||
# return next(y for x in self.state.tile.neighboring_floor for y in x.guests if do.DOOR in y.name)
|
||||
return next(y for x in state.entities.neighboring_positions(self.state.pos) for y in state.entities.pos_dict[x] if do.DOOR in y.name)
|
||||
return next(y for x in state.entities.neighboring_positions(self.state.pos)
|
||||
for y in state.entities.pos_dict[x] if do.DOOR in y.name)
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
|
@ -14,8 +14,8 @@ class TSPTargetAgent(TSPBaseAgent):
|
||||
def _handle_doors(self, state):
|
||||
|
||||
try:
|
||||
# return next(y for x in self.state.tile.neighboring_floor for y in x.guests if do.DOOR in y.name)
|
||||
return next(y for x in state.entities.neighboring_positions(self.state.pos) for y in state.entities.pos_dict[x] if do.DOOR in y.name)
|
||||
return next(y for x in state.entities.neighboring_positions(self.state.pos)
|
||||
for y in state.entities.pos_dict[x] if do.DOOR in y.name)
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
|
@ -1,8 +1,9 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import yaml
|
||||
|
||||
|
||||
def load_class(classname):
|
||||
from importlib import import_module
|
||||
@ -42,7 +43,6 @@ def get_class(arguments):
|
||||
|
||||
|
||||
def get_arguments(arguments):
|
||||
from importlib import import_module
|
||||
d = dict(arguments)
|
||||
if "classname" in d:
|
||||
del d["classname"]
|
||||
|
@ -1,4 +1,4 @@
|
||||
eneral:
|
||||
General:
|
||||
# Your Seed
|
||||
env_seed: 69
|
||||
# Individual or global rewards?
|
||||
@ -86,4 +86,4 @@ Rules:
|
||||
DoneAtDestinationReachAll:
|
||||
# reward_at_done: 1
|
||||
DoneAtMaxStepsReached:
|
||||
max_steps: 500
|
||||
max_steps: 200
|
||||
|
@ -1,15 +1,14 @@
|
||||
import abc
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .object import _Object
|
||||
from .object import Object
|
||||
from .. import constants as c
|
||||
from ...utils.results import ActionResult
|
||||
from ...utils.utility_classes import RenderEntity
|
||||
|
||||
|
||||
class Entity(_Object, abc.ABC):
|
||||
class Entity(Object, abc.ABC):
|
||||
"""Full Env Entity that lives on the environment Grid. Doors, Items, DirtPile etc..."""
|
||||
|
||||
@property
|
||||
@ -96,8 +95,9 @@ class Entity(_Object, abc.ABC):
|
||||
|
||||
def __init__(self, pos, bind_to=None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._view_directory = c.VALUE_NO_POS
|
||||
self._status = None
|
||||
self.set_pos(pos)
|
||||
self._pos = pos
|
||||
self._last_pos = pos
|
||||
if bind_to:
|
||||
try:
|
||||
@ -113,10 +113,6 @@ class Entity(_Object, abc.ABC):
|
||||
def render(self):
|
||||
return RenderEntity(self.__class__.__name__.lower(), self.pos)
|
||||
|
||||
@abc.abstractmethod
|
||||
def render(self):
|
||||
return RenderEntity(self.__class__.__name__.lower(), self.pos)
|
||||
|
||||
@property
|
||||
def obs_tag(self):
|
||||
try:
|
||||
@ -133,25 +129,3 @@ class Entity(_Object, abc.ABC):
|
||||
self._collection.delete_env_object(self)
|
||||
self._collection = other_collection
|
||||
return self._collection == other_collection
|
||||
|
||||
@classmethod
|
||||
def from_coordinates(cls, positions: [(int, int)], *args, entity_kwargs=None, **kwargs, ):
|
||||
collection = cls(*args, **kwargs)
|
||||
collection.add_items(
|
||||
[cls._entity(tuple(pos), **entity_kwargs if entity_kwargs is not None else {}) for pos in positions])
|
||||
return collection
|
||||
|
||||
def notify_del_entity(self, entity):
|
||||
try:
|
||||
self.pos_dict[entity.pos].remove(entity)
|
||||
except (ValueError, AttributeError):
|
||||
pass
|
||||
|
||||
def by_pos(self, pos: (int, int)):
|
||||
pos = tuple(pos)
|
||||
try:
|
||||
return self.state.entities.pos_dict[pos]
|
||||
except StopIteration:
|
||||
pass
|
||||
except ValueError:
|
||||
pass
|
||||
|
@ -5,7 +5,7 @@ from marl_factory_grid.environment import constants as c
|
||||
import marl_factory_grid.utils.helpers as h
|
||||
|
||||
|
||||
class _Object:
|
||||
class Object:
|
||||
"""Generell Objects for Organisation and Maintanance such as Actions etc..."""
|
||||
|
||||
_u_idx = defaultdict(lambda: 0)
|
||||
@ -56,7 +56,7 @@ class _Object:
|
||||
try:
|
||||
if self.var_has_position:
|
||||
name = h.add_pos_name(name, self)
|
||||
except (AttributeError):
|
||||
except AttributeError:
|
||||
pass
|
||||
return name
|
||||
|
||||
@ -67,8 +67,8 @@ class _Object:
|
||||
return hash(self.identifier)
|
||||
|
||||
def _identify_and_count_up(self):
|
||||
idx = _Object._u_idx[self.__class__.__name__]
|
||||
_Object._u_idx[self.__class__.__name__] += 1
|
||||
idx = Object._u_idx[self.__class__.__name__]
|
||||
Object._u_idx[self.__class__.__name__] += 1
|
||||
return idx
|
||||
|
||||
def set_collection(self, collection):
|
||||
@ -98,79 +98,3 @@ class _Object:
|
||||
|
||||
def unbind(self):
|
||||
self._bound_entity = None
|
||||
|
||||
|
||||
# class EnvObject(_Object):
|
||||
# """Objects that hold Information that are observable, but have no position on the environment grid. Inventories etc..."""
|
||||
#
|
||||
# _u_idx = defaultdict(lambda: 0)
|
||||
#
|
||||
# @property
|
||||
# def obs_tag(self):
|
||||
# try:
|
||||
# return self._collection.name or self.name
|
||||
# except AttributeError:
|
||||
# return self.name
|
||||
#
|
||||
# @property
|
||||
# def var_is_blocking_light(self):
|
||||
# try:
|
||||
# return self._collection.var_is_blocking_light or False
|
||||
# except AttributeError:
|
||||
# return False
|
||||
#
|
||||
# @property
|
||||
# def var_can_be_bound(self):
|
||||
# try:
|
||||
# return self._collection.var_can_be_bound or False
|
||||
# except AttributeError:
|
||||
# return False
|
||||
#
|
||||
# @property
|
||||
# def var_can_move(self):
|
||||
# try:
|
||||
# return self._collection.var_can_move or False
|
||||
# except AttributeError:
|
||||
# return False
|
||||
#
|
||||
# @property
|
||||
# def var_is_blocking_pos(self):
|
||||
# try:
|
||||
# return self._collection.var_is_blocking_pos or False
|
||||
# except AttributeError:
|
||||
# return False
|
||||
#
|
||||
# @property
|
||||
# def var_has_position(self):
|
||||
# try:
|
||||
# return self._collection.var_has_position or False
|
||||
# except AttributeError:
|
||||
# return False
|
||||
#
|
||||
# @property
|
||||
# def var_can_collide(self):
|
||||
# try:
|
||||
# return self._collection.var_can_collide or False
|
||||
# except AttributeError:
|
||||
# return False
|
||||
#
|
||||
#
|
||||
# @property
|
||||
# def encoding(self):
|
||||
# return c.VALUE_OCCUPIED_CELL
|
||||
#
|
||||
#
|
||||
# def __init__(self, **kwargs):
|
||||
# self._bound_entity = None
|
||||
# super(EnvObject, self).__init__(**kwargs)
|
||||
#
|
||||
#
|
||||
# def change_parent_collection(self, other_collection):
|
||||
# other_collection.add_item(self)
|
||||
# self._collection.delete_env_object(self)
|
||||
# self._collection = other_collection
|
||||
# return self._collection == other_collection
|
||||
#
|
||||
#
|
||||
# def summarize_state(self):
|
||||
# return dict(name=str(self.name))
|
||||
|
@ -1,6 +1,6 @@
|
||||
import numpy as np
|
||||
|
||||
from marl_factory_grid.environment.entity.object import _Object
|
||||
from marl_factory_grid.environment.entity.object import Object
|
||||
|
||||
|
||||
##########################################################################
|
||||
@ -8,7 +8,7 @@ from marl_factory_grid.environment.entity.object import _Object
|
||||
##########################################################################
|
||||
|
||||
|
||||
class PlaceHolder(_Object):
|
||||
class PlaceHolder(Object):
|
||||
|
||||
def __init__(self, *args, fill_value=0, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
@ -27,7 +27,7 @@ class PlaceHolder(_Object):
|
||||
return self.__class__.__name__
|
||||
|
||||
|
||||
class GlobalPosition(_Object):
|
||||
class GlobalPosition(Object):
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
|
@ -56,15 +56,18 @@ class Factory(gym.Env):
|
||||
self.level_filepath = Path(custom_level_path)
|
||||
else:
|
||||
self.level_filepath = Path(__file__).parent.parent / h.LEVELS_DIR / f'{self.conf.level_name}.txt'
|
||||
self._renderer = None # expensive - don't use; unless required !
|
||||
|
||||
parsed_entities = self.conf.load_entities()
|
||||
self.map = LevelParser(self.level_filepath, parsed_entities, self.conf.pomdp_r)
|
||||
|
||||
# Init for later usage:
|
||||
self.state: Gamestate
|
||||
self.map: LevelParser
|
||||
self.obs_builder: OBSBuilder
|
||||
# noinspection PyTypeChecker
|
||||
self.state: Gamestate = None
|
||||
# noinspection PyTypeChecker
|
||||
self.obs_builder: OBSBuilder = None
|
||||
|
||||
# expensive - don't use; unless required !
|
||||
self._renderer = None
|
||||
|
||||
# reset env to initial state, preparing env for new episode.
|
||||
# returns tuple where the first dict contains initial observation for each agent in the env
|
||||
@ -74,7 +77,7 @@ class Factory(gym.Env):
|
||||
return self.state.entities[item]
|
||||
|
||||
def reset(self) -> (dict, dict):
|
||||
if hasattr(self, 'state'):
|
||||
if self.state is not None:
|
||||
for entity_group in self.state.entities:
|
||||
try:
|
||||
entity_group[0].reset_uid()
|
||||
@ -160,7 +163,7 @@ class Factory(gym.Env):
|
||||
# Finalize
|
||||
reward, reward_info, done = self.summarize_step_results(tick_result, done_results)
|
||||
|
||||
info = reward_info
|
||||
info = dict(reward_info)
|
||||
|
||||
info.update(step_reward=sum(reward), step=self.state.curr_step)
|
||||
|
||||
|
@ -1,15 +1,15 @@
|
||||
from typing import List, Tuple, Union, Dict
|
||||
|
||||
from marl_factory_grid.environment.entity.entity import Entity
|
||||
from marl_factory_grid.environment.groups.objects import _Objects
|
||||
from marl_factory_grid.environment.groups.objects import Objects
|
||||
# noinspection PyProtectedMember
|
||||
from marl_factory_grid.environment.entity.object import _Object
|
||||
from marl_factory_grid.environment.entity.object import Object
|
||||
import marl_factory_grid.environment.constants as c
|
||||
from marl_factory_grid.utils.results import Result
|
||||
|
||||
|
||||
class Collection(_Objects):
|
||||
_entity = _Object # entity?
|
||||
class Collection(Objects):
|
||||
_entity = Object # entity?
|
||||
symbol = None
|
||||
|
||||
@property
|
||||
@ -58,7 +58,7 @@ class Collection(_Objects):
|
||||
def trigger_spawn(self, state, *entity_args, coords_or_quantity=None, ignore_blocking=False, **entity_kwargs):
|
||||
coords_or_quantity = coords_or_quantity if coords_or_quantity else self._coords_or_quantity
|
||||
if self.var_has_position:
|
||||
if isinstance(coords_or_quantity, int):
|
||||
if self.var_has_position and isinstance(coords_or_quantity, int):
|
||||
if ignore_blocking or self._ignore_blocking:
|
||||
coords_or_quantity = state.entities.floorlist[:coords_or_quantity]
|
||||
else:
|
||||
@ -87,8 +87,8 @@ class Collection(_Objects):
|
||||
raise ValueError(f'{self._entity.__name__} has no position!')
|
||||
return c.VALID
|
||||
|
||||
def despawn(self, items: List[_Object]):
|
||||
items = [items] if isinstance(items, _Object) else items
|
||||
def despawn(self, items: List[Object]):
|
||||
items = [items] if isinstance(items, Object) else items
|
||||
for item in items:
|
||||
del self[item]
|
||||
|
||||
|
@ -3,12 +3,12 @@ from operator import itemgetter
|
||||
from random import shuffle
|
||||
from typing import Dict
|
||||
|
||||
from marl_factory_grid.environment.groups.objects import _Objects
|
||||
from marl_factory_grid.environment.groups.objects import Objects
|
||||
from marl_factory_grid.utils.helpers import POS_MASK
|
||||
|
||||
|
||||
class Entities(_Objects):
|
||||
_entity = _Objects
|
||||
class Entities(Objects):
|
||||
_entity = Objects
|
||||
|
||||
@staticmethod
|
||||
def neighboring_positions(pos):
|
||||
@ -87,7 +87,7 @@ class Entities(_Objects):
|
||||
def __delitem__(self, name):
|
||||
assert_str = 'This group of entity does not exist in this collection!'
|
||||
assert any([key for key in name.keys() if key in self.keys()]), assert_str
|
||||
self[name]._observers.delete(self)
|
||||
self[name].del_observer(self)
|
||||
for entity in self[name]:
|
||||
entity.del_observer(self)
|
||||
return super(Entities, self).__delitem__(name)
|
||||
|
@ -1,15 +1,15 @@
|
||||
from collections import defaultdict
|
||||
from typing import List
|
||||
from typing import List, Iterator, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from marl_factory_grid.environment.entity.object import _Object
|
||||
from marl_factory_grid.environment.entity.object import Object
|
||||
import marl_factory_grid.environment.constants as c
|
||||
from marl_factory_grid.utils import helpers as h
|
||||
|
||||
|
||||
class _Objects:
|
||||
_entity = _Object
|
||||
class Objects:
|
||||
_entity = Object
|
||||
|
||||
@property
|
||||
def var_can_be_bound(self):
|
||||
@ -50,7 +50,7 @@ class _Objects:
|
||||
def __len__(self):
|
||||
return len(self._data)
|
||||
|
||||
def __iter__(self):
|
||||
def __iter__(self) -> Iterator[Union[Object, None]]:
|
||||
return iter(self.values())
|
||||
|
||||
def add_item(self, item: _entity):
|
||||
@ -130,13 +130,14 @@ class _Objects:
|
||||
repr_dict = {key: val for key, val in self._data.items() if key not in [c.WALLS]}
|
||||
return f'{self.__class__.__name__}[{repr_dict}]'
|
||||
|
||||
def notify_del_entity(self, entity: _Object):
|
||||
def notify_del_entity(self, entity: Object):
|
||||
try:
|
||||
# noinspection PyUnresolvedReferences
|
||||
self.pos_dict[entity.pos].remove(entity)
|
||||
except (AttributeError, ValueError, IndexError):
|
||||
pass
|
||||
|
||||
def notify_add_entity(self, entity: _Object):
|
||||
def notify_add_entity(self, entity: Object):
|
||||
try:
|
||||
if self not in entity.observers:
|
||||
entity.add_observer(self)
|
||||
|
@ -1,11 +1,11 @@
|
||||
import abc
|
||||
from random import shuffle
|
||||
from typing import List, Collection, Union
|
||||
from typing import List, Collection
|
||||
|
||||
from marl_factory_grid.environment import rewards as r, constants as c
|
||||
from marl_factory_grid.environment.entity.agent import Agent
|
||||
from marl_factory_grid.utils import helpers as h
|
||||
from marl_factory_grid.utils.results import TickResult, DoneResult
|
||||
from marl_factory_grid.environment import rewards as r, constants as c
|
||||
|
||||
|
||||
class Rule(abc.ABC):
|
||||
@ -118,8 +118,7 @@ class AssignGlobalPositions(Rule):
|
||||
def on_init(self, state, lvl_map):
|
||||
from marl_factory_grid.environment.entity.util import GlobalPosition
|
||||
for agent in state[c.AGENT]:
|
||||
gp = GlobalPosition(lvl_map.level_shape)
|
||||
gp.bind_to(agent)
|
||||
gp = GlobalPosition(agent, lvl_map.level_shape)
|
||||
state[c.GLOBALPOSITIONS].add_item(gp)
|
||||
return []
|
||||
|
||||
|
@ -6,7 +6,9 @@ from marl_factory_grid.utils.results import TickResult, DoneResult
|
||||
class TemplateRule(Rule):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(TemplateRule, self).__init__(*args, **kwargs)
|
||||
super(TemplateRule, self).__init__()
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
def on_init(self, state, lvl_map):
|
||||
pass
|
||||
|
@ -1,6 +1,5 @@
|
||||
from typing import Union
|
||||
|
||||
import marl_factory_grid.modules.batteries.constants
|
||||
from marl_factory_grid.environment.actions import Action
|
||||
from marl_factory_grid.utils.results import ActionResult
|
||||
|
||||
@ -24,5 +23,6 @@ class BtryCharge(Action):
|
||||
else:
|
||||
valid = c.NOT_VALID
|
||||
state.print(f'{entity.name} failed to charged batteries at {entity.pos}.')
|
||||
|
||||
return ActionResult(entity=entity, identifier=self._identifier, validity=valid,
|
||||
reward=marl_factory_grid.modules.batteries.constants.REWARD_CHARGE_VALID if valid else marl_factory_grid.modules.batteries.constants.Reward_CHARGE_FAIL)
|
||||
reward=b.REWARD_CHARGE_VALID if valid else b.Reward_CHARGE_FAIL)
|
||||
|
@ -1,11 +1,11 @@
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.environment.entity.entity import Entity
|
||||
from marl_factory_grid.environment.entity.object import _Object
|
||||
from marl_factory_grid.environment.entity.object import Object
|
||||
from marl_factory_grid.modules.batteries import constants as b
|
||||
from marl_factory_grid.utils.utility_classes import RenderEntity
|
||||
|
||||
|
||||
class Battery(_Object):
|
||||
class Battery(Object):
|
||||
|
||||
@property
|
||||
def var_can_be_bound(self):
|
||||
|
@ -1,11 +1,9 @@
|
||||
from typing import List, Union
|
||||
|
||||
import marl_factory_grid.modules.batteries.constants
|
||||
from marl_factory_grid.environment.rules import Rule
|
||||
from marl_factory_grid.utils.results import TickResult, DoneResult
|
||||
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.environment.rules import Rule
|
||||
from marl_factory_grid.modules.batteries import constants as b
|
||||
from marl_factory_grid.utils.results import TickResult, DoneResult
|
||||
|
||||
|
||||
class BatteryDecharge(Rule):
|
||||
|
@ -1,5 +1,3 @@
|
||||
from numpy import random
|
||||
|
||||
from marl_factory_grid.environment.entity.entity import Entity
|
||||
from marl_factory_grid.utils.utility_classes import RenderEntity
|
||||
from marl_factory_grid.modules.clean_up import constants as d
|
||||
|
@ -1,9 +1,7 @@
|
||||
from typing import Union, List, Tuple
|
||||
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.utils.results import Result
|
||||
from marl_factory_grid.environment.groups.collection import Collection
|
||||
from marl_factory_grid.modules.clean_up.entitites import DirtPile
|
||||
from marl_factory_grid.utils.results import Result
|
||||
|
||||
|
||||
class DirtPiles(Collection):
|
||||
|
@ -49,7 +49,7 @@ class RespawnDirt(Rule):
|
||||
def tick_step(self, state):
|
||||
collection = state[d.DIRT]
|
||||
if self._next_dirt_spawn < 0:
|
||||
pass # No DirtPile Spawn
|
||||
result = [] # No DirtPile Spawn
|
||||
elif not self._next_dirt_spawn:
|
||||
result = [collection.trigger_spawn(state, coords_or_quantity=self.respawn_n, amount=self.respawn_amount)]
|
||||
self._next_dirt_spawn = self.respawn_freq
|
||||
|
@ -21,4 +21,4 @@ class DestAction(Action):
|
||||
valid = c.NOT_VALID
|
||||
state.print(f'{entity.name} just tried to do_wait_action do_wait_action at {entity.pos} but failed')
|
||||
return ActionResult(entity=entity, identifier=self._identifier, validity=valid,
|
||||
reward=marl_factory_grid.modules.destinations.constants.REWARD_WAIT_VALID if valid else marl_factory_grid.modules.destinations.constants.REWARD_WAIT_FAIL)
|
||||
reward=d.REWARD_WAIT_VALID if valid else d.REWARD_WAIT_FAIL)
|
||||
|
@ -1,7 +1,5 @@
|
||||
from marl_factory_grid.environment.groups.collection import Collection
|
||||
from marl_factory_grid.modules.destinations.entitites import Destination
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.modules.destinations import constants as d
|
||||
|
||||
|
||||
class Destinations(Collection):
|
||||
|
@ -1,5 +1,3 @@
|
||||
from typing import Union
|
||||
|
||||
from marl_factory_grid.environment.groups.collection import Collection
|
||||
from marl_factory_grid.modules.doors import constants as d
|
||||
from marl_factory_grid.modules.doors.entitites import Door
|
||||
|
@ -1,8 +1,8 @@
|
||||
import random
|
||||
from typing import List, Union
|
||||
from typing import List
|
||||
|
||||
from marl_factory_grid.environment.rules import Rule
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.environment.rules import Rule
|
||||
from marl_factory_grid.utils.results import TickResult
|
||||
|
||||
|
||||
@ -14,8 +14,6 @@ class AgentSingleZonePlacementBeta(Rule):
|
||||
super().__init__()
|
||||
|
||||
def on_init(self, state, lvl_map):
|
||||
zones = state[c.ZONES]
|
||||
n_zones = state[c.ZONES]
|
||||
agents = state[c.AGENT]
|
||||
if len(self.coordinates) == len(agents):
|
||||
coordinates = self.coordinates
|
||||
|
@ -1,6 +1,3 @@
|
||||
from typing import NamedTuple
|
||||
|
||||
|
||||
SYMBOL_NO_ITEM = 0
|
||||
SYMBOL_DROP_OFF = 1
|
||||
# Item Env
|
||||
|
@ -14,27 +14,14 @@ class Item(Entity):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def auto_despawn(self):
|
||||
return self._auto_despawn
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
# Edit this if you want items to be drawn in the ops differently
|
||||
return 1
|
||||
|
||||
def set_auto_despawn(self, auto_despawn):
|
||||
self._auto_despawn = auto_despawn
|
||||
|
||||
def summarize_state(self) -> dict:
|
||||
super_summarization = super(Item, self).summarize_state()
|
||||
super_summarization.update(dict(auto_despawn=self.auto_despawn))
|
||||
return super_summarization
|
||||
|
||||
|
||||
class DropOffLocation(Entity):
|
||||
|
||||
|
||||
def render(self):
|
||||
return RenderEntity(i.DROP_OFF, self.pos)
|
||||
|
||||
@ -42,18 +29,16 @@ class DropOffLocation(Entity):
|
||||
def encoding(self):
|
||||
return i.SYMBOL_DROP_OFF
|
||||
|
||||
def __init__(self, *args, storage_size_until_full: int = 5, auto_item_despawn_interval: int = 5, **kwargs):
|
||||
def __init__(self, *args, storage_size_until_full: int = 5, **kwargs):
|
||||
super(DropOffLocation, self).__init__(*args, **kwargs)
|
||||
self.auto_item_despawn_interval = auto_item_despawn_interval
|
||||
self.storage = deque(maxlen=storage_size_until_full or None)
|
||||
|
||||
def place_item(self, item: Item):
|
||||
if self.is_full:
|
||||
raise RuntimeWarning("There is currently no way to clear the storage or make it unfull.")
|
||||
return bc.NOT_VALID # in Zeile 81 verschieben?
|
||||
return bc.NOT_VALID
|
||||
else:
|
||||
self.storage.append(item)
|
||||
item.set_auto_despawn(self.auto_item_despawn_interval)
|
||||
return c.VALID
|
||||
|
||||
@property
|
||||
|
@ -1,12 +1,9 @@
|
||||
from random import shuffle
|
||||
|
||||
from marl_factory_grid.modules.items import constants as i
|
||||
from marl_factory_grid.environment import constants as c
|
||||
|
||||
from marl_factory_grid.environment.groups.collection import Collection
|
||||
from marl_factory_grid.environment.groups.objects import _Objects
|
||||
from marl_factory_grid.environment.groups.mixins import IsBoundMixin
|
||||
from marl_factory_grid.environment.entity.agent import Agent
|
||||
from marl_factory_grid.environment.groups.collection import Collection
|
||||
from marl_factory_grid.environment.groups.mixins import IsBoundMixin
|
||||
from marl_factory_grid.environment.groups.objects import Objects
|
||||
from marl_factory_grid.modules.items import constants as i
|
||||
from marl_factory_grid.modules.items.entitites import Item, DropOffLocation
|
||||
from marl_factory_grid.utils.results import Result
|
||||
|
||||
@ -74,13 +71,12 @@ class Inventory(IsBoundMixin, Collection):
|
||||
self._collection = collection
|
||||
|
||||
|
||||
class Inventories(_Objects):
|
||||
class Inventories(Objects):
|
||||
_entity = Inventory
|
||||
|
||||
var_can_move = False
|
||||
var_has_position = False
|
||||
|
||||
|
||||
symbol = None
|
||||
|
||||
@property
|
||||
@ -116,7 +112,6 @@ class Inventories(_Objects):
|
||||
return [val.summarize_states(**kwargs) for key, val in self.items()]
|
||||
|
||||
|
||||
|
||||
class DropOffLocations(Collection):
|
||||
_entity = DropOffLocation
|
||||
|
||||
|
@ -1,9 +1,10 @@
|
||||
from typing import Union
|
||||
|
||||
import marl_factory_grid.modules.machines.constants
|
||||
from marl_factory_grid.environment.actions import Action
|
||||
from marl_factory_grid.utils.results import ActionResult
|
||||
|
||||
from marl_factory_grid.modules.machines import constants as m, rewards as r
|
||||
from marl_factory_grid.modules.machines import constants as m
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.utils import helpers as h
|
||||
|
||||
@ -16,8 +17,10 @@ class MachineAction(Action):
|
||||
def do(self, entity, state) -> Union[None, ActionResult]:
|
||||
if machine := h.get_first(state[m.MACHINES].by_pos(entity.pos)):
|
||||
if valid := machine.maintain():
|
||||
return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=r.MAINTAIN_VALID)
|
||||
return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=marl_factory_grid.modules.machines.constants.MAINTAIN_VALID)
|
||||
else:
|
||||
return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=r.MAINTAIN_FAIL)
|
||||
return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=marl_factory_grid.modules.machines.constants.MAINTAIN_FAIL)
|
||||
else:
|
||||
return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=r.MAINTAIN_FAIL)
|
||||
return ActionResult(entity=entity, identifier=self._identifier,
|
||||
validity=c.NOT_VALID, reward=marl_factory_grid.modules.machines.constants.MAINTAIN_FAIL
|
||||
)
|
||||
|
@ -11,3 +11,7 @@ STATE_MAINTAIN = 'maintenance'
|
||||
SYMBOL_WORK = 1
|
||||
SYMBOL_IDLE = 0.6
|
||||
SYMBOL_MAINTAIN = 0.3
|
||||
MAINTAIN_VALID: float = 0.5
|
||||
MAINTAIN_FAIL: float = -0.1
|
||||
FAIL_MISSING_MAINTENANCE: float = -0.5
|
||||
NONE: float = 0
|
||||
|
@ -31,11 +31,10 @@ class Machine(Entity):
|
||||
return c.NOT_VALID
|
||||
|
||||
def tick(self, state):
|
||||
# if self.status == m.STATE_MAINTAIN and any([c.AGENT in x.name for x in self.tile.guests]):
|
||||
if self.status == m.STATE_MAINTAIN and any([c.AGENT in x.name for x in state.entities.pos_dict[self.pos]]):
|
||||
others = state.entities.pos_dict[self.pos]
|
||||
if self.status == m.STATE_MAINTAIN and any([c.AGENT in x.name for x in others]):
|
||||
return TickResult(identifier=self.name, validity=c.VALID, entity=self)
|
||||
# elif self.status == m.STATE_MAINTAIN and not any([c.AGENT in x.name for x in self.tile.guests]):
|
||||
elif self.status == m.STATE_MAINTAIN and not any([c.AGENT in x.name for x in state.entities.pos_dict[self.pos]]):
|
||||
elif self.status == m.STATE_MAINTAIN and not any([c.AGENT in x.name for x in others]):
|
||||
self.status = m.STATE_WORK
|
||||
self.reset_counter()
|
||||
return None
|
||||
|
@ -1,5 +1,3 @@
|
||||
from typing import Union, List, Tuple
|
||||
|
||||
from marl_factory_grid.environment.groups.collection import Collection
|
||||
|
||||
from .entitites import Machine
|
||||
|
@ -1,5 +0,0 @@
|
||||
MAINTAIN_VALID: float = 0.5
|
||||
MAINTAIN_FAIL: float = -0.1
|
||||
FAIL_MISSING_MAINTENANCE: float = -0.5
|
||||
|
||||
NONE: float = 0
|
@ -1,3 +1,4 @@
|
||||
MAINTAINER = 'Maintainer' # TEMPLATE _identifier. Define your own!
|
||||
MAINTAINERS = 'Maintainers' # TEMPLATE _identifier. Define your own!
|
||||
|
||||
MAINTAINER_COLLISION_REWARD = -5
|
||||
|
@ -4,7 +4,6 @@ from marl_factory_grid.environment.groups.collection import Collection
|
||||
from .entities import Maintainer
|
||||
from ..machines import constants as mc
|
||||
from ..machines.actions import MachineAction
|
||||
from ...utils.states import Gamestate
|
||||
|
||||
|
||||
class Maintainers(Collection):
|
||||
@ -23,8 +22,6 @@ class Maintainers(Collection):
|
||||
self.size = size
|
||||
self._spawnrule = spawnrule
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args):
|
||||
self.add_items([self._entity(mc.MACHINES, MachineAction(), pos) for pos in coords_or_quantity])
|
||||
|
@ -1 +0,0 @@
|
||||
MAINTAINER_COLLISION_REWARD = -5
|
@ -1,15 +1,16 @@
|
||||
from typing import List
|
||||
|
||||
import marl_factory_grid.modules.maintenance.constants
|
||||
from marl_factory_grid.environment.rules import Rule
|
||||
from marl_factory_grid.utils.results import TickResult, DoneResult
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from . import rewards as r
|
||||
from . import constants as M
|
||||
|
||||
|
||||
class MoveMaintainers(Rule):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def tick_step(self, state) -> List[TickResult]:
|
||||
for maintainer in state[M.MAINTAINERS]:
|
||||
@ -20,8 +21,8 @@ class MoveMaintainers(Rule):
|
||||
|
||||
class DoneAtMaintainerCollision(Rule):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def on_check_done(self, state) -> List[DoneResult]:
|
||||
agents = list(state[c.AGENT].values())
|
||||
@ -30,5 +31,5 @@ class DoneAtMaintainerCollision(Rule):
|
||||
for agent in agents:
|
||||
if agent.pos in m_pos:
|
||||
done_results.append(DoneResult(entity=agent, validity=c.VALID, identifier=self.name,
|
||||
reward=r.MAINTAINER_COLLISION_REWARD))
|
||||
reward=marl_factory_grid.modules.maintenance.constants.MAINTAINER_COLLISION_REWARD))
|
||||
return done_results
|
||||
|
@ -1,10 +1,10 @@
|
||||
import random
|
||||
from typing import List, Tuple
|
||||
|
||||
from marl_factory_grid.environment.entity.object import _Object
|
||||
from marl_factory_grid.environment.entity.object import Object
|
||||
|
||||
|
||||
class Zone(_Object):
|
||||
class Zone(Object):
|
||||
|
||||
@property
|
||||
def positions(self):
|
||||
|
@ -1,8 +1,8 @@
|
||||
from marl_factory_grid.environment.groups.objects import _Objects
|
||||
from marl_factory_grid.environment.groups.objects import Objects
|
||||
from marl_factory_grid.modules.zones import Zone
|
||||
|
||||
|
||||
class Zones(_Objects):
|
||||
class Zones(Objects):
|
||||
symbol = None
|
||||
_entity = Zone
|
||||
|
||||
|
@ -58,7 +58,10 @@ class FactoryConfigParser(object):
|
||||
return str(self.config)
|
||||
|
||||
def __getitem__(self, item):
|
||||
try:
|
||||
return self.config[item]
|
||||
except KeyError:
|
||||
print(f'The mandatory {item} section could not be found in your .config gile. Check Spelling!')
|
||||
|
||||
def load_entities(self):
|
||||
entity_classes = dict()
|
||||
@ -161,7 +164,6 @@ class FactoryConfigParser(object):
|
||||
|
||||
def _load_smth(self, config, class_obj):
|
||||
rules = list()
|
||||
rules_names = list()
|
||||
for rule in config:
|
||||
e1 = e2 = e3 = None
|
||||
try:
|
||||
|
@ -61,8 +61,8 @@ class ObservationTranslator:
|
||||
:param per_agent_named_obs_spaces: `Named observation space` one for each agent. Overloaded.
|
||||
type per_agent_named_obs_spaces: Dict[str, dict]
|
||||
|
||||
:param placeholder_fill_value: Currently not fully implemented!!!
|
||||
:type placeholder_fill_value: Union[int, str] = 'N')
|
||||
:param placeholder_fill_value: Currently, not fully implemented!!!
|
||||
:type placeholder_fill_value: Union[int, str] = 'N'
|
||||
"""
|
||||
|
||||
if isinstance(placeholder_fill_value, str):
|
||||
|
@ -9,7 +9,7 @@ from marl_factory_grid.utils.helpers import IGNORED_DF_COLUMNS
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from marl_factory_grid.utils.plotting.compare_runs import plot_single_run
|
||||
from marl_factory_grid.utils.plotting.plot_single_runs import plot_single_run
|
||||
|
||||
|
||||
class EnvMonitor(Wrapper):
|
||||
@ -22,7 +22,6 @@ class EnvMonitor(Wrapper):
|
||||
self._monitor_df = pd.DataFrame()
|
||||
self._monitor_dict = dict()
|
||||
|
||||
|
||||
def step(self, action):
|
||||
obs_type, obs, reward, done, info = self.env.step(action)
|
||||
self._read_info(info)
|
||||
|
@ -2,11 +2,9 @@ from os import PathLike
|
||||
from pathlib import Path
|
||||
from typing import Union, List
|
||||
|
||||
import yaml
|
||||
from gymnasium import Wrapper
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from gymnasium import Wrapper
|
||||
|
||||
|
||||
class EnvRecorder(Wrapper):
|
||||
@ -106,7 +104,7 @@ class EnvRecorder(Wrapper):
|
||||
out_dict = {'episodes': self._recorder_out_list}
|
||||
out_dict.update(
|
||||
{'n_episodes': self._curr_episode,
|
||||
'metadata':dict(
|
||||
'metadata': dict(
|
||||
level_name=self.env.params['General']['level_name'],
|
||||
verbose=False,
|
||||
n_agents=len(self.env.params['Agents']),
|
||||
|
@ -5,7 +5,7 @@ from typing import Dict, List
|
||||
import numpy as np
|
||||
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.environment.entity.object import _Object
|
||||
from marl_factory_grid.environment.entity.object import Object
|
||||
from marl_factory_grid.environment.groups.utils import Combined
|
||||
from marl_factory_grid.utils.utility_classes import Floor
|
||||
from marl_factory_grid.utils.ray_caster import RayCaster
|
||||
@ -13,7 +13,6 @@ from marl_factory_grid.utils.states import Gamestate
|
||||
from marl_factory_grid.utils import helpers as h
|
||||
|
||||
|
||||
|
||||
class OBSBuilder(object):
|
||||
default_obs = [c.WALLS, c.OTHERS]
|
||||
|
||||
@ -128,7 +127,7 @@ class OBSBuilder(object):
|
||||
f'{re.escape("[")}(.*){re.escape("]")}'
|
||||
f'{re.escape("(")}{re.escape(agent.name)}{re.escape(")")}')
|
||||
name = next((key for key, val in self.all_obs.items()
|
||||
if pattern.search(str(val)) and isinstance(val, _Object)), None)
|
||||
if pattern.search(str(val)) and isinstance(val, Object)), None)
|
||||
e = self.all_obs[name]
|
||||
except KeyError:
|
||||
try:
|
||||
@ -181,11 +180,11 @@ class OBSBuilder(object):
|
||||
return obs, self.obs_layers[agent.name]
|
||||
|
||||
def _sort_and_name_observation_conf(self, agent):
|
||||
'''
|
||||
"""
|
||||
Builds the useable observation scheme per agent from conf.yaml.
|
||||
:param agent:
|
||||
:return:
|
||||
'''
|
||||
"""
|
||||
# Fixme: no asymetric shapes possible.
|
||||
self.ray_caster[agent.name] = RayCaster(agent, min(self.obs_shape))
|
||||
obs_layers = []
|
||||
|
@ -7,50 +7,11 @@ from typing import Union, List
|
||||
import pandas as pd
|
||||
|
||||
from marl_factory_grid.utils.helpers import IGNORED_DF_COLUMNS
|
||||
from marl_factory_grid.utils.plotting.plotting import prepare_plot
|
||||
from marl_factory_grid.utils.plotting.plotting_utils import prepare_plot
|
||||
|
||||
MODEL_MAP = None
|
||||
|
||||
|
||||
def plot_single_run(run_path: Union[str, PathLike], use_tex: bool = False, column_keys=None):
|
||||
run_path = Path(run_path)
|
||||
df_list = list()
|
||||
if run_path.is_dir():
|
||||
monitor_file = next(run_path.glob('*monitor*.pick'))
|
||||
elif run_path.exists() and run_path.is_file():
|
||||
monitor_file = run_path
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
with monitor_file.open('rb') as f:
|
||||
monitor_df = pickle.load(f)
|
||||
|
||||
monitor_df = monitor_df.fillna(0)
|
||||
df_list.append(monitor_df)
|
||||
|
||||
df = pd.concat(df_list, ignore_index=True)
|
||||
df = df.fillna(0).rename(columns={'episode': 'Episode'}).sort_values(['Episode'])
|
||||
if column_keys is not None:
|
||||
columns = [col for col in column_keys if col in df.columns]
|
||||
else:
|
||||
columns = [col for col in df.columns if col not in IGNORED_DF_COLUMNS]
|
||||
|
||||
roll_n = 50
|
||||
|
||||
non_overlapp_window = df.groupby(['Episode']).rolling(roll_n, min_periods=1).mean()
|
||||
|
||||
df_melted = df[columns + ['Episode']].reset_index().melt(
|
||||
id_vars=['Episode'], value_vars=columns, var_name="Measurement", value_name="Score"
|
||||
)
|
||||
|
||||
if df_melted['Episode'].max() > 800:
|
||||
skip_n = round(df_melted['Episode'].max() * 0.02)
|
||||
df_melted = df_melted[df_melted['Episode'] % skip_n == 0]
|
||||
|
||||
prepare_plot(run_path.parent / f'{run_path.parent.name}_monitor_lineplot.png', df_melted, use_tex=use_tex)
|
||||
print('Plotting done.')
|
||||
|
||||
|
||||
def compare_seed_runs(run_path: Union[str, PathLike], use_tex: bool = False):
|
||||
run_path = Path(run_path)
|
||||
df_list = list()
|
48
marl_factory_grid/utils/plotting/plot_single_runs.py
Normal file
48
marl_factory_grid/utils/plotting/plot_single_runs.py
Normal file
@ -0,0 +1,48 @@
|
||||
import pickle
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from marl_factory_grid.utils.helpers import IGNORED_DF_COLUMNS
|
||||
from marl_factory_grid.utils.plotting.plotting_utils import prepare_plot
|
||||
|
||||
|
||||
def plot_single_run(run_path: Union[str, PathLike], use_tex: bool = False, column_keys=None,
|
||||
file_key: str ='monitor', file_ext: str ='pkl'):
|
||||
run_path = Path(run_path)
|
||||
df_list = list()
|
||||
if run_path.is_dir():
|
||||
monitor_file = next(run_path.glob(f'*{file_key}*.{file_ext}'))
|
||||
elif run_path.exists() and run_path.is_file():
|
||||
monitor_file = run_path
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
with monitor_file.open('rb') as f:
|
||||
monitor_df = pickle.load(f)
|
||||
|
||||
monitor_df = monitor_df.fillna(0)
|
||||
df_list.append(monitor_df)
|
||||
|
||||
df = pd.concat(df_list, ignore_index=True)
|
||||
df = df.fillna(0).rename(columns={'episode': 'Episode'}).sort_values(['Episode'])
|
||||
if column_keys is not None:
|
||||
columns = [col for col in column_keys if col in df.columns]
|
||||
else:
|
||||
columns = [col for col in df.columns if col not in IGNORED_DF_COLUMNS]
|
||||
|
||||
# roll_n = 50
|
||||
# non_overlapp_window = df.groupby(['Episode']).rolling(roll_n, min_periods=1).mean()
|
||||
|
||||
df_melted = df[columns + ['Episode']].reset_index().melt(
|
||||
id_vars=['Episode'], value_vars=columns, var_name="Measurement", value_name="Score"
|
||||
)
|
||||
|
||||
if df_melted['Episode'].max() > 800:
|
||||
skip_n = round(df_melted['Episode'].max() * 0.02)
|
||||
df_melted = df_melted[df_melted['Episode'] % skip_n == 0]
|
||||
|
||||
prepare_plot(run_path.parent / f'{run_path.parent.name}_monitor_lineplot.png', df_melted, use_tex=use_tex)
|
||||
print('Plotting done.')
|
@ -60,7 +60,7 @@ def prepare_center_double_column_legend(df, hue, style, hue_order):
|
||||
print('Struggling to plot Figure using LaTeX - going back to normal.')
|
||||
plt.close('all')
|
||||
sns.set(rc={'text.usetex': False}, style='whitegrid')
|
||||
fig = plt.figure(figsize=(10, 11))
|
||||
_ = plt.figure(figsize=(10, 11))
|
||||
lineplot = sns.lineplot(data=df, x='Episode', y='Score', hue=hue, style=style,
|
||||
ci=95, palette=PALETTE, hue_order=hue_order, legend=False)
|
||||
# plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0)
|
@ -19,7 +19,7 @@ class RayCaster:
|
||||
return f'{self.__class__.__name__}({self.agent.name})'
|
||||
|
||||
def build_ray_targets(self):
|
||||
north = np.array([0, -1])*self.pomdp_r
|
||||
north = np.array([0, -1]) * self.pomdp_r
|
||||
thetas = [np.deg2rad(deg) for deg in np.linspace(-self.degs // 2, self.degs // 2, self.n_rays)[::-1]]
|
||||
rot_M = [
|
||||
[[math.cos(theta), -math.sin(theta)],
|
||||
@ -53,9 +53,9 @@ class RayCaster:
|
||||
diag_hits = all([
|
||||
self.ray_block_cache(
|
||||
key,
|
||||
# lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light)
|
||||
lambda: any(True for e in pos_dict[key] if e.var_is_blocking_light))
|
||||
# lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light))
|
||||
for key in ((x, y-cy), (x-cx, y))
|
||||
for key in ((x, y - cy), (x - cx, y))
|
||||
]) if (cx != 0 and cy != 0) else False
|
||||
|
||||
visible += entities_hit if not diag_hits else []
|
||||
@ -77,8 +77,8 @@ class RayCaster:
|
||||
agent = self.agent
|
||||
x_coords = range(agent.x - self.pomdp_r, agent.x + self.pomdp_r + 1)
|
||||
y_coords = range(agent.y - self.pomdp_r, agent.y + self.pomdp_r + 1)
|
||||
outline = list(product(x_coords, [agent.y - self.pomdp_r, agent.y + self.pomdp_r])) \
|
||||
+ list(product([agent.x - self.pomdp_r, agent.x + self.pomdp_r], y_coords))
|
||||
outline = list(product(x_coords, [agent.y - self.pomdp_r, agent.y + self.pomdp_r]))
|
||||
outline += list(product([agent.x - self.pomdp_r, agent.x + self.pomdp_r], y_coords))
|
||||
return outline
|
||||
|
||||
@staticmethod
|
||||
|
@ -1,9 +1,12 @@
|
||||
from typing import Union
|
||||
from dataclasses import dataclass
|
||||
|
||||
from marl_factory_grid.environment.entity.object import Object
|
||||
|
||||
TYPE_VALUE = 'value'
|
||||
TYPE_REWARD = 'reward'
|
||||
types = [TYPE_VALUE, TYPE_REWARD]
|
||||
TYPES = [TYPE_VALUE, TYPE_REWARD]
|
||||
|
||||
|
||||
@dataclass
|
||||
class InfoObject:
|
||||
@ -18,12 +21,13 @@ class Result:
|
||||
validity: bool
|
||||
reward: Union[float, None] = None
|
||||
value: Union[float, None] = None
|
||||
entity: None = None
|
||||
entity: Object = None
|
||||
|
||||
def get_infos(self):
|
||||
n = self.entity.name if self.entity is not None else "Global"
|
||||
return [InfoObject(identifier=f'{n}_{self.identifier}_{t}',
|
||||
val_type=t, value=self.__getattribute__(t)) for t in types
|
||||
# Return multiple Info Dicts
|
||||
return [InfoObject(identifier=f'{n}_{self.identifier}',
|
||||
val_type=t, value=self.__getattribute__(t)) for t in TYPES
|
||||
if self.__getattribute__(t) is not None]
|
||||
|
||||
def __repr__(self):
|
||||
@ -31,7 +35,7 @@ class Result:
|
||||
reward = f" | Reward: {self.reward}" if self.reward is not None else ""
|
||||
value = f" | Value: {self.value}" if self.value is not None else ""
|
||||
entity = f" | by: {self.entity.name}" if self.entity is not None else ""
|
||||
return f'{self.__class__.__name__}({self.identifier.capitalize()} {valid}valid{reward}{value})'
|
||||
return f'{self.__class__.__name__}({self.identifier.capitalize()} {valid}valid{reward}{value}{entity})'
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -1,11 +1,12 @@
|
||||
from itertools import islice
|
||||
from typing import List, Dict, Tuple
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.environment.entity.entity import Entity
|
||||
from marl_factory_grid.environment.rules import Rule
|
||||
from marl_factory_grid.utils.results import Result
|
||||
from marl_factory_grid.utils.results import Result, DoneResult
|
||||
|
||||
|
||||
class StepRules:
|
||||
@ -83,13 +84,51 @@ class Gamestate(object):
|
||||
return f'{self.__class__.__name__}({len(self.entities)} Entitites @ Step {self.curr_step})'
|
||||
|
||||
@property
|
||||
def random_free_position(self):
|
||||
def random_free_position(self) -> (int, int):
|
||||
"""
|
||||
Returns a single **free** position (x, y), which is **free** for spawning or walking.
|
||||
No Entity at this position posses *var_is_blocking_pos* or *var_can_collide*.
|
||||
|
||||
:return: Single **free** position.
|
||||
"""
|
||||
return self.get_n_random_free_positions(1)[0]
|
||||
|
||||
def get_n_random_free_positions(self, n):
|
||||
def get_n_random_free_positions(self, n) -> list[tuple[int, int]]:
|
||||
"""
|
||||
Returns a list of *n* **free** positions [(x, y), ... ], which are **free** for spawning or walking.
|
||||
No Entity at this position posses *var_is_blocking_pos* or *var_can_collide*.
|
||||
|
||||
:return: List of n **free** position.
|
||||
"""
|
||||
return list(islice(self.entities.free_positions_generator, n))
|
||||
|
||||
def tick(self, actions) -> List[Result]:
|
||||
@property
|
||||
def random_position(self) -> (int, int):
|
||||
"""
|
||||
Returns a single available position (x, y), ignores all entity attributes.
|
||||
|
||||
:return: Single random position.
|
||||
"""
|
||||
return self.get_n_random_positions(1)[0]
|
||||
|
||||
def get_n_random_positions(self, n) -> list[tuple[int, int]]:
|
||||
"""
|
||||
Returns a list of *n* available positions [(x, y), ... ], ignores all entity attributes.
|
||||
|
||||
:return: List of n random positions.
|
||||
"""
|
||||
return list(islice(self.entities.floorlist, n))
|
||||
|
||||
def tick(self, actions) -> list[Result]:
|
||||
"""
|
||||
Performs a single **Gamestate Tick**by calling the inner rule hooks in sequential order.
|
||||
- tick_pre_step_all: Things to do before the agents do their actions. Statechange, Moving, Spawning etc...
|
||||
- agent tick: Agents do their actions.
|
||||
- tick_step_all: Things to do after the agents did their actions. Statechange, Moving, Spawning etc...
|
||||
- tick_post_step_all: Things to do at the very end of each step. Counting, Reward calculations etc...
|
||||
|
||||
:return: List of *Result*-objects.
|
||||
"""
|
||||
results = list()
|
||||
self.curr_step += 1
|
||||
|
||||
@ -112,11 +151,23 @@ class Gamestate(object):
|
||||
|
||||
return results
|
||||
|
||||
def print(self, string):
|
||||
def print(self, string) -> None:
|
||||
"""
|
||||
When *verbose* is active, print stuff.
|
||||
|
||||
:param string: *String* to print.
|
||||
:type string: str
|
||||
:return: Nothing
|
||||
"""
|
||||
if self.verbose:
|
||||
print(string)
|
||||
|
||||
def check_done(self):
|
||||
def check_done(self) -> List[DoneResult]:
|
||||
"""
|
||||
Iterate all **Rules** that override tehe *on_ckeck_done* hook.
|
||||
|
||||
:return: List of Results
|
||||
"""
|
||||
results = list()
|
||||
for rule in self.rules:
|
||||
if on_check_done_result := rule.on_check_done(self):
|
||||
@ -124,20 +175,44 @@ class Gamestate(object):
|
||||
return results
|
||||
|
||||
def get_all_pos_with_collisions(self) -> List[Tuple[(int, int)]]:
|
||||
positions = [pos for pos, entities in self.entities.pos_dict.items() if len(entities) >= 2 and (len([e for e in entities if e.var_can_collide]) >= 2)]
|
||||
"""
|
||||
Returns a list positions [(x, y), ... ] on which collisions occur. This does not include agents,
|
||||
that were unable to move because their target direction was blocked, also a form of collision.
|
||||
|
||||
:return: List of positions.
|
||||
"""
|
||||
positions = [pos for pos, entities in self.entities.pos_dict.items() if
|
||||
len(entities) >= 2 and (len([e for e in entities if e.var_can_collide]) >= 2)
|
||||
]
|
||||
return positions
|
||||
|
||||
def check_move_validity(self, moving_entity, position):
|
||||
if moving_entity.pos != position and not any(
|
||||
entity.var_is_blocking_pos for entity in self.entities.pos_dict[position]) and not (
|
||||
moving_entity.var_is_blocking_pos and self.entities.is_occupied(position)):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
def check_move_validity(self, moving_entity: Entity, target_position: (int, int)) -> bool:
|
||||
"""
|
||||
Whether it is safe to move to the target positions and moving entity does not introduce a blocking attribute,
|
||||
when position is allready occupied.
|
||||
|
||||
def check_pos_validity(self, position):
|
||||
if not any(entity.var_is_blocking_pos for entity in self.entities.pos_dict[position]):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
:param moving_entity: Entity
|
||||
:param target_position: pos
|
||||
:return: Safe to move to
|
||||
"""
|
||||
|
||||
is_not_blocked = self.check_pos_validity(target_position)
|
||||
will_not_block_others = moving_entity.var_is_blocking_pos and self.entities.is_occupied(target_position)
|
||||
|
||||
if moving_entity.pos != target_position and is_not_blocked and not will_not_block_others:
|
||||
return c.VALID
|
||||
else:
|
||||
return c.NOT_VALID
|
||||
|
||||
def check_pos_validity(self, pos: (int, int)) -> bool:
|
||||
"""
|
||||
Check if *pos* is a valid position to move or spawn to.
|
||||
|
||||
:param pos: position to check
|
||||
:return: Wheter pos is a valid target.
|
||||
"""
|
||||
|
||||
if not any(e.var_is_blocking_pos for e in self.entities.pos_dict[pos]) and pos in self.entities.floorlist:
|
||||
return c.VALID
|
||||
else:
|
||||
return c.NOT_VALID
|
||||
|
@ -28,7 +28,9 @@ class ConfigExplainer:
|
||||
|
||||
def explain_module(self, class_to_explain):
|
||||
parameters = inspect.signature(class_to_explain).parameters
|
||||
explained = {class_to_explain.__name__: {key: val.default for key, val in parameters.items() if key not in EXCLUDED}}
|
||||
explained = {class_to_explain.__name__:
|
||||
{key: val.default for key, val in parameters.items() if key not in EXCLUDED}
|
||||
}
|
||||
return explained
|
||||
|
||||
def _load_and_compare(self, compare_class, paths):
|
||||
|
@ -6,18 +6,21 @@ from marl_factory_grid.environment.factory import Factory
|
||||
|
||||
from marl_factory_grid.utils.logging.envmonitor import EnvMonitor
|
||||
from marl_factory_grid.utils.logging.recorder import EnvRecorder
|
||||
from marl_factory_grid.utils.plotting.plot_single_runs import plot_single_run
|
||||
from marl_factory_grid.utils.tools import ConfigExplainer
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Render at each step?
|
||||
render = True
|
||||
render = False
|
||||
# Reveal all possible Modules (Entities, Rules, Agents[Actions, Observations], etc.)
|
||||
explain_config = False
|
||||
# Collect statistics?
|
||||
monitor = False
|
||||
monitor = True
|
||||
# Record as Protobuf?
|
||||
record = False
|
||||
# Plot Results?
|
||||
plotting = True
|
||||
|
||||
run_path = Path('study_out')
|
||||
|
||||
@ -38,7 +41,7 @@ if __name__ == '__main__':
|
||||
factory = EnvRecorder(factory)
|
||||
|
||||
# RL learn Loop
|
||||
for episode in trange(500):
|
||||
for episode in trange(10):
|
||||
_ = factory.reset()
|
||||
done = False
|
||||
if render:
|
||||
@ -54,7 +57,10 @@ if __name__ == '__main__':
|
||||
break
|
||||
|
||||
if monitor:
|
||||
factory.save_run(run_path / 'test.pkl')
|
||||
factory.save_run(run_path / 'test_monitor.pkl')
|
||||
if record:
|
||||
factory.save_records(run_path / 'test.pb')
|
||||
if plotting:
|
||||
plot_single_run(run_path)
|
||||
|
||||
print('Done!!! Goodbye....')
|
||||
|
@ -56,6 +56,7 @@ if __name__ == '__main__':
|
||||
for model_idx, model in enumerate(models)]
|
||||
else:
|
||||
actions = models[0].predict(env_state, deterministic=determin)[0]
|
||||
# noinspection PyTupleAssignmentBalance
|
||||
env_state, step_r, done_bool, info_obj = env.step(actions)
|
||||
|
||||
rew += step_r
|
||||
|
@ -5,7 +5,6 @@ from pathlib import Path
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
|
||||
conf_path = Path('wg0')
|
||||
wg0_conf = configparser.ConfigParser()
|
||||
wg0_conf.read(conf_path/'wg0.conf')
|
||||
@ -17,7 +16,6 @@ if __name__ == '__main__':
|
||||
# Delete any old conf.json for the current peer
|
||||
(conf_path / f'{client_name}.json').unlink(missing_ok=True)
|
||||
|
||||
|
||||
peer = wg0_conf[client_name]
|
||||
|
||||
date_time = datetime.now().strftime('%Y-%m-%dT%H:%M:%S.%f000Z')
|
||||
|
Loading…
x
Reference in New Issue
Block a user