new rules, new spawn logic, small fixes, default and narrow corridor debugged

This commit is contained in:
Steffen Illium 2023-11-09 17:50:20 +01:00
parent 9b9c6e0385
commit 06a5130b25
67 changed files with 768 additions and 921 deletions

View File

@ -56,7 +56,7 @@ Just define what your environment needs in a *yaml*-configfile like:
- Items - Items
Rules: Rules:
Defaults: {} Defaults: {}
Collision: WatchCollisions:
done_at_collisions: !!bool True done_at_collisions: !!bool True
ItemRespawn: ItemRespawn:
spawn_freq: 5 spawn_freq: 5

View File

@ -1,6 +1 @@
from .environment import * from .quickstart import init
from .modules import *
from .utils import *
from .quickstart import init

View File

@ -30,4 +30,3 @@ class TSPTargetAgent(TSPBaseAgent):
except (StopIteration, UnboundLocalError): except (StopIteration, UnboundLocalError):
print('Will not happen') print('Will not happen')
return action_obj return action_obj

View File

@ -26,12 +26,16 @@ def points_to_graph(coordiniates, allow_euclidean_connections=True, allow_manhat
assert allow_euclidean_connections or allow_manhattan_connections assert allow_euclidean_connections or allow_manhattan_connections
possible_connections = itertools.combinations(coordiniates, 2) possible_connections = itertools.combinations(coordiniates, 2)
graph = nx.Graph() graph = nx.Graph()
for a, b in possible_connections: if allow_manhattan_connections and allow_euclidean_connections:
diff = np.linalg.norm(np.asarray(a)-np.asarray(b)) graph.add_edges_from(
if allow_manhattan_connections and allow_euclidean_connections and diff <= np.sqrt(2): (a, b) for a, b in possible_connections if np.linalg.norm(np.asarray(a) - np.asarray(b)) <= np.sqrt(2)
graph.add_edge(a, b) )
elif not allow_manhattan_connections and allow_euclidean_connections and diff == np.sqrt(2): elif not allow_manhattan_connections and allow_euclidean_connections:
graph.add_edge(a, b) graph.add_edges_from(
elif allow_manhattan_connections and not allow_euclidean_connections and diff == 1: (a, b) for a, b in possible_connections if np.linalg.norm(np.asarray(a) - np.asarray(b)) == np.sqrt(2)
graph.add_edge(a, b) )
elif allow_manhattan_connections and not allow_euclidean_connections:
graph.add_edges_from(
(a, b) for a, b in possible_connections if np.linalg.norm(np.asarray(a) - np.asarray(b)) == 1
)
return graph return graph

View File

@ -22,26 +22,41 @@ Agents:
- Inventory - Inventory
- DropOffLocations - DropOffLocations
- Maintainers - Maintainers
# This is special for agents, as each one is differten and can act as an adversary e.g.
Positions:
- (16, 7)
- (16, 6)
- (16, 3)
- (16, 4)
- (16, 5)
Entities: Entities:
Batteries: Batteries:
initial_charge: 0.8 initial_charge: 0.8
per_action_costs: 0.02 per_action_costs: 0.02
ChargePods: {} ChargePods:
Destinations: {} coords_or_quantity: 2
Destinations:
coords_or_quantity: 1
spawn_mode: GROUPED
DirtPiles: DirtPiles:
coords_or_quantity: 10
initial_amount: 2
clean_amount: 1 clean_amount: 1
dirt_spawn_r_var: 0.1 dirt_spawn_r_var: 0.1
initial_amount: 2
initial_dirt_ratio: 0.05
max_global_amount: 20 max_global_amount: 20
max_local_amount: 5 max_local_amount: 5
Doors: {} Doors:
DropOffLocations: {} DropOffLocations:
coords_or_quantity: 1
max_dropoff_storage_size: 0
GlobalPositions: {} GlobalPositions: {}
Inventories: {} Inventories: {}
Items: {} Items:
Machines: {} coords_or_quantity: 5
Maintainers: {} Machines:
coords_or_quantity: 2
Maintainers:
coords_or_quantity: 1
Zones: {} Zones: {}
General: General:
@ -49,32 +64,31 @@ General:
individual_rewards: true individual_rewards: true
level_name: large level_name: large
pomdp_r: 3 pomdp_r: 3
verbose: false verbose: True
tests: false
Rules: Rules:
SpawnAgents: {} # Environment Dynamics
DoneAtBatteryDischarge: {}
Collision:
done_at_collisions: false
AssignGlobalPositions: {}
DoneAtDestinationReachAny: {}
DestinationReachReward: {}
SpawnDestinations:
n_dests: 1
spawn_mode: GROUPED
DoneOnAllDirtCleaned: {}
SpawnDirt:
spawn_freq: 15
EntitiesSmearDirtOnMove: EntitiesSmearDirtOnMove:
smear_ratio: 0.2 smear_ratio: 0.2
DoorAutoClose: DoorAutoClose:
close_frequency: 10 close_frequency: 10
ItemRules: MoveMaintainers:
max_dropoff_storage_size: 0
n_items: 5 # Respawn Stuff
n_locations: 5 RespawnDirt:
spawn_frequency: 15 respawn_freq: 15
MaxStepsReached: RespawnItems:
respawn_freq: 15
# Utilities
WatchCollisions:
done_at_collisions: false
# Done Conditions
DoneAtDestinationReachAny:
DoneOnAllDirtCleaned:
DoneAtBatteryDischarge:
DoneAtMaintainerCollision:
DoneAtMaxStepsReached:
max_steps: 500 max_steps: 500
# AgentSingleZonePlacement:
# n_zones: 4

View File

@ -1,3 +1,10 @@
General:
env_seed: 69
individual_rewards: true
level_name: narrow_corridor
pomdp_r: 0
verbose: true
Agents: Agents:
Wolfgang: Wolfgang:
Actions: Actions:
@ -10,6 +17,7 @@ Agents:
Positions: Positions:
- (2, 1) - (2, 1)
- (2, 5) - (2, 5)
is_blocking_pos: true
Karl-Heinz: Karl-Heinz:
Actions: Actions:
- Noop - Noop
@ -21,26 +29,30 @@ Agents:
Positions: Positions:
- (2, 1) - (2, 1)
- (2, 5) - (2, 5)
Entities: is_blocking_pos: true
Destinations: {}
General: Entities:
env_seed: 69 Destinations:
individual_rewards: true ignore_blocking: true
level_name: narrow_corridor spawnrule:
pomdp_r: 0 SpawnDestinationsPerAgent:
verbose: true coords_or_quantity:
Wolfgang:
- (2, 1)
- (2, 5)
Karl-Heinz:
- (2, 1)
- (2, 5)
# Whether you want to provide a numeric Position observation.
# GlobalPositions:
# normalized: false
Rules: Rules:
SpawnAgents: {} # Utilities
Collision: WatchCollisions:
done_at_collisions: false done_at_collisions: false
FixedDestinationSpawn: # Done Conditions
per_agent_positions: # DoneAtDestinationReachAny:
Wolfgang: DoneAtDestinationReachAll:
- (2, 1) DoneAtMaxStepsReached:
- (2, 5) max_steps: 500
Karl-Heinz:
- (2, 1)
- (2, 5)
DestinationReachAll: {}

View File

@ -48,9 +48,9 @@ class Move(Action, abc.ABC):
reward = r.MOVEMENTS_VALID if move_validity else r.MOVEMENTS_FAIL reward = r.MOVEMENTS_VALID if move_validity else r.MOVEMENTS_FAIL
return ActionResult(entity=entity, identifier=self._identifier, validity=move_validity, reward=reward) return ActionResult(entity=entity, identifier=self._identifier, validity=move_validity, reward=reward)
else: # There is no place to go, propably collision else: # There is no place to go, propably collision
# This is currently handeld by the Collision rule, so that it can be switched on and off by conf.yml # This is currently handeld by the WatchCollisions rule, so that it can be switched on and off by conf.yml
# return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=r.COLLISION) # return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=r.COLLISION)
return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=0) return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID)
def _calc_new_pos(self, pos): def _calc_new_pos(self, pos):
x_diff, y_diff = MOVEMAP[self._identifier] x_diff, y_diff = MOVEMAP[self._identifier]

View File

@ -10,6 +10,7 @@ AGENT = 'Agent' # Identifier of Agent-objects an
OTHERS = 'Other' OTHERS = 'Other'
COMBINED = 'Combined' COMBINED = 'Combined'
GLOBALPOSITIONS = 'GlobalPositions' # Identifier of the global position slice GLOBALPOSITIONS = 'GlobalPositions' # Identifier of the global position slice
SPAWN_ENTITY_RULE = 'SpawnEntity'
# Attributes # Attributes
IS_BLOCKING_LIGHT = 'var_is_blocking_light' IS_BLOCKING_LIGHT = 'var_is_blocking_light'
@ -29,7 +30,7 @@ VALUE_NO_POS = (-9999, -9999) # Invalid Position value used in the e
ACTION = 'action' # Identifier of Action-objects and groups (groups). ACTION = 'action' # Identifier of Action-objects and groups (groups).
COLLISION = 'Collision' # Identifier to use in the context of collitions. COLLISION = 'Collisions' # Identifier to use in the context of collitions.
# LAST_POS = 'LAST_POS' # Identifiert for retrieving an enitites last pos. # LAST_POS = 'LAST_POS' # Identifiert for retrieving an enitites last pos.
VALIDITY = 'VALIDITY' # Identifiert for retrieving the Validity of Action, Tick, etc. ... VALIDITY = 'VALIDITY' # Identifiert for retrieving the Validity of Action, Tick, etc. ...
@ -54,3 +55,5 @@ NOOP = 'Noop'
# Result Identifier # Result Identifier
MOVEMENTS_VALID = 'motion_valid' MOVEMENTS_VALID = 'motion_valid'
MOVEMENTS_FAIL = 'motion_not_valid' MOVEMENTS_FAIL = 'motion_not_valid'
DEFAULT_PATH = 'environment'
MODULE_PATH = 'modules'

View File

@ -12,14 +12,6 @@ from marl_factory_grid.environment import constants as c
class Agent(Entity): class Agent(Entity):
@property
def var_is_blocking_light(self):
return False
@property
def var_can_move(self):
return True
@property @property
def var_is_paralyzed(self): def var_is_paralyzed(self):
return len(self._paralyzed) return len(self._paralyzed)
@ -28,14 +20,6 @@ class Agent(Entity):
def paralyze_reasons(self): def paralyze_reasons(self):
return [x for x in self._paralyzed] return [x for x in self._paralyzed]
@property
def var_is_blocking_pos(self):
return False
@property
def var_has_position(self):
return True
@property @property
def obs_tag(self): def obs_tag(self):
return self.name return self.name
@ -48,10 +32,6 @@ class Agent(Entity):
def observations(self): def observations(self):
return self._observations return self._observations
@property
def var_can_collide(self):
return True
def step_result(self): def step_result(self):
pass pass
@ -60,16 +40,21 @@ class Agent(Entity):
return self._collection return self._collection
@property @property
def state(self): def var_is_blocking_pos(self):
return self._state or ActionResult(entity=self, identifier=c.NOOP, validity=c.VALID, reward=0) return self._is_blocking_pos
def __init__(self, actions: List[Action], observations: List[str], *args, **kwargs): @property
def state(self):
return self._state or ActionResult(entity=self, identifier=c.NOOP, validity=c.VALID)
def __init__(self, actions: List[Action], observations: List[str], *args, is_blocking_pos=False, **kwargs):
super(Agent, self).__init__(*args, **kwargs) super(Agent, self).__init__(*args, **kwargs)
self._paralyzed = set() self._paralyzed = set()
self.step_result = dict() self.step_result = dict()
self._actions = actions self._actions = actions
self._observations = observations self._observations = observations
self._state: Union[Result, None] = None self._state: Union[Result, None] = None
self._is_blocking_pos = is_blocking_pos
# noinspection PyAttributeOutsideInit # noinspection PyAttributeOutsideInit
def clear_temp_state(self): def clear_temp_state(self):

View File

@ -14,7 +14,7 @@ class Entity(_Object, abc.ABC):
@property @property
def state(self): def state(self):
return self._status or ActionResult(entity=self, identifier=c.NOOP, validity=c.VALID, reward=0) return self._status or ActionResult(entity=self, identifier=c.NOOP, validity=c.VALID)
@property @property
def var_has_position(self): def var_has_position(self):
@ -60,6 +60,10 @@ class Entity(_Object, abc.ABC):
def pos(self): def pos(self):
return self._pos return self._pos
def set_pos(self, pos):
assert isinstance(pos, tuple) and len(pos) == 2
self._pos = pos
@property @property
def last_pos(self): def last_pos(self):
try: try:
@ -84,7 +88,7 @@ class Entity(_Object, abc.ABC):
for observer in self.observers: for observer in self.observers:
observer.notify_del_entity(self) observer.notify_del_entity(self)
self._view_directory = curr_pos[0] - next_pos[0], curr_pos[1] - next_pos[1] self._view_directory = curr_pos[0] - next_pos[0], curr_pos[1] - next_pos[1]
self._pos = next_pos self.set_pos(next_pos)
for observer in self.observers: for observer in self.observers:
observer.notify_add_entity(self) observer.notify_add_entity(self)
return valid return valid
@ -93,7 +97,7 @@ class Entity(_Object, abc.ABC):
def __init__(self, pos, bind_to=None, **kwargs): def __init__(self, pos, bind_to=None, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self._status = None self._status = None
self._pos = pos self.set_pos(pos)
self._last_pos = pos self._last_pos = pos
if bind_to: if bind_to:
try: try:
@ -109,8 +113,9 @@ class Entity(_Object, abc.ABC):
def render(self): def render(self):
return RenderEntity(self.__class__.__name__.lower(), self.pos) return RenderEntity(self.__class__.__name__.lower(), self.pos)
def __repr__(self): @abc.abstractmethod
return super(Entity, self).__repr__() + f'(@{self.pos})' def render(self):
return RenderEntity(self.__class__.__name__.lower(), self.pos)
@property @property
def obs_tag(self): def obs_tag(self):
@ -149,4 +154,4 @@ class Entity(_Object, abc.ABC):
except StopIteration: except StopIteration:
pass pass
except ValueError: except ValueError:
print() pass

View File

@ -1,24 +0,0 @@
# noinspection PyAttributeOutsideInit
class BoundEntityMixin:
@property
def bound_entity(self):
return self._bound_entity
@property
def name(self):
if self.bound_entity:
return f'{self.__class__.__name__}({self.bound_entity.name})'
else:
pass
def belongs_to_entity(self, entity):
return entity == self.bound_entity
def bind_to(self, entity):
self._bound_entity = entity
def unbind(self):
self._bound_entity = None

View File

@ -13,10 +13,6 @@ class _Object:
def __bool__(self): def __bool__(self):
return True return True
@property
def var_has_position(self):
return False
@property @property
def var_can_be_bound(self): def var_can_be_bound(self):
try: try:
@ -30,22 +26,14 @@ class _Object:
@property @property
def name(self): def name(self):
if self._str_ident is not None: return f'{self.__class__.__name__}[{self.identifier}]'
name = f'{self.__class__.__name__}[{self._str_ident}]'
else:
name = f'{self.__class__.__name__}#{self.u_int}'
if self.bound_entity:
name = h.add_bound_name(name, self.bound_entity)
if self.var_has_position:
name = h.add_pos_name(name, self)
return name
@property @property
def identifier(self): def identifier(self):
if self._str_ident is not None: if self._str_ident is not None:
return self._str_ident return self._str_ident
else: else:
return self.name return self.u_int
def reset_uid(self): def reset_uid(self):
self._u_idx = defaultdict(lambda: 0) self._u_idx = defaultdict(lambda: 0)
@ -62,7 +50,15 @@ class _Object:
print(f'Following kwargs were passed, but ignored: {kwargs}') print(f'Following kwargs were passed, but ignored: {kwargs}')
def __repr__(self): def __repr__(self):
return f'{self.name}' name = self.name
if self.bound_entity:
name = h.add_bound_name(name, self.bound_entity)
try:
if self.var_has_position:
name = h.add_pos_name(name, self)
except (AttributeError):
pass
return name
def __eq__(self, other) -> bool: def __eq__(self, other) -> bool:
return other == self.identifier return other == self.identifier
@ -88,7 +84,7 @@ class _Object:
def summarize_state(self): def summarize_state(self):
return dict() return dict()
def bind(self, entity): def bind_to(self, entity):
# noinspection PyAttributeOutsideInit # noinspection PyAttributeOutsideInit
self._bound_entity = entity self._bound_entity = entity
return c.VALID return c.VALID
@ -100,9 +96,6 @@ class _Object:
def bound_entity(self): def bound_entity(self):
return self._bound_entity return self._bound_entity
def bind_to(self, entity):
self._bound_entity = entity
def unbind(self): def unbind(self):
self._bound_entity = None self._bound_entity = None

View File

@ -24,7 +24,7 @@ class PlaceHolder(_Object):
@property @property
def name(self): def name(self):
return "PlaceHolder" return self.__class__.__name__
class GlobalPosition(_Object): class GlobalPosition(_Object):
@ -36,7 +36,8 @@ class GlobalPosition(_Object):
else: else:
return self.bound_entity.pos return self.bound_entity.pos
def __init__(self, level_shape, *args, normalized: bool = True, **kwargs): def __init__(self, agent, level_shape, *args, normalized: bool = True, **kwargs):
super(GlobalPosition, self).__init__(*args, **kwargs) super(GlobalPosition, self).__init__(*args, **kwargs)
self.bind_to(agent)
self._normalized = normalized self._normalized = normalized
self._shape = level_shape self._shape = level_shape

View File

@ -5,13 +5,8 @@ from marl_factory_grid.utils.utility_classes import RenderEntity
class Wall(Entity): class Wall(Entity):
@property def __init__(self, *args, **kwargs):
def var_has_position(self): super().__init__(*args, **kwargs)
return True
@property
def var_can_collide(self):
return True
@property @property
def encoding(self): def encoding(self):
@ -19,11 +14,3 @@ class Wall(Entity):
def render(self): def render(self):
return RenderEntity(c.WALL, self.pos) return RenderEntity(c.WALL, self.pos)
@property
def var_is_blocking_pos(self):
return True
@property
def var_is_blocking_light(self):
return True

View File

@ -87,11 +87,14 @@ class Factory(gym.Env):
entities = self.map.do_init() entities = self.map.do_init()
# Init rules # Init rules
rules = self.conf.load_rules() env_rules = self.conf.load_env_rules()
entity_rules = self.conf.load_entity_spawn_rules(entities)
env_rules.extend(entity_rules)
# Parse the agent conf # Parse the agent conf
parsed_agents_conf = self.conf.parse_agents_conf() parsed_agents_conf = self.conf.parse_agents_conf()
self.state = Gamestate(entities, parsed_agents_conf, rules, self.conf.env_seed, self.conf.verbose) self.state = Gamestate(entities, parsed_agents_conf, env_rules, self.map.level_shape,
self.conf.env_seed, self.conf.verbose)
# All is set up, trigger entity init with variable pos # All is set up, trigger entity init with variable pos
# All is set up, trigger additional init (after agent entity spawn etc) # All is set up, trigger additional init (after agent entity spawn etc)

View File

@ -1,10 +1,15 @@
from marl_factory_grid.environment.entity.agent import Agent from marl_factory_grid.environment.entity.agent import Agent
from marl_factory_grid.environment.groups.collection import Collection from marl_factory_grid.environment.groups.collection import Collection
from marl_factory_grid.environment.rules import SpawnAgents
class Agents(Collection): class Agents(Collection):
_entity = Agent _entity = Agent
@property
def spawn_rule(self):
return {SpawnAgents.__name__: {}}
@property @property
def var_is_blocking_light(self): def var_is_blocking_light(self):
return False return False

View File

@ -1,18 +1,25 @@
from typing import List, Tuple, Union from typing import List, Tuple, Union, Dict
from marl_factory_grid.environment.entity.entity import Entity from marl_factory_grid.environment.entity.entity import Entity
from marl_factory_grid.environment.groups.objects import _Objects from marl_factory_grid.environment.groups.objects import _Objects
# noinspection PyProtectedMember
from marl_factory_grid.environment.entity.object import _Object from marl_factory_grid.environment.entity.object import _Object
import marl_factory_grid.environment.constants as c import marl_factory_grid.environment.constants as c
from marl_factory_grid.utils.results import Result
class Collection(_Objects): class Collection(_Objects):
_entity = _Object # entity? _entity = _Object # entity?
symbol = None
@property @property
def var_is_blocking_light(self): def var_is_blocking_light(self):
return False return False
@property
def var_is_blocking_pos(self):
return False
@property @property
def var_can_collide(self): def var_can_collide(self):
return False return False
@ -23,29 +30,61 @@ class Collection(_Objects):
@property @property
def var_has_position(self): def var_has_position(self):
return False return True
# @property
# def var_has_bound(self):
# return False # batteries, globalpos, inventories true
@property
def var_can_be_bound(self):
return False
@property @property
def encodings(self): def encodings(self):
return [x.encoding for x in self] return [x.encoding for x in self]
def __init__(self, size, *args, **kwargs): @property
super(Collection, self).__init__(*args, **kwargs) def spawn_rule(self):
self.size = size """Prevent SpawnRule creation if Objects are spawned by map, Doors e.g."""
if self.symbol:
def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args): # woihn mit den args return None
if isinstance(coords_or_quantity, int): elif self._spawnrule:
self.add_items([self._entity() for _ in range(coords_or_quantity)]) return self._spawnrule
else: else:
self.add_items([self._entity(pos) for pos in coords_or_quantity]) return {c.SPAWN_ENTITY_RULE: dict(collection=self, coords_or_quantity=self._coords_or_quantity)}
def __init__(self, size, *args, coords_or_quantity: int = None, ignore_blocking=False,
spawnrule: Union[None, Dict[str, dict]] = None,
**kwargs):
super(Collection, self).__init__(*args, **kwargs)
self._coords_or_quantity = coords_or_quantity
self.size = size
self._spawnrule = spawnrule
self._ignore_blocking = ignore_blocking
def trigger_spawn(self, state, *entity_args, coords_or_quantity=None, ignore_blocking=False, **entity_kwargs):
coords_or_quantity = coords_or_quantity if coords_or_quantity else self._coords_or_quantity
if self.var_has_position:
if isinstance(coords_or_quantity, int):
if ignore_blocking or self._ignore_blocking:
coords_or_quantity = state.entities.floorlist[:coords_or_quantity]
else:
coords_or_quantity = state.get_n_random_free_positions(coords_or_quantity)
self.spawn(coords_or_quantity, *entity_args, **entity_kwargs)
state.print(f'{len(coords_or_quantity)} new {self.name} have been spawned at {coords_or_quantity}')
return Result(identifier=f'{self.name}_spawn', validity=c.VALID, value=len(coords_or_quantity))
else:
if isinstance(coords_or_quantity, int):
self.spawn(coords_or_quantity, *entity_args, **entity_kwargs)
state.print(f'{coords_or_quantity} new {self.name} have been spawned randomly.')
return Result(identifier=f'{self.name}_spawn', validity=c.VALID, value=coords_or_quantity)
else:
raise ValueError(f'{self._entity.__name__} has no position!')
def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args, **entity_kwargs):
if self.var_has_position:
if isinstance(coords_or_quantity, int):
raise ValueError(f'{self._entity.__name__} should have a position!')
else:
self.add_items([self._entity(pos, *entity_args, **entity_kwargs) for pos in coords_or_quantity])
else:
if isinstance(coords_or_quantity, int):
self.add_items([self._entity(*entity_args, **entity_kwargs) for _ in range(coords_or_quantity)])
else:
raise ValueError(f'{self._entity.__name__} has no position!')
return c.VALID return c.VALID
def despawn(self, items: List[_Object]): def despawn(self, items: List[_Object]):
@ -115,7 +154,7 @@ class Collection(_Objects):
except StopIteration: except StopIteration:
pass pass
except ValueError: except ValueError:
print() pass
@property @property
def positions(self): def positions(self):

View File

@ -1,6 +1,6 @@
from collections import defaultdict from collections import defaultdict
from operator import itemgetter from operator import itemgetter
from random import shuffle, random from random import shuffle
from typing import Dict from typing import Dict
from marl_factory_grid.environment.groups.objects import _Objects from marl_factory_grid.environment.groups.objects import _Objects
@ -12,10 +12,10 @@ class Entities(_Objects):
@staticmethod @staticmethod
def neighboring_positions(pos): def neighboring_positions(pos):
return (POS_MASK + pos).reshape(-1, 2) return [tuple(x) for x in (POS_MASK + pos).reshape(-1, 2)]
def get_entities_near_pos(self, pos): def get_entities_near_pos(self, pos):
return [y for x in itemgetter(*(tuple(x) for x in self.neighboring_positions(pos)))(self.pos_dict) for y in x] return [y for x in itemgetter(*self.neighboring_positions(pos))(self.pos_dict) for y in x]
def render(self): def render(self):
return [y for x in self for y in x.render() if x is not None] return [y for x in self for y in x.render() if x is not None]
@ -35,8 +35,9 @@ class Entities(_Objects):
super().__init__() super().__init__()
def guests_that_can_collide(self, pos): def guests_that_can_collide(self, pos):
return[x for val in self.pos_dict[pos] for x in val if x.var_can_collide] return [x for val in self.pos_dict[pos] for x in val if x.var_can_collide]
@property
def empty_positions(self): def empty_positions(self):
empty_positions = [key for key in self.floorlist if not self.pos_dict[key]] empty_positions = [key for key in self.floorlist if not self.pos_dict[key]]
shuffle(empty_positions) shuffle(empty_positions)
@ -48,11 +49,23 @@ class Entities(_Objects):
shuffle(empty_positions) shuffle(empty_positions)
return empty_positions return empty_positions
def is_blocked(self): @property
return[key for key, val in self.pos_dict.items() if any([x.var_is_blocking_pos for x in val])] def blocked_positions(self):
blocked_positions = [key for key, val in self.pos_dict.items() if any([x.var_is_blocking_pos for x in val])]
shuffle(blocked_positions)
return blocked_positions
def is_not_blocked(self): @property
return[key for key, val in self.pos_dict.items() if not all([x.var_is_blocking_pos for x in val])] def free_positions_generator(self):
generator = (
key for key in self.floorlist if all(not x.var_can_collide and not x.var_is_blocking_pos
for x in self.pos_dict[key])
)
return generator
@property
def free_positions_list(self):
return [x for x in self.free_positions_generator]
def iter_entities(self): def iter_entities(self):
return iter((x for sublist in self.values() for x in sublist)) return iter((x for sublist in self.values() for x in sublist))
@ -92,3 +105,6 @@ class Entities(_Objects):
@property @property
def positions(self): def positions(self):
return [k for k, v in self.pos_dict.items() for _ in v] return [k for k, v in self.pos_dict.items() for _ in v]
def is_occupied(self, pos):
return len([x for x in self.pos_dict[pos] if x.var_can_collide or x.var_is_blocking_pos]) >= 1

View File

@ -4,10 +4,6 @@ from marl_factory_grid.environment import constants as c
# noinspection PyUnresolvedReferences,PyTypeChecker # noinspection PyUnresolvedReferences,PyTypeChecker
class IsBoundMixin: class IsBoundMixin:
@property
def name(self):
return f'{self.__class__.__name__}({self._bound_entity.name})'
def __repr__(self): def __repr__(self):
return f'{self.__class__.__name__}#{self._bound_entity.name}({self._data})' return f'{self.__class__.__name__}#{self._bound_entity.name}({self._data})'

View File

@ -5,11 +5,16 @@ import numpy as np
from marl_factory_grid.environment.entity.object import _Object from marl_factory_grid.environment.entity.object import _Object
import marl_factory_grid.environment.constants as c import marl_factory_grid.environment.constants as c
from marl_factory_grid.utils import helpers as h
class _Objects: class _Objects:
_entity = _Object _entity = _Object
@property
def var_can_be_bound(self):
return False
@property @property
def observers(self): def observers(self):
return self._observers return self._observers
@ -148,12 +153,12 @@ class _Objects:
def by_entity(self, entity): def by_entity(self, entity):
try: try:
return next((x for x in self if x.belongs_to_entity(entity))) return h.get_first(self, filter_by=lambda x: x.belongs_to_entity(entity))
except (StopIteration, AttributeError): except (StopIteration, AttributeError):
return None return None
def idx_by_entity(self, entity): def idx_by_entity(self, entity):
try: try:
return next((idx for idx, x in enumerate(self) if x.belongs_to_entity(entity))) return h.get_first_index(self, filter_by=lambda x: x.belongs_to_entity(entity))
except (StopIteration, AttributeError): except (StopIteration, AttributeError):
return None return None

View File

@ -1,7 +1,10 @@
from typing import List, Union from typing import List, Union
from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.entity.util import GlobalPosition from marl_factory_grid.environment.entity.util import GlobalPosition
from marl_factory_grid.environment.groups.collection import Collection from marl_factory_grid.environment.groups.collection import Collection
from marl_factory_grid.utils.results import Result
from marl_factory_grid.utils.states import Gamestate
class Combined(Collection): class Combined(Collection):
@ -36,17 +39,17 @@ class GlobalPositions(Collection):
_entity = GlobalPosition _entity = GlobalPosition
@property var_is_blocking_light = False
def var_is_blocking_light(self): var_can_be_bound = True
return False var_can_collide = False
var_has_position = False
@property
def var_can_collide(self):
return False
@property
def var_can_be_bound(self):
return True
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(GlobalPositions, self).__init__(*args, **kwargs) super(GlobalPositions, self).__init__(*args, **kwargs)
def spawn(self, agents, level_shape, *args, **kwargs):
self.add_items([self._entity(agent, level_shape, *args, **kwargs) for agent in agents])
return [Result(identifier=f'{self.name}_spawn', validity=c.VALID, value=len(self))]
def trigger_spawn(self, state: Gamestate, *args, **kwargs) -> [Result]:
return self.spawn(state[c.AGENT], state.lvl_shape, *args, **kwargs)

View File

@ -7,9 +7,12 @@ class Walls(Collection):
_entity = Wall _entity = Wall
symbol = c.SYMBOL_WALL symbol = c.SYMBOL_WALL
@property var_can_collide = True
def var_has_position(self): var_is_blocking_light = True
return True var_can_move = False
var_has_position = True
var_can_be_bound = False
var_is_blocking_pos = True
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(Walls, self).__init__(*args, **kwargs) super(Walls, self).__init__(*args, **kwargs)

View File

@ -1,6 +1,6 @@
import abc import abc
from random import shuffle from random import shuffle
from typing import List from typing import List, Collection, Union
from marl_factory_grid.environment.entity.agent import Agent from marl_factory_grid.environment.entity.agent import Agent
from marl_factory_grid.utils import helpers as h from marl_factory_grid.utils import helpers as h
@ -39,6 +39,29 @@ class Rule(abc.ABC):
return [] return []
class SpawnEntity(Rule):
@property
def _collection(self) -> Collection:
return Collection()
@property
def name(self):
return f'{self.__class__.__name__}({self.collection.name})'
def __init__(self, collection, coords_or_quantity, ignore_blocking=False):
super().__init__()
self.coords_or_quantity = coords_or_quantity
self.collection = collection
self.ignore_blocking = ignore_blocking
def on_init(self, state, lvl_map) -> [TickResult]:
results = self.collection.trigger_spawn(state, ignore_blocking=self.ignore_blocking)
pos_str = f' on: {[x.pos for x in self.collection]}' if self.collection.var_has_position else ''
state.print(f'Initial {self.collection.__class__.__name__} were spawned{pos_str}')
return results
class SpawnAgents(Rule): class SpawnAgents(Rule):
def __init__(self): def __init__(self):
@ -46,14 +69,14 @@ class SpawnAgents(Rule):
pass pass
def on_init(self, state, lvl_map): def on_init(self, state, lvl_map):
agent_conf = state.agents_conf
# agents = Agents(lvl_map.size) # agents = Agents(lvl_map.size)
agents = state[c.AGENT] agents = state[c.AGENT]
empty_positions = state.entities.empty_positions()[:len(agent_conf)] empty_positions = state.entities.empty_positions[:len(state.agents_conf)]
for agent_name in agent_conf: for agent_name, agent_conf in state.agents_conf.items():
actions = agent_conf[agent_name]['actions'].copy() actions = agent_conf['actions'].copy()
observations = agent_conf[agent_name]['observations'].copy() observations = agent_conf['observations'].copy()
positions = agent_conf[agent_name]['positions'].copy() positions = agent_conf['positions'].copy()
other = agent_conf['other'].copy()
if positions: if positions:
shuffle(positions) shuffle(positions)
while True: while True:
@ -61,18 +84,18 @@ class SpawnAgents(Rule):
pos = positions.pop() pos = positions.pop()
except IndexError: except IndexError:
raise ValueError(f'It was not possible to spawn an Agent on the available position: ' raise ValueError(f'It was not possible to spawn an Agent on the available position: '
f'\n{agent_name[agent_name]["positions"].copy()}') f'\n{agent_conf["positions"].copy()}')
if agents.by_pos(pos) and state.check_pos_validity(pos): if bool(agents.by_pos(pos)) or not state.check_pos_validity(pos):
continue continue
else: else:
agents.add_item(Agent(actions, observations, pos, str_ident=agent_name)) agents.add_item(Agent(actions, observations, pos, str_ident=agent_name, **other))
break break
else: else:
agents.add_item(Agent(actions, observations, empty_positions.pop(), str_ident=agent_name)) agents.add_item(Agent(actions, observations, empty_positions.pop(), str_ident=agent_name, **other))
pass pass
class MaxStepsReached(Rule): class DoneAtMaxStepsReached(Rule):
def __init__(self, max_steps: int = 500): def __init__(self, max_steps: int = 500):
super().__init__() super().__init__()
@ -83,8 +106,8 @@ class MaxStepsReached(Rule):
def on_check_done(self, state): def on_check_done(self, state):
if self.max_steps <= state.curr_step: if self.max_steps <= state.curr_step:
return [DoneResult(validity=c.VALID, identifier=self.name, reward=0)] return [DoneResult(validity=c.VALID, identifier=self.name)]
return [DoneResult(validity=c.NOT_VALID, identifier=self.name, reward=0)] return [DoneResult(validity=c.NOT_VALID, identifier=self.name)]
class AssignGlobalPositions(Rule): class AssignGlobalPositions(Rule):
@ -101,7 +124,7 @@ class AssignGlobalPositions(Rule):
return [] return []
class Collision(Rule): class WatchCollisions(Rule):
def __init__(self, done_at_collisions: bool = False): def __init__(self, done_at_collisions: bool = False):
super().__init__() super().__init__()
@ -132,4 +155,4 @@ class Collision(Rule):
move_failed = any(h.is_move(x.state.identifier) and not x.state.validity for x in state[c.AGENT]) move_failed = any(h.is_move(x.state.identifier) and not x.state.validity for x in state[c.AGENT])
if inter_entity_collision_detected or move_failed: if inter_entity_collision_detected or move_failed:
return [DoneResult(validity=c.VALID, identifier=c.COLLISION, reward=r.COLLISION)] return [DoneResult(validity=c.VALID, identifier=c.COLLISION, reward=r.COLLISION)]
return [DoneResult(validity=c.NOT_VALID, identifier=self.name, reward=0)] return [DoneResult(validity=c.NOT_VALID, identifier=self.name)]

View File

@ -1,4 +1,4 @@
from .actions import BtryCharge from .actions import BtryCharge
from .entitites import Pod, Battery from .entitites import ChargePod, Battery
from .groups import ChargePods, Batteries from .groups import ChargePods, Batteries
from .rules import DoneAtBatteryDischarge, BatteryDecharge from .rules import DoneAtBatteryDischarge, BatteryDecharge

View File

@ -6,6 +6,7 @@ from marl_factory_grid.utils.results import ActionResult
from marl_factory_grid.modules.batteries import constants as b from marl_factory_grid.modules.batteries import constants as b
from marl_factory_grid.environment import constants as c from marl_factory_grid.environment import constants as c
from marl_factory_grid.utils import helpers as h
class BtryCharge(Action): class BtryCharge(Action):
@ -14,8 +15,8 @@ class BtryCharge(Action):
super().__init__(b.ACTION_CHARGE) super().__init__(b.ACTION_CHARGE)
def do(self, entity, state) -> Union[None, ActionResult]: def do(self, entity, state) -> Union[None, ActionResult]:
if charge_pod := state[b.CHARGE_PODS].by_pos(entity.pos): if charge_pod := h.get_first(state[b.CHARGE_PODS].by_pos(entity.pos)):
valid = charge_pod.charge_battery(state[b.BATTERIES].by_entity(entity)) valid = h.get_first(charge_pod.charge_battery(state[b.BATTERIES].by_entity(entity)))
if valid: if valid:
state.print(f'{entity.name} just charged batteries at {charge_pod.name}.') state.print(f'{entity.name} just charged batteries at {charge_pod.name}.')
else: else:

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.9 KiB

View File

@ -50,7 +50,7 @@ class Battery(_Object):
return summary return summary
class Pod(Entity): class ChargePod(Entity):
@property @property
def encoding(self): def encoding(self):
@ -58,7 +58,7 @@ class Pod(Entity):
def __init__(self, *args, charge_rate: float = 0.4, def __init__(self, *args, charge_rate: float = 0.4,
multi_charge: bool = False, **kwargs): multi_charge: bool = False, **kwargs):
super(Pod, self).__init__(*args, **kwargs) super(ChargePod, self).__init__(*args, **kwargs)
self.charge_rate = charge_rate self.charge_rate = charge_rate
self.multi_charge = multi_charge self.multi_charge = multi_charge

View File

@ -1,52 +1,36 @@
from typing import Union, List, Tuple from typing import Union, List, Tuple
from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.groups.collection import Collection from marl_factory_grid.environment.groups.collection import Collection
from marl_factory_grid.modules.batteries.entitites import Pod, Battery from marl_factory_grid.modules.batteries.entitites import ChargePod, Battery
from marl_factory_grid.utils.results import Result
class Batteries(Collection): class Batteries(Collection):
_entity = Battery _entity = Battery
@property var_has_position = False
def var_is_blocking_light(self): var_can_be_bound = True
return False
@property
def var_can_collide(self):
return False
@property
def var_can_move(self):
return False
@property
def var_has_position(self):
return False
@property
def var_can_be_bound(self):
return True
@property @property
def obs_tag(self): def obs_tag(self):
return self.__class__.__name__ return self.__class__.__name__
def __init__(self, *args, **kwargs): def __init__(self, size, initial_charge_level: float=1.0, *args, **kwargs):
super(Batteries, self).__init__(*args, **kwargs) super(Batteries, self).__init__(size, *args, **kwargs)
self.initial_charge_level = initial_charge_level
def spawn(self, agents, initial_charge_level): def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], agents, *entity_args, **entity_kwargs):
batteries = [self._entity(initial_charge_level, agent) for _, agent in enumerate(agents)] batteries = [self._entity(self.initial_charge_level, agent) for _, agent in enumerate(agents)]
self.add_items(batteries) self.add_items(batteries)
# def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args): hat keine pos def trigger_spawn(self, state, *entity_args, coords_or_quantity=None, **entity_kwargs):
# agents = entity_args[0] self.spawn(0, state[c.AGENT])
# initial_charge_level = entity_args[1] return Result(identifier=f'{self.name}_spawn', validity=c.VALID, value=len(self))
# batteries = [self._entity(initial_charge_level, agent) for _, agent in enumerate(agents)]
# self.add_items(batteries)
class ChargePods(Collection): class ChargePods(Collection):
_entity = Pod _entity = ChargePod
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(ChargePods, self).__init__(*args, **kwargs) super(ChargePods, self).__init__(*args, **kwargs)

View File

@ -49,10 +49,6 @@ class BatteryDecharge(Rule):
self.per_action_costs = per_action_costs self.per_action_costs = per_action_costs
self.initial_charge = initial_charge self.initial_charge = initial_charge
def on_init(self, state, lvl_map): # on reset?
assert len(state[c.AGENT]), "There are no agents, did you already spawn them?"
state[b.BATTERIES].spawn(state[c.AGENT], self.initial_charge)
def tick_step(self, state) -> List[TickResult]: def tick_step(self, state) -> List[TickResult]:
# Decharge # Decharge
batteries = state[b.BATTERIES] batteries = state[b.BATTERIES]
@ -66,7 +62,7 @@ class BatteryDecharge(Rule):
batteries.by_entity(agent).decharge(energy_consumption) batteries.by_entity(agent).decharge(energy_consumption)
results.append(TickResult(self.name, reward=0, entity=agent, validity=c.VALID)) results.append(TickResult(self.name, entity=agent, validity=c.VALID))
return results return results
@ -82,13 +78,13 @@ class BatteryDecharge(Rule):
if self.paralyze_agents_on_discharge: if self.paralyze_agents_on_discharge:
btry.bound_entity.paralyze(self.name) btry.bound_entity.paralyze(self.name)
results.append( results.append(
TickResult("Paralyzed", entity=btry.bound_entity, reward=0, validity=c.VALID) TickResult("Paralyzed", entity=btry.bound_entity, validity=c.VALID)
) )
state.print(f'{btry.bound_entity.name} has just been paralyzed!') state.print(f'{btry.bound_entity.name} has just been paralyzed!')
if btry.bound_entity.var_is_paralyzed and not btry.is_discharged: if btry.bound_entity.var_is_paralyzed and not btry.is_discharged:
btry.bound_entity.de_paralyze(self.name) btry.bound_entity.de_paralyze(self.name)
results.append( results.append(
TickResult("De-Paralyzed", entity=btry.bound_entity, reward=0, validity=c.VALID) TickResult("De-Paralyzed", entity=btry.bound_entity, validity=c.VALID)
) )
state.print(f'{btry.bound_entity.name} has just been de-paralyzed!') state.print(f'{btry.bound_entity.name} has just been de-paralyzed!')
return results return results
@ -132,7 +128,7 @@ class DoneAtBatteryDischarge(BatteryDecharge):
if any_discharged or all_discharged: if any_discharged or all_discharged:
return [DoneResult(self.name, validity=c.VALID, reward=self.reward_discharge_done)] return [DoneResult(self.name, validity=c.VALID, reward=self.reward_discharge_done)]
else: else:
return [DoneResult(self.name, validity=c.NOT_VALID, reward=0)] return [DoneResult(self.name, validity=c.NOT_VALID)]
class SpawnChargePods(Rule): class SpawnChargePods(Rule):
@ -155,7 +151,7 @@ class SpawnChargePods(Rule):
def on_init(self, state, lvl_map): def on_init(self, state, lvl_map):
pod_collection = state[b.CHARGE_PODS] pod_collection = state[b.CHARGE_PODS]
empty_positions = state.entities.empty_positions() empty_positions = state.entities.empty_positions
pods = pod_collection.from_coordinates(empty_positions, entity_kwargs=dict( pods = pod_collection.from_coordinates(empty_positions, entity_kwargs=dict(
multi_charge=self.multi_charge, charge_rate=self.charge_rate) multi_charge=self.multi_charge, charge_rate=self.charge_rate)
) )

View File

@ -1,4 +1,4 @@
from .actions import CleanUp from .actions import CleanUp
from .entitites import DirtPile from .entitites import DirtPile
from .groups import DirtPiles from .groups import DirtPiles
from .rules import SpawnDirt, EntitiesSmearDirtOnMove, DoneOnAllDirtCleaned from .rules import EntitiesSmearDirtOnMove, DoneOnAllDirtCleaned

View File

@ -7,22 +7,6 @@ from marl_factory_grid.modules.clean_up import constants as d
class DirtPile(Entity): class DirtPile(Entity):
@property
def var_can_collide(self):
return False
@property
def var_can_move(self):
return False
@property
def var_is_blocking_light(self):
return False
@property
def var_has_position(self):
return True
@property @property
def amount(self): def amount(self):
return self._amount return self._amount

View File

@ -9,68 +9,55 @@ from marl_factory_grid.modules.clean_up.entitites import DirtPile
class DirtPiles(Collection): class DirtPiles(Collection):
_entity = DirtPile _entity = DirtPile
@property var_is_blocking_light = False
def var_is_blocking_light(self): var_can_collide = False
return False var_can_move = False
var_has_position = True
@property @property
def var_can_collide(self): def global_amount(self):
return False
@property
def var_can_move(self):
return False
@property
def var_has_position(self):
return True
@property
def amount(self):
return sum([dirt.amount for dirt in self]) return sum([dirt.amount for dirt in self])
def __init__(self, *args, def __init__(self, *args,
max_local_amount=5, max_local_amount=5,
clean_amount=1, clean_amount=1,
max_global_amount: int = 20, **kwargs): max_global_amount: int = 20,
coords_or_quantity=10,
initial_amount=2,
amount_var=0.2,
n_var=0.2,
**kwargs):
super(DirtPiles, self).__init__(*args, **kwargs) super(DirtPiles, self).__init__(*args, **kwargs)
self.amount_var = amount_var
self.n_var = n_var
self.clean_amount = clean_amount self.clean_amount = clean_amount
self.max_global_amount = max_global_amount self.max_global_amount = max_global_amount
self.max_local_amount = max_local_amount self.max_local_amount = max_local_amount
self.coords_or_quantity = coords_or_quantity
self.initial_amount = initial_amount
def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args): def trigger_spawn(self, state, coords_or_quantity=0, amount=0) -> [Result]:
amount_s = entity_args[0] coords_or_quantity = coords_or_quantity if coords_or_quantity else self.coords_or_quantity
n_new = int(abs(coords_or_quantity + (state.rng.uniform(-self.n_var, self.n_var))))
n_new = state.get_n_random_free_positions(n_new)
amounts = [amount if amount else (self.initial_amount + state.rng.uniform(-self.amount_var, self.amount_var))
for _ in range(coords_or_quantity)]
spawn_counter = 0 spawn_counter = 0
for idx, pos in enumerate(coords_or_quantity): for idx, (pos, a) in enumerate(zip(n_new, amounts)):
if not self.amount > self.max_global_amount: if not self.global_amount > self.max_global_amount:
amount = amount_s[idx] if isinstance(amount_s, list) else amount_s
if dirt := self.by_pos(pos): if dirt := self.by_pos(pos):
dirt = next(dirt.iter()) dirt = next(dirt.iter())
new_value = dirt.amount + amount new_value = dirt.amount + a
dirt.set_new_amount(new_value) dirt.set_new_amount(new_value)
else: else:
dirt = DirtPile(pos, amount=amount) super().spawn([pos], amount=a)
self.add_item(dirt)
spawn_counter += 1 spawn_counter += 1
else: else:
return Result(identifier=f'{self.name}_spawn', validity=c.NOT_VALID, reward=0, return Result(identifier=f'{self.name}_spawn', validity=c.NOT_VALID, value=spawn_counter)
value=spawn_counter)
return Result(identifier=f'{self.name}_spawn', validity=c.VALID, reward=0, value=spawn_counter)
def trigger_dirt_spawn(self, n, amount, state, n_var=0.2, amount_var=0.2) -> Result: return Result(identifier=f'{self.name}_spawn', validity=c.VALID, value=spawn_counter)
free_for_dirt = [x for x in state.entities.floorlist if len(state.entities.pos_dict[x]) == 0 or (
len(state.entities.pos_dict[x]) >= 1 and isinstance(next(y for y in x), DirtPile))]
# free_for_dirt = [x for x in state[c.FLOOR]
# if len(x.guests) == 0 or (
# len(x.guests) == 1 and
# isinstance(next(y for y in x.guests), DirtPile))]
state.rng.shuffle(free_for_dirt)
new_spawn = int(abs(n + (state.rng.uniform(-n_var, n_var))))
new_amount_s = [abs(amount + (amount*state.rng.uniform(-amount_var, amount_var))) for _ in range(new_spawn)]
n_dirty_positions = free_for_dirt[:new_spawn]
return self.spawn(n_dirty_positions, new_amount_s)
def __repr__(self): def __repr__(self):
s = super(DirtPiles, self).__repr__() s = super(DirtPiles, self).__repr__()
return f'{s[:-1]}, {self.amount})' return f'{s[:-1]}, {self.global_amount}]'

View File

@ -22,58 +22,37 @@ class DoneOnAllDirtCleaned(Rule):
def on_check_done(self, state) -> [DoneResult]: def on_check_done(self, state) -> [DoneResult]:
if len(state[d.DIRT]) == 0 and state.curr_step: if len(state[d.DIRT]) == 0 and state.curr_step:
return [DoneResult(validity=c.VALID, identifier=self.name, reward=self.reward)] return [DoneResult(validity=c.VALID, identifier=self.name, reward=self.reward)]
return [DoneResult(validity=c.NOT_VALID, identifier=self.name, reward=0)] return [DoneResult(validity=c.NOT_VALID, identifier=self.name)]
class SpawnDirt(Rule): class RespawnDirt(Rule):
def __init__(self, initial_n: int = 5, initial_amount: float = 1.3, def __init__(self, respawn_freq: int = 15, respawn_n: int = 5, respawn_amount: float = 1.0):
respawn_n: int = 3, respawn_amount: float = 0.8,
n_var: float = 0.2, amount_var: float = 0.2, spawn_freq: int = 15):
""" """
Defines the spawn pattern of intial and additional 'Dirt'-entitites. Defines the spawn pattern of intial and additional 'Dirt'-entitites.
First chooses positions, then trys to spawn dirt until 'respawn_n' or the maximal global amount is reached. First chooses positions, then trys to spawn dirt until 'respawn_n' or the maximal global amount is reached.
If there is allready some, it is topped up to min(max_local_amount, amount). If there is allready some, it is topped up to min(max_local_amount, amount).
:type spawn_freq: int :type respawn_freq: int
:parameter spawn_freq: In which frequency should this Rule try to spawn new 'Dirt'? :parameter respawn_freq: In which frequency should this Rule try to spawn new 'Dirt'?
:type respawn_n: int :type respawn_n: int
:parameter respawn_n: How many respawn positions are considered. :parameter respawn_n: How many respawn positions are considered.
:type initial_n: int
:parameter initial_n: How much initial positions are considered.
:type amount_var: float
:parameter amount_var: Variance of amount to spawn.
:type n_var: float
:parameter n_var: Variance of n to spawn.
:type respawn_amount: float :type respawn_amount: float
:parameter respawn_amount: Defines how much dirt 'amount' is placed every 'spawn_freq' ticks. :parameter respawn_amount: Defines how much dirt 'amount' is placed every 'spawn_freq' ticks.
:type initial_amount: float
:parameter initial_amount: Defines how much dirt 'amount' is initially placed.
""" """
super().__init__() super().__init__()
self.amount_var = amount_var
self.n_var = n_var
self.respawn_amount = respawn_amount
self.respawn_n = respawn_n self.respawn_n = respawn_n
self.initial_amount = initial_amount self.respawn_amount = respawn_amount
self.initial_n = initial_n self.respawn_freq = respawn_freq
self.spawn_freq = spawn_freq self._next_dirt_spawn = respawn_freq
self._next_dirt_spawn = spawn_freq
def on_init(self, state, lvl_map) -> str:
result = state[d.DIRT].trigger_dirt_spawn(self.initial_n, self.initial_amount, state,
n_var=self.n_var, amount_var=self.amount_var)
state.print(f'Initial Dirt was spawned on: {[x.pos for x in state[d.DIRT]]}')
return result
def tick_step(self, state): def tick_step(self, state):
collection = state[d.DIRT]
if self._next_dirt_spawn < 0: if self._next_dirt_spawn < 0:
pass # No DirtPile Spawn pass # No DirtPile Spawn
elif not self._next_dirt_spawn: elif not self._next_dirt_spawn:
result = [state[d.DIRT].trigger_dirt_spawn(self.respawn_n, self.respawn_amount, state, result = [collection.trigger_spawn(state, coords_or_quantity=self.respawn_n, amount=self.respawn_amount)]
n_var=self.n_var, amount_var=self.amount_var)] self._next_dirt_spawn = self.respawn_freq
self._next_dirt_spawn = self.spawn_freq
else: else:
self._next_dirt_spawn -= 1 self._next_dirt_spawn -= 1
result = [] result = []
@ -99,8 +78,8 @@ class EntitiesSmearDirtOnMove(Rule):
for entity in state.moving_entites: for entity in state.moving_entites:
if is_move(entity.state.identifier) and entity.state.validity == c.VALID: if is_move(entity.state.identifier) and entity.state.validity == c.VALID:
if old_pos_dirt := state[d.DIRT].by_pos(entity.last_pos): if old_pos_dirt := state[d.DIRT].by_pos(entity.last_pos):
old_pos_dirt = next(iter(old_pos_dirt))
if smeared_dirt := round(old_pos_dirt.amount * self.smear_ratio, 2): if smeared_dirt := round(old_pos_dirt.amount * self.smear_ratio, 2):
if state[d.DIRT].spawn(entity.pos, amount=smeared_dirt): if state[d.DIRT].spawn(entity.pos, amount=smeared_dirt):
results.append(TickResult(identifier=self.name, entity=entity, results.append(TickResult(identifier=self.name, entity=entity, validity=c.VALID))
reward=0, validity=c.VALID))
return results return results

View File

@ -1,4 +1,7 @@
from .actions import DestAction from .actions import DestAction
from .entitites import Destination from .entitites import Destination
from .groups import Destinations from .groups import Destinations
from .rules import DoneAtDestinationReachAll, SpawnDestinations from .rules import (DoneAtDestinationReachAll,
DoneAtDestinationReachAny,
SpawnDestinationsPerAgent,
DestinationReachReward)

View File

@ -9,30 +9,6 @@ from marl_factory_grid.utils.utility_classes import RenderEntity
class Destination(Entity): class Destination(Entity):
@property
def var_can_move(self):
return False
@property
def var_can_collide(self):
return False
@property
def var_has_position(self):
return True
@property
def var_is_blocking_pos(self):
return False
@property
def var_is_blocking_light(self):
return False
@property
def var_can_be_bound(self):
return True
def was_reached(self): def was_reached(self):
return self._was_reached return self._was_reached

View File

@ -7,37 +7,14 @@ from marl_factory_grid.modules.destinations import constants as d
class Destinations(Collection): class Destinations(Collection):
_entity = Destination _entity = Destination
@property var_is_blocking_light = False
def var_is_blocking_light(self): var_can_collide = False
return False var_can_move = False
var_has_position = True
@property var_can_be_bound = True
def var_can_collide(self):
return False
@property
def var_can_move(self):
return False
@property
def var_has_position(self):
return True
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def __repr__(self): def __repr__(self):
return super(Destinations, self).__repr__() return super(Destinations, self).__repr__()
@staticmethod
def trigger_destination_spawn(n_dests, state):
coordinates = state.entities.floorlist[:n_dests]
if destinations := [Destination(pos) for pos in coordinates]:
state[d.DESTINATION].add_items(destinations)
state.print(f'{n_dests} new destinations have been spawned')
return c.VALID
else:
state.print('No Destiantions are spawning, limit is reached.')
return c.NOT_VALID

View File

@ -2,8 +2,8 @@ import ast
from random import shuffle from random import shuffle
from typing import List, Dict, Tuple from typing import List, Dict, Tuple
import marl_factory_grid.modules.destinations.constants
from marl_factory_grid.environment.rules import Rule from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.utils import helpers as h
from marl_factory_grid.utils.results import TickResult, DoneResult from marl_factory_grid.utils.results import TickResult, DoneResult
from marl_factory_grid.environment import constants as c from marl_factory_grid.environment import constants as c
@ -54,7 +54,7 @@ class DoneAtDestinationReachAll(DestinationReachReward):
""" """
This rule triggers and sets the done flag if ALL Destinations have been reached. This rule triggers and sets the done flag if ALL Destinations have been reached.
:type reward_at_done: object :type reward_at_done: float
:param reward_at_done: Specifies the reward, agent get, whenn all destinations are reached. :param reward_at_done: Specifies the reward, agent get, whenn all destinations are reached.
:type dest_reach_reward: float :type dest_reach_reward: float
:param dest_reach_reward: Specify the reward, agents get when reaching a single destination. :param dest_reach_reward: Specify the reward, agents get when reaching a single destination.
@ -65,7 +65,7 @@ class DoneAtDestinationReachAll(DestinationReachReward):
def on_check_done(self, state) -> List[DoneResult]: def on_check_done(self, state) -> List[DoneResult]:
if all(x.was_reached() for x in state[d.DESTINATION]): if all(x.was_reached() for x in state[d.DESTINATION]):
return [DoneResult(self.name, validity=c.VALID, reward=self.reward)] return [DoneResult(self.name, validity=c.VALID, reward=self.reward)]
return [DoneResult(self.name, validity=c.NOT_VALID, reward=0)] return [DoneResult(self.name, validity=c.NOT_VALID)]
class DoneAtDestinationReachAny(DestinationReachReward): class DoneAtDestinationReachAny(DestinationReachReward):
@ -75,7 +75,7 @@ class DoneAtDestinationReachAny(DestinationReachReward):
This rule triggers and sets the done flag if ANY Destinations has been reached. This rule triggers and sets the done flag if ANY Destinations has been reached.
!!! IMPORTANT: 'reward_at_done' is shared between the agents; 'dest_reach_reward' is bound to a specific one. !!! IMPORTANT: 'reward_at_done' is shared between the agents; 'dest_reach_reward' is bound to a specific one.
:type reward_at_done: object :type reward_at_done: float
:param reward_at_done: Specifies the reward, all agent get, when any destinations has been reached. :param reward_at_done: Specifies the reward, all agent get, when any destinations has been reached.
Default {d.REWARD_DEST_DONE} Default {d.REWARD_DEST_DONE}
:type dest_reach_reward: float :type dest_reach_reward: float
@ -87,67 +87,29 @@ class DoneAtDestinationReachAny(DestinationReachReward):
def on_check_done(self, state) -> List[DoneResult]: def on_check_done(self, state) -> List[DoneResult]:
if any(x.was_reached() for x in state[d.DESTINATION]): if any(x.was_reached() for x in state[d.DESTINATION]):
return [DoneResult(self.name, validity=c.VALID, reward=marl_factory_grid.modules.destinations.constants.REWARD_DEST_REACHED)] return [DoneResult(self.name, validity=c.VALID, reward=d.REWARD_DEST_REACHED)]
return [] return []
class SpawnDestinations(Rule):
def __init__(self, n_dests: int = 1, spawn_mode: str = d.MODE_GROUPED):
f"""
Defines how destinations are initially spawned and respawned in addition.
!!! This rule introduces no kind of reward or Env.-Done condition!
:type n_dests: int
:param n_dests: How many destiantions should be maintained (and initally spawnewd) on the map?
:type spawn_mode: str
:param spawn_mode: One of {d.SPAWN_MODES}. {d.MODE_GROUPED}: Always wait for all Dstiantions do be gone,
then respawn after the given time. {d.MODE_SINGLE}: Just spawn every destination,
that has been reached, after the given time
"""
super(SpawnDestinations, self).__init__()
self.n_dests = n_dests
self.spawn_mode = spawn_mode
def on_init(self, state, lvl_map):
# noinspection PyAttributeOutsideInit
state[d.DESTINATION].trigger_destination_spawn(self.n_dests, state)
pass
def tick_pre_step(self, state) -> List[TickResult]:
pass
def tick_step(self, state) -> List[TickResult]:
if n_dest_spawn := max(0, self.n_dests - len(state[d.DESTINATION])):
if self.spawn_mode == d.MODE_GROUPED and n_dest_spawn == self.n_dests:
validity = state[d.DESTINATION].trigger_destination_spawn(n_dest_spawn, state)
return [TickResult(self.name, validity=validity, entity=None, value=n_dest_spawn)]
elif self.spawn_mode == d.MODE_SINGLE and n_dest_spawn:
validity = state[d.DESTINATION].trigger_destination_spawn(n_dest_spawn, state)
return [TickResult(self.name, validity=validity, entity=None, value=n_dest_spawn)]
else:
pass
class SpawnDestinationsPerAgent(Rule): class SpawnDestinationsPerAgent(Rule):
def __init__(self, per_agent_positions: Dict[str, List[Tuple[int, int]]]): def __init__(self, coords_or_quantity: Dict[str, List[Tuple[int, int]]]):
""" """
Special rule, that spawn distinations, that are bound to a single agent a fixed set of positions. Special rule, that spawn distinations, that are bound to a single agent a fixed set of positions.
Usefull for introducing specialists, etc. .. Usefull for introducing specialists, etc. ..
!!! This rule does not introduce any reward or done condition. !!! This rule does not introduce any reward or done condition.
:type per_agent_positions: Dict[str, List[Tuple[int, int]] :type coords_or_quantity: Dict[str, List[Tuple[int, int]]
:param per_agent_positions: Please provide a dictionary with agent names as keys; and a list of possible :param coords_or_quantity: Please provide a dictionary with agent names as keys; and a list of possible
destiantion coords as value. Example: {Wolfgang: [(0, 0), (1, 1), ...]} destiantion coords as value. Example: {Wolfgang: [(0, 0), (1, 1), ...]}
""" """
super(Rule, self).__init__() super(Rule, self).__init__()
self.per_agent_positions = {key: [ast.literal_eval(x) for x in val] for key, val in per_agent_positions.items()} self.per_agent_positions = {key: [ast.literal_eval(x) for x in val] for key, val in coords_or_quantity.items()}
def on_init(self, state, lvl_map): def on_init(self, state, lvl_map):
for (agent_name, position_list) in self.per_agent_positions.items(): for (agent_name, position_list) in self.per_agent_positions.items():
agent = next(x for x in state[c.AGENT] if agent_name in x.name) # Fixme: Ugly AF agent = h.get_first(state[c.AGENT], lambda x: agent_name in x.name)
assert agent
position_list = position_list.copy() position_list = position_list.copy()
shuffle(position_list) shuffle(position_list)
while True: while True:
@ -155,7 +117,7 @@ class SpawnDestinationsPerAgent(Rule):
pos = position_list.pop() pos = position_list.pop()
except IndexError: except IndexError:
print(f"Could not spawn Destinations at: {self.per_agent_positions[agent_name]}") print(f"Could not spawn Destinations at: {self.per_agent_positions[agent_name]}")
print(f'Check your agent palcement: {state[c.AGENT]} ... Exit ...') print(f'Check your agent placement: {state[c.AGENT]} ... Exit ...')
exit(9999) exit(9999)
if (not pos == agent.pos) and (not state[d.DESTINATION].by_pos(pos)): if (not pos == agent.pos) and (not state[d.DESTINATION].by_pos(pos)):
destination = Destination(pos, bind_to=agent) destination = Destination(pos, bind_to=agent)

View File

@ -1,4 +1,5 @@
from marl_factory_grid.environment.entity.entity import Entity from marl_factory_grid.environment.entity.entity import Entity
from marl_factory_grid.utils import Result
from marl_factory_grid.utils.utility_classes import RenderEntity from marl_factory_grid.utils.utility_classes import RenderEntity
from marl_factory_grid.environment import constants as c from marl_factory_grid.environment import constants as c
@ -41,21 +42,6 @@ class Door(Entity):
def str_state(self): def str_state(self):
return 'open' if self.is_open else 'closed' return 'open' if self.is_open else 'closed'
def __init__(self, *args, closed_on_init=True, auto_close_interval=10, **kwargs):
self._status = d.STATE_CLOSED
super(Door, self).__init__(*args, **kwargs)
self.auto_close_interval = auto_close_interval
self.time_to_close = 0
if not closed_on_init:
self._open()
else:
self._close()
def summarize_state(self):
state_dict = super().summarize_state()
state_dict.update(state=str(self.str_state), time_to_close=int(self.time_to_close))
return state_dict
@property @property
def is_closed(self): def is_closed(self):
return self._status == d.STATE_CLOSED return self._status == d.STATE_CLOSED
@ -68,6 +54,25 @@ class Door(Entity):
def status(self): def status(self):
return self._status return self._status
@property
def time_to_close(self):
return self._time_to_close
def __init__(self, *args, closed_on_init=True, auto_close_interval=10, **kwargs):
self._status = d.STATE_CLOSED
super(Door, self).__init__(*args, **kwargs)
self._auto_close_interval = auto_close_interval
self._time_to_close = 0
if not closed_on_init:
self._open()
else:
self._close()
def summarize_state(self):
state_dict = super().summarize_state()
state_dict.update(state=str(self.str_state), time_to_close=self.time_to_close)
return state_dict
def render(self): def render(self):
name, state = 'door_open' if self.is_open else 'door_closed', 'blank' name, state = 'door_open' if self.is_open else 'door_closed', 'blank'
return RenderEntity(name, self.pos, 1, 'none', state, self.u_int + 1) return RenderEntity(name, self.pos, 1, 'none', state, self.u_int + 1)
@ -80,18 +85,35 @@ class Door(Entity):
return c.VALID return c.VALID
def tick(self, state): def tick(self, state):
if self.is_open and len(state.entities.pos_dict[self.pos]) == 2 and self.time_to_close: # Check if no entity is standing in the door
self.time_to_close -= 1 if len(state.entities.pos_dict[self.pos]) <= 2:
return c.NOT_VALID if self.is_open and self.time_to_close:
elif self.is_open and not self.time_to_close and len(state.entities.pos_dict[self.pos]) == 2: self._decrement_timer()
self.use() return Result(f"{d.DOOR}_tick", c.VALID, entity=self)
return c.VALID elif self.is_open and not self.time_to_close:
self.use()
return Result(f"{d.DOOR}_closed", c.VALID, entity=self)
else:
# No one is in door, but it is closed... Nothing to do....
return None
else: else:
return c.NOT_VALID # Entity is standing in the door, reset timer
self._reset_timer()
return Result(f"{d.DOOR}_reset", c.VALID, entity=self)
def _open(self): def _open(self):
self._status = d.STATE_OPEN self._status = d.STATE_OPEN
self.time_to_close = self.auto_close_interval self._reset_timer()
return True
def _close(self): def _close(self):
self._status = d.STATE_CLOSED self._status = d.STATE_CLOSED
return True
def _decrement_timer(self):
self._time_to_close -= 1
return True
def _reset_timer(self):
self._time_to_close = self._auto_close_interval
return True

View File

@ -18,8 +18,10 @@ class Doors(Collection):
super(Doors, self).__init__(*args, can_collide=True, **kwargs) super(Doors, self).__init__(*args, can_collide=True, **kwargs)
def tick_doors(self, state): def tick_doors(self, state):
result_dict = dict() results = list()
for door in self: for door in self:
did_tick = door.tick(state) tick_result = door.tick(state)
result_dict.update({door.name: did_tick}) if tick_result is not None:
return result_dict results.append(tick_result)
# TODO: Should return a Result object, not a random dict.
return results

View File

@ -19,10 +19,10 @@ class DoorAutoClose(Rule):
def tick_step(self, state): def tick_step(self, state):
if doors := state[d.DOORS]: if doors := state[d.DOORS]:
doors_tick_result = doors.tick_doors(state) doors_tick_results = doors.tick_doors(state)
doors_that_ticked = [key for key, val in doors_tick_result.items() if val] doors_that_closed = [x.entity.name for x in doors_tick_results if 'closed' in x.identifier]
state.print(f'{doors_that_ticked} were auto-closed' door_str = doors_that_closed if doors_that_closed else "No Doors"
if doors_that_ticked else 'No Doors were auto-closed') state.print(f'{door_str} were auto-closed')
return [TickResult(self.name, validity=c.VALID, value=1)] return [TickResult(self.name, validity=c.VALID, value=1)]
state.print('There are no doors, but you loaded the corresponding Module') state.print('There are no doors, but you loaded the corresponding Module')
return [] return []

View File

@ -1,4 +1,3 @@
from .actions import ItemAction from .actions import ItemAction
from .entitites import Item, DropOffLocation from .entitites import Item, DropOffLocation
from .groups import DropOffLocations, Items, Inventory, Inventories from .groups import DropOffLocations, Items, Inventory, Inventories
from .rules import ItemRules

View File

@ -29,7 +29,7 @@ class ItemAction(Action):
elif items := state[i.ITEM].by_pos(entity.pos): elif items := state[i.ITEM].by_pos(entity.pos):
item = items[0] item = items[0]
item.change_parent_collection(inventory) item.change_parent_collection(inventory)
item.set_pos_to(c.VALUE_NO_POS) item.set_pos(c.VALUE_NO_POS)
state.print(f'{entity.name} just picked up an item at {entity.pos}') state.print(f'{entity.name} just picked up an item at {entity.pos}')
return ActionResult(entity=entity, identifier=self._identifier, validity=c.VALID, reward=r.PICK_UP_VALID) return ActionResult(entity=entity, identifier=self._identifier, validity=c.VALID, reward=r.PICK_UP_VALID)

View File

@ -8,16 +8,11 @@ from marl_factory_grid.modules.items import constants as i
class Item(Entity): class Item(Entity):
@property
def var_can_collide(self):
return False
def render(self): def render(self):
return RenderEntity(i.ITEM, self.pos) if self.pos != c.VALUE_NO_POS else None return RenderEntity(i.ITEM, self.pos) if self.pos != c.VALUE_NO_POS else None
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._auto_despawn = -1
@property @property
def auto_despawn(self): def auto_despawn(self):
@ -31,9 +26,6 @@ class Item(Entity):
def set_auto_despawn(self, auto_despawn): def set_auto_despawn(self, auto_despawn):
self._auto_despawn = auto_despawn self._auto_despawn = auto_despawn
def set_pos_to(self, no_pos):
self._pos = no_pos
def summarize_state(self) -> dict: def summarize_state(self) -> dict:
super_summarization = super(Item, self).summarize_state() super_summarization = super(Item, self).summarize_state()
super_summarization.update(dict(auto_despawn=self.auto_despawn)) super_summarization.update(dict(auto_despawn=self.auto_despawn))
@ -42,21 +34,6 @@ class Item(Entity):
class DropOffLocation(Entity): class DropOffLocation(Entity):
@property
def var_can_collide(self):
return False
@property
def var_can_move(self):
return False
@property
def var_is_blocking_light(self):
return False
@property
def var_has_position(self):
return True
def render(self): def render(self):
return RenderEntity(i.DROP_OFF, self.pos) return RenderEntity(i.DROP_OFF, self.pos)

View File

@ -8,6 +8,7 @@ from marl_factory_grid.environment.groups.objects import _Objects
from marl_factory_grid.environment.groups.mixins import IsBoundMixin from marl_factory_grid.environment.groups.mixins import IsBoundMixin
from marl_factory_grid.environment.entity.agent import Agent from marl_factory_grid.environment.entity.agent import Agent
from marl_factory_grid.modules.items.entitites import Item, DropOffLocation from marl_factory_grid.modules.items.entitites import Item, DropOffLocation
from marl_factory_grid.utils.results import Result
class Items(Collection): class Items(Collection):
@ -15,7 +16,7 @@ class Items(Collection):
@property @property
def var_has_position(self): def var_has_position(self):
return False return True
@property @property
def is_blocking_light(self): def is_blocking_light(self):
@ -28,18 +29,18 @@ class Items(Collection):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@staticmethod def trigger_spawn(self, state, *entity_args, coords_or_quantity=None, **entity_kwargs) -> [Result]:
def trigger_item_spawn(state, n_items, spawn_frequency): coords_or_quantity = coords_or_quantity if coords_or_quantity else self._coords_or_quantity
if item_to_spawns := max(0, (n_items - len(state[i.ITEM]))): assert coords_or_quantity
position_list = [x for x in state.entities.floorlist]
shuffle(position_list) if item_to_spawns := max(0, (coords_or_quantity - len(self))):
position_list = state.entities.floorlist[:item_to_spawns] return super().trigger_spawn(state,
state[i.ITEM].spawn(position_list) *entity_args,
state.print(f'{item_to_spawns} new items have been spawned; next spawn in {spawn_frequency}') coords_or_quantity=item_to_spawns,
return len(position_list) **entity_kwargs)
else: else:
state.print('No Items are spawning, limit is reached.') state.print('No Items are spawning, limit is reached.')
return 0 return Result(identifier=f'{self.name}_spawn', validity=c.NOT_VALID, value=coords_or_quantity)
class Inventory(IsBoundMixin, Collection): class Inventory(IsBoundMixin, Collection):
@ -76,9 +77,15 @@ class Inventory(IsBoundMixin, Collection):
class Inventories(_Objects): class Inventories(_Objects):
_entity = Inventory _entity = Inventory
var_can_move = False
var_has_position = False
symbol = None
@property @property
def var_can_move(self): def spawn_rule(self):
return False return {c.SPAWN_ENTITY_RULE: dict(collection=self, coords_or_quantity=None)}
def __init__(self, size: int, *args, **kwargs): def __init__(self, size: int, *args, **kwargs):
super(Inventories, self).__init__(*args, **kwargs) super(Inventories, self).__init__(*args, **kwargs)
@ -86,10 +93,12 @@ class Inventories(_Objects):
self._obs = None self._obs = None
self._lazy_eval_transforms = [] self._lazy_eval_transforms = []
def spawn(self, agents): def spawn(self, agents, *args, **kwargs):
inventories = [self._entity(agent, self.size, ) self.add_items([self._entity(agent, self.size, *args, **kwargs) for _, agent in enumerate(agents)])
for _, agent in enumerate(agents)] return [Result(identifier=f'{self.name}_spawn', validity=c.VALID, value=len(self))]
self.add_items(inventories)
def trigger_spawn(self, state, *args, **kwargs) -> [Result]:
return self.spawn(state[c.AGENT], *args, **kwargs)
def idx_by_entity(self, entity): def idx_by_entity(self, entity):
try: try:
@ -106,9 +115,6 @@ class Inventories(_Objects):
def summarize_states(self, **kwargs): def summarize_states(self, **kwargs):
return [val.summarize_states(**kwargs) for key, val in self.items()] return [val.summarize_states(**kwargs) for key, val in self.items()]
@staticmethod
def trigger_inventory_spawn(state):
state[i.INVENTORY].spawn(state[c.AGENT])
class DropOffLocations(Collection): class DropOffLocations(Collection):
@ -135,7 +141,7 @@ class DropOffLocations(Collection):
@staticmethod @staticmethod
def trigger_drop_off_location_spawn(state, n_locations): def trigger_drop_off_location_spawn(state, n_locations):
empty_positions = state.entities.empty_positions()[:n_locations] empty_positions = state.entities.empty_positions[:n_locations]
do_entites = state[i.DROP_OFF] do_entites = state[i.DROP_OFF]
drop_offs = [DropOffLocation(pos) for pos in empty_positions] drop_offs = [DropOffLocation(pos) for pos in empty_positions]
do_entites.add_items(drop_offs) do_entites.add_items(drop_offs)

View File

@ -6,52 +6,28 @@ from marl_factory_grid.utils.results import TickResult
from marl_factory_grid.modules.items import constants as i from marl_factory_grid.modules.items import constants as i
class ItemRules(Rule): class RespawnItems(Rule):
def __init__(self, n_items: int = 5, spawn_frequency: int = 15, def __init__(self, n_items: int = 5, respawn_freq: int = 15, n_locations: int = 5):
n_locations: int = 5, max_dropoff_storage_size: int = 0):
super().__init__() super().__init__()
self.spawn_frequency = spawn_frequency self.spawn_frequency = respawn_freq
self._next_item_spawn = spawn_frequency self._next_item_spawn = respawn_freq
self.n_items = n_items self.n_items = n_items
self.max_dropoff_storage_size = max_dropoff_storage_size
self.n_locations = n_locations self.n_locations = n_locations
def on_init(self, state, lvl_map):
state[i.DROP_OFF].trigger_drop_off_location_spawn(state, self.n_locations)
self._next_item_spawn = self.spawn_frequency
state[i.INVENTORY].trigger_inventory_spawn(state)
state[i.ITEM].trigger_item_spawn(state, self.n_items, self.spawn_frequency)
def tick_step(self, state): def tick_step(self, state):
for item in list(state[i.ITEM].values()):
if item.auto_despawn >= 1:
item.set_auto_despawn(item.auto_despawn - 1)
elif not item.auto_despawn:
state[i.ITEM].delete_env_object(item)
else:
pass
if not self._next_item_spawn: if not self._next_item_spawn:
state[i.ITEM].trigger_item_spawn(state, self.n_items, self.spawn_frequency) state[i.ITEM].trigger_spawn(state, self.n_items, self.spawn_frequency)
else: else:
self._next_item_spawn = max(0, self._next_item_spawn - 1) self._next_item_spawn = max(0, self._next_item_spawn - 1)
return [] return []
def tick_post_step(self, state) -> List[TickResult]: def tick_post_step(self, state) -> List[TickResult]:
for item in list(state[i.ITEM].values()):
if item.auto_despawn >= 1:
item.set_auto_despawn(item.auto_despawn-1)
elif not item.auto_despawn:
state[i.ITEM].delete_env_object(item)
else:
pass
if not self._next_item_spawn: if not self._next_item_spawn:
if spawned_items := state[i.ITEM].trigger_item_spawn(state, self.n_items, self.spawn_frequency): if spawned_items := state[i.ITEM].trigger_spawn(state, self.n_items, self.spawn_frequency):
return [TickResult(self.name, validity=c.VALID, value=spawned_items, entity=None)] return [TickResult(self.name, validity=c.VALID, value=spawned_items.value)]
else: else:
return [TickResult(self.name, validity=c.NOT_VALID, value=0, entity=None)] return [TickResult(self.name, validity=c.NOT_VALID, value=0)]
else: else:
self._next_item_spawn = max(0, self._next_item_spawn-1) self._next_item_spawn = max(0, self._next_item_spawn-1)
return [] return []

View File

@ -1,3 +1,2 @@
from .entitites import Machine from .entitites import Machine
from .groups import Machines from .groups import Machines
from .rules import MachineRule

View File

@ -5,6 +5,7 @@ from marl_factory_grid.utils.results import ActionResult
from marl_factory_grid.modules.machines import constants as m, rewards as r from marl_factory_grid.modules.machines import constants as m, rewards as r
from marl_factory_grid.environment import constants as c from marl_factory_grid.environment import constants as c
from marl_factory_grid.utils import helpers as h
class MachineAction(Action): class MachineAction(Action):
@ -13,13 +14,10 @@ class MachineAction(Action):
super().__init__(m.MACHINE_ACTION) super().__init__(m.MACHINE_ACTION)
def do(self, entity, state) -> Union[None, ActionResult]: def do(self, entity, state) -> Union[None, ActionResult]:
if machine := state[m.MACHINES].by_pos(entity.pos): if machine := h.get_first(state[m.MACHINES].by_pos(entity.pos)):
if valid := machine.maintain(): if valid := machine.maintain():
return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=r.MAINTAIN_VALID) return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=r.MAINTAIN_VALID)
else: else:
return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=r.MAINTAIN_FAIL) return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=r.MAINTAIN_FAIL)
else: else:
return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=r.MAINTAIN_FAIL) return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=r.MAINTAIN_FAIL)

View File

@ -8,22 +8,6 @@ from . import constants as m
class Machine(Entity): class Machine(Entity):
@property
def var_can_collide(self):
return False
@property
def var_can_move(self):
return False
@property
def var_is_blocking_light(self):
return False
@property
def var_has_position(self):
return True
@property @property
def encoding(self): def encoding(self):
return self._encodings[self.status] return self._encodings[self.status]
@ -46,12 +30,12 @@ class Machine(Entity):
else: else:
return c.NOT_VALID return c.NOT_VALID
def tick(self): def tick(self, state):
# if self.status == m.STATE_MAINTAIN and any([c.AGENT in x.name for x in self.tile.guests]): # if self.status == m.STATE_MAINTAIN and any([c.AGENT in x.name for x in self.tile.guests]):
if self.status == m.STATE_MAINTAIN and any([c.AGENT in x.name for x in self.state.entities.pos_dict[self.pos]]): if self.status == m.STATE_MAINTAIN and any([c.AGENT in x.name for x in state.entities.pos_dict[self.pos]]):
return TickResult(identifier=self.name, validity=c.VALID, reward=0, entity=self) return TickResult(identifier=self.name, validity=c.VALID, entity=self)
# elif self.status == m.STATE_MAINTAIN and not any([c.AGENT in x.name for x in self.tile.guests]): # elif self.status == m.STATE_MAINTAIN and not any([c.AGENT in x.name for x in self.tile.guests]):
elif self.status == m.STATE_MAINTAIN and not any([c.AGENT in x.name for x in self.state.entities.pos_dict[self.pos]]): elif self.status == m.STATE_MAINTAIN and not any([c.AGENT in x.name for x in state.entities.pos_dict[self.pos]]):
self.status = m.STATE_WORK self.status = m.STATE_WORK
self.reset_counter() self.reset_counter()
return None return None

View File

@ -1,28 +0,0 @@
from typing import List
from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.utils.results import TickResult, DoneResult
from marl_factory_grid.environment import constants as c
from marl_factory_grid.modules.machines import constants as m
from marl_factory_grid.modules.machines.entitites import Machine
class MachineRule(Rule):
def __init__(self, n_machines: int = 2):
super(MachineRule, self).__init__()
self.n_machines = n_machines
def on_init(self, state, lvl_map):
state[m.MACHINES].spawn(state.entities.empty_positions())
def tick_pre_step(self, state) -> List[TickResult]:
pass
def tick_step(self, state) -> List[TickResult]:
pass
def tick_post_step(self, state) -> List[TickResult]:
pass
def on_check_done(self, state) -> List[DoneResult]:
pass

View File

@ -1,48 +1,35 @@
from random import shuffle
import networkx as nx import networkx as nx
import numpy as np import numpy as np
from ...algorithms.static.utils import points_to_graph from ...algorithms.static.utils import points_to_graph
from ...environment import constants as c from ...environment import constants as c
from ...environment.actions import Action, ALL_BASEACTIONS from ...environment.actions import Action, ALL_BASEACTIONS
from ...environment.entity.entity import Entity from ...environment.entity.entity import Entity
from ..doors import constants as do from ..doors import constants as do
from ..maintenance import constants as mi from ..maintenance import constants as mi
from ...utils.helpers import MOVEMAP from ...utils import helpers as h
from ...utils.utility_classes import RenderEntity from ...utils.utility_classes import RenderEntity, Floor
from ...utils.states import Gamestate from ..doors import DoorUse
class Maintainer(Entity): class Maintainer(Entity):
@property def __init__(self, objective: str, action: Action, *args, **kwargs):
def var_can_collide(self):
return True
@property
def var_can_move(self):
return False
@property
def var_is_blocking_light(self):
return False
@property
def var_has_position(self):
return True
def __init__(self, state: Gamestate, objective: str, action: Action, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.action = action self.action = action
self.actions = [x() for x in ALL_BASEACTIONS] self.actions = [x() for x in ALL_BASEACTIONS] + [DoorUse()]
self.objective = objective self.objective = objective
self._path = None self._path = None
self._next = [] self._next = []
self._last = [] self._last = []
self._last_serviced = 'None' self._last_serviced = 'None'
self._floortile_graph = points_to_graph(state.entities.floorlist) self._floortile_graph = None
def tick(self, state): def tick(self, state):
if found_objective := state[self.objective].by_pos(self.pos): if found_objective := h.get_first(state[self.objective].by_pos(self.pos)):
if found_objective.name != self._last_serviced: if found_objective.name != self._last_serviced:
self.action.do(self, state) self.action.do(self, state)
self._last_serviced = found_objective.name self._last_serviced = found_objective.name
@ -54,24 +41,27 @@ class Maintainer(Entity):
return action.do(self, state) return action.do(self, state)
def get_move_action(self, state) -> Action: def get_move_action(self, state) -> Action:
if not self._floortile_graph:
state.print("Generating Floorgraph....")
self._floortile_graph = points_to_graph(state.entities.floorlist)
if self._path is None or not self._path: if self._path is None or not self._path:
if not self._next: if not self._next:
self._next = list(state[self.objective].values()) self._next = list(state[self.objective].values()) + [Floor(*state.random_free_position)]
shuffle(self._next)
self._last = [] self._last = []
self._last.append(self._next.pop()) self._last.append(self._next.pop())
state.print("Calculating shortest path....")
self._path = self.calculate_route(self._last[-1]) self._path = self.calculate_route(self._last[-1])
if door := self._door_is_close(state): if door := self._closed_door_in_path(state):
if door.is_closed: state.print(f"{self} found {door} that is closed. Attempt to open.")
# Translate the action_object to an integer to have the same output as any other model # Translate the action_object to an integer to have the same output as any other model
action = do.ACTION_DOOR_USE action = do.ACTION_DOOR_USE
else:
action = self._predict_move(state)
else: else:
action = self._predict_move(state) action = self._predict_move(state)
# Translate the action_object to an integer to have the same output as any other model # Translate the action_object to an integer to have the same output as any other model
try: try:
action_obj = next(x for x in self.actions if x.name == action) action_obj = h.get_first(self.actions, lambda x: x.name == action)
except (StopIteration, UnboundLocalError): except (StopIteration, UnboundLocalError):
print('Will not happen') print('Will not happen')
raise EnvironmentError raise EnvironmentError
@ -81,11 +71,10 @@ class Maintainer(Entity):
route = nx.shortest_path(self._floortile_graph, self.pos, entity.pos) route = nx.shortest_path(self._floortile_graph, self.pos, entity.pos)
return route[1:] return route[1:]
def _door_is_close(self, state): def _closed_door_in_path(self, state):
state.print("Found a door that is close.") if self._path:
try: return h.get_first(state[do.DOORS].by_pos(self._path[0]), lambda x: x.is_closed)
return next(y for x in state.entities.neighboring_positions(self.state.pos) for y in state.entities.pos_dict[x] if do.DOOR in y.name) else:
except StopIteration:
return None return None
def _predict_move(self, state): def _predict_move(self, state):
@ -96,7 +85,7 @@ class Maintainer(Entity):
next_pos = self._path.pop(0) next_pos = self._path.pop(0)
diff = np.subtract(next_pos, self.pos) diff = np.subtract(next_pos, self.pos)
# Retrieve action based on the pos dif (like in: What do I have to do to get there?) # Retrieve action based on the pos dif (like in: What do I have to do to get there?)
action = next(action for action, pos_diff in MOVEMAP.items() if np.all(diff == pos_diff)) action = next(action for action, pos_diff in h.MOVEMAP.items() if np.all(diff == pos_diff))
return action return action
def render(self): def render(self):

View File

@ -1,4 +1,4 @@
from typing import Union, List, Tuple from typing import Union, List, Tuple, Dict
from marl_factory_grid.environment.groups.collection import Collection from marl_factory_grid.environment.groups.collection import Collection
from .entities import Maintainer from .entities import Maintainer
@ -10,25 +10,21 @@ from ...utils.states import Gamestate
class Maintainers(Collection): class Maintainers(Collection):
_entity = Maintainer _entity = Maintainer
@property var_can_collide = True
def var_can_collide(self): var_can_move = True
return True var_is_blocking_light = False
var_has_position = True
@property def __init__(self, size, *args, coords_or_quantity: int = None,
def var_can_move(self): spawnrule: Union[None, Dict[str, dict]] = None,
return True **kwargs):
super(Collection, self).__init__(*args, **kwargs)
@property self._coords_or_quantity = coords_or_quantity
def var_is_blocking_light(self): self.size = size
return False self._spawnrule = spawnrule
@property
def var_has_position(self):
return True
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args): def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args):
state = entity_args[0] self.add_items([self._entity(mc.MACHINES, MachineAction(), pos) for pos in coords_or_quantity])
self.add_items([self._entity(state, mc.MACHINES, MachineAction(), pos) for pos in coords_or_quantity])

View File

@ -4,29 +4,24 @@ from marl_factory_grid.utils.results import TickResult, DoneResult
from marl_factory_grid.environment import constants as c from marl_factory_grid.environment import constants as c
from . import rewards as r from . import rewards as r
from . import constants as M from . import constants as M
from marl_factory_grid.utils.states import Gamestate
class MaintenanceRule(Rule): class MoveMaintainers(Rule):
def __init__(self, n_maintainer: int = 1, *args, **kwargs): def __init__(self, *args, **kwargs):
super(MaintenanceRule, self).__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.n_maintainer = n_maintainer
def on_init(self, state: Gamestate, lvl_map):
state[M.MAINTAINERS].spawn(state.entities.empty_positions[:self.n_maintainer], state)
pass
def tick_pre_step(self, state) -> List[TickResult]:
pass
def tick_step(self, state) -> List[TickResult]: def tick_step(self, state) -> List[TickResult]:
for maintainer in state[M.MAINTAINERS]: for maintainer in state[M.MAINTAINERS]:
maintainer.tick(state) maintainer.tick(state)
# Todo: Return a Result Object.
return [] return []
def tick_post_step(self, state) -> List[TickResult]:
pass class DoneAtMaintainerCollision(Rule):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def on_check_done(self, state) -> List[DoneResult]: def on_check_done(self, state) -> List[DoneResult]:
agents = list(state[c.AGENT].values()) agents = list(state[c.AGENT].values())

View File

@ -1,8 +1,8 @@
from random import choices, choice from random import choices, choice
from . import constants as z, Zone from . import constants as z, Zone
from .. import Destination
from ..destinations import constants as d from ..destinations import constants as d
from ... import Destination
from ...environment.rules import Rule from ...environment.rules import Rule
from ...environment import constants as c from ...environment import constants as c

View File

@ -0,0 +1,3 @@
from . import helpers as h
from . import helpers
from .results import Result, DoneResult, ActionResult, TickResult

View File

@ -1,28 +1,24 @@
import ast import ast
from collections import defaultdict
from os import PathLike from os import PathLike
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union, List
import yaml import yaml
from marl_factory_grid.environment.groups.agents import Agents
from marl_factory_grid.environment.entity.agent import Agent
from marl_factory_grid.environment.rules import Rule from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.utils.helpers import locate_and_import_class from marl_factory_grid.utils.helpers import locate_and_import_class
from marl_factory_grid.environment.constants import DEFAULT_PATH, MODULE_PATH
from marl_factory_grid.environment import constants as c from marl_factory_grid.environment import constants as c
DEFAULT_PATH = 'environment'
MODULE_PATH = 'modules'
class FactoryConfigParser(object): class FactoryConfigParser(object):
default_entites = [] default_entites = []
default_rules = ['MaxStepsReached', 'Collision'] default_rules = ['DoneAtMaxStepsReached', 'WatchCollision']
default_actions = [c.MOVE8, c.NOOP] default_actions = [c.MOVE8, c.NOOP]
default_observations = [c.WALLS, c.AGENT] default_observations = [c.WALLS, c.AGENT]
def __init__(self, config_path, custom_modules_path: Union[None, PathLike] = None): def __init__(self, config_path, custom_modules_path: Union[PathLike] = None):
self.config_path = Path(config_path) self.config_path = Path(config_path)
self.custom_modules_path = Path(custom_modules_path) if custom_modules_path is not None else custom_modules_path self.custom_modules_path = Path(custom_modules_path) if custom_modules_path is not None else custom_modules_path
self.config = yaml.safe_load(self.config_path.open()) self.config = yaml.safe_load(self.config_path.open())
@ -46,6 +42,10 @@ class FactoryConfigParser(object):
def rules(self): def rules(self):
return self.config['Rules'] return self.config['Rules']
@property
def tests(self):
return self.config.get('Tests', [])
@property @property
def agents(self): def agents(self):
return self.config['Agents'] return self.config['Agents']
@ -61,7 +61,6 @@ class FactoryConfigParser(object):
return self.config[item] return self.config[item]
def load_entities(self): def load_entities(self):
# entites = Entities()
entity_classes = dict() entity_classes = dict()
entities = [] entities = []
if c.DEFAULTS in self.entities: if c.DEFAULTS in self.entities:
@ -69,28 +68,40 @@ class FactoryConfigParser(object):
entities.extend(x for x in self.entities if x != c.DEFAULTS) entities.extend(x for x in self.entities if x != c.DEFAULTS)
for entity in entities: for entity in entities:
e1 = e2 = e3 = None
try: try:
folder_path = Path(__file__).parent.parent / DEFAULT_PATH folder_path = Path(__file__).parent.parent / DEFAULT_PATH
entity_class = locate_and_import_class(entity, folder_path) entity_class = locate_and_import_class(entity, folder_path)
except AttributeError as e1: except AttributeError as e:
e1 = e
try: try:
folder_path = Path(__file__).parent.parent / MODULE_PATH module_path = Path(__file__).parent.parent / MODULE_PATH
entity_class = locate_and_import_class(entity, folder_path) entity_class = locate_and_import_class(entity, module_path)
except AttributeError as e2: except AttributeError as e:
try: e2 = e
folder_path = self.custom_modules_path if self.custom_modules_path:
entity_class = locate_and_import_class(entity, folder_path) try:
except AttributeError as e3: entity_class = locate_and_import_class(entity, self.custom_modules_path)
ents = [y for x in [e1.argss[1], e2.argss[1], e3.argss[1]] for y in x] except AttributeError as e:
print('### Error ### Error ### Error ### Error ### Error ###') e3 = e
print() pass
print(f'Class "{entity}" was not found in "{folder_path.name}"') if (e1 and e2) or e3:
print('Possible Entitys are:', str(ents)) ents = [y for x in [e1, e2, e3] if x is not None for y in x.args[1]]
print() print('##############################################################')
print('Goodbye') print('### Error ### Error ### Error ### Error ### Error ###')
print() print('##############################################################')
exit() print(f'Class "{entity}" was not found in "{module_path.name}"')
# raise AttributeError(e1.argss[0], e2.argss[0], e3.argss[0], 'Possible Entitys are:', str(ents)) print(f'Class "{entity}" was not found in "{folder_path.name}"')
print('##############################################################')
if self.custom_modules_path:
print(f'Class "{entity}" was not found in "{self.custom_modules_path}"')
print('Possible Entitys are:', str(ents))
print('##############################################################')
print('Goodbye')
print('##############################################################')
print('### Error ### Error ### Error ### Error ### Error ###')
print('##############################################################')
exit(-99999)
entity_kwargs = self.entities.get(entity, {}) entity_kwargs = self.entities.get(entity, {})
entity_symbol = entity_class.symbol if hasattr(entity_class, 'symbol') else None entity_symbol = entity_class.symbol if hasattr(entity_class, 'symbol') else None
@ -128,31 +139,86 @@ class FactoryConfigParser(object):
observations.extend(self.default_observations) observations.extend(self.default_observations)
observations.extend(x for x in self.agents[name]['Observations'] if x != c.DEFAULTS) observations.extend(x for x in self.agents[name]['Observations'] if x != c.DEFAULTS)
positions = [ast.literal_eval(x) for x in self.agents[name].get('Positions', [])] positions = [ast.literal_eval(x) for x in self.agents[name].get('Positions', [])]
parsed_agents_conf[name] = dict(actions=parsed_actions, observations=observations, positions=positions) other_kwargs = {k: v for k, v in self.agents[name].items() if k not in
['Actions', 'Observations', 'Positions']}
parsed_agents_conf[name] = dict(
actions=parsed_actions, observations=observations, positions=positions, other=other_kwargs
)
return parsed_agents_conf return parsed_agents_conf
def load_rules(self): def load_env_rules(self) -> List[Rule]:
# entites = Entities() rules = self.rules.copy()
rules_classes = dict()
rules = []
if c.DEFAULTS in self.rules: if c.DEFAULTS in self.rules:
for rule in self.default_rules: for rule in self.default_rules:
if rule not in rules: if rule not in rules:
rules.append(rule) rules.append({rule: {}})
rules.extend(x for x in self.rules if x != c.DEFAULTS)
for rule in rules: return self._load_smth(rules, Rule)
def load_env_tests(self) -> List[Rule]:
return self._load_smth(self.tests, None) # Test
def _load_smth(self, config, class_obj):
rules = list()
rules_names = list()
for rule in config:
e1 = e2 = e3 = None
try: try:
folder_path = (Path(__file__).parent.parent / DEFAULT_PATH) folder_path = (Path(__file__).parent.parent / DEFAULT_PATH)
rule_class = locate_and_import_class(rule, folder_path) rule_class = locate_and_import_class(rule, folder_path)
except AttributeError: except AttributeError as e:
e1 = e
try: try:
folder_path = (Path(__file__).parent.parent / MODULE_PATH) module_path = (Path(__file__).parent.parent / MODULE_PATH)
rule_class = locate_and_import_class(rule, folder_path) rule_class = locate_and_import_class(rule, module_path)
except AttributeError as e:
e2 = e
if self.custom_modules_path:
try:
rule_class = locate_and_import_class(rule, self.custom_modules_path)
except AttributeError as e:
e3 = e
pass
if (e1 and e2) or e3:
ents = [y for x in [e1, e2, e3] if x is not None for y in x.args[1]]
print('### Error ### Error ### Error ### Error ### Error ###')
print('')
print(f'Class "{rule}" was not found in "{module_path.name}"')
print(f'Class "{rule}" was not found in "{folder_path.name}"')
if self.custom_modules_path:
print(f'Class "{rule}" was not found in "{self.custom_modules_path}"')
print('Possible Entitys are:', str(ents))
print('')
print('Goodbye')
print('')
exit(-99999)
if issubclass(rule_class, class_obj):
rule_kwargs = config.get(rule, {})
rules.append(rule_class(**(rule_kwargs or {})))
return rules
def load_entity_spawn_rules(self, entities) -> List[Rule]:
rules = list()
rules_dicts = list()
for e in entities:
try:
if spawn_rule := e.spawn_rule:
rules_dicts.append(spawn_rule)
except AttributeError:
pass
for rule_dict in rules_dicts:
for rule_name, rule_kwargs in rule_dict.items():
try:
folder_path = (Path(__file__).parent.parent / DEFAULT_PATH)
rule_class = locate_and_import_class(rule_name, folder_path)
except AttributeError: except AttributeError:
rule_class = locate_and_import_class(rule, self.custom_modules_path) try:
# Fixme This check does not work! folder_path = (Path(__file__).parent.parent / MODULE_PATH)
# assert isinstance(rule_class, Rule), f'{rule_class.__name__} is no valid "Rule".' rule_class = locate_and_import_class(rule_name, folder_path)
rule_kwargs = self.rules.get(rule, {}) except AttributeError:
rules_classes.update({rule: {'class': rule_class, 'kwargs': rule_kwargs}}) rule_class = locate_and_import_class(rule_name, self.custom_modules_path)
return rules_classes rules.append(rule_class(**rule_kwargs))
return rules

View File

@ -2,7 +2,7 @@ import importlib
from collections import defaultdict from collections import defaultdict
from pathlib import PurePath, Path from pathlib import PurePath, Path
from typing import Union, Dict, List from typing import Union, Dict, List, Iterable, Callable
import numpy as np import numpy as np
from numpy.typing import ArrayLike from numpy.typing import ArrayLike
@ -222,7 +222,7 @@ def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''):
mod = importlib.import_module('.'.join(module_parts)) mod = importlib.import_module('.'.join(module_parts))
all_found_modules.extend([x for x in dir(mod) if (not(x.startswith('__') or len(x) <= 2) and x.istitle()) all_found_modules.extend([x for x in dir(mod) if (not(x.startswith('__') or len(x) <= 2) and x.istitle())
and x not in ['Entity', 'NamedTuple', 'List', 'Rule', 'Union', and x not in ['Entity', 'NamedTuple', 'List', 'Rule', 'Union',
'TickResult', 'ActionResult', 'Action', 'Agent', 'BoundEntityMixin', 'TickResult', 'ActionResult', 'Action', 'Agent',
'RenderEntity', 'TemplateRule', 'Objects', 'PositionMixin', 'RenderEntity', 'TemplateRule', 'Objects', 'PositionMixin',
'IsBoundMixin', 'EnvObject', 'EnvObjects', 'Dict', 'Any' 'IsBoundMixin', 'EnvObject', 'EnvObjects', 'Dict', 'Any'
]]) ]])
@ -240,7 +240,13 @@ def add_bound_name(name_str, bound_e):
def add_pos_name(name_str, bound_e): def add_pos_name(name_str, bound_e):
if bound_e.var_has_position: if bound_e.var_has_position:
return f'{name_str}({bound_e.pos})' return f'{name_str}@{bound_e.pos}'
return name_str return name_str
def get_first(iterable: Iterable, filter_by: Callable[[any], bool] = lambda _: True):
return next((x for x in iterable if filter_by(x)), None)
def get_first_index(iterable: Iterable, filter_by: Callable[[any], bool] = lambda _: True):
return next((idx for idx, x in enumerate(iterable) if filter_by(x)), None)

View File

@ -47,6 +47,7 @@ class LevelParser(object):
# All other # All other
for es_name in self.e_p_dict: for es_name in self.e_p_dict:
e_class, e_kwargs = self.e_p_dict[es_name]['class'], self.e_p_dict[es_name]['kwargs'] e_class, e_kwargs = self.e_p_dict[es_name]['class'], self.e_p_dict[es_name]['kwargs']
e_kwargs = e_kwargs if e_kwargs else {}
if hasattr(e_class, 'symbol') and e_class.symbol is not None: if hasattr(e_class, 'symbol') and e_class.symbol is not None:
symbols = e_class.symbol symbols = e_class.symbol

View File

@ -1,17 +1,17 @@
import math
import re import re
from collections import defaultdict from collections import defaultdict
from itertools import product
from typing import Dict, List from typing import Dict, List
import numpy as np import numpy as np
from numba import njit
from marl_factory_grid.environment import constants as c from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.entity.object import _Object
from marl_factory_grid.environment.groups.utils import Combined from marl_factory_grid.environment.groups.utils import Combined
import marl_factory_grid.utils.helpers as h
from marl_factory_grid.utils.states import Gamestate
from marl_factory_grid.utils.utility_classes import Floor from marl_factory_grid.utils.utility_classes import Floor
from marl_factory_grid.utils.ray_caster import RayCaster
from marl_factory_grid.utils.states import Gamestate
from marl_factory_grid.utils import helpers as h
class OBSBuilder(object): class OBSBuilder(object):
@ -77,11 +77,13 @@ class OBSBuilder(object):
def place_entity_in_observation(self, obs_array, agent, e): def place_entity_in_observation(self, obs_array, agent, e):
x, y = (e.x - agent.x) + self.pomdp_r, (e.y - agent.y) + self.pomdp_r x, y = (e.x - agent.x) + self.pomdp_r, (e.y - agent.y) + self.pomdp_r
try: if not min([y, x]) < 0:
obs_array[x, y] += e.encoding try:
except IndexError: obs_array[x, y] += e.encoding
# Seemded to be visible but is out of range except IndexError:
pass # Seemded to be visible but is out of range
pass
pass
def build_for_agent(self, agent, state) -> (List[str], np.ndarray): def build_for_agent(self, agent, state) -> (List[str], np.ndarray):
assert self._curr_env_step == state.curr_step, ( assert self._curr_env_step == state.curr_step, (
@ -121,18 +123,24 @@ class OBSBuilder(object):
e = self.all_obs[l_name] e = self.all_obs[l_name]
except KeyError: except KeyError:
try: try:
# Look for bound entity names! # Look for bound entity REPRs!
pattern = re.compile(f'{re.escape(l_name)}(.*){re.escape(agent.name)}') pattern = re.compile(f'{re.escape(l_name)}'
name = next((x for x in self.all_obs if pattern.search(x)), None) f'{re.escape("[")}(.*){re.escape("]")}'
f'{re.escape("(")}{re.escape(agent.name)}{re.escape(")")}')
name = next((key for key, val in self.all_obs.items()
if pattern.search(str(val)) and isinstance(val, _Object)), None)
e = self.all_obs[name] e = self.all_obs[name]
except KeyError: except KeyError:
try: try:
e = next(v for k, v in self.all_obs.items() if l_name in k and agent.name in k) e = next(v for k, v in self.all_obs.items() if l_name in k and agent.name in k)
except StopIteration: except StopIteration:
raise KeyError( print(f'# Check for spelling errors!')
f'Check for spelling errors! \n ' print(f'# No combination of "{l_name}" and "{agent.name}" could not be found in:')
f'No combination of "{l_name} and {agent.name}" could not be found in:\n ' print(f'# {list(dict(self.all_obs).keys())}')
f'{list(dict(self.all_obs).keys())}') print('#')
print('# exiting...')
print('#')
exit(-99999)
try: try:
positional = e.var_has_position positional = e.var_has_position
@ -161,15 +169,14 @@ class OBSBuilder(object):
try: try:
light_map = np.zeros(self.obs_shape) light_map = np.zeros(self.obs_shape)
visible_floor = self.ray_caster[agent.name].visible_entities(self._floortiles, reset_cache=False) visible_floor = self.ray_caster[agent.name].visible_entities(self._floortiles, reset_cache=False)
if self.pomdp_r:
for f in set(visible_floor): for f in set(visible_floor):
self.place_entity_in_observation(light_map, agent, f) self.place_entity_in_observation(light_map, agent, f)
else: # else:
for f in set(visible_floor): # for f in set(visible_floor):
light_map[f.x, f.y] += f.encoding # light_map[f.x, f.y] += f.encoding
self.curr_lightmaps[agent.name] = light_map self.curr_lightmaps[agent.name] = light_map
except (KeyError, ValueError): except (KeyError, ValueError):
print()
pass pass
return obs, self.obs_layers[agent.name] return obs, self.obs_layers[agent.name]
@ -185,7 +192,7 @@ class OBSBuilder(object):
for obs_str in agent.observations: for obs_str in agent.observations:
if isinstance(obs_str, dict): if isinstance(obs_str, dict):
obs_str, vals = next(obs_str.items().__iter__()) obs_str, vals = h.get_first(obs_str.items())
else: else:
vals = None vals = None
if obs_str == c.SELF: if obs_str == c.SELF:
@ -214,129 +221,3 @@ class OBSBuilder(object):
obs_layers.append(obs_str) obs_layers.append(obs_str)
self.obs_layers[agent.name] = obs_layers self.obs_layers[agent.name] = obs_layers
self.curr_lightmaps[agent.name] = np.zeros(self.obs_shape) self.curr_lightmaps[agent.name] = np.zeros(self.obs_shape)
class RayCaster:
def __init__(self, agent, pomdp_r, degs=360):
self.agent = agent
self.pomdp_r = pomdp_r
self.n_rays = (self.pomdp_r + 1) * 8
self.degs = degs
self.ray_targets = self.build_ray_targets()
self.obs_shape_cube = np.array([self.pomdp_r, self.pomdp_r])
self._cache_dict = {}
def __repr__(self):
return f'{self.__class__.__name__}({self.agent.name})'
def build_ray_targets(self):
north = np.array([0, -1]) * self.pomdp_r
thetas = [np.deg2rad(deg) for deg in np.linspace(-self.degs // 2, self.degs // 2, self.n_rays)[::-1]]
rot_M = [
[[math.cos(theta), -math.sin(theta)],
[math.sin(theta), math.cos(theta)]] for theta in thetas
]
rot_M = np.stack(rot_M, 0)
rot_M = np.unique(np.round(rot_M @ north), axis=0)
return rot_M.astype(int)
def ray_block_cache(self, key, callback):
if key not in self._cache_dict:
self._cache_dict[key] = callback()
return self._cache_dict[key]
def visible_entities(self, pos_dict, reset_cache=True):
visible = list()
if reset_cache:
self._cache_dict = {}
for ray in self.get_rays():
rx, ry = ray[0]
for x, y in ray:
cx, cy = x - rx, y - ry
entities_hit = pos_dict[(x, y)]
hits = self.ray_block_cache((x, y),
lambda: any(True for e in entities_hit if e.var_is_blocking_light)
)
diag_hits = all([
self.ray_block_cache(
key,
lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light) and bool(
pos_dict[key]))
for key in ((x, y - cy), (x - cx, y))
]) if (cx != 0 and cy != 0) else False
visible += entities_hit if not diag_hits else []
if hits or diag_hits:
break
rx, ry = x, y
return visible
def get_rays(self):
a_pos = self.agent.pos
outline = self.ray_targets + a_pos
return self.bresenham_loop(a_pos, outline)
# todo do this once and cache the points!
def get_fov_outline(self) -> np.ndarray:
return self.ray_targets + self.agent.pos
def get_square_outline(self):
agent = self.agent
x_coords = range(agent.x - self.pomdp_r, agent.x + self.pomdp_r + 1)
y_coords = range(agent.y - self.pomdp_r, agent.y + self.pomdp_r + 1)
outline = list(product(x_coords, [agent.y - self.pomdp_r, agent.y + self.pomdp_r])) \
+ list(product([agent.x - self.pomdp_r, agent.x + self.pomdp_r], y_coords))
return outline
@staticmethod
@njit
def bresenham_loop(a_pos, points):
results = []
for end in points:
x1, y1 = a_pos
x2, y2 = end
dx = x2 - x1
dy = y2 - y1
# Determine how steep the line is
is_steep = abs(dy) > abs(dx)
# Rotate line
if is_steep:
x1, y1 = y1, x1
x2, y2 = y2, x2
# Swap start and end points if necessary and store swap state
swapped = False
if x1 > x2:
x1, x2 = x2, x1
y1, y2 = y2, y1
swapped = True
# Recalculate differentials
dx = x2 - x1
dy = y2 - y1
# Calculate error
error = int(dx / 2.0)
ystep = 1 if y1 < y2 else -1
# Iterate over bounding box generating points between start and end
y = y1
points = []
for x in range(int(x1), int(x2) + 1):
coord = [y, x] if is_steep else [x, y]
points.append(coord)
error -= abs(dy)
if error < 0:
y += ystep
error += dx
# Reverse the list if the coordinates were swapped
if swapped:
points.reverse()
results.append(points)
return results

View File

@ -39,8 +39,9 @@ class RayCaster:
if reset_cache: if reset_cache:
self._cache_dict = dict() self._cache_dict = dict()
for ray in self.get_rays(): for ray in self.get_rays(): # Do not check, just trust.
rx, ry = ray[0] rx, ry = ray[0]
# self.ray_block_cache(ray[0], lambda: False) We do not do that, because of doors etc...
for x, y in ray: for x, y in ray:
cx, cy = x - rx, y - ry cx, cy = x - rx, y - ry
@ -52,7 +53,8 @@ class RayCaster:
diag_hits = all([ diag_hits = all([
self.ray_block_cache( self.ray_block_cache(
key, key,
lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light)) lambda: any(True for e in pos_dict[key] if e.var_is_blocking_light))
# lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light))
for key in ((x, y-cy), (x-cx, y)) for key in ((x, y-cy), (x-cx, y))
]) if (cx != 0 and cy != 0) else False ]) if (cx != 0 and cy != 0) else False

View File

@ -31,7 +31,7 @@ class Renderer:
def __init__(self, lvl_shape: Tuple[int, int] = (16, 16), def __init__(self, lvl_shape: Tuple[int, int] = (16, 16),
lvl_padded_shape: Union[Tuple[int, int], None] = None, lvl_padded_shape: Union[Tuple[int, int], None] = None,
cell_size: int = 40, fps: int = 7, cell_size: int = 40, fps: int = 7, factor: float = 0.9,
grid_lines: bool = True, view_radius: int = 2): grid_lines: bool = True, view_radius: int = 2):
# TODO: Customn_assets paths # TODO: Customn_assets paths
self.grid_h, self.grid_w = lvl_shape self.grid_h, self.grid_w = lvl_shape
@ -45,7 +45,7 @@ class Renderer:
self.screen = pygame.display.set_mode(self.screen_size) self.screen = pygame.display.set_mode(self.screen_size)
self.clock = pygame.time.Clock() self.clock = pygame.time.Clock()
assets = list(self.ASSETS.rglob('*.png')) assets = list(self.ASSETS.rglob('*.png'))
self.assets = {path.stem: self.load_asset(str(path), 1) for path in assets} self.assets = {path.stem: self.load_asset(str(path), factor) for path in assets}
self.fill_bg() self.fill_bg()
now = time.time() now = time.time()
@ -110,22 +110,22 @@ class Renderer:
pygame.quit() pygame.quit()
sys.exit() sys.exit()
self.fill_bg() self.fill_bg()
blits = deque() # First all others
for entity in [x for x in entities]: blits = deque(self.blit_params(x) for x in entities if not x.name.lower() == AGENT)
bp = self.blit_params(entity) # Then Agents, so that agents are rendered on top.
blits.append(bp) for agent in (x for x in entities if x.name.lower() == AGENT):
if entity.name.lower() == AGENT: agent_blit = self.blit_params(agent)
if self.view_radius > 0: if self.view_radius > 0:
vis_rects = self.visibility_rects(bp, entity.aux) vis_rects = self.visibility_rects(agent_blit, agent.aux)
blits.extendleft(vis_rects) blits.extendleft(vis_rects)
if entity.state != BLANK: if agent.state != BLANK:
agent_state_blits = self.blit_params( state_blit = self.blit_params(
RenderEntity(entity.state, (entity.pos[0] + 0.12, entity.pos[1]), 0.48, SCALE) RenderEntity(agent.state, (agent.pos[0] + 0.12, agent.pos[1]), 0.48, SCALE)
) )
textsurface = self.font.render(str(entity.id), False, (0, 0, 0)) textsurface = self.font.render(str(agent.id), False, (0, 0, 0))
text_blit = dict(source=textsurface, dest=(bp['dest'].center[0]-.07*self.cell_size, text_blit = dict(source=textsurface, dest=(agent_blit['dest'].center[0]-.07*self.cell_size,
bp['dest'].center[1])) agent_blit['dest'].center[1]))
blits += [agent_state_blits, text_blit] blits += [agent_blit, state_blit, text_blit]
for blit in blits: for blit in blits:
self.screen.blit(**blit) self.screen.blit(**blit)

View File

@ -28,7 +28,10 @@ class Result:
def __repr__(self): def __repr__(self):
valid = "not " if not self.validity else "" valid = "not " if not self.validity else ""
return f'{self.__class__.__name__}({self.identifier.capitalize()} {valid}valid: {self.reward})' reward = f" | Reward: {self.reward}" if self.reward is not None else ""
value = f" | Value: {self.value}" if self.value is not None else ""
entity = f" | by: {self.entity.name}" if self.entity is not None else ""
return f'{self.__class__.__name__}({self.identifier.capitalize()} {valid}valid{reward}{value})'
@dataclass @dataclass

View File

@ -1,3 +1,4 @@
from itertools import islice
from typing import List, Dict, Tuple from typing import List, Dict, Tuple
import numpy as np import numpy as np
@ -59,14 +60,15 @@ class Gamestate(object):
def moving_entites(self): def moving_entites(self):
return [y for x in self.entities for y in x if x.var_can_move] return [y for x in self.entities for y in x if x.var_can_move]
def __init__(self, entities, agents_conf, rules: Dict[str, dict], env_seed=69, verbose=False): def __init__(self, entities, agents_conf, rules: List[Rule], lvl_shape, env_seed=69, verbose=False):
self.lvl_shape = lvl_shape
self.entities = entities self.entities = entities
self.curr_step = 0 self.curr_step = 0
self.curr_actions = None self.curr_actions = None
self.agents_conf = agents_conf self.agents_conf = agents_conf
self.verbose = verbose self.verbose = verbose
self.rng = np.random.default_rng(env_seed) self.rng = np.random.default_rng(env_seed)
self.rules = StepRules(*(v['class'](**v['kwargs']) for v in rules.values())) self.rules = StepRules(*rules)
def __getitem__(self, item): def __getitem__(self, item):
return self.entities[item] return self.entities[item]
@ -80,6 +82,13 @@ class Gamestate(object):
def __repr__(self): def __repr__(self):
return f'{self.__class__.__name__}({len(self.entities)} Entitites @ Step {self.curr_step})' return f'{self.__class__.__name__}({len(self.entities)} Entitites @ Step {self.curr_step})'
@property
def random_free_position(self):
return self.get_n_random_free_positions(1)[0]
def get_n_random_free_positions(self, n):
return list(islice(self.entities.free_positions_generator, n))
def tick(self, actions) -> List[Result]: def tick(self, actions) -> List[Result]:
results = list() results = list()
self.curr_step += 1 self.curr_step += 1
@ -115,8 +124,7 @@ class Gamestate(object):
return results return results
def get_all_pos_with_collisions(self) -> List[Tuple[(int, int)]]: def get_all_pos_with_collisions(self) -> List[Tuple[(int, int)]]:
positions = [pos for pos, entity_list_for_position in self.entities.pos_dict.items() positions = [pos for pos, entities in self.entities.pos_dict.items() if len(entities) >= 2 and (len([e for e in entities if e.var_can_collide]) >= 2)]
if any([e.var_can_collide for e in entity_list_for_position])]
return positions return positions
def check_move_validity(self, moving_entity, position): def check_move_validity(self, moving_entity, position):

View File

@ -135,4 +135,3 @@ if __name__ == '__main__':
ce.get_observations() ce.get_observations()
ce.get_assets() ce.get_assets()
all_conf = ce.get_all() all_conf = ce.get_all()
print()

View File

@ -52,3 +52,6 @@ class Floor:
def __hash__(self): def __hash__(self):
return hash(self.name) return hash(self.name)
def __repr__(self):
return f"Floor{self.pos}"

View File

@ -6,6 +6,7 @@ import yaml
from marl_factory_grid.environment.factory import Factory from marl_factory_grid.environment.factory import Factory
from marl_factory_grid.utils.logging.envmonitor import EnvMonitor from marl_factory_grid.utils.logging.envmonitor import EnvMonitor
from marl_factory_grid.utils.logging.recorder import EnvRecorder from marl_factory_grid.utils.logging.recorder import EnvRecorder
from marl_factory_grid.utils import helpers as h
from marl_factory_grid.modules.doors import constants as d from marl_factory_grid.modules.doors import constants as d
@ -61,7 +62,7 @@ if __name__ == '__main__':
if render: if render:
env.render() env.render()
try: try:
door = next(x for x in env.unwrapped.unwrapped[d.DOORS] if x.is_open) door = h.get_first([x for x in env.unwrapped.unwrapped[d.DOORS] if x.is_open])
print('openDoor found') print('openDoor found')
except StopIteration: except StopIteration:
pass pass

View File

@ -1,8 +1,8 @@
from algorithms.utils import Checkpointer from algorithms.utils import Checkpointer
from pathlib import Path from pathlib import Path
from algorithms.utils import load_yaml_file, add_env_props, instantiate_class, load_class from algorithms.utils import load_yaml_file, add_env_props, instantiate_class, load_class
#from algorithms.marl import LoopSNAC, LoopIAC, LoopSEAC
# from algorithms.marl import LoopSNAC, LoopIAC, LoopSEAC
for i in range(0, 5): for i in range(0, 5):

View File

@ -0,0 +1,43 @@
import configparser
import json
from datetime import datetime
from pathlib import Path
if __name__ == '__main__':
conf_path = Path('wg0')
wg0_conf = configparser.ConfigParser()
wg0_conf.read(conf_path/'wg0.conf')
interface = wg0_conf['Interface']
# Iterate all pears
for client_name in wg0_conf.sections():
if client_name == 'Interface':
continue
# Delete any old conf.json for the current peer
(conf_path / f'{client_name}.json').unlink(missing_ok=True)
peer = wg0_conf[client_name]
date_time = datetime.now().strftime('%Y-%m-%dT%H:%M:%S.%f000Z')
jdict = dict(
id=client_name,
private_key=peer['PublicKey'],
public_key=peer['PublicKey'],
# preshared_key=wg0_conf[client_name_wg0]['PresharedKey'],
name=client_name,
email=f"sysadmin@mobile.ifi.lmu.de",
allocated_ips=[interface['Address'].replace('/24', '')],
allowed_ips=['10.4.0.0/24', '10.153.199.0/24'],
extra_allowed_ips=[],
use_server_dns=True,
enabled=True,
created_at=date_time,
updated_at=date_time
)
with (conf_path / f'{client_name}.json').open('w+') as f:
json.dump(jdict, f, indent='\t', separators=(',', ': '))
print(client_name, ' written...')