new rules, new spawn logic, small fixes, default and narrow corridor debugged

This commit is contained in:
Steffen Illium 2023-11-09 17:50:20 +01:00
parent 9b9c6e0385
commit 06a5130b25
67 changed files with 768 additions and 921 deletions

View File

@ -56,7 +56,7 @@ Just define what your environment needs in a *yaml*-configfile like:
- Items
Rules:
Defaults: {}
Collision:
WatchCollisions:
done_at_collisions: !!bool True
ItemRespawn:
spawn_freq: 5

View File

@ -1,6 +1 @@
from .environment import *
from .modules import *
from .utils import *
from .quickstart import init

View File

@ -30,4 +30,3 @@ class TSPTargetAgent(TSPBaseAgent):
except (StopIteration, UnboundLocalError):
print('Will not happen')
return action_obj

View File

@ -26,12 +26,16 @@ def points_to_graph(coordiniates, allow_euclidean_connections=True, allow_manhat
assert allow_euclidean_connections or allow_manhattan_connections
possible_connections = itertools.combinations(coordiniates, 2)
graph = nx.Graph()
for a, b in possible_connections:
diff = np.linalg.norm(np.asarray(a)-np.asarray(b))
if allow_manhattan_connections and allow_euclidean_connections and diff <= np.sqrt(2):
graph.add_edge(a, b)
elif not allow_manhattan_connections and allow_euclidean_connections and diff == np.sqrt(2):
graph.add_edge(a, b)
elif allow_manhattan_connections and not allow_euclidean_connections and diff == 1:
graph.add_edge(a, b)
if allow_manhattan_connections and allow_euclidean_connections:
graph.add_edges_from(
(a, b) for a, b in possible_connections if np.linalg.norm(np.asarray(a) - np.asarray(b)) <= np.sqrt(2)
)
elif not allow_manhattan_connections and allow_euclidean_connections:
graph.add_edges_from(
(a, b) for a, b in possible_connections if np.linalg.norm(np.asarray(a) - np.asarray(b)) == np.sqrt(2)
)
elif allow_manhattan_connections and not allow_euclidean_connections:
graph.add_edges_from(
(a, b) for a, b in possible_connections if np.linalg.norm(np.asarray(a) - np.asarray(b)) == 1
)
return graph

View File

@ -22,26 +22,41 @@ Agents:
- Inventory
- DropOffLocations
- Maintainers
# This is special for agents, as each one is differten and can act as an adversary e.g.
Positions:
- (16, 7)
- (16, 6)
- (16, 3)
- (16, 4)
- (16, 5)
Entities:
Batteries:
initial_charge: 0.8
per_action_costs: 0.02
ChargePods: {}
Destinations: {}
ChargePods:
coords_or_quantity: 2
Destinations:
coords_or_quantity: 1
spawn_mode: GROUPED
DirtPiles:
coords_or_quantity: 10
initial_amount: 2
clean_amount: 1
dirt_spawn_r_var: 0.1
initial_amount: 2
initial_dirt_ratio: 0.05
max_global_amount: 20
max_local_amount: 5
Doors: {}
DropOffLocations: {}
Doors:
DropOffLocations:
coords_or_quantity: 1
max_dropoff_storage_size: 0
GlobalPositions: {}
Inventories: {}
Items: {}
Machines: {}
Maintainers: {}
Items:
coords_or_quantity: 5
Machines:
coords_or_quantity: 2
Maintainers:
coords_or_quantity: 1
Zones: {}
General:
@ -49,32 +64,31 @@ General:
individual_rewards: true
level_name: large
pomdp_r: 3
verbose: false
verbose: True
tests: false
Rules:
SpawnAgents: {}
DoneAtBatteryDischarge: {}
Collision:
done_at_collisions: false
AssignGlobalPositions: {}
DoneAtDestinationReachAny: {}
DestinationReachReward: {}
SpawnDestinations:
n_dests: 1
spawn_mode: GROUPED
DoneOnAllDirtCleaned: {}
SpawnDirt:
spawn_freq: 15
# Environment Dynamics
EntitiesSmearDirtOnMove:
smear_ratio: 0.2
DoorAutoClose:
close_frequency: 10
ItemRules:
max_dropoff_storage_size: 0
n_items: 5
n_locations: 5
spawn_frequency: 15
MaxStepsReached:
MoveMaintainers:
# Respawn Stuff
RespawnDirt:
respawn_freq: 15
RespawnItems:
respawn_freq: 15
# Utilities
WatchCollisions:
done_at_collisions: false
# Done Conditions
DoneAtDestinationReachAny:
DoneOnAllDirtCleaned:
DoneAtBatteryDischarge:
DoneAtMaintainerCollision:
DoneAtMaxStepsReached:
max_steps: 500
# AgentSingleZonePlacement:
# n_zones: 4

View File

@ -1,3 +1,10 @@
General:
env_seed: 69
individual_rewards: true
level_name: narrow_corridor
pomdp_r: 0
verbose: true
Agents:
Wolfgang:
Actions:
@ -10,6 +17,7 @@ Agents:
Positions:
- (2, 1)
- (2, 5)
is_blocking_pos: true
Karl-Heinz:
Actions:
- Noop
@ -21,26 +29,30 @@ Agents:
Positions:
- (2, 1)
- (2, 5)
is_blocking_pos: true
Entities:
Destinations: {}
General:
env_seed: 69
individual_rewards: true
level_name: narrow_corridor
pomdp_r: 0
verbose: true
Rules:
SpawnAgents: {}
Collision:
done_at_collisions: false
FixedDestinationSpawn:
per_agent_positions:
Destinations:
ignore_blocking: true
spawnrule:
SpawnDestinationsPerAgent:
coords_or_quantity:
Wolfgang:
- (2, 1)
- (2, 5)
Karl-Heinz:
- (2, 1)
- (2, 5)
DestinationReachAll: {}
# Whether you want to provide a numeric Position observation.
# GlobalPositions:
# normalized: false
Rules:
# Utilities
WatchCollisions:
done_at_collisions: false
# Done Conditions
# DoneAtDestinationReachAny:
DoneAtDestinationReachAll:
DoneAtMaxStepsReached:
max_steps: 500

View File

@ -48,9 +48,9 @@ class Move(Action, abc.ABC):
reward = r.MOVEMENTS_VALID if move_validity else r.MOVEMENTS_FAIL
return ActionResult(entity=entity, identifier=self._identifier, validity=move_validity, reward=reward)
else: # There is no place to go, propably collision
# This is currently handeld by the Collision rule, so that it can be switched on and off by conf.yml
# This is currently handeld by the WatchCollisions rule, so that it can be switched on and off by conf.yml
# return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=r.COLLISION)
return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=0)
return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID)
def _calc_new_pos(self, pos):
x_diff, y_diff = MOVEMAP[self._identifier]

View File

@ -10,6 +10,7 @@ AGENT = 'Agent' # Identifier of Agent-objects an
OTHERS = 'Other'
COMBINED = 'Combined'
GLOBALPOSITIONS = 'GlobalPositions' # Identifier of the global position slice
SPAWN_ENTITY_RULE = 'SpawnEntity'
# Attributes
IS_BLOCKING_LIGHT = 'var_is_blocking_light'
@ -29,7 +30,7 @@ VALUE_NO_POS = (-9999, -9999) # Invalid Position value used in the e
ACTION = 'action' # Identifier of Action-objects and groups (groups).
COLLISION = 'Collision' # Identifier to use in the context of collitions.
COLLISION = 'Collisions' # Identifier to use in the context of collitions.
# LAST_POS = 'LAST_POS' # Identifiert for retrieving an enitites last pos.
VALIDITY = 'VALIDITY' # Identifiert for retrieving the Validity of Action, Tick, etc. ...
@ -54,3 +55,5 @@ NOOP = 'Noop'
# Result Identifier
MOVEMENTS_VALID = 'motion_valid'
MOVEMENTS_FAIL = 'motion_not_valid'
DEFAULT_PATH = 'environment'
MODULE_PATH = 'modules'

View File

@ -12,14 +12,6 @@ from marl_factory_grid.environment import constants as c
class Agent(Entity):
@property
def var_is_blocking_light(self):
return False
@property
def var_can_move(self):
return True
@property
def var_is_paralyzed(self):
return len(self._paralyzed)
@ -28,14 +20,6 @@ class Agent(Entity):
def paralyze_reasons(self):
return [x for x in self._paralyzed]
@property
def var_is_blocking_pos(self):
return False
@property
def var_has_position(self):
return True
@property
def obs_tag(self):
return self.name
@ -48,10 +32,6 @@ class Agent(Entity):
def observations(self):
return self._observations
@property
def var_can_collide(self):
return True
def step_result(self):
pass
@ -60,16 +40,21 @@ class Agent(Entity):
return self._collection
@property
def state(self):
return self._state or ActionResult(entity=self, identifier=c.NOOP, validity=c.VALID, reward=0)
def var_is_blocking_pos(self):
return self._is_blocking_pos
def __init__(self, actions: List[Action], observations: List[str], *args, **kwargs):
@property
def state(self):
return self._state or ActionResult(entity=self, identifier=c.NOOP, validity=c.VALID)
def __init__(self, actions: List[Action], observations: List[str], *args, is_blocking_pos=False, **kwargs):
super(Agent, self).__init__(*args, **kwargs)
self._paralyzed = set()
self.step_result = dict()
self._actions = actions
self._observations = observations
self._state: Union[Result, None] = None
self._is_blocking_pos = is_blocking_pos
# noinspection PyAttributeOutsideInit
def clear_temp_state(self):

View File

@ -14,7 +14,7 @@ class Entity(_Object, abc.ABC):
@property
def state(self):
return self._status or ActionResult(entity=self, identifier=c.NOOP, validity=c.VALID, reward=0)
return self._status or ActionResult(entity=self, identifier=c.NOOP, validity=c.VALID)
@property
def var_has_position(self):
@ -60,6 +60,10 @@ class Entity(_Object, abc.ABC):
def pos(self):
return self._pos
def set_pos(self, pos):
assert isinstance(pos, tuple) and len(pos) == 2
self._pos = pos
@property
def last_pos(self):
try:
@ -84,7 +88,7 @@ class Entity(_Object, abc.ABC):
for observer in self.observers:
observer.notify_del_entity(self)
self._view_directory = curr_pos[0] - next_pos[0], curr_pos[1] - next_pos[1]
self._pos = next_pos
self.set_pos(next_pos)
for observer in self.observers:
observer.notify_add_entity(self)
return valid
@ -93,7 +97,7 @@ class Entity(_Object, abc.ABC):
def __init__(self, pos, bind_to=None, **kwargs):
super().__init__(**kwargs)
self._status = None
self._pos = pos
self.set_pos(pos)
self._last_pos = pos
if bind_to:
try:
@ -109,8 +113,9 @@ class Entity(_Object, abc.ABC):
def render(self):
return RenderEntity(self.__class__.__name__.lower(), self.pos)
def __repr__(self):
return super(Entity, self).__repr__() + f'(@{self.pos})'
@abc.abstractmethod
def render(self):
return RenderEntity(self.__class__.__name__.lower(), self.pos)
@property
def obs_tag(self):
@ -149,4 +154,4 @@ class Entity(_Object, abc.ABC):
except StopIteration:
pass
except ValueError:
print()
pass

View File

@ -1,24 +0,0 @@
# noinspection PyAttributeOutsideInit
class BoundEntityMixin:
@property
def bound_entity(self):
return self._bound_entity
@property
def name(self):
if self.bound_entity:
return f'{self.__class__.__name__}({self.bound_entity.name})'
else:
pass
def belongs_to_entity(self, entity):
return entity == self.bound_entity
def bind_to(self, entity):
self._bound_entity = entity
def unbind(self):
self._bound_entity = None

View File

@ -13,10 +13,6 @@ class _Object:
def __bool__(self):
return True
@property
def var_has_position(self):
return False
@property
def var_can_be_bound(self):
try:
@ -30,22 +26,14 @@ class _Object:
@property
def name(self):
if self._str_ident is not None:
name = f'{self.__class__.__name__}[{self._str_ident}]'
else:
name = f'{self.__class__.__name__}#{self.u_int}'
if self.bound_entity:
name = h.add_bound_name(name, self.bound_entity)
if self.var_has_position:
name = h.add_pos_name(name, self)
return name
return f'{self.__class__.__name__}[{self.identifier}]'
@property
def identifier(self):
if self._str_ident is not None:
return self._str_ident
else:
return self.name
return self.u_int
def reset_uid(self):
self._u_idx = defaultdict(lambda: 0)
@ -62,7 +50,15 @@ class _Object:
print(f'Following kwargs were passed, but ignored: {kwargs}')
def __repr__(self):
return f'{self.name}'
name = self.name
if self.bound_entity:
name = h.add_bound_name(name, self.bound_entity)
try:
if self.var_has_position:
name = h.add_pos_name(name, self)
except (AttributeError):
pass
return name
def __eq__(self, other) -> bool:
return other == self.identifier
@ -88,7 +84,7 @@ class _Object:
def summarize_state(self):
return dict()
def bind(self, entity):
def bind_to(self, entity):
# noinspection PyAttributeOutsideInit
self._bound_entity = entity
return c.VALID
@ -100,9 +96,6 @@ class _Object:
def bound_entity(self):
return self._bound_entity
def bind_to(self, entity):
self._bound_entity = entity
def unbind(self):
self._bound_entity = None

View File

@ -24,7 +24,7 @@ class PlaceHolder(_Object):
@property
def name(self):
return "PlaceHolder"
return self.__class__.__name__
class GlobalPosition(_Object):
@ -36,7 +36,8 @@ class GlobalPosition(_Object):
else:
return self.bound_entity.pos
def __init__(self, level_shape, *args, normalized: bool = True, **kwargs):
def __init__(self, agent, level_shape, *args, normalized: bool = True, **kwargs):
super(GlobalPosition, self).__init__(*args, **kwargs)
self.bind_to(agent)
self._normalized = normalized
self._shape = level_shape

View File

@ -5,13 +5,8 @@ from marl_factory_grid.utils.utility_classes import RenderEntity
class Wall(Entity):
@property
def var_has_position(self):
return True
@property
def var_can_collide(self):
return True
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@property
def encoding(self):
@ -19,11 +14,3 @@ class Wall(Entity):
def render(self):
return RenderEntity(c.WALL, self.pos)
@property
def var_is_blocking_pos(self):
return True
@property
def var_is_blocking_light(self):
return True

View File

@ -87,11 +87,14 @@ class Factory(gym.Env):
entities = self.map.do_init()
# Init rules
rules = self.conf.load_rules()
env_rules = self.conf.load_env_rules()
entity_rules = self.conf.load_entity_spawn_rules(entities)
env_rules.extend(entity_rules)
# Parse the agent conf
parsed_agents_conf = self.conf.parse_agents_conf()
self.state = Gamestate(entities, parsed_agents_conf, rules, self.conf.env_seed, self.conf.verbose)
self.state = Gamestate(entities, parsed_agents_conf, env_rules, self.map.level_shape,
self.conf.env_seed, self.conf.verbose)
# All is set up, trigger entity init with variable pos
# All is set up, trigger additional init (after agent entity spawn etc)

View File

@ -1,10 +1,15 @@
from marl_factory_grid.environment.entity.agent import Agent
from marl_factory_grid.environment.groups.collection import Collection
from marl_factory_grid.environment.rules import SpawnAgents
class Agents(Collection):
_entity = Agent
@property
def spawn_rule(self):
return {SpawnAgents.__name__: {}}
@property
def var_is_blocking_light(self):
return False

View File

@ -1,18 +1,25 @@
from typing import List, Tuple, Union
from typing import List, Tuple, Union, Dict
from marl_factory_grid.environment.entity.entity import Entity
from marl_factory_grid.environment.groups.objects import _Objects
# noinspection PyProtectedMember
from marl_factory_grid.environment.entity.object import _Object
import marl_factory_grid.environment.constants as c
from marl_factory_grid.utils.results import Result
class Collection(_Objects):
_entity = _Object # entity?
symbol = None
@property
def var_is_blocking_light(self):
return False
@property
def var_is_blocking_pos(self):
return False
@property
def var_can_collide(self):
return False
@ -23,29 +30,61 @@ class Collection(_Objects):
@property
def var_has_position(self):
return False
# @property
# def var_has_bound(self):
# return False # batteries, globalpos, inventories true
@property
def var_can_be_bound(self):
return False
return True
@property
def encodings(self):
return [x.encoding for x in self]
def __init__(self, size, *args, **kwargs):
super(Collection, self).__init__(*args, **kwargs)
self.size = size
def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args): # woihn mit den args
if isinstance(coords_or_quantity, int):
self.add_items([self._entity() for _ in range(coords_or_quantity)])
@property
def spawn_rule(self):
"""Prevent SpawnRule creation if Objects are spawned by map, Doors e.g."""
if self.symbol:
return None
elif self._spawnrule:
return self._spawnrule
else:
self.add_items([self._entity(pos) for pos in coords_or_quantity])
return {c.SPAWN_ENTITY_RULE: dict(collection=self, coords_or_quantity=self._coords_or_quantity)}
def __init__(self, size, *args, coords_or_quantity: int = None, ignore_blocking=False,
spawnrule: Union[None, Dict[str, dict]] = None,
**kwargs):
super(Collection, self).__init__(*args, **kwargs)
self._coords_or_quantity = coords_or_quantity
self.size = size
self._spawnrule = spawnrule
self._ignore_blocking = ignore_blocking
def trigger_spawn(self, state, *entity_args, coords_or_quantity=None, ignore_blocking=False, **entity_kwargs):
coords_or_quantity = coords_or_quantity if coords_or_quantity else self._coords_or_quantity
if self.var_has_position:
if isinstance(coords_or_quantity, int):
if ignore_blocking or self._ignore_blocking:
coords_or_quantity = state.entities.floorlist[:coords_or_quantity]
else:
coords_or_quantity = state.get_n_random_free_positions(coords_or_quantity)
self.spawn(coords_or_quantity, *entity_args, **entity_kwargs)
state.print(f'{len(coords_or_quantity)} new {self.name} have been spawned at {coords_or_quantity}')
return Result(identifier=f'{self.name}_spawn', validity=c.VALID, value=len(coords_or_quantity))
else:
if isinstance(coords_or_quantity, int):
self.spawn(coords_or_quantity, *entity_args, **entity_kwargs)
state.print(f'{coords_or_quantity} new {self.name} have been spawned randomly.')
return Result(identifier=f'{self.name}_spawn', validity=c.VALID, value=coords_or_quantity)
else:
raise ValueError(f'{self._entity.__name__} has no position!')
def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args, **entity_kwargs):
if self.var_has_position:
if isinstance(coords_or_quantity, int):
raise ValueError(f'{self._entity.__name__} should have a position!')
else:
self.add_items([self._entity(pos, *entity_args, **entity_kwargs) for pos in coords_or_quantity])
else:
if isinstance(coords_or_quantity, int):
self.add_items([self._entity(*entity_args, **entity_kwargs) for _ in range(coords_or_quantity)])
else:
raise ValueError(f'{self._entity.__name__} has no position!')
return c.VALID
def despawn(self, items: List[_Object]):
@ -115,7 +154,7 @@ class Collection(_Objects):
except StopIteration:
pass
except ValueError:
print()
pass
@property
def positions(self):

View File

@ -1,6 +1,6 @@
from collections import defaultdict
from operator import itemgetter
from random import shuffle, random
from random import shuffle
from typing import Dict
from marl_factory_grid.environment.groups.objects import _Objects
@ -12,10 +12,10 @@ class Entities(_Objects):
@staticmethod
def neighboring_positions(pos):
return (POS_MASK + pos).reshape(-1, 2)
return [tuple(x) for x in (POS_MASK + pos).reshape(-1, 2)]
def get_entities_near_pos(self, pos):
return [y for x in itemgetter(*(tuple(x) for x in self.neighboring_positions(pos)))(self.pos_dict) for y in x]
return [y for x in itemgetter(*self.neighboring_positions(pos))(self.pos_dict) for y in x]
def render(self):
return [y for x in self for y in x.render() if x is not None]
@ -35,8 +35,9 @@ class Entities(_Objects):
super().__init__()
def guests_that_can_collide(self, pos):
return[x for val in self.pos_dict[pos] for x in val if x.var_can_collide]
return [x for val in self.pos_dict[pos] for x in val if x.var_can_collide]
@property
def empty_positions(self):
empty_positions = [key for key in self.floorlist if not self.pos_dict[key]]
shuffle(empty_positions)
@ -48,11 +49,23 @@ class Entities(_Objects):
shuffle(empty_positions)
return empty_positions
def is_blocked(self):
return[key for key, val in self.pos_dict.items() if any([x.var_is_blocking_pos for x in val])]
@property
def blocked_positions(self):
blocked_positions = [key for key, val in self.pos_dict.items() if any([x.var_is_blocking_pos for x in val])]
shuffle(blocked_positions)
return blocked_positions
def is_not_blocked(self):
return[key for key, val in self.pos_dict.items() if not all([x.var_is_blocking_pos for x in val])]
@property
def free_positions_generator(self):
generator = (
key for key in self.floorlist if all(not x.var_can_collide and not x.var_is_blocking_pos
for x in self.pos_dict[key])
)
return generator
@property
def free_positions_list(self):
return [x for x in self.free_positions_generator]
def iter_entities(self):
return iter((x for sublist in self.values() for x in sublist))
@ -92,3 +105,6 @@ class Entities(_Objects):
@property
def positions(self):
return [k for k, v in self.pos_dict.items() for _ in v]
def is_occupied(self, pos):
return len([x for x in self.pos_dict[pos] if x.var_can_collide or x.var_is_blocking_pos]) >= 1

View File

@ -4,10 +4,6 @@ from marl_factory_grid.environment import constants as c
# noinspection PyUnresolvedReferences,PyTypeChecker
class IsBoundMixin:
@property
def name(self):
return f'{self.__class__.__name__}({self._bound_entity.name})'
def __repr__(self):
return f'{self.__class__.__name__}#{self._bound_entity.name}({self._data})'

View File

@ -5,11 +5,16 @@ import numpy as np
from marl_factory_grid.environment.entity.object import _Object
import marl_factory_grid.environment.constants as c
from marl_factory_grid.utils import helpers as h
class _Objects:
_entity = _Object
@property
def var_can_be_bound(self):
return False
@property
def observers(self):
return self._observers
@ -148,12 +153,12 @@ class _Objects:
def by_entity(self, entity):
try:
return next((x for x in self if x.belongs_to_entity(entity)))
return h.get_first(self, filter_by=lambda x: x.belongs_to_entity(entity))
except (StopIteration, AttributeError):
return None
def idx_by_entity(self, entity):
try:
return next((idx for idx, x in enumerate(self) if x.belongs_to_entity(entity)))
return h.get_first_index(self, filter_by=lambda x: x.belongs_to_entity(entity))
except (StopIteration, AttributeError):
return None

View File

@ -1,7 +1,10 @@
from typing import List, Union
from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.entity.util import GlobalPosition
from marl_factory_grid.environment.groups.collection import Collection
from marl_factory_grid.utils.results import Result
from marl_factory_grid.utils.states import Gamestate
class Combined(Collection):
@ -36,17 +39,17 @@ class GlobalPositions(Collection):
_entity = GlobalPosition
@property
def var_is_blocking_light(self):
return False
@property
def var_can_collide(self):
return False
@property
def var_can_be_bound(self):
return True
var_is_blocking_light = False
var_can_be_bound = True
var_can_collide = False
var_has_position = False
def __init__(self, *args, **kwargs):
super(GlobalPositions, self).__init__(*args, **kwargs)
def spawn(self, agents, level_shape, *args, **kwargs):
self.add_items([self._entity(agent, level_shape, *args, **kwargs) for agent in agents])
return [Result(identifier=f'{self.name}_spawn', validity=c.VALID, value=len(self))]
def trigger_spawn(self, state: Gamestate, *args, **kwargs) -> [Result]:
return self.spawn(state[c.AGENT], state.lvl_shape, *args, **kwargs)

View File

@ -7,9 +7,12 @@ class Walls(Collection):
_entity = Wall
symbol = c.SYMBOL_WALL
@property
def var_has_position(self):
return True
var_can_collide = True
var_is_blocking_light = True
var_can_move = False
var_has_position = True
var_can_be_bound = False
var_is_blocking_pos = True
def __init__(self, *args, **kwargs):
super(Walls, self).__init__(*args, **kwargs)

View File

@ -1,6 +1,6 @@
import abc
from random import shuffle
from typing import List
from typing import List, Collection, Union
from marl_factory_grid.environment.entity.agent import Agent
from marl_factory_grid.utils import helpers as h
@ -39,6 +39,29 @@ class Rule(abc.ABC):
return []
class SpawnEntity(Rule):
@property
def _collection(self) -> Collection:
return Collection()
@property
def name(self):
return f'{self.__class__.__name__}({self.collection.name})'
def __init__(self, collection, coords_or_quantity, ignore_blocking=False):
super().__init__()
self.coords_or_quantity = coords_or_quantity
self.collection = collection
self.ignore_blocking = ignore_blocking
def on_init(self, state, lvl_map) -> [TickResult]:
results = self.collection.trigger_spawn(state, ignore_blocking=self.ignore_blocking)
pos_str = f' on: {[x.pos for x in self.collection]}' if self.collection.var_has_position else ''
state.print(f'Initial {self.collection.__class__.__name__} were spawned{pos_str}')
return results
class SpawnAgents(Rule):
def __init__(self):
@ -46,14 +69,14 @@ class SpawnAgents(Rule):
pass
def on_init(self, state, lvl_map):
agent_conf = state.agents_conf
# agents = Agents(lvl_map.size)
agents = state[c.AGENT]
empty_positions = state.entities.empty_positions()[:len(agent_conf)]
for agent_name in agent_conf:
actions = agent_conf[agent_name]['actions'].copy()
observations = agent_conf[agent_name]['observations'].copy()
positions = agent_conf[agent_name]['positions'].copy()
empty_positions = state.entities.empty_positions[:len(state.agents_conf)]
for agent_name, agent_conf in state.agents_conf.items():
actions = agent_conf['actions'].copy()
observations = agent_conf['observations'].copy()
positions = agent_conf['positions'].copy()
other = agent_conf['other'].copy()
if positions:
shuffle(positions)
while True:
@ -61,18 +84,18 @@ class SpawnAgents(Rule):
pos = positions.pop()
except IndexError:
raise ValueError(f'It was not possible to spawn an Agent on the available position: '
f'\n{agent_name[agent_name]["positions"].copy()}')
if agents.by_pos(pos) and state.check_pos_validity(pos):
f'\n{agent_conf["positions"].copy()}')
if bool(agents.by_pos(pos)) or not state.check_pos_validity(pos):
continue
else:
agents.add_item(Agent(actions, observations, pos, str_ident=agent_name))
agents.add_item(Agent(actions, observations, pos, str_ident=agent_name, **other))
break
else:
agents.add_item(Agent(actions, observations, empty_positions.pop(), str_ident=agent_name))
agents.add_item(Agent(actions, observations, empty_positions.pop(), str_ident=agent_name, **other))
pass
class MaxStepsReached(Rule):
class DoneAtMaxStepsReached(Rule):
def __init__(self, max_steps: int = 500):
super().__init__()
@ -83,8 +106,8 @@ class MaxStepsReached(Rule):
def on_check_done(self, state):
if self.max_steps <= state.curr_step:
return [DoneResult(validity=c.VALID, identifier=self.name, reward=0)]
return [DoneResult(validity=c.NOT_VALID, identifier=self.name, reward=0)]
return [DoneResult(validity=c.VALID, identifier=self.name)]
return [DoneResult(validity=c.NOT_VALID, identifier=self.name)]
class AssignGlobalPositions(Rule):
@ -101,7 +124,7 @@ class AssignGlobalPositions(Rule):
return []
class Collision(Rule):
class WatchCollisions(Rule):
def __init__(self, done_at_collisions: bool = False):
super().__init__()
@ -132,4 +155,4 @@ class Collision(Rule):
move_failed = any(h.is_move(x.state.identifier) and not x.state.validity for x in state[c.AGENT])
if inter_entity_collision_detected or move_failed:
return [DoneResult(validity=c.VALID, identifier=c.COLLISION, reward=r.COLLISION)]
return [DoneResult(validity=c.NOT_VALID, identifier=self.name, reward=0)]
return [DoneResult(validity=c.NOT_VALID, identifier=self.name)]

View File

@ -1,4 +1,4 @@
from .actions import BtryCharge
from .entitites import Pod, Battery
from .entitites import ChargePod, Battery
from .groups import ChargePods, Batteries
from .rules import DoneAtBatteryDischarge, BatteryDecharge

View File

@ -6,6 +6,7 @@ from marl_factory_grid.utils.results import ActionResult
from marl_factory_grid.modules.batteries import constants as b
from marl_factory_grid.environment import constants as c
from marl_factory_grid.utils import helpers as h
class BtryCharge(Action):
@ -14,8 +15,8 @@ class BtryCharge(Action):
super().__init__(b.ACTION_CHARGE)
def do(self, entity, state) -> Union[None, ActionResult]:
if charge_pod := state[b.CHARGE_PODS].by_pos(entity.pos):
valid = charge_pod.charge_battery(state[b.BATTERIES].by_entity(entity))
if charge_pod := h.get_first(state[b.CHARGE_PODS].by_pos(entity.pos)):
valid = h.get_first(charge_pod.charge_battery(state[b.BATTERIES].by_entity(entity)))
if valid:
state.print(f'{entity.name} just charged batteries at {charge_pod.name}.')
else:

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.9 KiB

View File

@ -50,7 +50,7 @@ class Battery(_Object):
return summary
class Pod(Entity):
class ChargePod(Entity):
@property
def encoding(self):
@ -58,7 +58,7 @@ class Pod(Entity):
def __init__(self, *args, charge_rate: float = 0.4,
multi_charge: bool = False, **kwargs):
super(Pod, self).__init__(*args, **kwargs)
super(ChargePod, self).__init__(*args, **kwargs)
self.charge_rate = charge_rate
self.multi_charge = multi_charge

View File

@ -1,52 +1,36 @@
from typing import Union, List, Tuple
from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.groups.collection import Collection
from marl_factory_grid.modules.batteries.entitites import Pod, Battery
from marl_factory_grid.modules.batteries.entitites import ChargePod, Battery
from marl_factory_grid.utils.results import Result
class Batteries(Collection):
_entity = Battery
@property
def var_is_blocking_light(self):
return False
@property
def var_can_collide(self):
return False
@property
def var_can_move(self):
return False
@property
def var_has_position(self):
return False
@property
def var_can_be_bound(self):
return True
var_has_position = False
var_can_be_bound = True
@property
def obs_tag(self):
return self.__class__.__name__
def __init__(self, *args, **kwargs):
super(Batteries, self).__init__(*args, **kwargs)
def __init__(self, size, initial_charge_level: float=1.0, *args, **kwargs):
super(Batteries, self).__init__(size, *args, **kwargs)
self.initial_charge_level = initial_charge_level
def spawn(self, agents, initial_charge_level):
batteries = [self._entity(initial_charge_level, agent) for _, agent in enumerate(agents)]
def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], agents, *entity_args, **entity_kwargs):
batteries = [self._entity(self.initial_charge_level, agent) for _, agent in enumerate(agents)]
self.add_items(batteries)
# def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args): hat keine pos
# agents = entity_args[0]
# initial_charge_level = entity_args[1]
# batteries = [self._entity(initial_charge_level, agent) for _, agent in enumerate(agents)]
# self.add_items(batteries)
def trigger_spawn(self, state, *entity_args, coords_or_quantity=None, **entity_kwargs):
self.spawn(0, state[c.AGENT])
return Result(identifier=f'{self.name}_spawn', validity=c.VALID, value=len(self))
class ChargePods(Collection):
_entity = Pod
_entity = ChargePod
def __init__(self, *args, **kwargs):
super(ChargePods, self).__init__(*args, **kwargs)

View File

@ -49,10 +49,6 @@ class BatteryDecharge(Rule):
self.per_action_costs = per_action_costs
self.initial_charge = initial_charge
def on_init(self, state, lvl_map): # on reset?
assert len(state[c.AGENT]), "There are no agents, did you already spawn them?"
state[b.BATTERIES].spawn(state[c.AGENT], self.initial_charge)
def tick_step(self, state) -> List[TickResult]:
# Decharge
batteries = state[b.BATTERIES]
@ -66,7 +62,7 @@ class BatteryDecharge(Rule):
batteries.by_entity(agent).decharge(energy_consumption)
results.append(TickResult(self.name, reward=0, entity=agent, validity=c.VALID))
results.append(TickResult(self.name, entity=agent, validity=c.VALID))
return results
@ -82,13 +78,13 @@ class BatteryDecharge(Rule):
if self.paralyze_agents_on_discharge:
btry.bound_entity.paralyze(self.name)
results.append(
TickResult("Paralyzed", entity=btry.bound_entity, reward=0, validity=c.VALID)
TickResult("Paralyzed", entity=btry.bound_entity, validity=c.VALID)
)
state.print(f'{btry.bound_entity.name} has just been paralyzed!')
if btry.bound_entity.var_is_paralyzed and not btry.is_discharged:
btry.bound_entity.de_paralyze(self.name)
results.append(
TickResult("De-Paralyzed", entity=btry.bound_entity, reward=0, validity=c.VALID)
TickResult("De-Paralyzed", entity=btry.bound_entity, validity=c.VALID)
)
state.print(f'{btry.bound_entity.name} has just been de-paralyzed!')
return results
@ -132,7 +128,7 @@ class DoneAtBatteryDischarge(BatteryDecharge):
if any_discharged or all_discharged:
return [DoneResult(self.name, validity=c.VALID, reward=self.reward_discharge_done)]
else:
return [DoneResult(self.name, validity=c.NOT_VALID, reward=0)]
return [DoneResult(self.name, validity=c.NOT_VALID)]
class SpawnChargePods(Rule):
@ -155,7 +151,7 @@ class SpawnChargePods(Rule):
def on_init(self, state, lvl_map):
pod_collection = state[b.CHARGE_PODS]
empty_positions = state.entities.empty_positions()
empty_positions = state.entities.empty_positions
pods = pod_collection.from_coordinates(empty_positions, entity_kwargs=dict(
multi_charge=self.multi_charge, charge_rate=self.charge_rate)
)

View File

@ -1,4 +1,4 @@
from .actions import CleanUp
from .entitites import DirtPile
from .groups import DirtPiles
from .rules import SpawnDirt, EntitiesSmearDirtOnMove, DoneOnAllDirtCleaned
from .rules import EntitiesSmearDirtOnMove, DoneOnAllDirtCleaned

View File

@ -7,22 +7,6 @@ from marl_factory_grid.modules.clean_up import constants as d
class DirtPile(Entity):
@property
def var_can_collide(self):
return False
@property
def var_can_move(self):
return False
@property
def var_is_blocking_light(self):
return False
@property
def var_has_position(self):
return True
@property
def amount(self):
return self._amount

View File

@ -9,68 +9,55 @@ from marl_factory_grid.modules.clean_up.entitites import DirtPile
class DirtPiles(Collection):
_entity = DirtPile
@property
def var_is_blocking_light(self):
return False
var_is_blocking_light = False
var_can_collide = False
var_can_move = False
var_has_position = True
@property
def var_can_collide(self):
return False
@property
def var_can_move(self):
return False
@property
def var_has_position(self):
return True
@property
def amount(self):
def global_amount(self):
return sum([dirt.amount for dirt in self])
def __init__(self, *args,
max_local_amount=5,
clean_amount=1,
max_global_amount: int = 20, **kwargs):
max_global_amount: int = 20,
coords_or_quantity=10,
initial_amount=2,
amount_var=0.2,
n_var=0.2,
**kwargs):
super(DirtPiles, self).__init__(*args, **kwargs)
self.amount_var = amount_var
self.n_var = n_var
self.clean_amount = clean_amount
self.max_global_amount = max_global_amount
self.max_local_amount = max_local_amount
self.coords_or_quantity = coords_or_quantity
self.initial_amount = initial_amount
def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args):
amount_s = entity_args[0]
def trigger_spawn(self, state, coords_or_quantity=0, amount=0) -> [Result]:
coords_or_quantity = coords_or_quantity if coords_or_quantity else self.coords_or_quantity
n_new = int(abs(coords_or_quantity + (state.rng.uniform(-self.n_var, self.n_var))))
n_new = state.get_n_random_free_positions(n_new)
amounts = [amount if amount else (self.initial_amount + state.rng.uniform(-self.amount_var, self.amount_var))
for _ in range(coords_or_quantity)]
spawn_counter = 0
for idx, pos in enumerate(coords_or_quantity):
if not self.amount > self.max_global_amount:
amount = amount_s[idx] if isinstance(amount_s, list) else amount_s
for idx, (pos, a) in enumerate(zip(n_new, amounts)):
if not self.global_amount > self.max_global_amount:
if dirt := self.by_pos(pos):
dirt = next(dirt.iter())
new_value = dirt.amount + amount
new_value = dirt.amount + a
dirt.set_new_amount(new_value)
else:
dirt = DirtPile(pos, amount=amount)
self.add_item(dirt)
super().spawn([pos], amount=a)
spawn_counter += 1
else:
return Result(identifier=f'{self.name}_spawn', validity=c.NOT_VALID, reward=0,
value=spawn_counter)
return Result(identifier=f'{self.name}_spawn', validity=c.VALID, reward=0, value=spawn_counter)
return Result(identifier=f'{self.name}_spawn', validity=c.NOT_VALID, value=spawn_counter)
def trigger_dirt_spawn(self, n, amount, state, n_var=0.2, amount_var=0.2) -> Result:
free_for_dirt = [x for x in state.entities.floorlist if len(state.entities.pos_dict[x]) == 0 or (
len(state.entities.pos_dict[x]) >= 1 and isinstance(next(y for y in x), DirtPile))]
# free_for_dirt = [x for x in state[c.FLOOR]
# if len(x.guests) == 0 or (
# len(x.guests) == 1 and
# isinstance(next(y for y in x.guests), DirtPile))]
state.rng.shuffle(free_for_dirt)
new_spawn = int(abs(n + (state.rng.uniform(-n_var, n_var))))
new_amount_s = [abs(amount + (amount*state.rng.uniform(-amount_var, amount_var))) for _ in range(new_spawn)]
n_dirty_positions = free_for_dirt[:new_spawn]
return self.spawn(n_dirty_positions, new_amount_s)
return Result(identifier=f'{self.name}_spawn', validity=c.VALID, value=spawn_counter)
def __repr__(self):
s = super(DirtPiles, self).__repr__()
return f'{s[:-1]}, {self.amount})'
return f'{s[:-1]}, {self.global_amount}]'

View File

@ -22,58 +22,37 @@ class DoneOnAllDirtCleaned(Rule):
def on_check_done(self, state) -> [DoneResult]:
if len(state[d.DIRT]) == 0 and state.curr_step:
return [DoneResult(validity=c.VALID, identifier=self.name, reward=self.reward)]
return [DoneResult(validity=c.NOT_VALID, identifier=self.name, reward=0)]
return [DoneResult(validity=c.NOT_VALID, identifier=self.name)]
class SpawnDirt(Rule):
class RespawnDirt(Rule):
def __init__(self, initial_n: int = 5, initial_amount: float = 1.3,
respawn_n: int = 3, respawn_amount: float = 0.8,
n_var: float = 0.2, amount_var: float = 0.2, spawn_freq: int = 15):
def __init__(self, respawn_freq: int = 15, respawn_n: int = 5, respawn_amount: float = 1.0):
"""
Defines the spawn pattern of intial and additional 'Dirt'-entitites.
First chooses positions, then trys to spawn dirt until 'respawn_n' or the maximal global amount is reached.
If there is allready some, it is topped up to min(max_local_amount, amount).
:type spawn_freq: int
:parameter spawn_freq: In which frequency should this Rule try to spawn new 'Dirt'?
:type respawn_freq: int
:parameter respawn_freq: In which frequency should this Rule try to spawn new 'Dirt'?
:type respawn_n: int
:parameter respawn_n: How many respawn positions are considered.
:type initial_n: int
:parameter initial_n: How much initial positions are considered.
:type amount_var: float
:parameter amount_var: Variance of amount to spawn.
:type n_var: float
:parameter n_var: Variance of n to spawn.
:type respawn_amount: float
:parameter respawn_amount: Defines how much dirt 'amount' is placed every 'spawn_freq' ticks.
:type initial_amount: float
:parameter initial_amount: Defines how much dirt 'amount' is initially placed.
"""
super().__init__()
self.amount_var = amount_var
self.n_var = n_var
self.respawn_amount = respawn_amount
self.respawn_n = respawn_n
self.initial_amount = initial_amount
self.initial_n = initial_n
self.spawn_freq = spawn_freq
self._next_dirt_spawn = spawn_freq
def on_init(self, state, lvl_map) -> str:
result = state[d.DIRT].trigger_dirt_spawn(self.initial_n, self.initial_amount, state,
n_var=self.n_var, amount_var=self.amount_var)
state.print(f'Initial Dirt was spawned on: {[x.pos for x in state[d.DIRT]]}')
return result
self.respawn_amount = respawn_amount
self.respawn_freq = respawn_freq
self._next_dirt_spawn = respawn_freq
def tick_step(self, state):
collection = state[d.DIRT]
if self._next_dirt_spawn < 0:
pass # No DirtPile Spawn
elif not self._next_dirt_spawn:
result = [state[d.DIRT].trigger_dirt_spawn(self.respawn_n, self.respawn_amount, state,
n_var=self.n_var, amount_var=self.amount_var)]
self._next_dirt_spawn = self.spawn_freq
result = [collection.trigger_spawn(state, coords_or_quantity=self.respawn_n, amount=self.respawn_amount)]
self._next_dirt_spawn = self.respawn_freq
else:
self._next_dirt_spawn -= 1
result = []
@ -99,8 +78,8 @@ class EntitiesSmearDirtOnMove(Rule):
for entity in state.moving_entites:
if is_move(entity.state.identifier) and entity.state.validity == c.VALID:
if old_pos_dirt := state[d.DIRT].by_pos(entity.last_pos):
old_pos_dirt = next(iter(old_pos_dirt))
if smeared_dirt := round(old_pos_dirt.amount * self.smear_ratio, 2):
if state[d.DIRT].spawn(entity.pos, amount=smeared_dirt):
results.append(TickResult(identifier=self.name, entity=entity,
reward=0, validity=c.VALID))
results.append(TickResult(identifier=self.name, entity=entity, validity=c.VALID))
return results

View File

@ -1,4 +1,7 @@
from .actions import DestAction
from .entitites import Destination
from .groups import Destinations
from .rules import DoneAtDestinationReachAll, SpawnDestinations
from .rules import (DoneAtDestinationReachAll,
DoneAtDestinationReachAny,
SpawnDestinationsPerAgent,
DestinationReachReward)

View File

@ -9,30 +9,6 @@ from marl_factory_grid.utils.utility_classes import RenderEntity
class Destination(Entity):
@property
def var_can_move(self):
return False
@property
def var_can_collide(self):
return False
@property
def var_has_position(self):
return True
@property
def var_is_blocking_pos(self):
return False
@property
def var_is_blocking_light(self):
return False
@property
def var_can_be_bound(self):
return True
def was_reached(self):
return self._was_reached

View File

@ -7,37 +7,14 @@ from marl_factory_grid.modules.destinations import constants as d
class Destinations(Collection):
_entity = Destination
@property
def var_is_blocking_light(self):
return False
@property
def var_can_collide(self):
return False
@property
def var_can_move(self):
return False
@property
def var_has_position(self):
return True
var_is_blocking_light = False
var_can_collide = False
var_can_move = False
var_has_position = True
var_can_be_bound = True
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __repr__(self):
return super(Destinations, self).__repr__()
@staticmethod
def trigger_destination_spawn(n_dests, state):
coordinates = state.entities.floorlist[:n_dests]
if destinations := [Destination(pos) for pos in coordinates]:
state[d.DESTINATION].add_items(destinations)
state.print(f'{n_dests} new destinations have been spawned')
return c.VALID
else:
state.print('No Destiantions are spawning, limit is reached.')
return c.NOT_VALID

View File

@ -2,8 +2,8 @@ import ast
from random import shuffle
from typing import List, Dict, Tuple
import marl_factory_grid.modules.destinations.constants
from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.utils import helpers as h
from marl_factory_grid.utils.results import TickResult, DoneResult
from marl_factory_grid.environment import constants as c
@ -54,7 +54,7 @@ class DoneAtDestinationReachAll(DestinationReachReward):
"""
This rule triggers and sets the done flag if ALL Destinations have been reached.
:type reward_at_done: object
:type reward_at_done: float
:param reward_at_done: Specifies the reward, agent get, whenn all destinations are reached.
:type dest_reach_reward: float
:param dest_reach_reward: Specify the reward, agents get when reaching a single destination.
@ -65,7 +65,7 @@ class DoneAtDestinationReachAll(DestinationReachReward):
def on_check_done(self, state) -> List[DoneResult]:
if all(x.was_reached() for x in state[d.DESTINATION]):
return [DoneResult(self.name, validity=c.VALID, reward=self.reward)]
return [DoneResult(self.name, validity=c.NOT_VALID, reward=0)]
return [DoneResult(self.name, validity=c.NOT_VALID)]
class DoneAtDestinationReachAny(DestinationReachReward):
@ -75,7 +75,7 @@ class DoneAtDestinationReachAny(DestinationReachReward):
This rule triggers and sets the done flag if ANY Destinations has been reached.
!!! IMPORTANT: 'reward_at_done' is shared between the agents; 'dest_reach_reward' is bound to a specific one.
:type reward_at_done: object
:type reward_at_done: float
:param reward_at_done: Specifies the reward, all agent get, when any destinations has been reached.
Default {d.REWARD_DEST_DONE}
:type dest_reach_reward: float
@ -87,67 +87,29 @@ class DoneAtDestinationReachAny(DestinationReachReward):
def on_check_done(self, state) -> List[DoneResult]:
if any(x.was_reached() for x in state[d.DESTINATION]):
return [DoneResult(self.name, validity=c.VALID, reward=marl_factory_grid.modules.destinations.constants.REWARD_DEST_REACHED)]
return [DoneResult(self.name, validity=c.VALID, reward=d.REWARD_DEST_REACHED)]
return []
class SpawnDestinations(Rule):
def __init__(self, n_dests: int = 1, spawn_mode: str = d.MODE_GROUPED):
f"""
Defines how destinations are initially spawned and respawned in addition.
!!! This rule introduces no kind of reward or Env.-Done condition!
:type n_dests: int
:param n_dests: How many destiantions should be maintained (and initally spawnewd) on the map?
:type spawn_mode: str
:param spawn_mode: One of {d.SPAWN_MODES}. {d.MODE_GROUPED}: Always wait for all Dstiantions do be gone,
then respawn after the given time. {d.MODE_SINGLE}: Just spawn every destination,
that has been reached, after the given time
"""
super(SpawnDestinations, self).__init__()
self.n_dests = n_dests
self.spawn_mode = spawn_mode
def on_init(self, state, lvl_map):
# noinspection PyAttributeOutsideInit
state[d.DESTINATION].trigger_destination_spawn(self.n_dests, state)
pass
def tick_pre_step(self, state) -> List[TickResult]:
pass
def tick_step(self, state) -> List[TickResult]:
if n_dest_spawn := max(0, self.n_dests - len(state[d.DESTINATION])):
if self.spawn_mode == d.MODE_GROUPED and n_dest_spawn == self.n_dests:
validity = state[d.DESTINATION].trigger_destination_spawn(n_dest_spawn, state)
return [TickResult(self.name, validity=validity, entity=None, value=n_dest_spawn)]
elif self.spawn_mode == d.MODE_SINGLE and n_dest_spawn:
validity = state[d.DESTINATION].trigger_destination_spawn(n_dest_spawn, state)
return [TickResult(self.name, validity=validity, entity=None, value=n_dest_spawn)]
else:
pass
class SpawnDestinationsPerAgent(Rule):
def __init__(self, per_agent_positions: Dict[str, List[Tuple[int, int]]]):
def __init__(self, coords_or_quantity: Dict[str, List[Tuple[int, int]]]):
"""
Special rule, that spawn distinations, that are bound to a single agent a fixed set of positions.
Usefull for introducing specialists, etc. ..
!!! This rule does not introduce any reward or done condition.
:type per_agent_positions: Dict[str, List[Tuple[int, int]]
:param per_agent_positions: Please provide a dictionary with agent names as keys; and a list of possible
:type coords_or_quantity: Dict[str, List[Tuple[int, int]]
:param coords_or_quantity: Please provide a dictionary with agent names as keys; and a list of possible
destiantion coords as value. Example: {Wolfgang: [(0, 0), (1, 1), ...]}
"""
super(Rule, self).__init__()
self.per_agent_positions = {key: [ast.literal_eval(x) for x in val] for key, val in per_agent_positions.items()}
self.per_agent_positions = {key: [ast.literal_eval(x) for x in val] for key, val in coords_or_quantity.items()}
def on_init(self, state, lvl_map):
for (agent_name, position_list) in self.per_agent_positions.items():
agent = next(x for x in state[c.AGENT] if agent_name in x.name) # Fixme: Ugly AF
agent = h.get_first(state[c.AGENT], lambda x: agent_name in x.name)
assert agent
position_list = position_list.copy()
shuffle(position_list)
while True:
@ -155,7 +117,7 @@ class SpawnDestinationsPerAgent(Rule):
pos = position_list.pop()
except IndexError:
print(f"Could not spawn Destinations at: {self.per_agent_positions[agent_name]}")
print(f'Check your agent palcement: {state[c.AGENT]} ... Exit ...')
print(f'Check your agent placement: {state[c.AGENT]} ... Exit ...')
exit(9999)
if (not pos == agent.pos) and (not state[d.DESTINATION].by_pos(pos)):
destination = Destination(pos, bind_to=agent)

View File

@ -1,4 +1,5 @@
from marl_factory_grid.environment.entity.entity import Entity
from marl_factory_grid.utils import Result
from marl_factory_grid.utils.utility_classes import RenderEntity
from marl_factory_grid.environment import constants as c
@ -41,21 +42,6 @@ class Door(Entity):
def str_state(self):
return 'open' if self.is_open else 'closed'
def __init__(self, *args, closed_on_init=True, auto_close_interval=10, **kwargs):
self._status = d.STATE_CLOSED
super(Door, self).__init__(*args, **kwargs)
self.auto_close_interval = auto_close_interval
self.time_to_close = 0
if not closed_on_init:
self._open()
else:
self._close()
def summarize_state(self):
state_dict = super().summarize_state()
state_dict.update(state=str(self.str_state), time_to_close=int(self.time_to_close))
return state_dict
@property
def is_closed(self):
return self._status == d.STATE_CLOSED
@ -68,6 +54,25 @@ class Door(Entity):
def status(self):
return self._status
@property
def time_to_close(self):
return self._time_to_close
def __init__(self, *args, closed_on_init=True, auto_close_interval=10, **kwargs):
self._status = d.STATE_CLOSED
super(Door, self).__init__(*args, **kwargs)
self._auto_close_interval = auto_close_interval
self._time_to_close = 0
if not closed_on_init:
self._open()
else:
self._close()
def summarize_state(self):
state_dict = super().summarize_state()
state_dict.update(state=str(self.str_state), time_to_close=self.time_to_close)
return state_dict
def render(self):
name, state = 'door_open' if self.is_open else 'door_closed', 'blank'
return RenderEntity(name, self.pos, 1, 'none', state, self.u_int + 1)
@ -80,18 +85,35 @@ class Door(Entity):
return c.VALID
def tick(self, state):
if self.is_open and len(state.entities.pos_dict[self.pos]) == 2 and self.time_to_close:
self.time_to_close -= 1
return c.NOT_VALID
elif self.is_open and not self.time_to_close and len(state.entities.pos_dict[self.pos]) == 2:
# Check if no entity is standing in the door
if len(state.entities.pos_dict[self.pos]) <= 2:
if self.is_open and self.time_to_close:
self._decrement_timer()
return Result(f"{d.DOOR}_tick", c.VALID, entity=self)
elif self.is_open and not self.time_to_close:
self.use()
return c.VALID
return Result(f"{d.DOOR}_closed", c.VALID, entity=self)
else:
return c.NOT_VALID
# No one is in door, but it is closed... Nothing to do....
return None
else:
# Entity is standing in the door, reset timer
self._reset_timer()
return Result(f"{d.DOOR}_reset", c.VALID, entity=self)
def _open(self):
self._status = d.STATE_OPEN
self.time_to_close = self.auto_close_interval
self._reset_timer()
return True
def _close(self):
self._status = d.STATE_CLOSED
return True
def _decrement_timer(self):
self._time_to_close -= 1
return True
def _reset_timer(self):
self._time_to_close = self._auto_close_interval
return True

View File

@ -18,8 +18,10 @@ class Doors(Collection):
super(Doors, self).__init__(*args, can_collide=True, **kwargs)
def tick_doors(self, state):
result_dict = dict()
results = list()
for door in self:
did_tick = door.tick(state)
result_dict.update({door.name: did_tick})
return result_dict
tick_result = door.tick(state)
if tick_result is not None:
results.append(tick_result)
# TODO: Should return a Result object, not a random dict.
return results

View File

@ -19,10 +19,10 @@ class DoorAutoClose(Rule):
def tick_step(self, state):
if doors := state[d.DOORS]:
doors_tick_result = doors.tick_doors(state)
doors_that_ticked = [key for key, val in doors_tick_result.items() if val]
state.print(f'{doors_that_ticked} were auto-closed'
if doors_that_ticked else 'No Doors were auto-closed')
doors_tick_results = doors.tick_doors(state)
doors_that_closed = [x.entity.name for x in doors_tick_results if 'closed' in x.identifier]
door_str = doors_that_closed if doors_that_closed else "No Doors"
state.print(f'{door_str} were auto-closed')
return [TickResult(self.name, validity=c.VALID, value=1)]
state.print('There are no doors, but you loaded the corresponding Module')
return []

View File

@ -1,4 +1,3 @@
from .actions import ItemAction
from .entitites import Item, DropOffLocation
from .groups import DropOffLocations, Items, Inventory, Inventories
from .rules import ItemRules

View File

@ -29,7 +29,7 @@ class ItemAction(Action):
elif items := state[i.ITEM].by_pos(entity.pos):
item = items[0]
item.change_parent_collection(inventory)
item.set_pos_to(c.VALUE_NO_POS)
item.set_pos(c.VALUE_NO_POS)
state.print(f'{entity.name} just picked up an item at {entity.pos}')
return ActionResult(entity=entity, identifier=self._identifier, validity=c.VALID, reward=r.PICK_UP_VALID)

View File

@ -8,16 +8,11 @@ from marl_factory_grid.modules.items import constants as i
class Item(Entity):
@property
def var_can_collide(self):
return False
def render(self):
return RenderEntity(i.ITEM, self.pos) if self.pos != c.VALUE_NO_POS else None
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._auto_despawn = -1
@property
def auto_despawn(self):
@ -31,9 +26,6 @@ class Item(Entity):
def set_auto_despawn(self, auto_despawn):
self._auto_despawn = auto_despawn
def set_pos_to(self, no_pos):
self._pos = no_pos
def summarize_state(self) -> dict:
super_summarization = super(Item, self).summarize_state()
super_summarization.update(dict(auto_despawn=self.auto_despawn))
@ -42,21 +34,6 @@ class Item(Entity):
class DropOffLocation(Entity):
@property
def var_can_collide(self):
return False
@property
def var_can_move(self):
return False
@property
def var_is_blocking_light(self):
return False
@property
def var_has_position(self):
return True
def render(self):
return RenderEntity(i.DROP_OFF, self.pos)

View File

@ -8,6 +8,7 @@ from marl_factory_grid.environment.groups.objects import _Objects
from marl_factory_grid.environment.groups.mixins import IsBoundMixin
from marl_factory_grid.environment.entity.agent import Agent
from marl_factory_grid.modules.items.entitites import Item, DropOffLocation
from marl_factory_grid.utils.results import Result
class Items(Collection):
@ -15,7 +16,7 @@ class Items(Collection):
@property
def var_has_position(self):
return False
return True
@property
def is_blocking_light(self):
@ -28,18 +29,18 @@ class Items(Collection):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@staticmethod
def trigger_item_spawn(state, n_items, spawn_frequency):
if item_to_spawns := max(0, (n_items - len(state[i.ITEM]))):
position_list = [x for x in state.entities.floorlist]
shuffle(position_list)
position_list = state.entities.floorlist[:item_to_spawns]
state[i.ITEM].spawn(position_list)
state.print(f'{item_to_spawns} new items have been spawned; next spawn in {spawn_frequency}')
return len(position_list)
def trigger_spawn(self, state, *entity_args, coords_or_quantity=None, **entity_kwargs) -> [Result]:
coords_or_quantity = coords_or_quantity if coords_or_quantity else self._coords_or_quantity
assert coords_or_quantity
if item_to_spawns := max(0, (coords_or_quantity - len(self))):
return super().trigger_spawn(state,
*entity_args,
coords_or_quantity=item_to_spawns,
**entity_kwargs)
else:
state.print('No Items are spawning, limit is reached.')
return 0
return Result(identifier=f'{self.name}_spawn', validity=c.NOT_VALID, value=coords_or_quantity)
class Inventory(IsBoundMixin, Collection):
@ -76,9 +77,15 @@ class Inventory(IsBoundMixin, Collection):
class Inventories(_Objects):
_entity = Inventory
var_can_move = False
var_has_position = False
symbol = None
@property
def var_can_move(self):
return False
def spawn_rule(self):
return {c.SPAWN_ENTITY_RULE: dict(collection=self, coords_or_quantity=None)}
def __init__(self, size: int, *args, **kwargs):
super(Inventories, self).__init__(*args, **kwargs)
@ -86,10 +93,12 @@ class Inventories(_Objects):
self._obs = None
self._lazy_eval_transforms = []
def spawn(self, agents):
inventories = [self._entity(agent, self.size, )
for _, agent in enumerate(agents)]
self.add_items(inventories)
def spawn(self, agents, *args, **kwargs):
self.add_items([self._entity(agent, self.size, *args, **kwargs) for _, agent in enumerate(agents)])
return [Result(identifier=f'{self.name}_spawn', validity=c.VALID, value=len(self))]
def trigger_spawn(self, state, *args, **kwargs) -> [Result]:
return self.spawn(state[c.AGENT], *args, **kwargs)
def idx_by_entity(self, entity):
try:
@ -106,9 +115,6 @@ class Inventories(_Objects):
def summarize_states(self, **kwargs):
return [val.summarize_states(**kwargs) for key, val in self.items()]
@staticmethod
def trigger_inventory_spawn(state):
state[i.INVENTORY].spawn(state[c.AGENT])
class DropOffLocations(Collection):
@ -135,7 +141,7 @@ class DropOffLocations(Collection):
@staticmethod
def trigger_drop_off_location_spawn(state, n_locations):
empty_positions = state.entities.empty_positions()[:n_locations]
empty_positions = state.entities.empty_positions[:n_locations]
do_entites = state[i.DROP_OFF]
drop_offs = [DropOffLocation(pos) for pos in empty_positions]
do_entites.add_items(drop_offs)

View File

@ -6,52 +6,28 @@ from marl_factory_grid.utils.results import TickResult
from marl_factory_grid.modules.items import constants as i
class ItemRules(Rule):
class RespawnItems(Rule):
def __init__(self, n_items: int = 5, spawn_frequency: int = 15,
n_locations: int = 5, max_dropoff_storage_size: int = 0):
def __init__(self, n_items: int = 5, respawn_freq: int = 15, n_locations: int = 5):
super().__init__()
self.spawn_frequency = spawn_frequency
self._next_item_spawn = spawn_frequency
self.spawn_frequency = respawn_freq
self._next_item_spawn = respawn_freq
self.n_items = n_items
self.max_dropoff_storage_size = max_dropoff_storage_size
self.n_locations = n_locations
def on_init(self, state, lvl_map):
state[i.DROP_OFF].trigger_drop_off_location_spawn(state, self.n_locations)
self._next_item_spawn = self.spawn_frequency
state[i.INVENTORY].trigger_inventory_spawn(state)
state[i.ITEM].trigger_item_spawn(state, self.n_items, self.spawn_frequency)
def tick_step(self, state):
for item in list(state[i.ITEM].values()):
if item.auto_despawn >= 1:
item.set_auto_despawn(item.auto_despawn - 1)
elif not item.auto_despawn:
state[i.ITEM].delete_env_object(item)
else:
pass
if not self._next_item_spawn:
state[i.ITEM].trigger_item_spawn(state, self.n_items, self.spawn_frequency)
state[i.ITEM].trigger_spawn(state, self.n_items, self.spawn_frequency)
else:
self._next_item_spawn = max(0, self._next_item_spawn - 1)
return []
def tick_post_step(self, state) -> List[TickResult]:
for item in list(state[i.ITEM].values()):
if item.auto_despawn >= 1:
item.set_auto_despawn(item.auto_despawn-1)
elif not item.auto_despawn:
state[i.ITEM].delete_env_object(item)
else:
pass
if not self._next_item_spawn:
if spawned_items := state[i.ITEM].trigger_item_spawn(state, self.n_items, self.spawn_frequency):
return [TickResult(self.name, validity=c.VALID, value=spawned_items, entity=None)]
if spawned_items := state[i.ITEM].trigger_spawn(state, self.n_items, self.spawn_frequency):
return [TickResult(self.name, validity=c.VALID, value=spawned_items.value)]
else:
return [TickResult(self.name, validity=c.NOT_VALID, value=0, entity=None)]
return [TickResult(self.name, validity=c.NOT_VALID, value=0)]
else:
self._next_item_spawn = max(0, self._next_item_spawn-1)
return []

View File

@ -1,3 +1,2 @@
from .entitites import Machine
from .groups import Machines
from .rules import MachineRule

View File

@ -5,6 +5,7 @@ from marl_factory_grid.utils.results import ActionResult
from marl_factory_grid.modules.machines import constants as m, rewards as r
from marl_factory_grid.environment import constants as c
from marl_factory_grid.utils import helpers as h
class MachineAction(Action):
@ -13,13 +14,10 @@ class MachineAction(Action):
super().__init__(m.MACHINE_ACTION)
def do(self, entity, state) -> Union[None, ActionResult]:
if machine := state[m.MACHINES].by_pos(entity.pos):
if machine := h.get_first(state[m.MACHINES].by_pos(entity.pos)):
if valid := machine.maintain():
return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=r.MAINTAIN_VALID)
else:
return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=r.MAINTAIN_FAIL)
else:
return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=r.MAINTAIN_FAIL)

View File

@ -8,22 +8,6 @@ from . import constants as m
class Machine(Entity):
@property
def var_can_collide(self):
return False
@property
def var_can_move(self):
return False
@property
def var_is_blocking_light(self):
return False
@property
def var_has_position(self):
return True
@property
def encoding(self):
return self._encodings[self.status]
@ -46,12 +30,12 @@ class Machine(Entity):
else:
return c.NOT_VALID
def tick(self):
def tick(self, state):
# if self.status == m.STATE_MAINTAIN and any([c.AGENT in x.name for x in self.tile.guests]):
if self.status == m.STATE_MAINTAIN and any([c.AGENT in x.name for x in self.state.entities.pos_dict[self.pos]]):
return TickResult(identifier=self.name, validity=c.VALID, reward=0, entity=self)
if self.status == m.STATE_MAINTAIN and any([c.AGENT in x.name for x in state.entities.pos_dict[self.pos]]):
return TickResult(identifier=self.name, validity=c.VALID, entity=self)
# elif self.status == m.STATE_MAINTAIN and not any([c.AGENT in x.name for x in self.tile.guests]):
elif self.status == m.STATE_MAINTAIN and not any([c.AGENT in x.name for x in self.state.entities.pos_dict[self.pos]]):
elif self.status == m.STATE_MAINTAIN and not any([c.AGENT in x.name for x in state.entities.pos_dict[self.pos]]):
self.status = m.STATE_WORK
self.reset_counter()
return None

View File

@ -1,28 +0,0 @@
from typing import List
from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.utils.results import TickResult, DoneResult
from marl_factory_grid.environment import constants as c
from marl_factory_grid.modules.machines import constants as m
from marl_factory_grid.modules.machines.entitites import Machine
class MachineRule(Rule):
def __init__(self, n_machines: int = 2):
super(MachineRule, self).__init__()
self.n_machines = n_machines
def on_init(self, state, lvl_map):
state[m.MACHINES].spawn(state.entities.empty_positions())
def tick_pre_step(self, state) -> List[TickResult]:
pass
def tick_step(self, state) -> List[TickResult]:
pass
def tick_post_step(self, state) -> List[TickResult]:
pass
def on_check_done(self, state) -> List[DoneResult]:
pass

View File

@ -1,48 +1,35 @@
from random import shuffle
import networkx as nx
import numpy as np
from ...algorithms.static.utils import points_to_graph
from ...environment import constants as c
from ...environment.actions import Action, ALL_BASEACTIONS
from ...environment.entity.entity import Entity
from ..doors import constants as do
from ..maintenance import constants as mi
from ...utils.helpers import MOVEMAP
from ...utils.utility_classes import RenderEntity
from ...utils.states import Gamestate
from ...utils import helpers as h
from ...utils.utility_classes import RenderEntity, Floor
from ..doors import DoorUse
class Maintainer(Entity):
@property
def var_can_collide(self):
return True
@property
def var_can_move(self):
return False
@property
def var_is_blocking_light(self):
return False
@property
def var_has_position(self):
return True
def __init__(self, state: Gamestate, objective: str, action: Action, *args, **kwargs):
def __init__(self, objective: str, action: Action, *args, **kwargs):
super().__init__(*args, **kwargs)
self.action = action
self.actions = [x() for x in ALL_BASEACTIONS]
self.actions = [x() for x in ALL_BASEACTIONS] + [DoorUse()]
self.objective = objective
self._path = None
self._next = []
self._last = []
self._last_serviced = 'None'
self._floortile_graph = points_to_graph(state.entities.floorlist)
self._floortile_graph = None
def tick(self, state):
if found_objective := state[self.objective].by_pos(self.pos):
if found_objective := h.get_first(state[self.objective].by_pos(self.pos)):
if found_objective.name != self._last_serviced:
self.action.do(self, state)
self._last_serviced = found_objective.name
@ -54,24 +41,27 @@ class Maintainer(Entity):
return action.do(self, state)
def get_move_action(self, state) -> Action:
if not self._floortile_graph:
state.print("Generating Floorgraph....")
self._floortile_graph = points_to_graph(state.entities.floorlist)
if self._path is None or not self._path:
if not self._next:
self._next = list(state[self.objective].values())
self._next = list(state[self.objective].values()) + [Floor(*state.random_free_position)]
shuffle(self._next)
self._last = []
self._last.append(self._next.pop())
state.print("Calculating shortest path....")
self._path = self.calculate_route(self._last[-1])
if door := self._door_is_close(state):
if door.is_closed:
if door := self._closed_door_in_path(state):
state.print(f"{self} found {door} that is closed. Attempt to open.")
# Translate the action_object to an integer to have the same output as any other model
action = do.ACTION_DOOR_USE
else:
action = self._predict_move(state)
else:
action = self._predict_move(state)
# Translate the action_object to an integer to have the same output as any other model
try:
action_obj = next(x for x in self.actions if x.name == action)
action_obj = h.get_first(self.actions, lambda x: x.name == action)
except (StopIteration, UnboundLocalError):
print('Will not happen')
raise EnvironmentError
@ -81,11 +71,10 @@ class Maintainer(Entity):
route = nx.shortest_path(self._floortile_graph, self.pos, entity.pos)
return route[1:]
def _door_is_close(self, state):
state.print("Found a door that is close.")
try:
return next(y for x in state.entities.neighboring_positions(self.state.pos) for y in state.entities.pos_dict[x] if do.DOOR in y.name)
except StopIteration:
def _closed_door_in_path(self, state):
if self._path:
return h.get_first(state[do.DOORS].by_pos(self._path[0]), lambda x: x.is_closed)
else:
return None
def _predict_move(self, state):
@ -96,7 +85,7 @@ class Maintainer(Entity):
next_pos = self._path.pop(0)
diff = np.subtract(next_pos, self.pos)
# Retrieve action based on the pos dif (like in: What do I have to do to get there?)
action = next(action for action, pos_diff in MOVEMAP.items() if np.all(diff == pos_diff))
action = next(action for action, pos_diff in h.MOVEMAP.items() if np.all(diff == pos_diff))
return action
def render(self):

View File

@ -1,4 +1,4 @@
from typing import Union, List, Tuple
from typing import Union, List, Tuple, Dict
from marl_factory_grid.environment.groups.collection import Collection
from .entities import Maintainer
@ -10,25 +10,21 @@ from ...utils.states import Gamestate
class Maintainers(Collection):
_entity = Maintainer
@property
def var_can_collide(self):
return True
var_can_collide = True
var_can_move = True
var_is_blocking_light = False
var_has_position = True
@property
def var_can_move(self):
return True
@property
def var_is_blocking_light(self):
return False
@property
def var_has_position(self):
return True
def __init__(self, size, *args, coords_or_quantity: int = None,
spawnrule: Union[None, Dict[str, dict]] = None,
**kwargs):
super(Collection, self).__init__(*args, **kwargs)
self._coords_or_quantity = coords_or_quantity
self.size = size
self._spawnrule = spawnrule
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args):
state = entity_args[0]
self.add_items([self._entity(state, mc.MACHINES, MachineAction(), pos) for pos in coords_or_quantity])
self.add_items([self._entity(mc.MACHINES, MachineAction(), pos) for pos in coords_or_quantity])

View File

@ -4,29 +4,24 @@ from marl_factory_grid.utils.results import TickResult, DoneResult
from marl_factory_grid.environment import constants as c
from . import rewards as r
from . import constants as M
from marl_factory_grid.utils.states import Gamestate
class MaintenanceRule(Rule):
class MoveMaintainers(Rule):
def __init__(self, n_maintainer: int = 1, *args, **kwargs):
super(MaintenanceRule, self).__init__(*args, **kwargs)
self.n_maintainer = n_maintainer
def on_init(self, state: Gamestate, lvl_map):
state[M.MAINTAINERS].spawn(state.entities.empty_positions[:self.n_maintainer], state)
pass
def tick_pre_step(self, state) -> List[TickResult]:
pass
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def tick_step(self, state) -> List[TickResult]:
for maintainer in state[M.MAINTAINERS]:
maintainer.tick(state)
# Todo: Return a Result Object.
return []
def tick_post_step(self, state) -> List[TickResult]:
pass
class DoneAtMaintainerCollision(Rule):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def on_check_done(self, state) -> List[DoneResult]:
agents = list(state[c.AGENT].values())

View File

@ -1,8 +1,8 @@
from random import choices, choice
from . import constants as z, Zone
from .. import Destination
from ..destinations import constants as d
from ... import Destination
from ...environment.rules import Rule
from ...environment import constants as c

View File

@ -0,0 +1,3 @@
from . import helpers as h
from . import helpers
from .results import Result, DoneResult, ActionResult, TickResult

View File

@ -1,28 +1,24 @@
import ast
from collections import defaultdict
from os import PathLike
from pathlib import Path
from typing import Union
from typing import Union, List
import yaml
from marl_factory_grid.environment.groups.agents import Agents
from marl_factory_grid.environment.entity.agent import Agent
from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.utils.helpers import locate_and_import_class
from marl_factory_grid.environment.constants import DEFAULT_PATH, MODULE_PATH
from marl_factory_grid.environment import constants as c
DEFAULT_PATH = 'environment'
MODULE_PATH = 'modules'
class FactoryConfigParser(object):
default_entites = []
default_rules = ['MaxStepsReached', 'Collision']
default_rules = ['DoneAtMaxStepsReached', 'WatchCollision']
default_actions = [c.MOVE8, c.NOOP]
default_observations = [c.WALLS, c.AGENT]
def __init__(self, config_path, custom_modules_path: Union[None, PathLike] = None):
def __init__(self, config_path, custom_modules_path: Union[PathLike] = None):
self.config_path = Path(config_path)
self.custom_modules_path = Path(custom_modules_path) if custom_modules_path is not None else custom_modules_path
self.config = yaml.safe_load(self.config_path.open())
@ -46,6 +42,10 @@ class FactoryConfigParser(object):
def rules(self):
return self.config['Rules']
@property
def tests(self):
return self.config.get('Tests', [])
@property
def agents(self):
return self.config['Agents']
@ -61,7 +61,6 @@ class FactoryConfigParser(object):
return self.config[item]
def load_entities(self):
# entites = Entities()
entity_classes = dict()
entities = []
if c.DEFAULTS in self.entities:
@ -69,28 +68,40 @@ class FactoryConfigParser(object):
entities.extend(x for x in self.entities if x != c.DEFAULTS)
for entity in entities:
e1 = e2 = e3 = None
try:
folder_path = Path(__file__).parent.parent / DEFAULT_PATH
entity_class = locate_and_import_class(entity, folder_path)
except AttributeError as e1:
except AttributeError as e:
e1 = e
try:
folder_path = Path(__file__).parent.parent / MODULE_PATH
entity_class = locate_and_import_class(entity, folder_path)
except AttributeError as e2:
module_path = Path(__file__).parent.parent / MODULE_PATH
entity_class = locate_and_import_class(entity, module_path)
except AttributeError as e:
e2 = e
if self.custom_modules_path:
try:
folder_path = self.custom_modules_path
entity_class = locate_and_import_class(entity, folder_path)
except AttributeError as e3:
ents = [y for x in [e1.argss[1], e2.argss[1], e3.argss[1]] for y in x]
entity_class = locate_and_import_class(entity, self.custom_modules_path)
except AttributeError as e:
e3 = e
pass
if (e1 and e2) or e3:
ents = [y for x in [e1, e2, e3] if x is not None for y in x.args[1]]
print('##############################################################')
print('### Error ### Error ### Error ### Error ### Error ###')
print()
print('##############################################################')
print(f'Class "{entity}" was not found in "{module_path.name}"')
print(f'Class "{entity}" was not found in "{folder_path.name}"')
print('##############################################################')
if self.custom_modules_path:
print(f'Class "{entity}" was not found in "{self.custom_modules_path}"')
print('Possible Entitys are:', str(ents))
print()
print('##############################################################')
print('Goodbye')
print()
exit()
# raise AttributeError(e1.argss[0], e2.argss[0], e3.argss[0], 'Possible Entitys are:', str(ents))
print('##############################################################')
print('### Error ### Error ### Error ### Error ### Error ###')
print('##############################################################')
exit(-99999)
entity_kwargs = self.entities.get(entity, {})
entity_symbol = entity_class.symbol if hasattr(entity_class, 'symbol') else None
@ -128,31 +139,86 @@ class FactoryConfigParser(object):
observations.extend(self.default_observations)
observations.extend(x for x in self.agents[name]['Observations'] if x != c.DEFAULTS)
positions = [ast.literal_eval(x) for x in self.agents[name].get('Positions', [])]
parsed_agents_conf[name] = dict(actions=parsed_actions, observations=observations, positions=positions)
other_kwargs = {k: v for k, v in self.agents[name].items() if k not in
['Actions', 'Observations', 'Positions']}
parsed_agents_conf[name] = dict(
actions=parsed_actions, observations=observations, positions=positions, other=other_kwargs
)
return parsed_agents_conf
def load_rules(self):
# entites = Entities()
rules_classes = dict()
rules = []
def load_env_rules(self) -> List[Rule]:
rules = self.rules.copy()
if c.DEFAULTS in self.rules:
for rule in self.default_rules:
if rule not in rules:
rules.append(rule)
rules.extend(x for x in self.rules if x != c.DEFAULTS)
rules.append({rule: {}})
for rule in rules:
return self._load_smth(rules, Rule)
def load_env_tests(self) -> List[Rule]:
return self._load_smth(self.tests, None) # Test
def _load_smth(self, config, class_obj):
rules = list()
rules_names = list()
for rule in config:
e1 = e2 = e3 = None
try:
folder_path = (Path(__file__).parent.parent / DEFAULT_PATH)
rule_class = locate_and_import_class(rule, folder_path)
except AttributeError as e:
e1 = e
try:
module_path = (Path(__file__).parent.parent / MODULE_PATH)
rule_class = locate_and_import_class(rule, module_path)
except AttributeError as e:
e2 = e
if self.custom_modules_path:
try:
rule_class = locate_and_import_class(rule, self.custom_modules_path)
except AttributeError as e:
e3 = e
pass
if (e1 and e2) or e3:
ents = [y for x in [e1, e2, e3] if x is not None for y in x.args[1]]
print('### Error ### Error ### Error ### Error ### Error ###')
print('')
print(f'Class "{rule}" was not found in "{module_path.name}"')
print(f'Class "{rule}" was not found in "{folder_path.name}"')
if self.custom_modules_path:
print(f'Class "{rule}" was not found in "{self.custom_modules_path}"')
print('Possible Entitys are:', str(ents))
print('')
print('Goodbye')
print('')
exit(-99999)
if issubclass(rule_class, class_obj):
rule_kwargs = config.get(rule, {})
rules.append(rule_class(**(rule_kwargs or {})))
return rules
def load_entity_spawn_rules(self, entities) -> List[Rule]:
rules = list()
rules_dicts = list()
for e in entities:
try:
if spawn_rule := e.spawn_rule:
rules_dicts.append(spawn_rule)
except AttributeError:
pass
for rule_dict in rules_dicts:
for rule_name, rule_kwargs in rule_dict.items():
try:
folder_path = (Path(__file__).parent.parent / DEFAULT_PATH)
rule_class = locate_and_import_class(rule_name, folder_path)
except AttributeError:
try:
folder_path = (Path(__file__).parent.parent / MODULE_PATH)
rule_class = locate_and_import_class(rule, folder_path)
rule_class = locate_and_import_class(rule_name, folder_path)
except AttributeError:
rule_class = locate_and_import_class(rule, self.custom_modules_path)
# Fixme This check does not work!
# assert isinstance(rule_class, Rule), f'{rule_class.__name__} is no valid "Rule".'
rule_kwargs = self.rules.get(rule, {})
rules_classes.update({rule: {'class': rule_class, 'kwargs': rule_kwargs}})
return rules_classes
rule_class = locate_and_import_class(rule_name, self.custom_modules_path)
rules.append(rule_class(**rule_kwargs))
return rules

View File

@ -2,7 +2,7 @@ import importlib
from collections import defaultdict
from pathlib import PurePath, Path
from typing import Union, Dict, List
from typing import Union, Dict, List, Iterable, Callable
import numpy as np
from numpy.typing import ArrayLike
@ -222,7 +222,7 @@ def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''):
mod = importlib.import_module('.'.join(module_parts))
all_found_modules.extend([x for x in dir(mod) if (not(x.startswith('__') or len(x) <= 2) and x.istitle())
and x not in ['Entity', 'NamedTuple', 'List', 'Rule', 'Union',
'TickResult', 'ActionResult', 'Action', 'Agent', 'BoundEntityMixin',
'TickResult', 'ActionResult', 'Action', 'Agent',
'RenderEntity', 'TemplateRule', 'Objects', 'PositionMixin',
'IsBoundMixin', 'EnvObject', 'EnvObjects', 'Dict', 'Any'
]])
@ -240,7 +240,13 @@ def add_bound_name(name_str, bound_e):
def add_pos_name(name_str, bound_e):
if bound_e.var_has_position:
return f'{name_str}({bound_e.pos})'
return f'{name_str}@{bound_e.pos}'
return name_str
def get_first(iterable: Iterable, filter_by: Callable[[any], bool] = lambda _: True):
return next((x for x in iterable if filter_by(x)), None)
def get_first_index(iterable: Iterable, filter_by: Callable[[any], bool] = lambda _: True):
return next((idx for idx, x in enumerate(iterable) if filter_by(x)), None)

View File

@ -47,6 +47,7 @@ class LevelParser(object):
# All other
for es_name in self.e_p_dict:
e_class, e_kwargs = self.e_p_dict[es_name]['class'], self.e_p_dict[es_name]['kwargs']
e_kwargs = e_kwargs if e_kwargs else {}
if hasattr(e_class, 'symbol') and e_class.symbol is not None:
symbols = e_class.symbol

View File

@ -1,17 +1,17 @@
import math
import re
from collections import defaultdict
from itertools import product
from typing import Dict, List
import numpy as np
from numba import njit
from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.entity.object import _Object
from marl_factory_grid.environment.groups.utils import Combined
import marl_factory_grid.utils.helpers as h
from marl_factory_grid.utils.states import Gamestate
from marl_factory_grid.utils.utility_classes import Floor
from marl_factory_grid.utils.ray_caster import RayCaster
from marl_factory_grid.utils.states import Gamestate
from marl_factory_grid.utils import helpers as h
class OBSBuilder(object):
@ -77,11 +77,13 @@ class OBSBuilder(object):
def place_entity_in_observation(self, obs_array, agent, e):
x, y = (e.x - agent.x) + self.pomdp_r, (e.y - agent.y) + self.pomdp_r
if not min([y, x]) < 0:
try:
obs_array[x, y] += e.encoding
except IndexError:
# Seemded to be visible but is out of range
pass
pass
def build_for_agent(self, agent, state) -> (List[str], np.ndarray):
assert self._curr_env_step == state.curr_step, (
@ -121,18 +123,24 @@ class OBSBuilder(object):
e = self.all_obs[l_name]
except KeyError:
try:
# Look for bound entity names!
pattern = re.compile(f'{re.escape(l_name)}(.*){re.escape(agent.name)}')
name = next((x for x in self.all_obs if pattern.search(x)), None)
# Look for bound entity REPRs!
pattern = re.compile(f'{re.escape(l_name)}'
f'{re.escape("[")}(.*){re.escape("]")}'
f'{re.escape("(")}{re.escape(agent.name)}{re.escape(")")}')
name = next((key for key, val in self.all_obs.items()
if pattern.search(str(val)) and isinstance(val, _Object)), None)
e = self.all_obs[name]
except KeyError:
try:
e = next(v for k, v in self.all_obs.items() if l_name in k and agent.name in k)
except StopIteration:
raise KeyError(
f'Check for spelling errors! \n '
f'No combination of "{l_name} and {agent.name}" could not be found in:\n '
f'{list(dict(self.all_obs).keys())}')
print(f'# Check for spelling errors!')
print(f'# No combination of "{l_name}" and "{agent.name}" could not be found in:')
print(f'# {list(dict(self.all_obs).keys())}')
print('#')
print('# exiting...')
print('#')
exit(-99999)
try:
positional = e.var_has_position
@ -161,15 +169,14 @@ class OBSBuilder(object):
try:
light_map = np.zeros(self.obs_shape)
visible_floor = self.ray_caster[agent.name].visible_entities(self._floortiles, reset_cache=False)
if self.pomdp_r:
for f in set(visible_floor):
self.place_entity_in_observation(light_map, agent, f)
else:
for f in set(visible_floor):
light_map[f.x, f.y] += f.encoding
# else:
# for f in set(visible_floor):
# light_map[f.x, f.y] += f.encoding
self.curr_lightmaps[agent.name] = light_map
except (KeyError, ValueError):
print()
pass
return obs, self.obs_layers[agent.name]
@ -185,7 +192,7 @@ class OBSBuilder(object):
for obs_str in agent.observations:
if isinstance(obs_str, dict):
obs_str, vals = next(obs_str.items().__iter__())
obs_str, vals = h.get_first(obs_str.items())
else:
vals = None
if obs_str == c.SELF:
@ -214,129 +221,3 @@ class OBSBuilder(object):
obs_layers.append(obs_str)
self.obs_layers[agent.name] = obs_layers
self.curr_lightmaps[agent.name] = np.zeros(self.obs_shape)
class RayCaster:
def __init__(self, agent, pomdp_r, degs=360):
self.agent = agent
self.pomdp_r = pomdp_r
self.n_rays = (self.pomdp_r + 1) * 8
self.degs = degs
self.ray_targets = self.build_ray_targets()
self.obs_shape_cube = np.array([self.pomdp_r, self.pomdp_r])
self._cache_dict = {}
def __repr__(self):
return f'{self.__class__.__name__}({self.agent.name})'
def build_ray_targets(self):
north = np.array([0, -1]) * self.pomdp_r
thetas = [np.deg2rad(deg) for deg in np.linspace(-self.degs // 2, self.degs // 2, self.n_rays)[::-1]]
rot_M = [
[[math.cos(theta), -math.sin(theta)],
[math.sin(theta), math.cos(theta)]] for theta in thetas
]
rot_M = np.stack(rot_M, 0)
rot_M = np.unique(np.round(rot_M @ north), axis=0)
return rot_M.astype(int)
def ray_block_cache(self, key, callback):
if key not in self._cache_dict:
self._cache_dict[key] = callback()
return self._cache_dict[key]
def visible_entities(self, pos_dict, reset_cache=True):
visible = list()
if reset_cache:
self._cache_dict = {}
for ray in self.get_rays():
rx, ry = ray[0]
for x, y in ray:
cx, cy = x - rx, y - ry
entities_hit = pos_dict[(x, y)]
hits = self.ray_block_cache((x, y),
lambda: any(True for e in entities_hit if e.var_is_blocking_light)
)
diag_hits = all([
self.ray_block_cache(
key,
lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light) and bool(
pos_dict[key]))
for key in ((x, y - cy), (x - cx, y))
]) if (cx != 0 and cy != 0) else False
visible += entities_hit if not diag_hits else []
if hits or diag_hits:
break
rx, ry = x, y
return visible
def get_rays(self):
a_pos = self.agent.pos
outline = self.ray_targets + a_pos
return self.bresenham_loop(a_pos, outline)
# todo do this once and cache the points!
def get_fov_outline(self) -> np.ndarray:
return self.ray_targets + self.agent.pos
def get_square_outline(self):
agent = self.agent
x_coords = range(agent.x - self.pomdp_r, agent.x + self.pomdp_r + 1)
y_coords = range(agent.y - self.pomdp_r, agent.y + self.pomdp_r + 1)
outline = list(product(x_coords, [agent.y - self.pomdp_r, agent.y + self.pomdp_r])) \
+ list(product([agent.x - self.pomdp_r, agent.x + self.pomdp_r], y_coords))
return outline
@staticmethod
@njit
def bresenham_loop(a_pos, points):
results = []
for end in points:
x1, y1 = a_pos
x2, y2 = end
dx = x2 - x1
dy = y2 - y1
# Determine how steep the line is
is_steep = abs(dy) > abs(dx)
# Rotate line
if is_steep:
x1, y1 = y1, x1
x2, y2 = y2, x2
# Swap start and end points if necessary and store swap state
swapped = False
if x1 > x2:
x1, x2 = x2, x1
y1, y2 = y2, y1
swapped = True
# Recalculate differentials
dx = x2 - x1
dy = y2 - y1
# Calculate error
error = int(dx / 2.0)
ystep = 1 if y1 < y2 else -1
# Iterate over bounding box generating points between start and end
y = y1
points = []
for x in range(int(x1), int(x2) + 1):
coord = [y, x] if is_steep else [x, y]
points.append(coord)
error -= abs(dy)
if error < 0:
y += ystep
error += dx
# Reverse the list if the coordinates were swapped
if swapped:
points.reverse()
results.append(points)
return results

View File

@ -39,8 +39,9 @@ class RayCaster:
if reset_cache:
self._cache_dict = dict()
for ray in self.get_rays():
for ray in self.get_rays(): # Do not check, just trust.
rx, ry = ray[0]
# self.ray_block_cache(ray[0], lambda: False) We do not do that, because of doors etc...
for x, y in ray:
cx, cy = x - rx, y - ry
@ -52,7 +53,8 @@ class RayCaster:
diag_hits = all([
self.ray_block_cache(
key,
lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light))
lambda: any(True for e in pos_dict[key] if e.var_is_blocking_light))
# lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light))
for key in ((x, y-cy), (x-cx, y))
]) if (cx != 0 and cy != 0) else False

View File

@ -31,7 +31,7 @@ class Renderer:
def __init__(self, lvl_shape: Tuple[int, int] = (16, 16),
lvl_padded_shape: Union[Tuple[int, int], None] = None,
cell_size: int = 40, fps: int = 7,
cell_size: int = 40, fps: int = 7, factor: float = 0.9,
grid_lines: bool = True, view_radius: int = 2):
# TODO: Customn_assets paths
self.grid_h, self.grid_w = lvl_shape
@ -45,7 +45,7 @@ class Renderer:
self.screen = pygame.display.set_mode(self.screen_size)
self.clock = pygame.time.Clock()
assets = list(self.ASSETS.rglob('*.png'))
self.assets = {path.stem: self.load_asset(str(path), 1) for path in assets}
self.assets = {path.stem: self.load_asset(str(path), factor) for path in assets}
self.fill_bg()
now = time.time()
@ -110,22 +110,22 @@ class Renderer:
pygame.quit()
sys.exit()
self.fill_bg()
blits = deque()
for entity in [x for x in entities]:
bp = self.blit_params(entity)
blits.append(bp)
if entity.name.lower() == AGENT:
# First all others
blits = deque(self.blit_params(x) for x in entities if not x.name.lower() == AGENT)
# Then Agents, so that agents are rendered on top.
for agent in (x for x in entities if x.name.lower() == AGENT):
agent_blit = self.blit_params(agent)
if self.view_radius > 0:
vis_rects = self.visibility_rects(bp, entity.aux)
vis_rects = self.visibility_rects(agent_blit, agent.aux)
blits.extendleft(vis_rects)
if entity.state != BLANK:
agent_state_blits = self.blit_params(
RenderEntity(entity.state, (entity.pos[0] + 0.12, entity.pos[1]), 0.48, SCALE)
if agent.state != BLANK:
state_blit = self.blit_params(
RenderEntity(agent.state, (agent.pos[0] + 0.12, agent.pos[1]), 0.48, SCALE)
)
textsurface = self.font.render(str(entity.id), False, (0, 0, 0))
text_blit = dict(source=textsurface, dest=(bp['dest'].center[0]-.07*self.cell_size,
bp['dest'].center[1]))
blits += [agent_state_blits, text_blit]
textsurface = self.font.render(str(agent.id), False, (0, 0, 0))
text_blit = dict(source=textsurface, dest=(agent_blit['dest'].center[0]-.07*self.cell_size,
agent_blit['dest'].center[1]))
blits += [agent_blit, state_blit, text_blit]
for blit in blits:
self.screen.blit(**blit)

View File

@ -28,7 +28,10 @@ class Result:
def __repr__(self):
valid = "not " if not self.validity else ""
return f'{self.__class__.__name__}({self.identifier.capitalize()} {valid}valid: {self.reward})'
reward = f" | Reward: {self.reward}" if self.reward is not None else ""
value = f" | Value: {self.value}" if self.value is not None else ""
entity = f" | by: {self.entity.name}" if self.entity is not None else ""
return f'{self.__class__.__name__}({self.identifier.capitalize()} {valid}valid{reward}{value})'
@dataclass

View File

@ -1,3 +1,4 @@
from itertools import islice
from typing import List, Dict, Tuple
import numpy as np
@ -59,14 +60,15 @@ class Gamestate(object):
def moving_entites(self):
return [y for x in self.entities for y in x if x.var_can_move]
def __init__(self, entities, agents_conf, rules: Dict[str, dict], env_seed=69, verbose=False):
def __init__(self, entities, agents_conf, rules: List[Rule], lvl_shape, env_seed=69, verbose=False):
self.lvl_shape = lvl_shape
self.entities = entities
self.curr_step = 0
self.curr_actions = None
self.agents_conf = agents_conf
self.verbose = verbose
self.rng = np.random.default_rng(env_seed)
self.rules = StepRules(*(v['class'](**v['kwargs']) for v in rules.values()))
self.rules = StepRules(*rules)
def __getitem__(self, item):
return self.entities[item]
@ -80,6 +82,13 @@ class Gamestate(object):
def __repr__(self):
return f'{self.__class__.__name__}({len(self.entities)} Entitites @ Step {self.curr_step})'
@property
def random_free_position(self):
return self.get_n_random_free_positions(1)[0]
def get_n_random_free_positions(self, n):
return list(islice(self.entities.free_positions_generator, n))
def tick(self, actions) -> List[Result]:
results = list()
self.curr_step += 1
@ -115,8 +124,7 @@ class Gamestate(object):
return results
def get_all_pos_with_collisions(self) -> List[Tuple[(int, int)]]:
positions = [pos for pos, entity_list_for_position in self.entities.pos_dict.items()
if any([e.var_can_collide for e in entity_list_for_position])]
positions = [pos for pos, entities in self.entities.pos_dict.items() if len(entities) >= 2 and (len([e for e in entities if e.var_can_collide]) >= 2)]
return positions
def check_move_validity(self, moving_entity, position):

View File

@ -135,4 +135,3 @@ if __name__ == '__main__':
ce.get_observations()
ce.get_assets()
all_conf = ce.get_all()
print()

View File

@ -52,3 +52,6 @@ class Floor:
def __hash__(self):
return hash(self.name)
def __repr__(self):
return f"Floor{self.pos}"

View File

@ -6,6 +6,7 @@ import yaml
from marl_factory_grid.environment.factory import Factory
from marl_factory_grid.utils.logging.envmonitor import EnvMonitor
from marl_factory_grid.utils.logging.recorder import EnvRecorder
from marl_factory_grid.utils import helpers as h
from marl_factory_grid.modules.doors import constants as d
@ -61,7 +62,7 @@ if __name__ == '__main__':
if render:
env.render()
try:
door = next(x for x in env.unwrapped.unwrapped[d.DOORS] if x.is_open)
door = h.get_first([x for x in env.unwrapped.unwrapped[d.DOORS] if x.is_open])
print('openDoor found')
except StopIteration:
pass

View File

@ -1,8 +1,8 @@
from algorithms.utils import Checkpointer
from pathlib import Path
from algorithms.utils import load_yaml_file, add_env_props, instantiate_class, load_class
#from algorithms.marl import LoopSNAC, LoopIAC, LoopSEAC
# from algorithms.marl import LoopSNAC, LoopIAC, LoopSEAC
for i in range(0, 5):

View File

@ -0,0 +1,43 @@
import configparser
import json
from datetime import datetime
from pathlib import Path
if __name__ == '__main__':
conf_path = Path('wg0')
wg0_conf = configparser.ConfigParser()
wg0_conf.read(conf_path/'wg0.conf')
interface = wg0_conf['Interface']
# Iterate all pears
for client_name in wg0_conf.sections():
if client_name == 'Interface':
continue
# Delete any old conf.json for the current peer
(conf_path / f'{client_name}.json').unlink(missing_ok=True)
peer = wg0_conf[client_name]
date_time = datetime.now().strftime('%Y-%m-%dT%H:%M:%S.%f000Z')
jdict = dict(
id=client_name,
private_key=peer['PublicKey'],
public_key=peer['PublicKey'],
# preshared_key=wg0_conf[client_name_wg0]['PresharedKey'],
name=client_name,
email=f"sysadmin@mobile.ifi.lmu.de",
allocated_ips=[interface['Address'].replace('/24', '')],
allowed_ips=['10.4.0.0/24', '10.153.199.0/24'],
extra_allowed_ips=[],
use_server_dns=True,
enabled=True,
created_at=date_time,
updated_at=date_time
)
with (conf_path / f'{client_name}.json').open('w+') as f:
json.dump(jdict, f, indent='\t', separators=(',', ': '))
print(client_name, ' written...')