Resolved some warnings and style issues

This commit is contained in:
Steffen Illium 2023-11-10 09:29:54 +01:00
parent a9462a8b6f
commit 6711a0976b
64 changed files with 331 additions and 361 deletions

5
.idea/.gitignore generated vendored Normal file
View File

@ -0,0 +1,5 @@
# Default ignored files
/shelf/
/workspace.xml
# Editor-based HTTP Client requests
/httpRequests/

View File

@ -94,7 +94,7 @@ All [Entites](marl_factory_grid/environment/entity/global_entities.py) are avail
#### Rules
[Rules](marl_factory_grid/environment/entity/object.py) define how the environment behaves on micro-scale.
[Rules](marl_factory_grid/environment/entity/object.py) define how the environment behaves on microscale.
Each of the hookes (`on_init`, `pre_step`, `on_step`, '`post_step`', `on_done`)
provide env-access to implement customn logic, calculate rewards, or gather information.
@ -107,6 +107,7 @@ Make sure to bring your own assets for each Entity living in the Gridworld as th
PNG-files (transparent background) of square aspect-ratio should do the job, in general.
<img src="/marl_factory_grid/environment/assets/wall.png" width="5%">
<!--suppress HtmlUnknownAttribute -->
<html &nbsp&nbsp&nbsp&nbsp html>
<img src="/marl_factory_grid/environment/assets/agent/agent.png" width="5%">

View File

@ -1 +1 @@
from .quickstart import init
from .quickstart import init

View File

@ -1 +1,4 @@
import os, sys; sys.path.append(os.path.dirname(os.path.realpath(__file__)))
import os
import sys
sys.path.append(os.path.dirname(os.path.realpath(__file__)))

View File

@ -1 +1 @@
from marl_factory_grid.algorithms.marl.memory import MARLActorCriticMemory
from marl_factory_grid.algorithms.marl.memory import MARLActorCriticMemory

View File

@ -28,6 +28,7 @@ class Names:
BATCH_SIZE = 'bnatch_size'
N_ACTIONS = 'n_actions'
nms = Names
ListOrTensor = Union[List, torch.Tensor]
@ -112,10 +113,9 @@ class BaseActorCritic:
next_obs, reward, done, info = env.step(action)
done = [done] * self.n_agents if isinstance(done, bool) else done
last_hiddens = dict(hidden_actor =out[nms.HIDDEN_ACTOR],
last_hiddens = dict(hidden_actor=out[nms.HIDDEN_ACTOR],
hidden_critic=out[nms.HIDDEN_CRITIC])
tm.add(observation=obs, action=action, reward=reward, done=done,
logits=out.get(nms.LOGITS, None), values=out.get(nms.CRITIC, None),
**last_hiddens)
@ -142,7 +142,9 @@ class BaseActorCritic:
print(f'reward at episode: {episode} = {rew_log}')
episode += 1
df_results.append([episode, rew_log, *reward])
df_results = pd.DataFrame(df_results, columns=['steps', 'reward', *[f'agent#{i}' for i in range(self.n_agents)]])
df_results = pd.DataFrame(df_results,
columns=['steps', 'reward', *[f'agent#{i}' for i in range(self.n_agents)]]
)
if checkpointer is not None:
df_results.to_csv(checkpointer.path / 'results.csv', index=False)
return df_results
@ -157,24 +159,27 @@ class BaseActorCritic:
last_action, reward = [-1] * self.n_agents, [0.] * self.n_agents
done, rew_log, eps_rew = [False] * self.n_agents, 0, torch.zeros(self.n_agents)
while not all(done):
if render: env.render()
if render:
env.render()
out = self.forward(obs, last_action, **last_hiddens)
action = self.get_actions(out)
next_obs, reward, done, info = env.step(action)
if isinstance(done, bool): done = [done] * obs.shape[0]
if isinstance(done, bool):
done = [done] * obs.shape[0]
obs = next_obs
last_action = action
last_hiddens = dict(hidden_actor=out.get(nms.HIDDEN_ACTOR, None),
hidden_critic=out.get(nms.HIDDEN_CRITIC, None)
)
eps_rew += torch.tensor(reward)
results.append(eps_rew.tolist() + [sum(eps_rew).item()] + [episode])
results.append(eps_rew.tolist() + [np.sum(eps_rew).item()] + [episode])
episode += 1
agent_columns = [f'agent#{i}' for i in range(self.cfg['environment']['n_agents'])]
results = pd.DataFrame(results, columns=agent_columns + ['sum', 'episode'])
results = pd.melt(results, id_vars=['episode'], value_vars=agent_columns + ['sum'], value_name='reward', var_name='agent')
results = pd.melt(results, id_vars=['episode'], value_vars=agent_columns + ['sum'],
value_name='reward', var_name='agent')
return results
@staticmethod

View File

@ -36,7 +36,7 @@ class LoopMAPPO(LoopSNAC):
rewards_ = torch.stack(rewards_, dim=1)
return rewards_
def mappo(self, batch, network, gamma, entropy_coef, vf_coef, clip_range, **kwargs):
def mappo(self, batch, network, gamma, entropy_coef, vf_coef, clip_range, **__):
out = network(batch[nms.OBSERVATION], batch[nms.ACTION], batch[nms.HIDDEN_ACTOR], batch[nms.HIDDEN_CRITIC])
logits = out[nms.LOGITS][:, :-1] # last one only needed for v_{t+1}
@ -45,7 +45,7 @@ class LoopMAPPO(LoopSNAC):
# monte carlo returns
mc_returns = self.monte_carlo_returns(batch[nms.REWARD], batch[nms.DONE], gamma)
mc_returns = (mc_returns - mc_returns.mean()) / (mc_returns.std() + 1e-8) #todo: norm across agent ok?
mc_returns = (mc_returns - mc_returns.mean()) / (mc_returns.std() + 1e-8) # todo: norm across agent ok?
advantages = mc_returns - out[nms.CRITIC][:, :-1]
# policy loss

View File

@ -1,8 +1,7 @@
import numpy as np
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from torch.nn.utils import spectral_norm
class RecurrentAC(nn.Module):
@ -88,8 +87,8 @@ class NormalizedLinear(nn.Linear):
self.trainable_magnitude = trainable_magnitude
self.scale = nn.Parameter(torch.tensor([1.]), requires_grad=trainable_magnitude)
def forward(self, input):
normalized_input = F.normalize(input, dim=-1, p=2, eps=1e-5)
def forward(self, in_array):
normalized_input = F.normalize(in_array, dim=-1, p=2, eps=1e-5)
normalized_weight = F.normalize(self.weight, dim=-1, p=2, eps=1e-5)
return F.linear(normalized_input, normalized_weight) * self.d_sqrt * self.scale

View File

@ -16,7 +16,7 @@ class LoopSEAC(LoopIAC):
with torch.inference_mode(True):
true_action_logp = torch.stack([
torch.log_softmax(out[nms.LOGITS][ag_i, :-1], -1)
.gather(index=actions[ag_i, 1:, None], dim=-1)
.gather(index=actions[ag_i, 1:, None], dim=-1)
for ag_i, out in enumerate(outputs)
], 0).squeeze()
@ -38,7 +38,6 @@ class LoopSEAC(LoopIAC):
a2c_loss = (-iw*log_ap * advantages.detach()).mean(-1)
value_loss = (iw*advantages.pow(2)).mean(-1) # n_agent
# weighted loss
@ -53,4 +52,4 @@ class LoopSEAC(LoopIAC):
self.optimizer[ag_i].zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.net[ag_i].parameters(), 0.5)
self.optimizer[ag_i].step()
self.optimizer[ag_i].step()

View File

@ -30,4 +30,4 @@ class LoopSNAC(BaseActorCritic):
self._as_torch(actions).unsqueeze(1),
hidden_actor, hidden_critic
)
return out
return out

View File

@ -56,8 +56,8 @@ class TSPBaseAgent(ABC):
def _door_is_close(self, state):
try:
# return next(y for x in self.state.tile.neighboring_floor for y in x.guests if do.DOOR in y.name)
return next(y for x in state.entities.neighboring_positions(self.state.pos) for y in state.entities.pos_dict[x] if do.DOOR in y.name)
return next(y for x in state.entities.neighboring_positions(self.state.pos)
for y in state.entities.pos_dict[x] if do.DOOR in y.name)
except StopIteration:
return None

View File

@ -14,8 +14,8 @@ class TSPTargetAgent(TSPBaseAgent):
def _handle_doors(self, state):
try:
# return next(y for x in self.state.tile.neighboring_floor for y in x.guests if do.DOOR in y.name)
return next(y for x in state.entities.neighboring_positions(self.state.pos) for y in state.entities.pos_dict[x] if do.DOOR in y.name)
return next(y for x in state.entities.neighboring_positions(self.state.pos)
for y in state.entities.pos_dict[x] if do.DOOR in y.name)
except StopIteration:
return None

View File

@ -1,8 +1,9 @@
import torch
import numpy as np
import yaml
from pathlib import Path
import numpy as np
import torch
import yaml
def load_class(classname):
from importlib import import_module
@ -42,7 +43,6 @@ def get_class(arguments):
def get_arguments(arguments):
from importlib import import_module
d = dict(arguments)
if "classname" in d:
del d["classname"]
@ -82,4 +82,4 @@ class Checkpointer(object):
for name, model in to_save:
self.save_experiment(name, model)
self.__current_checkpoint += 1
self.__current_step += 1
self.__current_step += 1

View File

@ -1,4 +1,4 @@
eneral:
General:
# Your Seed
env_seed: 69
# Individual or global rewards?
@ -86,4 +86,4 @@ Rules:
DoneAtDestinationReachAll:
# reward_at_done: 1
DoneAtMaxStepsReached:
max_steps: 500
max_steps: 200

View File

@ -1,15 +1,14 @@
import abc
from collections import defaultdict
import numpy as np
from .object import _Object
from .object import Object
from .. import constants as c
from ...utils.results import ActionResult
from ...utils.utility_classes import RenderEntity
class Entity(_Object, abc.ABC):
class Entity(Object, abc.ABC):
"""Full Env Entity that lives on the environment Grid. Doors, Items, DirtPile etc..."""
@property
@ -96,8 +95,9 @@ class Entity(_Object, abc.ABC):
def __init__(self, pos, bind_to=None, **kwargs):
super().__init__(**kwargs)
self._view_directory = c.VALUE_NO_POS
self._status = None
self.set_pos(pos)
self._pos = pos
self._last_pos = pos
if bind_to:
try:
@ -113,10 +113,6 @@ class Entity(_Object, abc.ABC):
def render(self):
return RenderEntity(self.__class__.__name__.lower(), self.pos)
@abc.abstractmethod
def render(self):
return RenderEntity(self.__class__.__name__.lower(), self.pos)
@property
def obs_tag(self):
try:
@ -133,25 +129,3 @@ class Entity(_Object, abc.ABC):
self._collection.delete_env_object(self)
self._collection = other_collection
return self._collection == other_collection
@classmethod
def from_coordinates(cls, positions: [(int, int)], *args, entity_kwargs=None, **kwargs, ):
collection = cls(*args, **kwargs)
collection.add_items(
[cls._entity(tuple(pos), **entity_kwargs if entity_kwargs is not None else {}) for pos in positions])
return collection
def notify_del_entity(self, entity):
try:
self.pos_dict[entity.pos].remove(entity)
except (ValueError, AttributeError):
pass
def by_pos(self, pos: (int, int)):
pos = tuple(pos)
try:
return self.state.entities.pos_dict[pos]
except StopIteration:
pass
except ValueError:
pass

View File

@ -5,7 +5,7 @@ from marl_factory_grid.environment import constants as c
import marl_factory_grid.utils.helpers as h
class _Object:
class Object:
"""Generell Objects for Organisation and Maintanance such as Actions etc..."""
_u_idx = defaultdict(lambda: 0)
@ -50,15 +50,15 @@ class _Object:
print(f'Following kwargs were passed, but ignored: {kwargs}')
def __repr__(self):
name = self.name
if self.bound_entity:
name = h.add_bound_name(name, self.bound_entity)
try:
if self.var_has_position:
name = h.add_pos_name(name, self)
except (AttributeError):
pass
return name
name = self.name
if self.bound_entity:
name = h.add_bound_name(name, self.bound_entity)
try:
if self.var_has_position:
name = h.add_pos_name(name, self)
except AttributeError:
pass
return name
def __eq__(self, other) -> bool:
return other == self.identifier
@ -67,8 +67,8 @@ class _Object:
return hash(self.identifier)
def _identify_and_count_up(self):
idx = _Object._u_idx[self.__class__.__name__]
_Object._u_idx[self.__class__.__name__] += 1
idx = Object._u_idx[self.__class__.__name__]
Object._u_idx[self.__class__.__name__] += 1
return idx
def set_collection(self, collection):
@ -98,79 +98,3 @@ class _Object:
def unbind(self):
self._bound_entity = None
# class EnvObject(_Object):
# """Objects that hold Information that are observable, but have no position on the environment grid. Inventories etc..."""
#
# _u_idx = defaultdict(lambda: 0)
#
# @property
# def obs_tag(self):
# try:
# return self._collection.name or self.name
# except AttributeError:
# return self.name
#
# @property
# def var_is_blocking_light(self):
# try:
# return self._collection.var_is_blocking_light or False
# except AttributeError:
# return False
#
# @property
# def var_can_be_bound(self):
# try:
# return self._collection.var_can_be_bound or False
# except AttributeError:
# return False
#
# @property
# def var_can_move(self):
# try:
# return self._collection.var_can_move or False
# except AttributeError:
# return False
#
# @property
# def var_is_blocking_pos(self):
# try:
# return self._collection.var_is_blocking_pos or False
# except AttributeError:
# return False
#
# @property
# def var_has_position(self):
# try:
# return self._collection.var_has_position or False
# except AttributeError:
# return False
#
# @property
# def var_can_collide(self):
# try:
# return self._collection.var_can_collide or False
# except AttributeError:
# return False
#
#
# @property
# def encoding(self):
# return c.VALUE_OCCUPIED_CELL
#
#
# def __init__(self, **kwargs):
# self._bound_entity = None
# super(EnvObject, self).__init__(**kwargs)
#
#
# def change_parent_collection(self, other_collection):
# other_collection.add_item(self)
# self._collection.delete_env_object(self)
# self._collection = other_collection
# return self._collection == other_collection
#
#
# def summarize_state(self):
# return dict(name=str(self.name))

View File

@ -1,6 +1,6 @@
import numpy as np
from marl_factory_grid.environment.entity.object import _Object
from marl_factory_grid.environment.entity.object import Object
##########################################################################
@ -8,7 +8,7 @@ from marl_factory_grid.environment.entity.object import _Object
##########################################################################
class PlaceHolder(_Object):
class PlaceHolder(Object):
def __init__(self, *args, fill_value=0, **kwargs):
super().__init__(*args, **kwargs)
@ -27,7 +27,7 @@ class PlaceHolder(_Object):
return self.__class__.__name__
class GlobalPosition(_Object):
class GlobalPosition(Object):
@property
def encoding(self):

View File

@ -56,15 +56,18 @@ class Factory(gym.Env):
self.level_filepath = Path(custom_level_path)
else:
self.level_filepath = Path(__file__).parent.parent / h.LEVELS_DIR / f'{self.conf.level_name}.txt'
self._renderer = None # expensive - don't use; unless required !
parsed_entities = self.conf.load_entities()
self.map = LevelParser(self.level_filepath, parsed_entities, self.conf.pomdp_r)
# Init for later usage:
self.state: Gamestate
self.map: LevelParser
self.obs_builder: OBSBuilder
# noinspection PyTypeChecker
self.state: Gamestate = None
# noinspection PyTypeChecker
self.obs_builder: OBSBuilder = None
# expensive - don't use; unless required !
self._renderer = None
# reset env to initial state, preparing env for new episode.
# returns tuple where the first dict contains initial observation for each agent in the env
@ -74,7 +77,7 @@ class Factory(gym.Env):
return self.state.entities[item]
def reset(self) -> (dict, dict):
if hasattr(self, 'state'):
if self.state is not None:
for entity_group in self.state.entities:
try:
entity_group[0].reset_uid()
@ -160,7 +163,7 @@ class Factory(gym.Env):
# Finalize
reward, reward_info, done = self.summarize_step_results(tick_result, done_results)
info = reward_info
info = dict(reward_info)
info.update(step_reward=sum(reward), step=self.state.curr_step)

View File

@ -1,15 +1,15 @@
from typing import List, Tuple, Union, Dict
from marl_factory_grid.environment.entity.entity import Entity
from marl_factory_grid.environment.groups.objects import _Objects
from marl_factory_grid.environment.groups.objects import Objects
# noinspection PyProtectedMember
from marl_factory_grid.environment.entity.object import _Object
from marl_factory_grid.environment.entity.object import Object
import marl_factory_grid.environment.constants as c
from marl_factory_grid.utils.results import Result
class Collection(_Objects):
_entity = _Object # entity?
class Collection(Objects):
_entity = Object # entity?
symbol = None
@property
@ -58,7 +58,7 @@ class Collection(_Objects):
def trigger_spawn(self, state, *entity_args, coords_or_quantity=None, ignore_blocking=False, **entity_kwargs):
coords_or_quantity = coords_or_quantity if coords_or_quantity else self._coords_or_quantity
if self.var_has_position:
if isinstance(coords_or_quantity, int):
if self.var_has_position and isinstance(coords_or_quantity, int):
if ignore_blocking or self._ignore_blocking:
coords_or_quantity = state.entities.floorlist[:coords_or_quantity]
else:
@ -87,8 +87,8 @@ class Collection(_Objects):
raise ValueError(f'{self._entity.__name__} has no position!')
return c.VALID
def despawn(self, items: List[_Object]):
items = [items] if isinstance(items, _Object) else items
def despawn(self, items: List[Object]):
items = [items] if isinstance(items, Object) else items
for item in items:
del self[item]

View File

@ -3,12 +3,12 @@ from operator import itemgetter
from random import shuffle
from typing import Dict
from marl_factory_grid.environment.groups.objects import _Objects
from marl_factory_grid.environment.groups.objects import Objects
from marl_factory_grid.utils.helpers import POS_MASK
class Entities(_Objects):
_entity = _Objects
class Entities(Objects):
_entity = Objects
@staticmethod
def neighboring_positions(pos):
@ -87,7 +87,7 @@ class Entities(_Objects):
def __delitem__(self, name):
assert_str = 'This group of entity does not exist in this collection!'
assert any([key for key in name.keys() if key in self.keys()]), assert_str
self[name]._observers.delete(self)
self[name].del_observer(self)
for entity in self[name]:
entity.del_observer(self)
return super(Entities, self).__delitem__(name)

View File

@ -1,15 +1,15 @@
from collections import defaultdict
from typing import List
from typing import List, Iterator, Union
import numpy as np
from marl_factory_grid.environment.entity.object import _Object
from marl_factory_grid.environment.entity.object import Object
import marl_factory_grid.environment.constants as c
from marl_factory_grid.utils import helpers as h
class _Objects:
_entity = _Object
class Objects:
_entity = Object
@property
def var_can_be_bound(self):
@ -50,7 +50,7 @@ class _Objects:
def __len__(self):
return len(self._data)
def __iter__(self):
def __iter__(self) -> Iterator[Union[Object, None]]:
return iter(self.values())
def add_item(self, item: _entity):
@ -130,13 +130,14 @@ class _Objects:
repr_dict = {key: val for key, val in self._data.items() if key not in [c.WALLS]}
return f'{self.__class__.__name__}[{repr_dict}]'
def notify_del_entity(self, entity: _Object):
def notify_del_entity(self, entity: Object):
try:
# noinspection PyUnresolvedReferences
self.pos_dict[entity.pos].remove(entity)
except (AttributeError, ValueError, IndexError):
pass
def notify_add_entity(self, entity: _Object):
def notify_add_entity(self, entity: Object):
try:
if self not in entity.observers:
entity.add_observer(self)

View File

@ -1,11 +1,11 @@
import abc
from random import shuffle
from typing import List, Collection, Union
from typing import List, Collection
from marl_factory_grid.environment import rewards as r, constants as c
from marl_factory_grid.environment.entity.agent import Agent
from marl_factory_grid.utils import helpers as h
from marl_factory_grid.utils.results import TickResult, DoneResult
from marl_factory_grid.environment import rewards as r, constants as c
class Rule(abc.ABC):
@ -118,8 +118,7 @@ class AssignGlobalPositions(Rule):
def on_init(self, state, lvl_map):
from marl_factory_grid.environment.entity.util import GlobalPosition
for agent in state[c.AGENT]:
gp = GlobalPosition(lvl_map.level_shape)
gp.bind_to(agent)
gp = GlobalPosition(agent, lvl_map.level_shape)
state[c.GLOBALPOSITIONS].add_item(gp)
return []

View File

@ -6,7 +6,9 @@ from marl_factory_grid.utils.results import TickResult, DoneResult
class TemplateRule(Rule):
def __init__(self, *args, **kwargs):
super(TemplateRule, self).__init__(*args, **kwargs)
super(TemplateRule, self).__init__()
self.args = args
self.kwargs = kwargs
def on_init(self, state, lvl_map):
pass

View File

@ -1,6 +1,5 @@
from typing import Union
import marl_factory_grid.modules.batteries.constants
from marl_factory_grid.environment.actions import Action
from marl_factory_grid.utils.results import ActionResult
@ -24,5 +23,6 @@ class BtryCharge(Action):
else:
valid = c.NOT_VALID
state.print(f'{entity.name} failed to charged batteries at {entity.pos}.')
return ActionResult(entity=entity, identifier=self._identifier, validity=valid,
reward=marl_factory_grid.modules.batteries.constants.REWARD_CHARGE_VALID if valid else marl_factory_grid.modules.batteries.constants.Reward_CHARGE_FAIL)
reward=b.REWARD_CHARGE_VALID if valid else b.Reward_CHARGE_FAIL)

View File

@ -1,11 +1,11 @@
from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.entity.entity import Entity
from marl_factory_grid.environment.entity.object import _Object
from marl_factory_grid.environment.entity.object import Object
from marl_factory_grid.modules.batteries import constants as b
from marl_factory_grid.utils.utility_classes import RenderEntity
class Battery(_Object):
class Battery(Object):
@property
def var_can_be_bound(self):

View File

@ -1,11 +1,9 @@
from typing import List, Union
import marl_factory_grid.modules.batteries.constants
from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.utils.results import TickResult, DoneResult
from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.modules.batteries import constants as b
from marl_factory_grid.utils.results import TickResult, DoneResult
class BatteryDecharge(Rule):

View File

@ -1,5 +1,3 @@
from numpy import random
from marl_factory_grid.environment.entity.entity import Entity
from marl_factory_grid.utils.utility_classes import RenderEntity
from marl_factory_grid.modules.clean_up import constants as d

View File

@ -1,9 +1,7 @@
from typing import Union, List, Tuple
from marl_factory_grid.environment import constants as c
from marl_factory_grid.utils.results import Result
from marl_factory_grid.environment.groups.collection import Collection
from marl_factory_grid.modules.clean_up.entitites import DirtPile
from marl_factory_grid.utils.results import Result
class DirtPiles(Collection):

View File

@ -49,7 +49,7 @@ class RespawnDirt(Rule):
def tick_step(self, state):
collection = state[d.DIRT]
if self._next_dirt_spawn < 0:
pass # No DirtPile Spawn
result = [] # No DirtPile Spawn
elif not self._next_dirt_spawn:
result = [collection.trigger_spawn(state, coords_or_quantity=self.respawn_n, amount=self.respawn_amount)]
self._next_dirt_spawn = self.respawn_freq

View File

@ -21,4 +21,4 @@ class DestAction(Action):
valid = c.NOT_VALID
state.print(f'{entity.name} just tried to do_wait_action do_wait_action at {entity.pos} but failed')
return ActionResult(entity=entity, identifier=self._identifier, validity=valid,
reward=marl_factory_grid.modules.destinations.constants.REWARD_WAIT_VALID if valid else marl_factory_grid.modules.destinations.constants.REWARD_WAIT_FAIL)
reward=d.REWARD_WAIT_VALID if valid else d.REWARD_WAIT_FAIL)

View File

@ -1,7 +1,5 @@
from marl_factory_grid.environment.groups.collection import Collection
from marl_factory_grid.modules.destinations.entitites import Destination
from marl_factory_grid.environment import constants as c
from marl_factory_grid.modules.destinations import constants as d
class Destinations(Collection):

View File

@ -1,5 +1,3 @@
from typing import Union
from marl_factory_grid.environment.groups.collection import Collection
from marl_factory_grid.modules.doors import constants as d
from marl_factory_grid.modules.doors.entitites import Door

View File

@ -1,2 +1,2 @@
USE_DOOR_VALID: float = -0.00
USE_DOOR_FAIL: float = -0.01
USE_DOOR_FAIL: float = -0.01

View File

@ -1,8 +1,8 @@
import random
from typing import List, Union
from typing import List
from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.utils.results import TickResult
@ -14,8 +14,6 @@ class AgentSingleZonePlacementBeta(Rule):
super().__init__()
def on_init(self, state, lvl_map):
zones = state[c.ZONES]
n_zones = state[c.ZONES]
agents = state[c.AGENT]
if len(self.coordinates) == len(agents):
coordinates = self.coordinates
@ -31,4 +29,4 @@ class AgentSingleZonePlacementBeta(Rule):
return []
def tick_post_step(self, state) -> List[TickResult]:
return []
return []

View File

@ -1,6 +1,3 @@
from typing import NamedTuple
SYMBOL_NO_ITEM = 0
SYMBOL_DROP_OFF = 1
# Item Env

View File

@ -14,27 +14,14 @@ class Item(Entity):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@property
def auto_despawn(self):
return self._auto_despawn
@property
def encoding(self):
# Edit this if you want items to be drawn in the ops differently
return 1
def set_auto_despawn(self, auto_despawn):
self._auto_despawn = auto_despawn
def summarize_state(self) -> dict:
super_summarization = super(Item, self).summarize_state()
super_summarization.update(dict(auto_despawn=self.auto_despawn))
return super_summarization
class DropOffLocation(Entity):
def render(self):
return RenderEntity(i.DROP_OFF, self.pos)
@ -42,18 +29,16 @@ class DropOffLocation(Entity):
def encoding(self):
return i.SYMBOL_DROP_OFF
def __init__(self, *args, storage_size_until_full: int = 5, auto_item_despawn_interval: int = 5, **kwargs):
def __init__(self, *args, storage_size_until_full: int = 5, **kwargs):
super(DropOffLocation, self).__init__(*args, **kwargs)
self.auto_item_despawn_interval = auto_item_despawn_interval
self.storage = deque(maxlen=storage_size_until_full or None)
def place_item(self, item: Item):
if self.is_full:
raise RuntimeWarning("There is currently no way to clear the storage or make it unfull.")
return bc.NOT_VALID # in Zeile 81 verschieben?
return bc.NOT_VALID
else:
self.storage.append(item)
item.set_auto_despawn(self.auto_item_despawn_interval)
return c.VALID
@property

View File

@ -1,12 +1,9 @@
from random import shuffle
from marl_factory_grid.modules.items import constants as i
from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.groups.collection import Collection
from marl_factory_grid.environment.groups.objects import _Objects
from marl_factory_grid.environment.groups.mixins import IsBoundMixin
from marl_factory_grid.environment.entity.agent import Agent
from marl_factory_grid.environment.groups.collection import Collection
from marl_factory_grid.environment.groups.mixins import IsBoundMixin
from marl_factory_grid.environment.groups.objects import Objects
from marl_factory_grid.modules.items import constants as i
from marl_factory_grid.modules.items.entitites import Item, DropOffLocation
from marl_factory_grid.utils.results import Result
@ -74,13 +71,12 @@ class Inventory(IsBoundMixin, Collection):
self._collection = collection
class Inventories(_Objects):
class Inventories(Objects):
_entity = Inventory
var_can_move = False
var_has_position = False
symbol = None
@property
@ -116,7 +112,6 @@ class Inventories(_Objects):
return [val.summarize_states(**kwargs) for key, val in self.items()]
class DropOffLocations(Collection):
_entity = DropOffLocation

View File

@ -1,4 +1,4 @@
DROP_OFF_VALID: float = 0.1
DROP_OFF_FAIL: float = -0.1
PICK_UP_FAIL: float = -0.1
PICK_UP_VALID: float = 0.1
PICK_UP_VALID: float = 0.1

View File

@ -1,9 +1,10 @@
from typing import Union
import marl_factory_grid.modules.machines.constants
from marl_factory_grid.environment.actions import Action
from marl_factory_grid.utils.results import ActionResult
from marl_factory_grid.modules.machines import constants as m, rewards as r
from marl_factory_grid.modules.machines import constants as m
from marl_factory_grid.environment import constants as c
from marl_factory_grid.utils import helpers as h
@ -16,8 +17,10 @@ class MachineAction(Action):
def do(self, entity, state) -> Union[None, ActionResult]:
if machine := h.get_first(state[m.MACHINES].by_pos(entity.pos)):
if valid := machine.maintain():
return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=r.MAINTAIN_VALID)
return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=marl_factory_grid.modules.machines.constants.MAINTAIN_VALID)
else:
return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=r.MAINTAIN_FAIL)
return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=marl_factory_grid.modules.machines.constants.MAINTAIN_FAIL)
else:
return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=r.MAINTAIN_FAIL)
return ActionResult(entity=entity, identifier=self._identifier,
validity=c.NOT_VALID, reward=marl_factory_grid.modules.machines.constants.MAINTAIN_FAIL
)

View File

@ -11,3 +11,7 @@ STATE_MAINTAIN = 'maintenance'
SYMBOL_WORK = 1
SYMBOL_IDLE = 0.6
SYMBOL_MAINTAIN = 0.3
MAINTAIN_VALID: float = 0.5
MAINTAIN_FAIL: float = -0.1
FAIL_MISSING_MAINTENANCE: float = -0.5
NONE: float = 0

View File

@ -31,11 +31,10 @@ class Machine(Entity):
return c.NOT_VALID
def tick(self, state):
# if self.status == m.STATE_MAINTAIN and any([c.AGENT in x.name for x in self.tile.guests]):
if self.status == m.STATE_MAINTAIN and any([c.AGENT in x.name for x in state.entities.pos_dict[self.pos]]):
others = state.entities.pos_dict[self.pos]
if self.status == m.STATE_MAINTAIN and any([c.AGENT in x.name for x in others]):
return TickResult(identifier=self.name, validity=c.VALID, entity=self)
# elif self.status == m.STATE_MAINTAIN and not any([c.AGENT in x.name for x in self.tile.guests]):
elif self.status == m.STATE_MAINTAIN and not any([c.AGENT in x.name for x in state.entities.pos_dict[self.pos]]):
elif self.status == m.STATE_MAINTAIN and not any([c.AGENT in x.name for x in others]):
self.status = m.STATE_WORK
self.reset_counter()
return None

View File

@ -1,5 +1,3 @@
from typing import Union, List, Tuple
from marl_factory_grid.environment.groups.collection import Collection
from .entitites import Machine

View File

@ -1,5 +0,0 @@
MAINTAIN_VALID: float = 0.5
MAINTAIN_FAIL: float = -0.1
FAIL_MISSING_MAINTENANCE: float = -0.5
NONE: float = 0

View File

@ -1,3 +1,4 @@
MAINTAINER = 'Maintainer' # TEMPLATE _identifier. Define your own!
MAINTAINERS = 'Maintainers' # TEMPLATE _identifier. Define your own!
MAINTAINER_COLLISION_REWARD = -5

View File

@ -4,7 +4,6 @@ from marl_factory_grid.environment.groups.collection import Collection
from .entities import Maintainer
from ..machines import constants as mc
from ..machines.actions import MachineAction
from ...utils.states import Gamestate
class Maintainers(Collection):
@ -23,8 +22,6 @@ class Maintainers(Collection):
self.size = size
self._spawnrule = spawnrule
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args):
self.add_items([self._entity(mc.MACHINES, MachineAction(), pos) for pos in coords_or_quantity])

View File

@ -1 +0,0 @@
MAINTAINER_COLLISION_REWARD = -5

View File

@ -1,15 +1,16 @@
from typing import List
import marl_factory_grid.modules.maintenance.constants
from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.utils.results import TickResult, DoneResult
from marl_factory_grid.environment import constants as c
from . import rewards as r
from . import constants as M
class MoveMaintainers(Rule):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __init__(self):
super().__init__()
def tick_step(self, state) -> List[TickResult]:
for maintainer in state[M.MAINTAINERS]:
@ -20,8 +21,8 @@ class MoveMaintainers(Rule):
class DoneAtMaintainerCollision(Rule):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __init__(self):
super().__init__()
def on_check_done(self, state) -> List[DoneResult]:
agents = list(state[c.AGENT].values())
@ -30,5 +31,5 @@ class DoneAtMaintainerCollision(Rule):
for agent in agents:
if agent.pos in m_pos:
done_results.append(DoneResult(entity=agent, validity=c.VALID, identifier=self.name,
reward=r.MAINTAINER_COLLISION_REWARD))
reward=marl_factory_grid.modules.maintenance.constants.MAINTAINER_COLLISION_REWARD))
return done_results

View File

@ -1,10 +1,10 @@
import random
from typing import List, Tuple
from marl_factory_grid.environment.entity.object import _Object
from marl_factory_grid.environment.entity.object import Object
class Zone(_Object):
class Zone(Object):
@property
def positions(self):

View File

@ -1,8 +1,8 @@
from marl_factory_grid.environment.groups.objects import _Objects
from marl_factory_grid.environment.groups.objects import Objects
from marl_factory_grid.modules.zones import Zone
class Zones(_Objects):
class Zones(Objects):
symbol = None
_entity = Zone

View File

@ -58,7 +58,10 @@ class FactoryConfigParser(object):
return str(self.config)
def __getitem__(self, item):
return self.config[item]
try:
return self.config[item]
except KeyError:
print(f'The mandatory {item} section could not be found in your .config gile. Check Spelling!')
def load_entities(self):
entity_classes = dict()
@ -161,7 +164,6 @@ class FactoryConfigParser(object):
def _load_smth(self, config, class_obj):
rules = list()
rules_names = list()
for rule in config:
e1 = e2 = e3 = None
try:

View File

@ -61,8 +61,8 @@ class ObservationTranslator:
:param per_agent_named_obs_spaces: `Named observation space` one for each agent. Overloaded.
type per_agent_named_obs_spaces: Dict[str, dict]
:param placeholder_fill_value: Currently not fully implemented!!!
:type placeholder_fill_value: Union[int, str] = 'N')
:param placeholder_fill_value: Currently, not fully implemented!!!
:type placeholder_fill_value: Union[int, str] = 'N'
"""
if isinstance(placeholder_fill_value, str):

View File

@ -9,7 +9,7 @@ from marl_factory_grid.utils.helpers import IGNORED_DF_COLUMNS
import pandas as pd
from marl_factory_grid.utils.plotting.compare_runs import plot_single_run
from marl_factory_grid.utils.plotting.plot_single_runs import plot_single_run
class EnvMonitor(Wrapper):
@ -22,7 +22,6 @@ class EnvMonitor(Wrapper):
self._monitor_df = pd.DataFrame()
self._monitor_dict = dict()
def step(self, action):
obs_type, obs, reward, done, info = self.env.step(action)
self._read_info(info)

View File

@ -2,11 +2,9 @@ from os import PathLike
from pathlib import Path
from typing import Union, List
import yaml
from gymnasium import Wrapper
import numpy as np
import pandas as pd
from gymnasium import Wrapper
class EnvRecorder(Wrapper):
@ -106,7 +104,7 @@ class EnvRecorder(Wrapper):
out_dict = {'episodes': self._recorder_out_list}
out_dict.update(
{'n_episodes': self._curr_episode,
'metadata':dict(
'metadata': dict(
level_name=self.env.params['General']['level_name'],
verbose=False,
n_agents=len(self.env.params['Agents']),

View File

@ -5,7 +5,7 @@ from typing import Dict, List
import numpy as np
from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.entity.object import _Object
from marl_factory_grid.environment.entity.object import Object
from marl_factory_grid.environment.groups.utils import Combined
from marl_factory_grid.utils.utility_classes import Floor
from marl_factory_grid.utils.ray_caster import RayCaster
@ -13,7 +13,6 @@ from marl_factory_grid.utils.states import Gamestate
from marl_factory_grid.utils import helpers as h
class OBSBuilder(object):
default_obs = [c.WALLS, c.OTHERS]
@ -128,7 +127,7 @@ class OBSBuilder(object):
f'{re.escape("[")}(.*){re.escape("]")}'
f'{re.escape("(")}{re.escape(agent.name)}{re.escape(")")}')
name = next((key for key, val in self.all_obs.items()
if pattern.search(str(val)) and isinstance(val, _Object)), None)
if pattern.search(str(val)) and isinstance(val, Object)), None)
e = self.all_obs[name]
except KeyError:
try:
@ -181,11 +180,11 @@ class OBSBuilder(object):
return obs, self.obs_layers[agent.name]
def _sort_and_name_observation_conf(self, agent):
'''
"""
Builds the useable observation scheme per agent from conf.yaml.
:param agent:
:return:
'''
"""
# Fixme: no asymetric shapes possible.
self.ray_caster[agent.name] = RayCaster(agent, min(self.obs_shape))
obs_layers = []

View File

@ -7,50 +7,11 @@ from typing import Union, List
import pandas as pd
from marl_factory_grid.utils.helpers import IGNORED_DF_COLUMNS
from marl_factory_grid.utils.plotting.plotting import prepare_plot
from marl_factory_grid.utils.plotting.plotting_utils import prepare_plot
MODEL_MAP = None
def plot_single_run(run_path: Union[str, PathLike], use_tex: bool = False, column_keys=None):
run_path = Path(run_path)
df_list = list()
if run_path.is_dir():
monitor_file = next(run_path.glob('*monitor*.pick'))
elif run_path.exists() and run_path.is_file():
monitor_file = run_path
else:
raise ValueError
with monitor_file.open('rb') as f:
monitor_df = pickle.load(f)
monitor_df = monitor_df.fillna(0)
df_list.append(monitor_df)
df = pd.concat(df_list, ignore_index=True)
df = df.fillna(0).rename(columns={'episode': 'Episode'}).sort_values(['Episode'])
if column_keys is not None:
columns = [col for col in column_keys if col in df.columns]
else:
columns = [col for col in df.columns if col not in IGNORED_DF_COLUMNS]
roll_n = 50
non_overlapp_window = df.groupby(['Episode']).rolling(roll_n, min_periods=1).mean()
df_melted = df[columns + ['Episode']].reset_index().melt(
id_vars=['Episode'], value_vars=columns, var_name="Measurement", value_name="Score"
)
if df_melted['Episode'].max() > 800:
skip_n = round(df_melted['Episode'].max() * 0.02)
df_melted = df_melted[df_melted['Episode'] % skip_n == 0]
prepare_plot(run_path.parent / f'{run_path.parent.name}_monitor_lineplot.png', df_melted, use_tex=use_tex)
print('Plotting done.')
def compare_seed_runs(run_path: Union[str, PathLike], use_tex: bool = False):
run_path = Path(run_path)
df_list = list()

View File

@ -0,0 +1,48 @@
import pickle
from os import PathLike
from pathlib import Path
from typing import Union
import pandas as pd
from marl_factory_grid.utils.helpers import IGNORED_DF_COLUMNS
from marl_factory_grid.utils.plotting.plotting_utils import prepare_plot
def plot_single_run(run_path: Union[str, PathLike], use_tex: bool = False, column_keys=None,
file_key: str ='monitor', file_ext: str ='pkl'):
run_path = Path(run_path)
df_list = list()
if run_path.is_dir():
monitor_file = next(run_path.glob(f'*{file_key}*.{file_ext}'))
elif run_path.exists() and run_path.is_file():
monitor_file = run_path
else:
raise ValueError
with monitor_file.open('rb') as f:
monitor_df = pickle.load(f)
monitor_df = monitor_df.fillna(0)
df_list.append(monitor_df)
df = pd.concat(df_list, ignore_index=True)
df = df.fillna(0).rename(columns={'episode': 'Episode'}).sort_values(['Episode'])
if column_keys is not None:
columns = [col for col in column_keys if col in df.columns]
else:
columns = [col for col in df.columns if col not in IGNORED_DF_COLUMNS]
# roll_n = 50
# non_overlapp_window = df.groupby(['Episode']).rolling(roll_n, min_periods=1).mean()
df_melted = df[columns + ['Episode']].reset_index().melt(
id_vars=['Episode'], value_vars=columns, var_name="Measurement", value_name="Score"
)
if df_melted['Episode'].max() > 800:
skip_n = round(df_melted['Episode'].max() * 0.02)
df_melted = df_melted[df_melted['Episode'] % skip_n == 0]
prepare_plot(run_path.parent / f'{run_path.parent.name}_monitor_lineplot.png', df_melted, use_tex=use_tex)
print('Plotting done.')

View File

@ -60,7 +60,7 @@ def prepare_center_double_column_legend(df, hue, style, hue_order):
print('Struggling to plot Figure using LaTeX - going back to normal.')
plt.close('all')
sns.set(rc={'text.usetex': False}, style='whitegrid')
fig = plt.figure(figsize=(10, 11))
_ = plt.figure(figsize=(10, 11))
lineplot = sns.lineplot(data=df, x='Episode', y='Score', hue=hue, style=style,
ci=95, palette=PALETTE, hue_order=hue_order, legend=False)
# plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0)

View File

@ -19,7 +19,7 @@ class RayCaster:
return f'{self.__class__.__name__}({self.agent.name})'
def build_ray_targets(self):
north = np.array([0, -1])*self.pomdp_r
north = np.array([0, -1]) * self.pomdp_r
thetas = [np.deg2rad(deg) for deg in np.linspace(-self.degs // 2, self.degs // 2, self.n_rays)[::-1]]
rot_M = [
[[math.cos(theta), -math.sin(theta)],
@ -53,9 +53,9 @@ class RayCaster:
diag_hits = all([
self.ray_block_cache(
key,
lambda: any(True for e in pos_dict[key] if e.var_is_blocking_light))
# lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light))
for key in ((x, y-cy), (x-cx, y))
# lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light)
lambda: any(True for e in pos_dict[key] if e.var_is_blocking_light))
for key in ((x, y - cy), (x - cx, y))
]) if (cx != 0 and cy != 0) else False
visible += entities_hit if not diag_hits else []
@ -77,8 +77,8 @@ class RayCaster:
agent = self.agent
x_coords = range(agent.x - self.pomdp_r, agent.x + self.pomdp_r + 1)
y_coords = range(agent.y - self.pomdp_r, agent.y + self.pomdp_r + 1)
outline = list(product(x_coords, [agent.y - self.pomdp_r, agent.y + self.pomdp_r])) \
+ list(product([agent.x - self.pomdp_r, agent.x + self.pomdp_r], y_coords))
outline = list(product(x_coords, [agent.y - self.pomdp_r, agent.y + self.pomdp_r]))
outline += list(product([agent.x - self.pomdp_r, agent.x + self.pomdp_r], y_coords))
return outline
@staticmethod

View File

@ -1,9 +1,12 @@
from typing import Union
from dataclasses import dataclass
from marl_factory_grid.environment.entity.object import Object
TYPE_VALUE = 'value'
TYPE_REWARD = 'reward'
types = [TYPE_VALUE, TYPE_REWARD]
TYPES = [TYPE_VALUE, TYPE_REWARD]
@dataclass
class InfoObject:
@ -18,12 +21,13 @@ class Result:
validity: bool
reward: Union[float, None] = None
value: Union[float, None] = None
entity: None = None
entity: Object = None
def get_infos(self):
n = self.entity.name if self.entity is not None else "Global"
return [InfoObject(identifier=f'{n}_{self.identifier}_{t}',
val_type=t, value=self.__getattribute__(t)) for t in types
# Return multiple Info Dicts
return [InfoObject(identifier=f'{n}_{self.identifier}',
val_type=t, value=self.__getattribute__(t)) for t in TYPES
if self.__getattribute__(t) is not None]
def __repr__(self):
@ -31,7 +35,7 @@ class Result:
reward = f" | Reward: {self.reward}" if self.reward is not None else ""
value = f" | Value: {self.value}" if self.value is not None else ""
entity = f" | by: {self.entity.name}" if self.entity is not None else ""
return f'{self.__class__.__name__}({self.identifier.capitalize()} {valid}valid{reward}{value})'
return f'{self.__class__.__name__}({self.identifier.capitalize()} {valid}valid{reward}{value}{entity})'
@dataclass

View File

@ -1,11 +1,12 @@
from itertools import islice
from typing import List, Dict, Tuple
from typing import List, Tuple
import numpy as np
from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.entity.entity import Entity
from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.utils.results import Result
from marl_factory_grid.utils.results import Result, DoneResult
class StepRules:
@ -83,13 +84,51 @@ class Gamestate(object):
return f'{self.__class__.__name__}({len(self.entities)} Entitites @ Step {self.curr_step})'
@property
def random_free_position(self):
def random_free_position(self) -> (int, int):
"""
Returns a single **free** position (x, y), which is **free** for spawning or walking.
No Entity at this position posses *var_is_blocking_pos* or *var_can_collide*.
:return: Single **free** position.
"""
return self.get_n_random_free_positions(1)[0]
def get_n_random_free_positions(self, n):
def get_n_random_free_positions(self, n) -> list[tuple[int, int]]:
"""
Returns a list of *n* **free** positions [(x, y), ... ], which are **free** for spawning or walking.
No Entity at this position posses *var_is_blocking_pos* or *var_can_collide*.
:return: List of n **free** position.
"""
return list(islice(self.entities.free_positions_generator, n))
def tick(self, actions) -> List[Result]:
@property
def random_position(self) -> (int, int):
"""
Returns a single available position (x, y), ignores all entity attributes.
:return: Single random position.
"""
return self.get_n_random_positions(1)[0]
def get_n_random_positions(self, n) -> list[tuple[int, int]]:
"""
Returns a list of *n* available positions [(x, y), ... ], ignores all entity attributes.
:return: List of n random positions.
"""
return list(islice(self.entities.floorlist, n))
def tick(self, actions) -> list[Result]:
"""
Performs a single **Gamestate Tick**by calling the inner rule hooks in sequential order.
- tick_pre_step_all: Things to do before the agents do their actions. Statechange, Moving, Spawning etc...
- agent tick: Agents do their actions.
- tick_step_all: Things to do after the agents did their actions. Statechange, Moving, Spawning etc...
- tick_post_step_all: Things to do at the very end of each step. Counting, Reward calculations etc...
:return: List of *Result*-objects.
"""
results = list()
self.curr_step += 1
@ -112,11 +151,23 @@ class Gamestate(object):
return results
def print(self, string):
def print(self, string) -> None:
"""
When *verbose* is active, print stuff.
:param string: *String* to print.
:type string: str
:return: Nothing
"""
if self.verbose:
print(string)
def check_done(self):
def check_done(self) -> List[DoneResult]:
"""
Iterate all **Rules** that override tehe *on_ckeck_done* hook.
:return: List of Results
"""
results = list()
for rule in self.rules:
if on_check_done_result := rule.on_check_done(self):
@ -124,20 +175,44 @@ class Gamestate(object):
return results
def get_all_pos_with_collisions(self) -> List[Tuple[(int, int)]]:
positions = [pos for pos, entities in self.entities.pos_dict.items() if len(entities) >= 2 and (len([e for e in entities if e.var_can_collide]) >= 2)]
"""
Returns a list positions [(x, y), ... ] on which collisions occur. This does not include agents,
that were unable to move because their target direction was blocked, also a form of collision.
:return: List of positions.
"""
positions = [pos for pos, entities in self.entities.pos_dict.items() if
len(entities) >= 2 and (len([e for e in entities if e.var_can_collide]) >= 2)
]
return positions
def check_move_validity(self, moving_entity, position):
if moving_entity.pos != position and not any(
entity.var_is_blocking_pos for entity in self.entities.pos_dict[position]) and not (
moving_entity.var_is_blocking_pos and self.entities.is_occupied(position)):
return True
else:
return False
def check_move_validity(self, moving_entity: Entity, target_position: (int, int)) -> bool:
"""
Whether it is safe to move to the target positions and moving entity does not introduce a blocking attribute,
when position is allready occupied.
def check_pos_validity(self, position):
if not any(entity.var_is_blocking_pos for entity in self.entities.pos_dict[position]):
return True
else:
return False
:param moving_entity: Entity
:param target_position: pos
:return: Safe to move to
"""
is_not_blocked = self.check_pos_validity(target_position)
will_not_block_others = moving_entity.var_is_blocking_pos and self.entities.is_occupied(target_position)
if moving_entity.pos != target_position and is_not_blocked and not will_not_block_others:
return c.VALID
else:
return c.NOT_VALID
def check_pos_validity(self, pos: (int, int)) -> bool:
"""
Check if *pos* is a valid position to move or spawn to.
:param pos: position to check
:return: Wheter pos is a valid target.
"""
if not any(e.var_is_blocking_pos for e in self.entities.pos_dict[pos]) and pos in self.entities.floorlist:
return c.VALID
else:
return c.NOT_VALID

View File

@ -28,7 +28,9 @@ class ConfigExplainer:
def explain_module(self, class_to_explain):
parameters = inspect.signature(class_to_explain).parameters
explained = {class_to_explain.__name__: {key: val.default for key, val in parameters.items() if key not in EXCLUDED}}
explained = {class_to_explain.__name__:
{key: val.default for key, val in parameters.items() if key not in EXCLUDED}
}
return explained
def _load_and_compare(self, compare_class, paths):

View File

@ -6,18 +6,21 @@ from marl_factory_grid.environment.factory import Factory
from marl_factory_grid.utils.logging.envmonitor import EnvMonitor
from marl_factory_grid.utils.logging.recorder import EnvRecorder
from marl_factory_grid.utils.plotting.plot_single_runs import plot_single_run
from marl_factory_grid.utils.tools import ConfigExplainer
if __name__ == '__main__':
# Render at each step?
render = True
render = False
# Reveal all possible Modules (Entities, Rules, Agents[Actions, Observations], etc.)
explain_config = False
# Collect statistics?
monitor = False
monitor = True
# Record as Protobuf?
record = False
# Plot Results?
plotting = True
run_path = Path('study_out')
@ -38,7 +41,7 @@ if __name__ == '__main__':
factory = EnvRecorder(factory)
# RL learn Loop
for episode in trange(500):
for episode in trange(10):
_ = factory.reset()
done = False
if render:
@ -54,7 +57,10 @@ if __name__ == '__main__':
break
if monitor:
factory.save_run(run_path / 'test.pkl')
factory.save_run(run_path / 'test_monitor.pkl')
if record:
factory.save_records(run_path / 'test.pb')
if plotting:
plot_single_run(run_path)
print('Done!!! Goodbye....')

View File

@ -56,6 +56,7 @@ if __name__ == '__main__':
for model_idx, model in enumerate(models)]
else:
actions = models[0].predict(env_state, deterministic=determin)[0]
# noinspection PyTupleAssignmentBalance
env_state, step_r, done_bool, info_obj = env.step(actions)
rew += step_r

View File

@ -5,7 +5,6 @@ from pathlib import Path
if __name__ == '__main__':
conf_path = Path('wg0')
wg0_conf = configparser.ConfigParser()
wg0_conf.read(conf_path/'wg0.conf')
@ -17,7 +16,6 @@ if __name__ == '__main__':
# Delete any old conf.json for the current peer
(conf_path / f'{client_name}.json').unlink(missing_ok=True)
peer = wg0_conf[client_name]
date_time = datetime.now().strftime('%Y-%m-%dT%H:%M:%S.%f000Z')