Merge branch 'main' into unit_testing

# Conflicts:
#	marl_factory_grid/modules/doors/groups.py
#	marl_factory_grid/utils/states.py
This commit is contained in:
Chanumask 2023-11-23 12:58:12 +01:00
commit fcea1972a4
63 changed files with 1477 additions and 330 deletions

5
.gitignore vendored
View File

@ -81,7 +81,7 @@ acs-*.bib
# endnotes
*.ent
# fixme
# f i x m e
*.lox
# feynmf/feynmp
@ -143,7 +143,6 @@ acs-*.bib
# knitr
*-concordance.tex
# TODO Comment the next line if you want to keep your tikz graphics files
*.tikz
*-tikzDictionary
@ -225,7 +224,7 @@ pythontex-files-*/
*.hst
*.ver
# easy-todo
# easy-t o d o
*.lod
# xcolor

View File

@ -10,12 +10,14 @@ build-job: # This job runs in the build stage, which runs first.
variables:
TWINE_USERNAME: $USER_NAME
TWINE_PASSWORD: $API_KEY
TWINE_REPOSITORY: marl-factory-grid
image: python:slim
script:
- echo "Compiling the code..."
- pip install twine --upgrade
- python setup.py sdist
- echo "Compile complete."
- twine upload dist/* --username $USER_NAME --password $API_KEY --repository marl-factory-grid
- pip install -U twine
- python setup.py sdist bdist_wheel
- twine check dist/*
# try upload in test platform before the oficial
- twine upload --repository-url https://upload.pypi.org/legacy/ dist/*
- echo "Upload complete."

View File

@ -1,7 +1,6 @@
from pathlib import Path
import numpy as np
import torch
import yaml
@ -74,6 +73,7 @@ class Checkpointer(object):
def save_experiment(self, name: str, model):
cpt_path = self.path / f'checkpoint_{self.__current_checkpoint}'
cpt_path.mkdir(exist_ok=True, parents=True)
import torch
torch.save(model.state_dict(), cpt_path / f'{name}.pt')
def step(self, to_save):

View File

@ -0,0 +1,74 @@
Agents:
Wolfgang:
Actions:
- Move8
- DoorUse
- Clean
- Noop
Observations:
- Walls
- Doors
- Other
- DirtPiles
Clones: 8
Juergen:
Actions:
- Move8
- DoorUse
- ItemAction
- Noop
Observations:
- Walls
- Doors
- Other
- Items
- DropOffLocations
- Inventory
Entities:
DirtPiles:
coords_or_quantity: 10
initial_amount: 2
clean_amount: 1
dirt_spawn_r_var: 0.1
max_global_amount: 20
max_local_amount: 5
Doors:
DropOffLocations:
coords_or_quantity: 1
max_dropoff_storage_size: 0
Inventories: {}
Items:
coords_or_quantity: 5
General:
env_seed: 69
individual_rewards: true
level_name: rooms
pomdp_r: 3
verbose: True
tests: false
Rules:
# Environment Dynamics
EntitiesSmearDirtOnMove:
smear_ratio: 0.2
DoorAutoClose:
close_frequency: 7
# Respawn Stuff
RespawnDirt:
respawn_freq: 30
RespawnItems:
respawn_freq: 50
# Utilities
WatchCollisions:
done_at_collisions: false
# Done Conditions
DoneOnAllDirtCleaned:
DoneAtMaxStepsReached:
max_steps: 500

View File

@ -3,7 +3,7 @@ Agents:
Actions:
- Noop
- Charge
- CleanUp
- Clean
- DestAction
- DoorUse
- ItemAction
@ -23,6 +23,7 @@ Agents:
- DropOffLocations
- Maintainers
Entities:
Batteries:
initial_charge: 0.8
per_action_costs: 0.02
@ -57,7 +58,7 @@ General:
individual_rewards: true
level_name: large
pomdp_r: 3
verbose: True
verbose: false
tests: false
Rules:

View File

@ -28,9 +28,10 @@ class Action(abc.ABC):
def __repr__(self):
return f'Action[{self._identifier}]'
def get_result(self, validity, entity):
def get_result(self, validity, entity, action_introduced_collision=False):
reward = self.valid_reward if validity else self.fail_reward
return ActionResult(self.__class__.__name__, validity, reward=reward, entity=entity)
return ActionResult(self.__class__.__name__, validity, reward=reward, entity=entity,
action_introduced_collision=action_introduced_collision)
class Noop(Action):
@ -50,19 +51,24 @@ class Move(Action, abc.ABC):
def do(self, entity, state):
new_pos = self._calc_new_pos(entity.pos)
collision = False
if state.check_move_validity(entity, new_pos):
valid = entity.move(new_pos, state)
# Aftermath Collision Check
if len([x for x in state.entities.by_pos(entity.pos) if x.var_can_collide]) > 1:
# The entity did move, but there was something to collide with...
collision = True
else:
# There is no place to go, propably collision
# This is currently handeld by the WatchCollisions rule, so that it can be switched on and off by conf.yml
# return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=r.COLLISION)
valid = c.NOT_VALID
collision = True
if valid:
state.print(f'{entity.name} just moved to {entity.pos}.')
else:
state.print(f'{entity.name} just tried to move to {new_pos} but either failed or hat a Collision.')
return self.get_result(valid, entity)
return self.get_result(valid, entity, action_introduced_collision=collision)
def _calc_new_pos(self, pos):
x_diff, y_diff = MOVEMAP[self._identifier]

View File

@ -24,6 +24,7 @@ SYMBOL_FLOOR = '-'
VALID = True # Identifier to rename boolean values in the context of actions.
NOT_VALID = False # Identifier to rename boolean values in the context of actions.
VALUE_FREE_CELL = 0 # Free-Cell value used in observation
VALUE_OCCUPIED_CELL = 1 # Occupied-Cell value used in observation
VALUE_NO_POS = (-9999, -9999) # Invalid Position value used in the environment (smth. is off-grid)

View File

@ -14,36 +14,70 @@ class Agent(Entity):
@property
def var_is_paralyzed(self):
"""
TODO
:return:
"""
return len(self._paralyzed)
@property
def paralyze_reasons(self):
"""
TODO
:return:
"""
return [x for x in self._paralyzed]
@property
def obs_tag(self):
"""Internal Usage"""
return self.name
@property
def actions(self):
"""
TODO
:return:
"""
return self._actions
@property
def observations(self):
"""
TODO
:return:
"""
return self._observations
def step_result(self):
pass
"""
TODO
FIXME THINK ITS LEGACY... Not Used any more
@property
def collection(self):
return self._collection
:return:
"""
pass
@property
def var_is_blocking_pos(self):
return self._is_blocking_pos
def __init__(self, actions: List[Action], observations: List[str], *args, is_blocking_pos=False, **kwargs):
"""
TODO
:return:
"""
super(Agent, self).__init__(*args, **kwargs)
self._paralyzed = set()
self.step_result = dict()
@ -53,29 +87,58 @@ class Agent(Entity):
self._is_blocking_pos = is_blocking_pos
def summarize_state(self):
"""
TODO
:return:
"""
state_dict = super().summarize_state()
state_dict.update(valid=bool(self.state.validity), action=str(self.state.identifier))
return state_dict
def set_state(self, action_result):
self._status = action_result
def set_state(self, state):
"""
TODO
:return:
"""
self._status = state
return c.VALID
def paralyze(self, reason):
"""
TODO
:return:
"""
self._paralyzed.add(reason)
return c.VALID
def de_paralyze(self, reason):
def de_paralyze(self, reason) -> bool:
"""
TODO
:return:
"""
try:
self._paralyzed.remove(reason)
return c.VALID
except KeyError:
return c.NOT_VALID
def render(self):
i = next(idx for idx, x in enumerate(self._collection) if x.name == self.name)
def render(self) -> RenderEntity:
i = self.collection.idx_by_entity(self)
assert i is not None
curr_state = self.state
name = c.AGENT
if curr_state.identifier == c.COLLISION:
render_state = renderer.STATE_COLLISION
name = renderer.STATE_COLLISION
render_state=None
elif curr_state.validity:
if curr_state.identifier == c.NOOP:
render_state = renderer.STATE_IDLE
@ -86,4 +149,4 @@ class Agent(Entity):
else:
render_state = renderer.STATE_INVALID
return RenderEntity(c.AGENT, self.pos, 1, 'none', render_state, i + 1, real_name=self.name)
return RenderEntity(name, self.pos, 1, 'none', render_state, i + 1, real_name=self.name)

View File

@ -4,23 +4,40 @@ import numpy as np
from .object import Object
from .. import constants as c
from ...utils.results import ActionResult
from ...utils.results import State
from ...utils.utility_classes import RenderEntity
class Entity(Object, abc.ABC):
"""Full Env Entity that lives on the environment Grid. Doors, Items, DirtPile etc..."""
@property
def state(self):
return self._status or ActionResult(entity=self, identifier=c.NOOP, validity=c.VALID)
"""
TODO
:return:
"""
return self._status or State(entity=self, identifier=c.NOOP, validity=c.VALID)
@property
def var_has_position(self):
"""
TODO
:return:
"""
return self.pos != c.VALUE_NO_POS
@property
def var_is_blocking_light(self):
"""
TODO
:return:
"""
try:
return self._collection.var_is_blocking_light or False
except AttributeError:
@ -28,6 +45,12 @@ class Entity(Object, abc.ABC):
@property
def var_can_move(self):
"""
TODO
:return:
"""
try:
return self._collection.var_can_move or False
except AttributeError:
@ -35,6 +58,12 @@ class Entity(Object, abc.ABC):
@property
def var_is_blocking_pos(self):
"""
TODO
:return:
"""
try:
return self._collection.var_is_blocking_pos or False
except AttributeError:
@ -42,6 +71,12 @@ class Entity(Object, abc.ABC):
@property
def var_can_collide(self):
"""
TODO
:return:
"""
try:
return self._collection.var_can_collide or False
except AttributeError:
@ -49,22 +84,53 @@ class Entity(Object, abc.ABC):
@property
def x(self):
"""
TODO
:return:
"""
return self.pos[0]
@property
def y(self):
"""
TODO
:return:
"""
return self.pos[1]
@property
def pos(self):
"""
TODO
:return:
"""
return self._pos
def set_pos(self, pos):
def set_pos(self, pos) -> bool:
"""
TODO
:return:
"""
assert isinstance(pos, tuple) and len(pos) == 2
self._pos = pos
return c.VALID
@property
def last_pos(self):
"""
TODO
:return:
"""
try:
return self._last_pos
except AttributeError:
@ -74,12 +140,24 @@ class Entity(Object, abc.ABC):
@property
def direction_of_view(self):
"""
TODO
:return:
"""
if self._last_pos != c.VALUE_NO_POS:
return 0, 0
else:
return np.subtract(self._last_pos, self.pos)
def move(self, next_pos, state):
"""
TODO
:return:
"""
next_pos = next_pos
curr_pos = self._pos
if not_same_pos := curr_pos != next_pos:
@ -90,22 +168,24 @@ class Entity(Object, abc.ABC):
self.set_pos(next_pos)
for observer in self.observers:
observer.notify_add_entity(self)
# Aftermath Collision Check
if len([x for x in state.entities.by_pos(next_pos) if x.var_can_collide]) > 1:
# The entity did move, but there was something to collide with...
# Is then reported as a non-valid move, which did work.
valid = False
return valid
# Bad naming... Was the same was the same pos, not moving....
return not_same_pos
def __init__(self, pos, bind_to=None, **kwargs):
"""
Full Env Entity that lives on the environment Grid. Doors, Items, DirtPile etc...
TODO
:return:
"""
super().__init__(**kwargs)
self._view_directory = c.VALUE_NO_POS
self._status = None
self._pos = pos
self._last_pos = pos
self._collection = None
if bind_to:
try:
self.bind_to(bind_to)
@ -114,14 +194,27 @@ class Entity(Object, abc.ABC):
exit()
def summarize_state(self) -> dict:
"""
TODO
:return:
"""
return dict(name=str(self.name), x=int(self.x), y=int(self.y), can_collide=bool(self.var_can_collide))
@abc.abstractmethod
def render(self):
"""
TODO
:return:
"""
return RenderEntity(self.__class__.__name__.lower(), self.pos)
@property
def obs_tag(self):
"""Internal Usage"""
try:
return self._collection.name or self.name
except AttributeError:
@ -129,10 +222,32 @@ class Entity(Object, abc.ABC):
@property
def encoding(self):
"""
TODO
:return:
"""
return c.VALUE_OCCUPIED_CELL
def change_parent_collection(self, other_collection):
"""
TODO
:return:
"""
other_collection.add_item(self)
self._collection.delete_env_object(self)
self._collection = other_collection
return self._collection == other_collection
@property
def collection(self):
"""
TODO
:return:
"""
return self._collection

View File

@ -6,40 +6,85 @@ import marl_factory_grid.utils.helpers as h
class Object:
"""Generell Objects for Organisation and Maintanance such as Actions etc..."""
_u_idx = defaultdict(lambda: 0)
def __bool__(self):
return True
@property
def bound_entity(self):
"""
TODO
:return:
"""
return self._bound_entity
@property
def var_can_be_bound(self):
def var_can_be_bound(self) -> bool:
"""
TODO
Indicates if it is possible to bind this object to another Entity or Object.
:return: Whether this object can be bound.
"""
try:
return self._collection.var_can_be_bound or False
except AttributeError:
return False
@property
def observers(self):
def observers(self) -> set:
"""
TODO
:return:
"""
return self._observers
@property
def name(self):
"""
TODO
:return:
"""
return f'{self.__class__.__name__}[{self.identifier}]'
@property
def identifier(self):
"""
TODO
:return:
"""
if self._str_ident is not None:
return self._str_ident
else:
return self.u_int
def reset_uid(self):
"""
TODO
:return:
"""
self._u_idx = defaultdict(lambda: 0)
return True
def __init__(self, str_ident: Union[str, None] = None, **kwargs):
"""
Generell Objects for Organisation and Maintanance such as Actions etc...
TODO
:param str_ident:
:return:
"""
self._status = None
self._bound_entity = None
self._observers = set()
@ -50,6 +95,9 @@ class Object:
if kwargs:
print(f'Following kwargs were passed, but ignored: {kwargs}')
def __bool__(self) -> bool:
return True
def __repr__(self):
name = self.name
if self.bound_entity:
@ -67,39 +115,61 @@ class Object:
def __hash__(self):
return hash(self.identifier)
def _identify_and_count_up(self):
def _identify_and_count_up(self) -> int:
"""Internal Usage"""
idx = Object._u_idx[self.__class__.__name__]
Object._u_idx[self.__class__.__name__] += 1
return idx
def set_collection(self, collection):
"""Internal Usage"""
self._collection = collection
return self
def add_observer(self, observer):
"""Internal Usage"""
self.observers.add(observer)
observer.notify_add_entity(self)
return self
def del_observer(self, observer):
"""Internal Usage"""
self.observers.remove(observer)
return self
def summarize_state(self):
return dict()
def clear_temp_state(self):
"""Internal Usage"""
self._status = None
return self
def bind_to(self, entity):
# noinspection PyAttributeOutsideInit
"""
TODO
:return:
"""
self._bound_entity = entity
return c.VALID
def belongs_to_entity(self, entity):
"""
TODO
:return:
"""
return self._bound_entity == entity
@property
def bound_entity(self):
return self._bound_entity
def unbind(self):
"""
TODO
:return:
"""
previously_bound = self._bound_entity
self._bound_entity = None
return previously_bound

View File

@ -11,15 +11,33 @@ from marl_factory_grid.environment.entity.object import Object
class PlaceHolder(Object):
def __init__(self, *args, fill_value=0, **kwargs):
"""
TODO
:return:
"""
super().__init__(*args, **kwargs)
self._fill_value = fill_value
@property
def can_collide(self):
def var_can_collide(self):
"""
TODO
:return:
"""
return False
@property
def encoding(self):
"""
TODO
:return:
"""
return self._fill_value
@property
@ -29,14 +47,30 @@ class PlaceHolder(Object):
class GlobalPosition(Object):
@property
def obs_tag(self):
return self.name
@property
def encoding(self):
"""
TODO
:return:
"""
if self._normalized:
return tuple(np.divide(self._bound_entity.pos, self._shape))
else:
return self.bound_entity.pos
def __init__(self, agent, level_shape, *args, normalized: bool = True, **kwargs):
"""
TODO
:return:
"""
super(GlobalPosition, self).__init__(*args, **kwargs)
self.bind_to(agent)
self._normalized = normalized

View File

@ -6,6 +6,12 @@ from marl_factory_grid.utils.utility_classes import RenderEntity
class Wall(Entity):
def __init__(self, *args, **kwargs):
"""
TODO
:return:
"""
super().__init__(*args, **kwargs)
@property

View File

@ -23,22 +23,52 @@ class Factory(gym.Env):
@property
def action_space(self):
"""
TODO
:return:
"""
return self.state[c.AGENT].action_space
@property
def named_action_space(self):
"""
TODO
:return:
"""
return self.state[c.AGENT].named_action_space
@property
def observation_space(self):
"""
TODO
:return:
"""
return self.obs_builder.observation_space(self.state)
@property
def named_observation_space(self):
"""
TODO
:return:
"""
return self.obs_builder.named_observation_space(self.state)
@property
def params(self) -> dict:
"""
FIXME LAGEGY
:return:
"""
import yaml
config_path = Path(self._config_file)
config_dict = yaml.safe_load(config_path.open())
@ -49,6 +79,12 @@ class Factory(gym.Env):
def __init__(self, config_file: Union[str, PathLike], custom_modules_path: Union[None, PathLike] = None,
custom_level_path: Union[None, PathLike] = None):
"""
TODO
:return:
"""
self._config_file = config_file
self.conf = FactoryConfigParser(self._config_file, custom_modules_path)
# Attribute Assignment

View File

@ -5,6 +5,12 @@ from marl_factory_grid.environment.groups.collection import Collection
class Agents(Collection):
_entity = Agent
@property
def obs_pairs(self):
pair_list = [(self.name, self)]
pair_list.extend([(a.name, a) for a in self])
return pair_list
@property
def spawn_rule(self):
return {}
@ -21,17 +27,33 @@ class Agents(Collection):
def var_has_position(self):
return True
@property
def var_can_collide(self):
return True
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@property
def action_space(self):
"""
TODO
:return:
"""
from gymnasium import spaces
space = spaces.Tuple([spaces.Discrete(len(x.actions)) for x in self])
return space
@property
def named_action_space(self):
def named_action_space(self) -> dict[str, dict[str, list[int]]]:
"""
TODO
:return:
"""
named_space = dict()
for agent in self:
named_space[agent.name] = {action.name: idx for idx, action in enumerate(agent.actions)}

View File

@ -118,12 +118,6 @@ class Collection(Objects):
except (StopIteration, AttributeError):
return None
def idx_by_entity(self, entity):
try:
return next((idx for idx, x in enumerate(self) if x.belongs_to_entity(entity)))
except (StopIteration, AttributeError):
return None
def render(self):
if self.var_has_position:
return [y for y in [x.render() for x in self] if y is not None]

View File

@ -28,9 +28,3 @@ class HasBoundMixin:
return next((x for x in self if x.belongs_to_entity(entity)))
except (StopIteration, AttributeError):
return None
def idx_by_entity(self, entity):
try:
return next((idx for idx, x in enumerate(self) if x.belongs_to_entity(entity)))
except (StopIteration, AttributeError):
return None

View File

@ -160,7 +160,7 @@ class Objects:
def idx_by_entity(self, entity):
try:
return h.get_first_index(self, filter_by=lambda x: x.belongs_to_entity(entity))
return h.get_first_index(self, filter_by=lambda x: x == entity)
except (StopIteration, AttributeError):
return None

View File

@ -1,4 +1,4 @@
from typing import List, Union
from typing import List, Union, Iterable
from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.entity.util import GlobalPosition
@ -39,17 +39,36 @@ class GlobalPositions(Collection):
_entity = GlobalPosition
var_is_blocking_light = False
var_can_be_bound = True
var_can_collide = False
var_has_position = False
@property
def var_is_blocking_light(self):
return False
@property
def var_has_position(self):
return False
@property
def var_can_collide(self):
return False
@property
def var_can_be_bound(self):
return True
def __init__(self, *args, **kwargs):
"""
TODO
:return:
"""
super(GlobalPositions, self).__init__(*args, **kwargs)
def spawn(self, agents, level_shape, *args, **kwargs):
def spawn(self, agents, level_shape, *args, **kwargs) -> list[Result]:
self.add_items([self._entity(agent, level_shape, *args, **kwargs) for agent in agents])
return [Result(identifier=f'{self.name}_spawn', validity=c.VALID, value=len(self))]
def trigger_spawn(self, state: Gamestate, *args, **kwargs) -> [Result]:
return self.spawn(state[c.AGENT], state.lvl_shape, *args, **kwargs)
def trigger_spawn(self, state: Gamestate, *args, **kwargs) -> list[Result]:
result = self.spawn(state[c.AGENT], state.lvl_shape, *args, **kwargs)
state.print(f'{len(self)} new {self.__class__.__name__} have been spawned for {[x for x in state[c.AGENT]]}')
return result

View File

@ -12,44 +12,94 @@ class Rule(abc.ABC):
@property
def name(self):
"""
TODO
:return:
"""
return self.__class__.__name__
def __init__(self):
"""
TODO
:return:
"""
pass
def __repr__(self):
return f'{self.name}'
def on_init(self, state, lvl_map):
"""
TODO
:return:
"""
return []
def on_reset(self, state) -> List[TickResult]:
"""
TODO
:return:
"""
return []
def tick_pre_step(self, state) -> List[TickResult]:
"""
TODO
:return:
"""
return []
def tick_step(self, state) -> List[TickResult]:
"""
TODO
:return:
"""
return []
def tick_post_step(self, state) -> List[TickResult]:
"""
TODO
:return:
"""
return []
def on_check_done(self, state) -> List[DoneResult]:
"""
TODO
:return:
"""
return []
class SpawnEntity(Rule):
@property
def _collection(self) -> Collection:
return Collection()
@property
def name(self):
return f'{self.__class__.__name__}({self.collection.name})'
def __init__(self, collection, coords_or_quantity, ignore_blocking=False):
"""
TODO
:return:
"""
super().__init__()
self.coords_or_quantity = coords_or_quantity
self.collection = collection
@ -65,6 +115,12 @@ class SpawnEntity(Rule):
class SpawnAgents(Rule):
def __init__(self):
"""
TODO
:return:
"""
super().__init__()
pass
@ -91,6 +147,12 @@ class SpawnAgents(Rule):
class DoneAtMaxStepsReached(Rule):
def __init__(self, max_steps: int = 500):
"""
TODO
:return:
"""
super().__init__()
self.max_steps = max_steps
@ -103,6 +165,12 @@ class DoneAtMaxStepsReached(Rule):
class AssignGlobalPositions(Rule):
def __init__(self):
"""
TODO
:return:
"""
super().__init__()
def on_reset(self, state, lvl_map):
@ -116,6 +184,12 @@ class AssignGlobalPositions(Rule):
class WatchCollisions(Rule):
def __init__(self, reward=r.COLLISION, done_at_collisions: bool = False, reward_at_done=r.COLLISION_DONE):
"""
TODO
:return:
"""
super().__init__()
self.reward_at_done = reward_at_done
self.reward = reward
@ -124,9 +198,14 @@ class WatchCollisions(Rule):
def tick_post_step(self, state) -> List[TickResult]:
self.curr_done = False
pos_with_collisions = state.get_collision_positions()
results = list()
for pos in pos_with_collisions:
for agent in state[c.AGENT]:
a_s = agent.state
if h.is_move(a_s.identifier) and a_s.action_introduced_collision:
results.append(TickResult(entity=agent, identifier=c.COLLISION,
reward=self.reward, validity=c.VALID))
for pos in state.get_collision_positions():
guests = [x for x in state.entities.pos_dict[pos] if x.var_can_collide]
if len(guests) >= 2:
for i, guest in enumerate(guests):
@ -136,15 +215,18 @@ class WatchCollisions(Rule):
)
except AttributeError:
pass
results.append(TickResult(entity=guest, identifier=c.COLLISION,
reward=self.reward, validity=c.VALID))
if not any([x.entity == guest for x in results]):
results.append(TickResult(entity=guest, identifier=c.COLLISION,
reward=self.reward, validity=c.VALID))
self.curr_done = True if self.done_at_collisions else False
return results
def on_check_done(self, state) -> List[DoneResult]:
if self.done_at_collisions:
inter_entity_collision_detected = self.curr_done
move_failed = any(h.is_move(x.state.identifier) and not x.state.validity for x in state[c.AGENT])
if inter_entity_collision_detected or move_failed:
collision_in_step = any(h.is_move(x.state.identifier) and x.state.action_introduced_collision
for x in state[c.AGENT]
)
if inter_entity_collision_detected or collision_in_step:
return [DoneResult(validity=c.VALID, identifier=c.COLLISION, reward=self.reward_at_done)]
return []

View File

@ -11,6 +11,10 @@ from marl_factory_grid.utils import helpers as h
class Charge(Action):
def __init__(self):
"""
Checks if a charge pod is present at the entity's position.
If found, it attempts to charge the battery using the charge pod.
"""
super().__init__(b.ACTION_CHARGE, b.REWARD_CHARGE_VALID, b.Reward_CHARGE_FAIL)
def do(self, entity, state) -> Union[None, ActionResult]:

View File

@ -13,7 +13,12 @@ class Battery(Object):
return True
@property
def is_discharged(self):
def is_discharged(self) -> bool:
"""
Indicates whether the Batteries charge level is at 0 or not.
:return: Whether this battery is empty.
"""
return self.charge_level == 0
@property
@ -24,12 +29,27 @@ class Battery(Object):
def encoding(self):
return self.charge_level
def __init__(self, initial_charge_level: float, owner: Entity, *args, **kwargs):
def __init__(self, initial_charge_level, owner, *args, **kwargs):
"""
Represents a battery entity in the environment that can be bound to an agent and charged at chargepods.
:param initial_charge_level: The current charge level of the battery, ranging from 0 to 1.
:type initial_charge_level: float
:param owner: The entity to which the battery is bound.
:type owner: Entity
"""
super(Battery, self).__init__(*args, **kwargs)
self.charge_level = initial_charge_level
self.bind_to(owner)
def do_charge_action(self, amount):
def do_charge_action(self, amount) -> bool:
"""
Updates the Battery's charge level accordingly.
:param amount: Amount added to the Battery's charge level.
:returns: whether the battery could be charged. if not, it was already fully charged.
"""
if self.charge_level < 1:
# noinspection PyTypeChecker
self.charge_level = min(1, amount + self.charge_level)
@ -37,7 +57,10 @@ class Battery(Object):
else:
return c.NOT_VALID
def decharge(self, amount) -> float:
def decharge(self, amount) -> bool:
"""
Decreases the charge value of a battery. Currently only riggered by the battery-decharge rule.
"""
if self.charge_level != 0:
# noinspection PyTypeChecker
self.charge_level = max(0, amount + self.charge_level)
@ -57,13 +80,27 @@ class ChargePod(Entity):
def encoding(self):
return b.CHARGE_POD_SYMBOL
def __init__(self, *args, charge_rate: float = 0.4,
multi_charge: bool = False, **kwargs):
def __init__(self, *args, charge_rate: float = 0.4, multi_charge: bool = False, **kwargs):
"""
Represents a charging pod for batteries in the environment.
:param charge_rate: The rate at which the charging pod charges batteries. Default is 0.4.
:type charge_rate: float
:param multi_charge: Indicates whether the charging pod supports charging multiple batteries simultaneously.
Default is False.
:type multi_charge: bool
"""
super(ChargePod, self).__init__(*args, **kwargs)
self.charge_rate = charge_rate
self.multi_charge = multi_charge
def charge_battery(self, entity, state):
def charge_battery(self, entity, state) -> bool:
"""
Checks whether the battery can be charged. If so, triggers the charge action.
:returns: whether the action was successful (valid) or not.
"""
battery = state[b.BATTERIES].by_entity(entity)
if battery.charge_level >= 1.0:
return c.NOT_VALID
@ -76,6 +113,6 @@ class ChargePod(Entity):
return RenderEntity(b.CHARGE_PODS, self.pos)
def summarize_state(self) -> dict:
summery = super().summarize_state()
summery.update(charge_rate=self.charge_rate)
return summery
summary = super().summarize_state()
summary.update(charge_rate=self.charge_rate)
return summary

View File

@ -9,22 +9,32 @@ from marl_factory_grid.utils.results import Result
class Batteries(Collection):
_entity = Battery
var_has_position = False
var_can_be_bound = True
@property
def var_has_position(self):
return False
@property
def obs_tag(self):
return self.__class__.__name__
def var_can_be_bound(self):
return True
def __init__(self, size, initial_charge_level: float=1.0, *args, **kwargs):
def __init__(self, size, initial_charge_level=1.0, *args, **kwargs):
"""
A collection of batteries that can spawn batteries.
:param size: The maximum allowed size of the collection. Ensures that the collection does not exceed this size.
:type size: int
:param initial_charge_level: The initial charge level of the battery.
:type initial_charge_level: float
"""
super(Batteries, self).__init__(size, *args, **kwargs)
self.initial_charge_level = initial_charge_level
def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], agents, *entity_args, **entity_kwargs):
batteries = [self._entity(self.initial_charge_level, agent) for _, agent in enumerate(agents)]
def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args, **entity_kwargs):
batteries = [self._entity(self.initial_charge_level, agent) for _, agent in enumerate(entity_args[0])]
self.add_items(batteries)
def trigger_spawn(self, state, *entity_args, coords_or_quantity=None, **entity_kwargs):
def trigger_spawn(self, state, *entity_args, coords_or_quantity=None, **entity_kwargs):
self.spawn(0, state[c.AGENT])
return Result(identifier=f'{self.name}_spawn', validity=c.VALID, value=len(self))
@ -33,6 +43,9 @@ class ChargePods(Collection):
_entity = ChargePod
def __init__(self, *args, **kwargs):
"""
A collection of charge pods in the environment.
"""
super(ChargePods, self).__init__(*args, **kwargs)
def __repr__(self):

View File

@ -24,16 +24,16 @@ class BatteryDecharge(Rule):
2. float: each action "costs" the same.
----
!!! Does not introduce any Env.-Done condition.
!!! Batterys can only be charged if agent posses the "Charge(Action.
!!! Batterys can only be charged if there are "Charpods" and they are spawned!
!!! Batteries can only be charged if agent posses the "Charge" Action.
!!! Batteries can only be charged if there are "Charge Pods" and they are spawned!
----
:type initial_charge: float
:param initial_charge: How much juice they have.
:type battery_discharge_reward: float
:param battery_discharge_reward: Negativ reward, when agents let their batters discharge.
:param battery_discharge_reward: Negative reward, when agents let their batters discharge.
Default: {b.REWARD_BATTERY_DISCHARGED}
:type battery_failed_reward: float
:param battery_failed_reward: Negativ reward, when agent cannot charge, but do (overcharge, not on station).
:param battery_failed_reward: Negative reward, when agent cannot charge, but do (overcharge, not on station).
Default: {b.Reward_CHARGE_FAIL}
:type battery_charge_reward: float
:param battery_charge_reward: Positive reward, when agent actually charge their battery.
@ -48,7 +48,6 @@ class BatteryDecharge(Rule):
self.initial_charge = initial_charge
def tick_step(self, state) -> List[TickResult]:
# Decharge
batteries = state[b.BATTERIES]
results = []
@ -104,13 +103,13 @@ class DoneAtBatteryDischarge(BatteryDecharge):
:type initial_charge: float
:param initial_charge: How much juice they have.
:type reward_discharge_done: float
:param reward_discharge_done: Global negativ reward, when agents let their batters discharge.
:param reward_discharge_done: Global negative reward, when agents let their batters discharge.
Default: {b.REWARD_BATTERY_DISCHARGED}
:type battery_discharge_reward: float
:param battery_discharge_reward: Negativ reward, when agents let their batters discharge.
:param battery_discharge_reward: Negative reward, when agents let their batters discharge.
Default: {b.REWARD_BATTERY_DISCHARGED}
:type battery_failed_reward: float
:param battery_failed_reward: Negativ reward, when agent cannot charge, but do (overcharge, not on station).
:param battery_failed_reward: Negative reward, when agent cannot charge, but do (overcharge, not on station).
Default: {b.Reward_CHARGE_FAIL}
:type battery_charge_reward: float
:param battery_charge_reward: Positive reward, when agent actually charge their battery.

View File

@ -1,4 +1,4 @@
from .actions import CleanUp
from .actions import Clean
from .entitites import DirtPile
from .groups import DirtPiles
from .rules import EntitiesSmearDirtOnMove, DoneOnAllDirtCleaned

View File

@ -3,15 +3,18 @@ from typing import Union
from marl_factory_grid.environment.actions import Action
from marl_factory_grid.utils.results import ActionResult
from marl_factory_grid.modules.clean_up import constants as d, rewards as r
from marl_factory_grid.modules.clean_up import constants as d
from marl_factory_grid.environment import constants as c
class CleanUp(Action):
class Clean(Action):
def __init__(self):
super().__init__(d.CLEAN_UP, r.CLEAN_UP_VALID, r.CLEAN_UP_FAIL)
"""
Attempts to reduce dirt amount on entity's position.
"""
super().__init__(d.CLEAN_UP, d.REWARD_CLEAN_UP_VALID, d.REWARD_CLEAN_UP_FAIL)
def do(self, entity, state) -> Union[None, ActionResult]:
if dirt := next((x for x in state.entities.pos_dict[entity.pos] if "dirt" in x.name.lower()), None):

View File

@ -5,3 +5,7 @@ CLEAN_UP = 'do_cleanup_action'
CLEAN_UP_VALID = 'clean_up_valid'
CLEAN_UP_FAIL = 'clean_up_fail'
CLEAN_UP_ALL = 'all_cleaned_up'
REWARD_CLEAN_UP_VALID: float = 0.5
REWARD_CLEAN_UP_FAIL: float = -0.1
REWARD_CLEAN_UP_ALL: float = 4.5

View File

@ -7,19 +7,33 @@ class DirtPile(Entity):
@property
def amount(self):
"""
Internal Usage
"""
return self._amount
@property
def encoding(self):
# Edit this if you want items to be drawn in the ops differntly
return self._amount
def __init__(self, *args, amount=2, max_local_amount=5, **kwargs):
"""
Represents a pile of dirt at a specific position in the environment.
:param amount: The amount of dirt in the pile.
:type amount: float
:param max_local_amount: The maximum amount of dirt allowed in a single pile at one position.
:type max_local_amount: float
"""
super(DirtPile, self).__init__(*args, **kwargs)
self._amount = amount
self.max_local_amount = max_local_amount
def set_new_amount(self, amount):
"""
Internal Usage
"""
self._amount = min(amount, self.max_local_amount)
def summarize_state(self):

View File

@ -2,29 +2,62 @@ from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.groups.collection import Collection
from marl_factory_grid.modules.clean_up.entitites import DirtPile
from marl_factory_grid.utils.results import Result
from marl_factory_grid.utils import helpers as h
class DirtPiles(Collection):
_entity = DirtPile
var_is_blocking_light = False
var_can_collide = False
var_can_move = False
var_has_position = True
@property
def var_is_blocking_light(self):
return False
@property
def global_amount(self):
def var_can_collide(self):
return False
@property
def var_can_move(self):
return False
@property
def var_has_position(self):
return True
@property
def global_amount(self) -> float:
"""
Internal Usage
"""
return sum([dirt.amount for dirt in self])
def __init__(self, *args,
max_local_amount=5,
clean_amount=1,
max_global_amount: int = 20,
coords_or_quantity=10,
initial_amount=2,
amount_var=0.2,
n_var=0.2,
**kwargs):
def __init__(self, *args, max_local_amount=5, clean_amount=1, max_global_amount: int = 20, coords_or_quantity=10,
initial_amount=2, amount_var=0.2, n_var=0.2, **kwargs):
"""
A Collection of dirt piles that triggers their spawn.
:param max_local_amount: The maximum amount of dirt allowed in a single pile at one position.
:type max_local_amount: int
:param clean_amount: The amount of dirt removed by a single cleaning action.
:type clean_amount: int
:param max_global_amount: The maximum total amount of dirt allowed in the environment.
:type max_global_amount: int
:param coords_or_quantity: Determines whether to use coordinates or quantity when triggering dirt pile spawn.
:type coords_or_quantity: Union[Tuple[int, int], int]
:param initial_amount: The initial amount of dirt in each newly spawned pile.
:type initial_amount: int
:param amount_var: The variability in the initial amount of dirt in each pile.
:type amount_var: float
:param n_var: The variability in the number of new dirt piles spawned.
:type n_var: float
"""
super(DirtPiles, self).__init__(*args, **kwargs)
self.amount_var = amount_var
self.n_var = n_var
@ -50,7 +83,7 @@ class DirtPiles(Collection):
for idx, (pos, a) in enumerate(zip(n_new, amounts)):
if not self.global_amount > self.max_global_amount:
if dirt := self.by_pos(pos):
dirt = next(dirt.iter())
dirt = h.get_first(dirt)
new_value = dirt.amount + a
dirt.set_new_amount(new_value)
else:

View File

@ -1,3 +0,0 @@
CLEAN_UP_VALID: float = 0.5
CLEAN_UP_FAIL: float = -0.1
CLEAN_UP_ALL: float = 4.5

View File

@ -1,4 +1,4 @@
from marl_factory_grid.modules.clean_up import constants as d, rewards as r
from marl_factory_grid.modules.clean_up import constants as d
from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.rules import Rule
@ -9,9 +9,9 @@ from marl_factory_grid.utils.results import DoneResult
class DoneOnAllDirtCleaned(Rule):
def __init__(self, reward: float = r.CLEAN_UP_ALL):
def __init__(self, reward: float = d.REWARD_CLEAN_UP_ALL):
"""
Defines a 'Done'-condition which tirggers, when there is no more 'Dirt' in the environment.
Defines a 'Done'-condition which triggers, when there is no more 'Dirt' in the environment.
:type reward: float
:parameter reward: Given reward when condition triggers.
@ -29,9 +29,9 @@ class RespawnDirt(Rule):
def __init__(self, respawn_freq: int = 15, respawn_n: int = 5, respawn_amount: float = 1.0):
"""
Defines the spawn pattern of intial and additional 'Dirt'-entitites.
First chooses positions, then trys to spawn dirt until 'respawn_n' or the maximal global amount is reached.
If there is allready some, it is topped up to min(max_local_amount, amount).
Defines the spawn pattern of initial and additional 'Dirt'-entities.
First chooses positions, then tries to spawn dirt until 'respawn_n' or the maximal global amount is reached.
If there is already some, it is topped up to min(max_local_amount, amount).
:type respawn_freq: int
:parameter respawn_freq: In which frequency should this Rule try to spawn new 'Dirt'?

View File

@ -1,16 +1,17 @@
from typing import Union
import marl_factory_grid.modules.destinations.constants
from marl_factory_grid.environment.actions import Action
from marl_factory_grid.utils.results import ActionResult
from marl_factory_grid.modules.destinations import constants as d
from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.actions import Action
from marl_factory_grid.modules.destinations import constants as d
from marl_factory_grid.utils.results import ActionResult
class DestAction(Action):
def __init__(self):
"""
Attempts to wait at destination.
"""
super().__init__(d.DESTINATION, d.REWARD_WAIT_VALID, d.REWARD_WAIT_FAIL)
def do(self, entity, state) -> Union[None, ActionResult]:

View File

@ -9,24 +9,37 @@ from marl_factory_grid.utils.utility_classes import RenderEntity
class Destination(Entity):
def was_reached(self):
return self._was_reached
@property
def encoding(self):
return d.DEST_SYMBOL
def __init__(self, *args, action_counts=0, **kwargs):
"""
Represents a destination in the environment that agents aim to reach.
"""
super(Destination, self).__init__(*args, **kwargs)
self._was_reached = False
self.action_counts = action_counts
self._per_agent_actions = defaultdict(lambda: 0)
def do_wait_action(self, agent: Agent):
def do_wait_action(self, agent) -> bool:
"""
Performs a wait action for the given agent at the destination.
:param agent: The agent performing the wait action.
:type agent: Agent
:return: Whether the action was valid or not.
:rtype: bool
"""
self._per_agent_actions[agent.name] += 1
return c.VALID
def has_just_been_reached(self, state):
"""
Checks if the destination has just been reached based on the current state.
"""
if self.was_reached():
return False
agent_at_position = any(state[c.AGENT].by_pos(self.pos))
@ -38,6 +51,9 @@ class Destination(Entity):
return agent_at_position or any(x >= self.action_counts for x in self._per_agent_actions.values())
def agent_did_action(self, agent: Agent):
"""
Internal usage, currently no usage.
"""
return self._per_agent_actions[agent.name] >= self.action_counts
def summarize_state(self) -> dict:
@ -57,3 +73,6 @@ class Destination(Entity):
def unmark_as_reached(self):
self._was_reached = False
def was_reached(self) -> bool:
return self._was_reached

View File

@ -5,13 +5,30 @@ from marl_factory_grid.modules.destinations.entitites import Destination
class Destinations(Collection):
_entity = Destination
var_is_blocking_light = False
var_can_collide = False
var_can_move = False
var_has_position = True
var_can_be_bound = True
@property
def var_is_blocking_light(self):
return False
@property
def var_can_collide(self):
return False
@property
def var_can_move(self):
return False
@property
def var_has_position(self):
return True
@property
def var_can_be_bound(self):
return True
def __init__(self, *args, **kwargs):
"""
A collection of destinations.
"""
super().__init__(*args, **kwargs)
def __repr__(self):

View File

@ -10,18 +10,17 @@ from marl_factory_grid.environment import constants as c
from marl_factory_grid.modules.destinations import constants as d
from marl_factory_grid.modules.destinations.entitites import Destination
ANY = 'any'
ALL = 'all'
SIMULTANOIUS = 'simultanious'
CONDITIONS =[ALL, ANY, SIMULTANOIUS]
ANY = 'any'
ALL = 'all'
SIMULTANEOUS = 'simultanious'
CONDITIONS = [ALL, ANY, SIMULTANEOUS]
class DestinationReachReward(Rule):
def __init__(self, dest_reach_reward=d.REWARD_DEST_REACHED):
"""
This rule introduces the basic functionality, so that targts (Destinations) can be reached and marked as such.
This rule introduces the basic functionality, so that targets (Destinations) can be reached and marked as such.
Additionally, rewards are reported.
:type dest_reach_reward: float
@ -61,7 +60,7 @@ class DoneAtDestinationReach(DestinationReachReward):
This rule triggers and sets the done flag if ALL Destinations have been reached.
:type reward_at_done: float
:param reward_at_done: Specifies the reward, agent get, whenn all destinations are reached.
:param reward_at_done: Specifies the reward, agent get, when all destinations are reached.
:type dest_reach_reward: float
:param dest_reach_reward: Specify the reward, agents get when reaching a single destination.
"""
@ -77,7 +76,7 @@ class DoneAtDestinationReach(DestinationReachReward):
elif self.condition == ALL:
if all(x.was_reached() for x in state[d.DESTINATION]):
return [DoneResult(self.name, validity=c.VALID, reward=self.reward)]
elif self.condition == SIMULTANOIUS:
elif self.condition == SIMULTANEOUS:
if all(x.was_reached() for x in state[d.DESTINATION]):
return [DoneResult(self.name, validity=c.VALID, reward=self.reward)]
else:
@ -100,13 +99,13 @@ class DoneAtDestinationReach(DestinationReachReward):
class SpawnDestinationsPerAgent(Rule):
def __init__(self, coords_or_quantity: Dict[str, List[Tuple[int, int] | int]]):
"""
Special rule, that spawn distinations, that are bound to a single agent a fixed set of positions.
Usefull for introducing specialists, etc. ..
Special rule, that spawn destinations, that are bound to a single agent a fixed set of positions.
Useful for introducing specialists, etc. ..
!!! This rule does not introduce any reward or done condition.
:param coords_or_quantity: Please provide a dictionary with agent names as keys; and a list of possible
destiantion coords as value. Example: {Wolfgang: [(0, 0), (1, 1), ...]}
destination coords as value. Example: {Wolfgang: [(0, 0), (1, 1), ...]}
"""
super(Rule, self).__init__()
self.per_agent_positions = dict()

View File

@ -1,16 +1,18 @@
from typing import Union
from marl_factory_grid.environment.actions import Action
from marl_factory_grid.modules.doors import constants as d
from marl_factory_grid.modules.doors.entitites import Door
from marl_factory_grid.modules.doors import constants as d, rewards as r
from marl_factory_grid.environment import constants as c
from marl_factory_grid.utils.results import ActionResult
class DoorUse(Action):
def __init__(self, **kwargs):
super().__init__(d.ACTION_DOOR_USE, r.USE_DOOR_VALID, r.USE_DOOR_FAIL, **kwargs)
"""
Attempts to interact with door (open/close it) and returns an action result if successful.
"""
super().__init__(d.ACTION_DOOR_USE, d.REWARD_USE_DOOR_VALID, d.REWARD_USE_DOOR_FAIL, **kwargs)
def do(self, entity, state) -> Union[None, ActionResult]:
# Check if agent really is standing on a door:
@ -26,6 +28,6 @@ class DoorUse(Action):
except AttributeError:
pass
if not valid:
# When he doesn't stand necxxt to a door tell me.
# When he doesn't stand next to a door tell me.
state.print(f'{entity.name} just tried to use a door at {entity.pos}, but there is none.')
return self.get_result(valid, entity)

View File

@ -16,3 +16,7 @@ STATE_OPEN = 'open' # Identifier to compare door-is-
# Actions
ACTION_DOOR_USE = 'use_door' # Identifier for door-action
# Rewards
REWARD_USE_DOOR_VALID: float = -0.00 # Reward for successful door use
REWARD_USE_DOOR_FAIL: float = -0.01 # Reward for unsuccessful door use

View File

@ -1,3 +1,5 @@
from typing import Union
from marl_factory_grid.environment.entity.entity import Entity
from marl_factory_grid.utils import Result
from marl_factory_grid.utils.utility_classes import RenderEntity
@ -16,6 +18,9 @@ class DoorIndicator(Entity):
return []
def __init__(self, *args, **kwargs):
"""
Is added around a door for agents to see.
"""
super().__init__(*args, **kwargs)
self.__delattr__('move')
@ -39,23 +44,38 @@ class Door(Entity):
return d.VALUE_CLOSED_DOOR if self.is_closed else d.VALUE_OPEN_DOOR
@property
def str_state(self):
def str_state(self) -> str:
"""
Internal Usage
"""
return 'open' if self.is_open else 'closed'
@property
def is_closed(self):
def is_closed(self) -> bool:
return self._state == d.STATE_CLOSED
@property
def is_open(self):
def is_open(self) -> bool:
return self._state == d.STATE_OPEN
@property
def time_to_close(self):
"""
:returns: The time it takes for the door to close.
:rtype: float
"""
return self._time_to_close
def __init__(self, *args, closed_on_init=True, auto_close_interval=10, **kwargs):
"""
A door entity that can be opened or closed by agents or rules.
:param closed_on_init: Whether the door spawns as open or closed.
:type closed_on_init: bool
:param auto_close_interval: after how many steps should the door automatically close itself,
:type auto_close_interval: int
"""
self._state = d.STATE_CLOSED
super(Door, self).__init__(*args, **kwargs)
self._auto_close_interval = auto_close_interval
@ -74,14 +94,17 @@ class Door(Entity):
name, state = 'door_open' if self.is_open else 'door_closed', 'blank'
return RenderEntity(name, self.pos, 1, 'none', state, self.u_int + 1)
def use(self):
def use(self) -> bool:
"""
Internal usage
"""
if self._state == d.STATE_OPEN:
self._close()
else:
self._open()
return c.VALID
def tick(self, state):
def tick(self, state) -> Union[Result, None]:
# Check if no entity is standing in the door
if not any(e for e in state.entities.by_pos(self.pos) if e.var_can_collide or e.var_is_blocking_pos):
# if len(state.entities.pos_dict[self.pos]) <= 2: #can collide can block
@ -99,23 +122,38 @@ class Door(Entity):
self._reset_timer()
return Result(f"{d.DOOR}_reset", c.VALID, entity=self)
def _open(self):
def _open(self) -> bool:
"""
Internal Usage
"""
self._state = d.STATE_OPEN
self._reset_timer()
return True
def _close(self):
def _close(self) -> bool:
"""
Internal Usage
"""
self._state = d.STATE_CLOSED
return True
def _decrement_timer(self):
def _decrement_timer(self) -> bool:
"""
Internal Usage
"""
self._time_to_close -= 1
return True
def _reset_timer(self):
def _reset_timer(self) -> bool:
"""
Internal Usage
"""
self._time_to_close = self._auto_close_interval
return True
def reset(self):
"""
Internal Usage
"""
self._close()
self._reset_timer()

View File

@ -1,6 +1,9 @@
from typing import List
from marl_factory_grid.environment.groups.collection import Collection
from marl_factory_grid.modules.doors import constants as d
from marl_factory_grid.modules.doors.entitites import Door
from marl_factory_grid.utils import Result
class Doors(Collection):
@ -13,18 +16,21 @@ class Doors(Collection):
return True
def __init__(self, *args, **kwargs):
"""
A collection of doors that can tick and reset all doors.
"""
super(Doors, self).__init__(*args, can_collide=True, **kwargs)
def tick_doors(self, state):
def tick_doors(self, state) -> List[Result]:
results = list()
for door in self:
assert(isinstance(door, Door))
assert isinstance(door, Door)
tick_result = door.tick(state)
if tick_result is not None:
results.append(tick_result)
# TODO: Should return a Result object, not a random dict.
return results
def reset(self):
for door in self:
assert isinstance(door, Door)
door.reset()

View File

@ -1,2 +0,0 @@
USE_DOOR_VALID: float = -0.00
USE_DOOR_FAIL: float = -0.01

View File

@ -3,20 +3,40 @@ from typing import Union
from marl_factory_grid.environment.actions import Action
from marl_factory_grid.utils.results import ActionResult
from marl_factory_grid.modules.items import constants as i, rewards as r
from marl_factory_grid.modules.items import constants as i
from marl_factory_grid.environment import constants as c
class ItemAction(Action):
def __init__(self, failed_dropoff_reward: float | None = None, valid_dropoff_reward: float | None = None, **kwargs):
super().__init__(i.ITEM_ACTION, r.PICK_UP_FAIL, r.PICK_UP_VALID, **kwargs)
self.failed_drop_off_reward = failed_dropoff_reward if failed_dropoff_reward is not None else r.DROP_OFF_FAIL
self.valid_drop_off_reward = valid_dropoff_reward if valid_dropoff_reward is not None else r.DROP_OFF_FAIL
"""
Allows an entity to pick up or drop off items in the environment.
def get_dropoff_result(self, validity, entity):
:param failed_drop_off_reward: The reward assigned when a drop-off action fails. Default is None.
:type failed_dropoff_reward: float | None
:param valid_drop_off_reward: The reward assigned when a drop-off action is successful. Default is None.
:type valid_dropoff_reward: float | None
"""
super().__init__(i.ITEM_ACTION, i.REWARD_PICK_UP_FAIL, i.REWARD_PICK_UP_VALID, **kwargs)
self.failed_drop_off_reward = failed_dropoff_reward if failed_dropoff_reward is not None else i.REWARD_DROP_OFF_FAIL
self.valid_drop_off_reward = valid_dropoff_reward if valid_dropoff_reward is not None else i.REWARD_DROP_OFF_VALID
def get_dropoff_result(self, validity, entity) -> ActionResult:
"""
Generates an ActionResult for a drop-off action based on its validity.
:param validity: Whether the drop-off action is valid.
:type validity: bool
:param entity: The entity performing the action.
:type entity: Entity
:return: ActionResult for the drop-off action.
:rtype: ActionResult
"""
reward = self.valid_drop_off_reward if validity else self.failed_drop_off_reward
return ActionResult(self.__name__, validity, reward=reward, entity=entity)
return ActionResult(self.__class__.__name__, validity, reward=reward, entity=entity)
def do(self, entity, state) -> Union[None, ActionResult]:
inventory = state[i.INVENTORY].by_entity(entity)

View File

@ -6,3 +6,9 @@ INVENTORY = 'Inventories'
DROP_OFF = 'DropOffLocations'
ITEM_ACTION = 'ITEMACTION'
# Rewards
REWARD_DROP_OFF_VALID: float = 0.1
REWARD_DROP_OFF_FAIL: float = -0.1
REWARD_PICK_UP_FAIL: float = -0.1
REWARD_PICK_UP_VALID: float = 0.1

View File

@ -8,39 +8,52 @@ from marl_factory_grid.modules.items import constants as i
class Item(Entity):
@property
def encoding(self):
return 1
def __init__(self, *args, **kwargs):
"""
An item that can be picked up or dropped by agents. If picked up, it enters the agents inventory.
"""
super().__init__(*args, **kwargs)
def render(self):
return RenderEntity(i.ITEM, self.pos) if self.pos != c.VALUE_NO_POS else None
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@property
def encoding(self):
# Edit this if you want items to be drawn in the ops differently
return 1
class DropOffLocation(Entity):
def render(self):
return RenderEntity(i.DROP_OFF, self.pos)
@property
def encoding(self):
return i.SYMBOL_DROP_OFF
def __init__(self, *args, storage_size_until_full: int = 5, **kwargs):
@property
def is_full(self) -> bool:
"""
Checks whether the drop-off location is full or whether another item can be dropped here.
"""
return False if not self.storage.maxlen else self.storage.maxlen == len(self.storage)
def __init__(self, *args, storage_size_until_full=5, **kwargs):
"""
Represents a drop-off location in the environment that agents aim to drop items at.
:param storage_size_until_full: The number of items that can be dropped here until it is considered full.
:type storage_size_until_full: int
"""
super(DropOffLocation, self).__init__(*args, **kwargs)
self.storage = deque(maxlen=storage_size_until_full or None)
def place_item(self, item: Item):
def place_item(self, item: Item) -> bool:
"""
If the storage of the drop-off location is not full, the item is placed. Otherwise, a RuntimeWarning is raised.
"""
if self.is_full:
raise RuntimeWarning("There is currently no way to clear the storage or make it unfull.")
return bc.NOT_VALID
return c.NOT_VALID
else:
self.storage.append(item)
return c.VALID
@property
def is_full(self):
return False if not self.storage.maxlen else self.storage.maxlen == len(self.storage)
def render(self):
return RenderEntity(i.DROP_OFF, self.pos)

View File

@ -1,3 +1,5 @@
from typing import Dict, Any
from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.entity.agent import Agent
from marl_factory_grid.environment.groups.collection import Collection
@ -16,14 +18,17 @@ class Items(Collection):
return True
@property
def is_blocking_light(self):
def var_is_blocking_light(self):
return False
@property
def can_collide(self):
def var_can_collide(self):
return False
def __init__(self, *args, **kwargs):
"""
A collection of items that triggers their spawn.
"""
super().__init__(*args, **kwargs)
def trigger_spawn(self, state, *entity_args, coords_or_quantity=None, **entity_kwargs) -> [Result]:
@ -51,61 +56,87 @@ class Inventory(IsBoundMixin, Collection):
def obs_tag(self):
return self.name
def __init__(self, agent: Agent, *args, **kwargs):
@property
def name(self):
return f'{self.__class__.__name__}[{self._bound_entity.name}]'
def __init__(self, agent, *args, **kwargs):
"""
An inventory that can hold items picked up by the agent this is bound to.
:param agent: The agent this inventory is bound to and belongs to.
:type agent: Agent
"""
super(Inventory, self).__init__(*args, **kwargs)
self._collection = None
self.bind(agent)
def __repr__(self):
return f'{self.__class__.__name__}#{self._bound_entity.name}({dict(self._data)})'
def summarize_states(self, **kwargs):
attr_dict = {key: val for key, val in self.__dict__.items() if not key.startswith('_') and key != 'data'}
attr_dict.update(dict(items=[val.summarize_state(**kwargs) for key, val in self.items()]))
attr_dict.update(dict(name=self.name, belongs_to=self._bound_entity.name))
return attr_dict
def pop(self):
def pop(self) -> Item:
"""
Removes and returns the first item in the inventory.
"""
item_to_pop = self[0]
self.delete_env_object(item_to_pop)
return item_to_pop
def set_collection(self, collection):
"""
No usage
"""
self._collection = collection
def clear_temp_state(self):
# Entites need this, but inventories have no state....
"""
Entites need this, but inventories have no state.
"""
pass
class Inventories(Objects):
_entity = Inventory
var_can_move = False
var_has_position = False
symbol = None
@property
def spawn_rule(self):
def var_can_move(self):
return False
@property
def var_has_position(self):
return False
@property
def spawn_rule(self) -> dict[Any, dict[str, Any]]:
"""
:returns: a dict containing the specified spawn rule and its arguments.
:rtype: dict(dict(collection=self, coords_or_quantity=None))
"""
return {c.SPAWN_ENTITY_RULE: dict(collection=self, coords_or_quantity=None)}
def __init__(self, size: int, *args, **kwargs):
"""
TODO
"""
super(Inventories, self).__init__(*args, **kwargs)
self.size = size
self._obs = None
self._lazy_eval_transforms = []
def spawn(self, agents, *args, **kwargs):
def spawn(self, agents, *args, **kwargs) -> [Result]:
self.add_items([self._entity(agent, self.size, *args, **kwargs) for _, agent in enumerate(agents)])
return [Result(identifier=f'{self.name}_spawn', validity=c.VALID, value=len(self))]
def trigger_spawn(self, state, *args, **kwargs) -> [Result]:
return self.spawn(state[c.AGENT], *args, **kwargs)
def idx_by_entity(self, entity):
try:
return next((idx for idx, inv in enumerate(self) if inv.belongs_to_entity(entity)))
except StopIteration:
return None
def by_entity(self, entity):
try:
return next((inv for inv in self if inv.belongs_to_entity(entity)))
@ -136,6 +167,9 @@ class DropOffLocations(Collection):
return True
def __init__(self, *args, **kwargs):
"""
A Collection of Drop-off locations that can trigger their spawn.
"""
super(DropOffLocations, self).__init__(*args, **kwargs)
@staticmethod

View File

@ -1,4 +0,0 @@
DROP_OFF_VALID: float = 0.1
DROP_OFF_FAIL: float = -0.1
PICK_UP_FAIL: float = -0.1
PICK_UP_VALID: float = 0.1

View File

@ -9,6 +9,16 @@ from marl_factory_grid.modules.items import constants as i
class RespawnItems(Rule):
def __init__(self, n_items: int = 5, respawn_freq: int = 15, n_locations: int = 5):
"""
Defines the respawning behaviour of items.
:param n_items: Specifies how many items should respawn.
:type n_items: int
:param respawn_freq: Specifies how often items should respawn.
:type respawn_freq: int
:param n_locations: Specifies at how many locations items should be able to respawn.
:type: int
"""
super().__init__()
self.spawn_frequency = respawn_freq
self._next_item_spawn = respawn_freq

View File

@ -1,17 +1,18 @@
from typing import Union
import marl_factory_grid.modules.machines.constants
from marl_factory_grid.environment.actions import Action
from marl_factory_grid.utils.results import ActionResult
from marl_factory_grid.modules.machines import constants as m
from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.actions import Action
from marl_factory_grid.modules.machines import constants as m
from marl_factory_grid.utils import helpers as h
from marl_factory_grid.utils.results import ActionResult
class MachineAction(Action):
def __init__(self):
"""
Attempts to maintain the machine and returns an action result if successful.
"""
super().__init__(m.MACHINE_ACTION, m.MAINTAIN_VALID, m.MAINTAIN_FAIL)
def do(self, entity, state) -> Union[None, ActionResult]:

View File

@ -13,6 +13,14 @@ class Machine(Entity):
return self._encodings[self.status]
def __init__(self, *args, work_interval: int = 10, pause_interval: int = 15, **kwargs):
"""
Represents a machine entity that the maintainer will try to maintain.
:param work_interval: How long should the machine work before pausing.
:type work_interval: int
:param pause_interval: How long should the machine pause before continuing to work.
:type pause_interval: int
"""
super(Machine, self).__init__(*args, **kwargs)
self._intervals = dict({m.STATE_IDLE: pause_interval, m.STATE_WORK: work_interval})
self._encodings = dict({m.STATE_IDLE: pause_interval, m.STATE_WORK: work_interval})
@ -21,7 +29,10 @@ class Machine(Entity):
self.health = 100
self._counter = 0
def maintain(self):
def maintain(self) -> bool:
"""
Attempts to maintain the machine by increasing its health.
"""
if self.status == m.STATE_WORK:
return c.NOT_VALID
if self.health <= 98:
@ -31,6 +42,15 @@ class Machine(Entity):
return c.NOT_VALID
def tick(self, state):
"""
Updates the machine's mode (work, pause) depending on its current counter and whether an agent is currently on
its position. If no agent is standing on the machine's position, it decrements its own health.
:param state: The current game state.
:type state: GameState
:return: The result of the tick operation on the machine.
:rtype: TickResult | None
"""
others = state.entities.pos_dict[self.pos]
if self.status == m.STATE_MAINTAIN and any([c.AGENT in x.name for x in others]):
return TickResult(identifier=self.name, validity=c.VALID, entity=self)
@ -48,6 +68,9 @@ class Machine(Entity):
return None
def reset_counter(self):
"""
Internal Usage
"""
self._counter = self._intervals[self.status]
def render(self):

View File

@ -20,5 +20,8 @@ class Machines(Collection):
return True
def __init__(self, *args, **kwargs):
"""
A Collection of Machines.
"""
super(Machines, self).__init__(*args, **kwargs)

View File

@ -15,7 +15,15 @@ from ..doors import DoorUse
class Maintainer(Entity):
def __init__(self, objective: str, action: Action, *args, **kwargs):
def __init__(self, objective, action, *args, **kwargs):
"""
Represents the maintainer entity that aims to maintain machines.
:param objective: The maintainer's objective, e.g., "Machines".
:type objective: str
:param action: The default action to be performed by the maintainer.
:type action: Action
"""
super().__init__(*args, **kwargs)
self.action = action
self.actions = [x() for x in ALL_BASEACTIONS] + [DoorUse()]
@ -26,6 +34,16 @@ class Maintainer(Entity):
self._last_serviced = 'None'
def tick(self, state):
"""
If there is an objective at the current position, the maintainer performs its action on the objective.
If the objective has changed since the last servicing, the maintainer performs the action and updates
the last serviced objective. Otherwise, it calculates a move action and performs it.
:param state: The current game state.
:type state: GameState
:return: The result of the action performed by the maintainer.
:rtype: ActionResult
"""
if found_objective := h.get_first(state[self.objective].by_pos(self.pos)):
if found_objective.name != self._last_serviced:
result = self.action.do(self, state)
@ -40,9 +58,24 @@ class Maintainer(Entity):
return result
def set_state(self, action_result):
"""
Updates the maintainers own status with an action result.
"""
self._status = action_result
def get_move_action(self, state) -> Action:
"""
Retrieves the next move action for the agent.
If a path is not already determined, the agent calculates the shortest path to its objective, considering doors
and obstacles. If a closed door is found in the calculated path, the agent attempts to open it.
:param state: The current state of the environment.
:type state: GameState
:return: The chosen move action for the agent.
:rtype: Action
"""
if self._path is None or not len(self._path):
if not self._next:
self._next = list(state[self.objective].values()) + [Floor(*state.random_free_position)]
@ -70,17 +103,27 @@ class Maintainer(Entity):
raise EnvironmentError
return action_obj
def calculate_route(self, entity, floortile_graph):
def calculate_route(self, entity, floortile_graph) -> list:
"""
:returns: path, include both the source and target position
:rtype: list
"""
route = nx.shortest_path(floortile_graph, self.pos, entity.pos)
return route[1:]
def _closed_door_in_path(self, state):
"""
Internal Use
"""
if self._path:
return h.get_first(state[do.DOORS].by_pos(self._path[0]), lambda x: x.is_closed)
else:
return None
def _predict_move(self, state):
def _predict_move(self, state) -> Action:
"""
Internal Use
"""
next_pos = self._path[0]
if any(x for x in state.entities.pos_dict[next_pos] if x.var_can_collide) > 0:
action = c.NOOP

View File

@ -9,12 +9,26 @@ from ..machines.actions import MachineAction
class Maintainers(Collection):
_entity = Maintainer
var_can_collide = True
var_can_move = True
var_is_blocking_light = False
var_has_position = True
@property
def var_can_collide(self):
return True
@property
def var_can_move(self):
return True
@property
def var_is_blocking_light(self):
return False
@property
def var_has_position(self):
return True
def __init__(self, *args, **kwargs):
"""
A collection of maintainers
"""
super().__init__(*args, **kwargs)
def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args):

View File

@ -9,6 +9,9 @@ from . import constants as M
class MoveMaintainers(Rule):
def __init__(self):
"""
This rule is responsible for moving the maintainers at every step of the environment.
"""
super().__init__()
def tick_step(self, state) -> List[TickResult]:
@ -22,6 +25,9 @@ class MoveMaintainers(Rule):
class DoneAtMaintainerCollision(Rule):
def __init__(self):
"""
When active, this rule stops the environment after a maintainer reports a collision with another entity.
"""
super().__init__()
def on_check_done(self, state) -> List[DoneResult]:

View File

@ -47,7 +47,7 @@ class AgentSingleZonePlacement(Rule):
class IndividualDestinationZonePlacement(Rule):
def __init__(self):
raise NotImplementedError("This is rpetty new, and needs to be debugged, after the zones")
raise NotImplementedError("This is pretty new, and needs to be debugged, after the zones")
super().__init__()
def on_reset(self, state):

View File

@ -1,8 +1,7 @@
import importlib
from collections import defaultdict
from pathlib import PurePath, Path
from typing import Union, Dict, List, Iterable, Callable
from typing import Union, Dict, List, Iterable, Callable, Any
import numpy as np
from numpy.typing import ArrayLike
@ -21,23 +20,22 @@ This file is used for:
In this file they are defined to be used across the entire package.
"""
LEVELS_DIR = 'levels' # for use in studies and experiments
STEPS_START = 1 # Define where to the stepcount; which is the first step
LEVELS_DIR = 'levels' # for use in studies and experiments
STEPS_START = 1 # Define where to the stepcount; which is the first step
IGNORED_DF_COLUMNS = ['Episode', 'Run', # For plotting, which values are ignored when loading monitor files
IGNORED_DF_COLUMNS = ['Episode', 'Run', # For plotting, which values are ignored when loading monitor files
'train_step', 'step', 'index', 'dirt_amount', 'dirty_pos_count', 'terminal_observation',
'episode']
POS_MASK = np.asarray([[[-1, -1], [0, -1], [1, -1]],
[[-1, 0], [0, 0], [1, 0]],
[[-1, 1], [0, 1], [1, 1]]])
[[-1, 0], [0, 0], [1, 0]],
[[-1, 1], [0, 1], [1, 1]]])
MOVEMAP = defaultdict(lambda: (0, 0),
MOVEMAP = defaultdict(lambda: (0, 0),
{c.NORTH: (-1, 0), c.NORTHEAST: (-1, 1),
c.EAST: (0, 1), c.SOUTHEAST: (1, 1),
c.SOUTH: (1, 0), c.SOUTHWEST: (1, -1),
c.WEST: (0, -1), c.NORTHWEST: (-1, -1)
c.EAST: (0, 1), c.SOUTHEAST: (1, 1),
c.SOUTH: (1, 0), c.SOUTHWEST: (1, -1),
c.WEST: (0, -1), c.NORTHWEST: (-1, -1)
}
)
@ -80,7 +78,19 @@ class ObservationTranslator:
self._this_named_obs_space = this_named_observation_space
self._per_agent_named_obs_space = list(per_agent_named_obs_spaces)
def translate_observation(self, agent_idx: int, obs):
def translate_observation(self, agent_idx, obs) -> ArrayLike:
"""
Translates the observation of the given agent.
:param agent_idx: Agent identifier.
:type agent_idx: int
:param obs: The observation to be translated.
:type obs: ArrayLike
:return: The translated observation.
:rtype: ArrayLike
"""
target_obs_space = self._per_agent_named_obs_space[agent_idx]
translation = dict()
for name, idxs in target_obs_space.items():
@ -98,7 +108,10 @@ class ObservationTranslator:
translation = dict(sorted(translation.items()))
return np.concatenate(list(translation.values()), axis=-3)
def translate_observations(self, observations: List[ArrayLike]):
def translate_observations(self, observations) -> List[ArrayLike]:
"""
Internal Usage
"""
return [self.translate_observation(idx, observation) for idx, observation in enumerate(observations)]
def __call__(self, observations):
@ -129,11 +142,26 @@ class ActionTranslator:
self._per_agent_idx_actions = [{idx: a for a, idx in x.items()} for x in self._per_agent_named_action_space]
def translate_action(self, agent_idx: int, action: int):
"""
Translates the observation of the given agent.
:param agent_idx: Agent identifier.
:type agent_idx: int
:param action: The action to be translated.
:type action: int
:return: The translated action.
:rtype: ArrayLike
"""
named_action = self._per_agent_idx_actions[agent_idx][action]
translated_action = self._target_named_action_space[named_action]
return translated_action
def translate_actions(self, actions: List[int]):
"""
Intenal Usage
"""
return [self.translate_action(idx, action) for idx, action in enumerate(actions)]
def __call__(self, actions):
@ -179,6 +207,13 @@ def one_hot_level(level, symbol: str):
def is_move(action_name: str):
"""
Check if the given action name corresponds to a movement action.
:param action_name: The name of the action to check.
:type action_name: str
:return: True if the action is a movement action, False otherwise.
"""
return action_name in MOVEMAP.keys()
@ -208,7 +243,18 @@ def asset_str(agent):
def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''):
"""Locate an object by name or dotted path, importing as necessary."""
"""
Locate an object by name or dotted path.
:param class_name: The class name to be imported
:type class_name: str
:param folder_path: The path to the module containing the class.
:type folder_path: Union[str, PurePath]
:return: The imported module class.
:raises AttributeError: If the specified class is not found in the provided folder path.
"""
import sys
sys.path.append("../../environment")
folder_path = Path(folder_path).resolve()
@ -220,15 +266,15 @@ def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''):
for module_path in module_paths:
module_parts = [x.replace('.py', '') for idx, x in enumerate(module_path.parts) if idx >= package_pos]
mod = importlib.import_module('.'.join(module_parts))
all_found_modules.extend([x for x in dir(mod) if (not(x.startswith('__') or len(x) <= 2) and x.istitle())
and x not in ['Entity', 'NamedTuple', 'List', 'Rule', 'Union',
all_found_modules.extend([x for x in dir(mod) if (not (x.startswith('__') or len(x) <= 2) and x.istitle())
and x not in ['Entity', 'NamedTuple', 'List', 'Rule', 'Union',
'TickResult', 'ActionResult', 'Action', 'Agent',
'RenderEntity', 'TemplateRule', 'Objects', 'PositionMixin',
'IsBoundMixin', 'EnvObject', 'EnvObjects', 'Dict', 'Any', 'Factory',
'Move8']])
try:
model_class = mod.__getattribute__(class_name)
return model_class
module_class = mod.__getattribute__(class_name)
return module_class
except AttributeError:
continue
raise AttributeError(f'Class "{class_name}" was not found in "{folder_path.name}"', list(set(all_found_modules)))
@ -244,9 +290,33 @@ def add_pos_name(name_str, bound_e):
return name_str
def get_first(iterable: Iterable, filter_by: Callable[[any], bool] = lambda _: True):
def get_first(iterable: Iterable, filter_by: Callable[[any], bool] = lambda _: True) -> Any | None:
"""
Get the first element from an iterable that satisfies the specified condition.
:param iterable: The iterable to search.
:type iterable: Iterable
:param filter_by: A function that filters elements, defaults to lambda _: True.
:type filter_by: Callable[[Any], bool]
:return: The first element that satisfies the condition, or None if none is found.
:rtype: Any
"""
return next((x for x in iterable if filter_by(x)), None)
def get_first_index(iterable: Iterable, filter_by: Callable[[any], bool] = lambda _: True):
def get_first_index(iterable: Iterable, filter_by: Callable[[any], bool] = lambda _: True) -> int | None:
"""
Get the index of the first element from an iterable that satisfies the specified condition.
:param iterable: The iterable to search.
:type iterable: Iterable
:param filter_by: A function that filters elements, defaults to lambda _: True.
:type filter_by: Callable[[Any], bool]
:return: The index of the first element that satisfies the condition, or None if none is found.
:rtype: Optional[int]
"""
return next((idx for idx, x in enumerate(iterable) if filter_by(x)), None)

View File

@ -15,9 +15,24 @@ class LevelParser(object):
@property
def pomdp_d(self):
"""
Internal Usage
"""
return self.pomdp_r * 2 + 1
def __init__(self, level_file_path: PathLike, entity_parse_dict: Dict[Entities, dict], pomdp_r=0):
"""
Parses a level file and creates the initial state of the environment.
:param level_file_path: Path to the level file.
:type level_file_path: PathLike
:param entity_parse_dict: Dictionary specifying how to parse different entities.
:type entity_parse_dict: Dict[Entities, dict]
:param pomdp_r: The POMDP radius. Defaults to 0.
:type pomdp_r: int
"""
self.pomdp_r = pomdp_r
self.e_p_dict = entity_parse_dict
self._parsed_level = h.parse_level(Path(level_file_path))
@ -25,14 +40,30 @@ class LevelParser(object):
self.level_shape = level_array.shape
self.size = self.pomdp_r ** 2 if self.pomdp_r else np.prod(self.level_shape)
def get_coordinates_for_symbol(self, symbol, negate=False):
def get_coordinates_for_symbol(self, symbol, negate=False) -> np.ndarray:
"""
Get the coordinates for a given symbol in the parsed level.
:param symbol: The symbol to search for.
:param negate: If True, get coordinates not matching the symbol. Defaults to False.
:return: Array of coordinates.
:rtype: np.ndarray
"""
level_array = h.one_hot_level(self._parsed_level, symbol)
if negate:
return np.argwhere(level_array != c.VALUE_OCCUPIED_CELL)
else:
return np.argwhere(level_array == c.VALUE_OCCUPIED_CELL)
def do_init(self):
def do_init(self) -> Entities:
"""
Initialize the environment map state by creating entities such as Walls, Agents or Machines according to the
entity parse dict.
:return: A dict of all parsed entities with their positions.
:rtype: Entities
"""
# Global Entities
list_of_all_positions = ([tuple(f) for f in self.get_coordinates_for_symbol(c.SYMBOL_WALL, negate=True)])
entities = Entities(list_of_all_positions)

View File

@ -18,12 +18,24 @@ class OBSBuilder(object):
@property
def pomdp_d(self):
"""
TODO
:return:
"""
if self.pomdp_r:
return (self.pomdp_r * 2) + 1
else:
return 0
def __init__(self, level_shape: np.size, state: Gamestate, pomdp_r: int):
"""
TODO
:return:
"""
self.all_obs = dict()
self.ray_caster = dict()

View File

@ -7,6 +7,12 @@ from numba import njit
class RayCaster:
def __init__(self, agent, pomdp_r, degs=360):
"""
TODO
:return:
"""
self.agent = agent
self.pomdp_r = pomdp_r
self.n_rays = 100 # (self.pomdp_r + 1) * 8

View File

@ -33,6 +33,12 @@ class Renderer:
lvl_padded_shape: Union[Tuple[int, int], None] = None,
cell_size: int = 40, fps: int = 7, factor: float = 0.9,
grid_lines: bool = True, view_radius: int = 2):
"""
TODO
:return:
"""
# TODO: Customn_assets paths
self.grid_h, self.grid_w = lvl_shape
self.lvl_padded_shape = lvl_padded_shape if lvl_padded_shape is not None else lvl_shape

View File

@ -3,13 +3,16 @@ from dataclasses import dataclass
from marl_factory_grid.environment.entity.object import Object
TYPE_VALUE = 'value'
TYPE_VALUE = 'value'
TYPE_REWARD = 'reward'
TYPES = [TYPE_VALUE, TYPE_REWARD]
@dataclass
class InfoObject:
"""
Data class representing information about an entity or the global environment.
"""
identifier: str
val_type: str
value: Union[float, int]
@ -17,6 +20,16 @@ class InfoObject:
@dataclass
class Result:
"""
A generic result class representing outcomes of operations or actions.
Attributes:
- identifier: A unique identifier for the result.
- validity: A boolean indicating whether the operation or action was successful.
- reward: The reward associated with the result, if applicable.
- value: The value associated with the result, if applicable.
- entity: The entity associated with the result, if applicable.
"""
identifier: str
validity: bool
reward: Union[float, None] = None
@ -24,6 +37,11 @@ class Result:
entity: Object = None
def get_infos(self):
"""
Get information about the result.
:return: A list of InfoObject representing different types of information.
"""
n = self.entity.name if self.entity is not None else "Global"
# Return multiple Info Dicts
return [InfoObject(identifier=f'{n}_{self.identifier}',
@ -38,16 +56,37 @@ class Result:
return f'{self.__class__.__name__}({self.identifier.capitalize()} {valid}valid{reward}{value}{entity})'
@dataclass
class TickResult(Result):
pass
@dataclass
class ActionResult(Result):
def __init__(self, *args, action_introduced_collision: bool = False, **kwargs):
"""
A specific Result class representing outcomes of actions.
:param action_introduced_collision: Wether the action did introduce a colision between agents or other entities.
These need to be able to collide.
"""
super().__init__(*args, **kwargs)
self.action_introduced_collision = action_introduced_collision
pass
@dataclass
class DoneResult(Result):
"""
A specific Result class representing the completion of an action or operation.
"""
pass
@dataclass
class State(Result):
# TODO: change identifier to action/last_action
pass
@dataclass
class TickResult(Result):
"""
A specific Result class representing outcomes of tick operations.
"""
pass

View File

@ -1,5 +1,4 @@
from itertools import islice
from itertools import islice
from typing import List, Tuple
import numpy as np
@ -15,6 +14,12 @@ from marl_factory_grid.utils.results import Result
class StepRules:
def __init__(self, *args):
"""
TODO
:return:
"""
if args:
self.rules = list(args)
else:
@ -80,6 +85,12 @@ class Gamestate(object):
return [y for x in self.entities for y in x if x.var_can_move]
def __init__(self, entities, agents_conf, rules: List[Rule], tests: [Test], lvl_shape, env_seed=69, verbose=False):
"""
TODO
:return:
"""
self.lvl_shape = lvl_shape
self.entities = entities
self.curr_step = 0

View File

@ -2,38 +2,68 @@ import importlib
import inspect
from os import PathLike
from pathlib import Path
from typing import Union
import yaml
from marl_factory_grid.environment import constants as c
from marl_factory_grid.utils.helpers import locate_and_import_class
ACTION = 'Action'
GENERAL = 'General'
ENTITIES = 'Objects'
ACTION = 'Action'
GENERAL = 'General'
ENTITIES = 'Objects'
OBSERVATIONS = 'Observations'
RULES = 'Rule'
ASSETS = 'Assets'
EXCLUDED = ['identifier', 'args', 'kwargs', 'Move', 'Agent', 'GlobalPositions', 'Walls',
'TemplateRule', 'Entities', 'EnvObjects', 'Zones', ]
RULES = 'Rule'
TESTS = 'Tests'
EXCLUDED = ['identifier', 'args', 'kwargs', 'Move', 'Agent', 'GlobalPositions', 'Walls',
'TemplateRule', 'Entities', 'EnvObjects', 'Zones', ]
class ConfigExplainer:
def __init__(self, custom_path: Union[None, PathLike] = None):
self.base_path = Path(__file__).parent.parent.resolve()
self.custom_path = custom_path
self.searchspace = [ACTION, GENERAL, ENTITIES, OBSERVATIONS, RULES, ASSETS]
def __init__(self, custom_path: None | PathLike = None):
"""
This utility serves as a helper for debugging and exploring available modules and classes.
Does not do anything unless told.
The functions get_xxxxx() retrieves and returns the information and save_xxxxx() dumps them to disk.
def explain_module(self, class_to_explain):
get_all() and save_all() helps geting a general overview.
When provided with a custom path, your own modules become available.
:param custom_path: Path to your custom module folder.
"""
self.base_path = Path(__file__).parent.parent.resolve()
self.custom_path = Path(custom_path) if custom_path is not None else custom_path
self.searchspace = [ACTION, GENERAL, ENTITIES, OBSERVATIONS, RULES, TESTS]
@staticmethod
def _explain_module(class_to_explain):
"""
INTERNAL USE ONLY
"""
parameters = inspect.signature(class_to_explain).parameters
explained = {class_to_explain.__name__:
{key: val.default for key, val in parameters.items() if key not in EXCLUDED}
}
return explained
def _get_by_identifier(self, identifier):
"""
INTERNAL USE ONLY
"""
entities_base_cls = locate_and_import_class(identifier, self.base_path)
module_paths = [x.resolve() for x in self.base_path.rglob('*.py') if x.is_file() and '__init__' not in x.name]
found_entities = self._load_and_compare(entities_base_cls, module_paths)
if self.custom_path is not None:
module_paths = [x.resolve() for x in self.custom_path.rglob('*.py') if x.is_file()
and '__init__' not in x.name]
found_entities.update(self._load_and_compare(entities_base_cls, module_paths))
return found_entities
def _load_and_compare(self, compare_class, paths):
"""
INTERNAL USE ONLY
"""
conf = {}
package_pos = next(idx for idx, x in enumerate(Path(__file__).resolve().parts) if x == 'marl_factory_grid')
for module_path in paths:
@ -44,40 +74,97 @@ class ConfigExplainer:
mod = mods.__getattribute__(key)
try:
if issubclass(mod, compare_class) and mod != compare_class:
conf.update(self.explain_module(mod))
conf.update(self._explain_module(mod))
except TypeError:
pass
return conf
def save_actions(self, output_conf_file: PathLike = Path('../../quickstart') / 'explained_actions.yml'):
self._save_to_file(self.get_entities(), output_conf_file, ACTION)
@staticmethod
def _save_to_file(data: dict, filepath: PathLike, tag: str = ''):
"""
INTERNAL USE ONLY
"""
filepath = Path(filepath)
yaml.Dumper.ignore_aliases = lambda *args: True
with filepath.open('w') as f:
yaml.dump(data, f, encoding='utf-8')
print(f'Example config {"for " + tag + " " if tag else " "}dumped')
print(f'See file: {filepath}')
def get_actions(self):
def get_actions(self) -> list[str]:
"""
Retrieve all actions from module folders.
:returns: A list of all available actions.
"""
actions = self._get_by_identifier(ACTION)
assert all(not x for x in actions.values()), 'Please only provide Names, no Mappings.'
actions = list(actions.keys())
actions.extend([c.MOVE8, c.MOVE4])
# TODO: Print to file!
return actions
def save_entities(self, output_conf_file: PathLike = Path('../../quickstart') / 'explained_entities.yml'):
self._save_to_file(self.get_entities(), output_conf_file, ENTITIES)
def get_all(self) -> dict[str]:
"""
Retrieve all available configurations from module folders.
:returns: A dictionary of all available configurations.
"""
config_dict = {
'General': self.get_general_section(),
'Agents': self.get_agent_section(),
'Entities': self.get_entities(),
'Rules': self.get_rules()
}
return config_dict
def get_entities(self):
"""
Retrieve all entities from module folders.
:returns: A list of all available entities.
"""
entities = self._get_by_identifier(ENTITIES)
return entities
def save_rules(self, output_conf_file: PathLike = Path('../../quickstart') / 'explained_rules.yml'):
self._save_to_file(self.get_entities(), output_conf_file, RULES)
@staticmethod
def get_general_section():
"""
Build the general section.
def get_rules(self):
:returns: A list of all available entities.
"""
general = {'level_name': 'rooms', 'env_seed': 69, 'verbose': False,
'pomdp_r': 3, 'individual_rewards': True, 'tests': False}
return general
def get_agent_section(self):
"""
Build the Agent section and retrieve all available actions and observations from module folders.
:returns: Agent section.
"""
agents = dict(
ExampleAgentName=dict(
Actions=self.get_actions(),
Observations=self.get_observations())),
return agents
def get_rules(self) -> dict[str]:
"""
Retrieve all rules from module folders.
:returns: All available rules.
"""
rules = self._get_by_identifier(RULES)
return rules
def get_assets(self):
pass
def get_observations(self) -> list[str]:
"""
Retrieve all agent observations from module folders.
def get_observations(self):
:returns: A list of all available observations.
"""
names = [c.ALL, c.COMBINED, c.SELF, c.OTHERS, "Agent['ExampleAgentName']"]
for key, val in self.get_entities().items():
try:
@ -95,45 +182,47 @@ class ConfigExplainer:
names.extend(e)
return names
def _get_by_identifier(self, identifier):
entities_base_cls = locate_and_import_class(identifier, self.base_path)
module_paths = [x.resolve() for x in self.base_path.rglob('*.py') if x.is_file() and '__init__' not in x.name]
found_entities = self._load_and_compare(entities_base_cls, module_paths)
if self.custom_path is not None:
module_paths = [x.resolve() for x in self.custom_path.rglob('*.py') if x.is_file()
and '__init__' not in x.name]
found_entities.update(self._load_and_compare(entities_base_cls, module_paths))
return found_entities
def save_actions(self, output_conf_file: PathLike = Path('../../quickstart') / 'actions.yml'):
"""
Write all availale actions to a file.
:param output_conf_file: File to write to. Defaults to ../../quickstart/actions.yml
"""
self._save_to_file(self.get_entities(), output_conf_file, ACTION)
def save_all(self, output_conf_file: PathLike = Path('../../quickstart') / 'explained.yml'):
def save_entities(self, output_conf_file: PathLike = Path('../../quickstart') / 'entities.yml'):
"""
Write all availale entities to a file.
:param output_conf_file: File to write to. Defaults to ../../quickstart/entities.yml
"""
self._save_to_file(self.get_entities(), output_conf_file, ENTITIES)
def save_observations(self, output_conf_file: PathLike = Path('../../quickstart') / 'observations.yml'):
"""
Write all availale observations to a file.
:param output_conf_file: File to write to. Defaults to ../../quickstart/observations.yml
"""
self._save_to_file(self.get_entities(), output_conf_file, OBSERVATIONS)
def save_rules(self, output_conf_file: PathLike = Path('../../quickstart') / 'rules.yml'):
"""
Write all availale rules to a file.
:param output_conf_file: File to write to. Defaults to ../../quickstart/rules.yml
"""
self._save_to_file(self.get_entities(), output_conf_file, RULES)
def save_all(self, output_conf_file: PathLike = Path('../../quickstart') / 'all.yml'):
"""
Write all availale keywords to a file.
:param output_conf_file: File to write to. Defaults to ../../quickstart/all.yml
"""
self._save_to_file(self.get_all(), output_conf_file, 'ALL')
def get_all(self):
config_dict = {GENERAL: {'level_name': 'rooms', 'env_seed': 69, 'verbose': False,
'pomdp_r': 3, 'individual_rewards': True},
'Agents': dict(
ExampleAgentName=dict(
Actions=self.get_actions(),
Observations=self.get_observations())),
'Entities': self.get_entities(),
'Rules': self.get_rules(),
'Assets': self.get_assets()}
return config_dict
def _save_to_file(self, data: dict, filepath: PathLike, tag: str = ''):
filepath = Path(filepath)
yaml.Dumper.ignore_aliases = lambda *args: True
with filepath.open('w') as f:
yaml.dump(data, f, encoding='utf-8')
print(f'Example config {"for " + tag + " " if tag else " "}dumped')
print(f'See file: {filepath}')
if __name__ == '__main__':
ce = ConfigExplainer()
ce.get_actions()
ce.get_entities()
ce.get_rules()
ce.get_observations()
ce.get_assets()
# ce.get_actions()
# ce.get_entities()
# ce.get_rules()
# ce.get_observations()
all_conf = ce.get_all()
ce.save_all()

View File

@ -18,6 +18,10 @@ class MarlFrameStack(gym.ObservationWrapper):
@dataclass
class RenderEntity:
"""
This class defines the interface to communicate with the Renderer. Name and pos are used to load an asset file
named name.png and place it at the given pos.
"""
name: str
pos: np.array
value: float = 1
@ -30,6 +34,10 @@ class RenderEntity:
@dataclass
class Floor:
"""
This class defines Entity like Floor-Objects, which do not come with the overhead.
Solely used for field-of-view calculation.
"""
@property
def encoding(self):

View File

@ -29,7 +29,7 @@ if __name__ == '__main__':
ce.save_all(run_path / 'all_out.yaml')
# Path to config File
path = Path('marl_factory_grid/configs/default_config.yaml')
path = Path('marl_factory_grid/configs/clean_and_bring.yaml')
# Env Init
factory = Factory(path)

View File

@ -5,7 +5,7 @@ long_description = (this_directory / "README.md").read_text()
setup(name='Marl-Factory-Grid',
version='0.2.0',
version='0.2.3',
description='A framework to research MARL agents in various setings.',
author='Steffen Illium',
author_email='steffen.illium@ifi.lmu.de',