mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-22 14:56:43 +02:00
new rules, new spawn logic, small fixes, default and narrow corridor debugged
This commit is contained in:
parent
9b9c6e0385
commit
06a5130b25
@ -56,7 +56,7 @@ Just define what your environment needs in a *yaml*-configfile like:
|
|||||||
- Items
|
- Items
|
||||||
Rules:
|
Rules:
|
||||||
Defaults: {}
|
Defaults: {}
|
||||||
Collision:
|
WatchCollisions:
|
||||||
done_at_collisions: !!bool True
|
done_at_collisions: !!bool True
|
||||||
ItemRespawn:
|
ItemRespawn:
|
||||||
spawn_freq: 5
|
spawn_freq: 5
|
||||||
|
@ -1,6 +1 @@
|
|||||||
from .environment import *
|
from .quickstart import init
|
||||||
from .modules import *
|
|
||||||
from .utils import *
|
|
||||||
|
|
||||||
from .quickstart import init
|
|
||||||
|
|
@ -30,4 +30,3 @@ class TSPTargetAgent(TSPBaseAgent):
|
|||||||
except (StopIteration, UnboundLocalError):
|
except (StopIteration, UnboundLocalError):
|
||||||
print('Will not happen')
|
print('Will not happen')
|
||||||
return action_obj
|
return action_obj
|
||||||
|
|
||||||
|
@ -26,12 +26,16 @@ def points_to_graph(coordiniates, allow_euclidean_connections=True, allow_manhat
|
|||||||
assert allow_euclidean_connections or allow_manhattan_connections
|
assert allow_euclidean_connections or allow_manhattan_connections
|
||||||
possible_connections = itertools.combinations(coordiniates, 2)
|
possible_connections = itertools.combinations(coordiniates, 2)
|
||||||
graph = nx.Graph()
|
graph = nx.Graph()
|
||||||
for a, b in possible_connections:
|
if allow_manhattan_connections and allow_euclidean_connections:
|
||||||
diff = np.linalg.norm(np.asarray(a)-np.asarray(b))
|
graph.add_edges_from(
|
||||||
if allow_manhattan_connections and allow_euclidean_connections and diff <= np.sqrt(2):
|
(a, b) for a, b in possible_connections if np.linalg.norm(np.asarray(a) - np.asarray(b)) <= np.sqrt(2)
|
||||||
graph.add_edge(a, b)
|
)
|
||||||
elif not allow_manhattan_connections and allow_euclidean_connections and diff == np.sqrt(2):
|
elif not allow_manhattan_connections and allow_euclidean_connections:
|
||||||
graph.add_edge(a, b)
|
graph.add_edges_from(
|
||||||
elif allow_manhattan_connections and not allow_euclidean_connections and diff == 1:
|
(a, b) for a, b in possible_connections if np.linalg.norm(np.asarray(a) - np.asarray(b)) == np.sqrt(2)
|
||||||
graph.add_edge(a, b)
|
)
|
||||||
|
elif allow_manhattan_connections and not allow_euclidean_connections:
|
||||||
|
graph.add_edges_from(
|
||||||
|
(a, b) for a, b in possible_connections if np.linalg.norm(np.asarray(a) - np.asarray(b)) == 1
|
||||||
|
)
|
||||||
return graph
|
return graph
|
||||||
|
@ -22,26 +22,41 @@ Agents:
|
|||||||
- Inventory
|
- Inventory
|
||||||
- DropOffLocations
|
- DropOffLocations
|
||||||
- Maintainers
|
- Maintainers
|
||||||
|
# This is special for agents, as each one is differten and can act as an adversary e.g.
|
||||||
|
Positions:
|
||||||
|
- (16, 7)
|
||||||
|
- (16, 6)
|
||||||
|
- (16, 3)
|
||||||
|
- (16, 4)
|
||||||
|
- (16, 5)
|
||||||
Entities:
|
Entities:
|
||||||
Batteries:
|
Batteries:
|
||||||
initial_charge: 0.8
|
initial_charge: 0.8
|
||||||
per_action_costs: 0.02
|
per_action_costs: 0.02
|
||||||
ChargePods: {}
|
ChargePods:
|
||||||
Destinations: {}
|
coords_or_quantity: 2
|
||||||
|
Destinations:
|
||||||
|
coords_or_quantity: 1
|
||||||
|
spawn_mode: GROUPED
|
||||||
DirtPiles:
|
DirtPiles:
|
||||||
|
coords_or_quantity: 10
|
||||||
|
initial_amount: 2
|
||||||
clean_amount: 1
|
clean_amount: 1
|
||||||
dirt_spawn_r_var: 0.1
|
dirt_spawn_r_var: 0.1
|
||||||
initial_amount: 2
|
|
||||||
initial_dirt_ratio: 0.05
|
|
||||||
max_global_amount: 20
|
max_global_amount: 20
|
||||||
max_local_amount: 5
|
max_local_amount: 5
|
||||||
Doors: {}
|
Doors:
|
||||||
DropOffLocations: {}
|
DropOffLocations:
|
||||||
|
coords_or_quantity: 1
|
||||||
|
max_dropoff_storage_size: 0
|
||||||
GlobalPositions: {}
|
GlobalPositions: {}
|
||||||
Inventories: {}
|
Inventories: {}
|
||||||
Items: {}
|
Items:
|
||||||
Machines: {}
|
coords_or_quantity: 5
|
||||||
Maintainers: {}
|
Machines:
|
||||||
|
coords_or_quantity: 2
|
||||||
|
Maintainers:
|
||||||
|
coords_or_quantity: 1
|
||||||
Zones: {}
|
Zones: {}
|
||||||
|
|
||||||
General:
|
General:
|
||||||
@ -49,32 +64,31 @@ General:
|
|||||||
individual_rewards: true
|
individual_rewards: true
|
||||||
level_name: large
|
level_name: large
|
||||||
pomdp_r: 3
|
pomdp_r: 3
|
||||||
verbose: false
|
verbose: True
|
||||||
|
tests: false
|
||||||
|
|
||||||
Rules:
|
Rules:
|
||||||
SpawnAgents: {}
|
# Environment Dynamics
|
||||||
DoneAtBatteryDischarge: {}
|
|
||||||
Collision:
|
|
||||||
done_at_collisions: false
|
|
||||||
AssignGlobalPositions: {}
|
|
||||||
DoneAtDestinationReachAny: {}
|
|
||||||
DestinationReachReward: {}
|
|
||||||
SpawnDestinations:
|
|
||||||
n_dests: 1
|
|
||||||
spawn_mode: GROUPED
|
|
||||||
DoneOnAllDirtCleaned: {}
|
|
||||||
SpawnDirt:
|
|
||||||
spawn_freq: 15
|
|
||||||
EntitiesSmearDirtOnMove:
|
EntitiesSmearDirtOnMove:
|
||||||
smear_ratio: 0.2
|
smear_ratio: 0.2
|
||||||
DoorAutoClose:
|
DoorAutoClose:
|
||||||
close_frequency: 10
|
close_frequency: 10
|
||||||
ItemRules:
|
MoveMaintainers:
|
||||||
max_dropoff_storage_size: 0
|
|
||||||
n_items: 5
|
# Respawn Stuff
|
||||||
n_locations: 5
|
RespawnDirt:
|
||||||
spawn_frequency: 15
|
respawn_freq: 15
|
||||||
MaxStepsReached:
|
RespawnItems:
|
||||||
|
respawn_freq: 15
|
||||||
|
|
||||||
|
# Utilities
|
||||||
|
WatchCollisions:
|
||||||
|
done_at_collisions: false
|
||||||
|
|
||||||
|
# Done Conditions
|
||||||
|
DoneAtDestinationReachAny:
|
||||||
|
DoneOnAllDirtCleaned:
|
||||||
|
DoneAtBatteryDischarge:
|
||||||
|
DoneAtMaintainerCollision:
|
||||||
|
DoneAtMaxStepsReached:
|
||||||
max_steps: 500
|
max_steps: 500
|
||||||
# AgentSingleZonePlacement:
|
|
||||||
# n_zones: 4
|
|
||||||
|
@ -1,3 +1,10 @@
|
|||||||
|
General:
|
||||||
|
env_seed: 69
|
||||||
|
individual_rewards: true
|
||||||
|
level_name: narrow_corridor
|
||||||
|
pomdp_r: 0
|
||||||
|
verbose: true
|
||||||
|
|
||||||
Agents:
|
Agents:
|
||||||
Wolfgang:
|
Wolfgang:
|
||||||
Actions:
|
Actions:
|
||||||
@ -10,6 +17,7 @@ Agents:
|
|||||||
Positions:
|
Positions:
|
||||||
- (2, 1)
|
- (2, 1)
|
||||||
- (2, 5)
|
- (2, 5)
|
||||||
|
is_blocking_pos: true
|
||||||
Karl-Heinz:
|
Karl-Heinz:
|
||||||
Actions:
|
Actions:
|
||||||
- Noop
|
- Noop
|
||||||
@ -21,26 +29,30 @@ Agents:
|
|||||||
Positions:
|
Positions:
|
||||||
- (2, 1)
|
- (2, 1)
|
||||||
- (2, 5)
|
- (2, 5)
|
||||||
Entities:
|
is_blocking_pos: true
|
||||||
Destinations: {}
|
|
||||||
|
|
||||||
General:
|
Entities:
|
||||||
env_seed: 69
|
Destinations:
|
||||||
individual_rewards: true
|
ignore_blocking: true
|
||||||
level_name: narrow_corridor
|
spawnrule:
|
||||||
pomdp_r: 0
|
SpawnDestinationsPerAgent:
|
||||||
verbose: true
|
coords_or_quantity:
|
||||||
|
Wolfgang:
|
||||||
|
- (2, 1)
|
||||||
|
- (2, 5)
|
||||||
|
Karl-Heinz:
|
||||||
|
- (2, 1)
|
||||||
|
- (2, 5)
|
||||||
|
# Whether you want to provide a numeric Position observation.
|
||||||
|
# GlobalPositions:
|
||||||
|
# normalized: false
|
||||||
|
|
||||||
Rules:
|
Rules:
|
||||||
SpawnAgents: {}
|
# Utilities
|
||||||
Collision:
|
WatchCollisions:
|
||||||
done_at_collisions: false
|
done_at_collisions: false
|
||||||
FixedDestinationSpawn:
|
# Done Conditions
|
||||||
per_agent_positions:
|
# DoneAtDestinationReachAny:
|
||||||
Wolfgang:
|
DoneAtDestinationReachAll:
|
||||||
- (2, 1)
|
DoneAtMaxStepsReached:
|
||||||
- (2, 5)
|
max_steps: 500
|
||||||
Karl-Heinz:
|
|
||||||
- (2, 1)
|
|
||||||
- (2, 5)
|
|
||||||
DestinationReachAll: {}
|
|
||||||
|
@ -48,9 +48,9 @@ class Move(Action, abc.ABC):
|
|||||||
reward = r.MOVEMENTS_VALID if move_validity else r.MOVEMENTS_FAIL
|
reward = r.MOVEMENTS_VALID if move_validity else r.MOVEMENTS_FAIL
|
||||||
return ActionResult(entity=entity, identifier=self._identifier, validity=move_validity, reward=reward)
|
return ActionResult(entity=entity, identifier=self._identifier, validity=move_validity, reward=reward)
|
||||||
else: # There is no place to go, propably collision
|
else: # There is no place to go, propably collision
|
||||||
# This is currently handeld by the Collision rule, so that it can be switched on and off by conf.yml
|
# This is currently handeld by the WatchCollisions rule, so that it can be switched on and off by conf.yml
|
||||||
# return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=r.COLLISION)
|
# return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=r.COLLISION)
|
||||||
return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=0)
|
return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID)
|
||||||
|
|
||||||
def _calc_new_pos(self, pos):
|
def _calc_new_pos(self, pos):
|
||||||
x_diff, y_diff = MOVEMAP[self._identifier]
|
x_diff, y_diff = MOVEMAP[self._identifier]
|
||||||
|
@ -10,6 +10,7 @@ AGENT = 'Agent' # Identifier of Agent-objects an
|
|||||||
OTHERS = 'Other'
|
OTHERS = 'Other'
|
||||||
COMBINED = 'Combined'
|
COMBINED = 'Combined'
|
||||||
GLOBALPOSITIONS = 'GlobalPositions' # Identifier of the global position slice
|
GLOBALPOSITIONS = 'GlobalPositions' # Identifier of the global position slice
|
||||||
|
SPAWN_ENTITY_RULE = 'SpawnEntity'
|
||||||
|
|
||||||
# Attributes
|
# Attributes
|
||||||
IS_BLOCKING_LIGHT = 'var_is_blocking_light'
|
IS_BLOCKING_LIGHT = 'var_is_blocking_light'
|
||||||
@ -29,7 +30,7 @@ VALUE_NO_POS = (-9999, -9999) # Invalid Position value used in the e
|
|||||||
|
|
||||||
|
|
||||||
ACTION = 'action' # Identifier of Action-objects and groups (groups).
|
ACTION = 'action' # Identifier of Action-objects and groups (groups).
|
||||||
COLLISION = 'Collision' # Identifier to use in the context of collitions.
|
COLLISION = 'Collisions' # Identifier to use in the context of collitions.
|
||||||
# LAST_POS = 'LAST_POS' # Identifiert for retrieving an enitites last pos.
|
# LAST_POS = 'LAST_POS' # Identifiert for retrieving an enitites last pos.
|
||||||
VALIDITY = 'VALIDITY' # Identifiert for retrieving the Validity of Action, Tick, etc. ...
|
VALIDITY = 'VALIDITY' # Identifiert for retrieving the Validity of Action, Tick, etc. ...
|
||||||
|
|
||||||
@ -54,3 +55,5 @@ NOOP = 'Noop'
|
|||||||
# Result Identifier
|
# Result Identifier
|
||||||
MOVEMENTS_VALID = 'motion_valid'
|
MOVEMENTS_VALID = 'motion_valid'
|
||||||
MOVEMENTS_FAIL = 'motion_not_valid'
|
MOVEMENTS_FAIL = 'motion_not_valid'
|
||||||
|
DEFAULT_PATH = 'environment'
|
||||||
|
MODULE_PATH = 'modules'
|
||||||
|
@ -12,14 +12,6 @@ from marl_factory_grid.environment import constants as c
|
|||||||
|
|
||||||
class Agent(Entity):
|
class Agent(Entity):
|
||||||
|
|
||||||
@property
|
|
||||||
def var_is_blocking_light(self):
|
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
|
||||||
def var_can_move(self):
|
|
||||||
return True
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def var_is_paralyzed(self):
|
def var_is_paralyzed(self):
|
||||||
return len(self._paralyzed)
|
return len(self._paralyzed)
|
||||||
@ -28,14 +20,6 @@ class Agent(Entity):
|
|||||||
def paralyze_reasons(self):
|
def paralyze_reasons(self):
|
||||||
return [x for x in self._paralyzed]
|
return [x for x in self._paralyzed]
|
||||||
|
|
||||||
@property
|
|
||||||
def var_is_blocking_pos(self):
|
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
|
||||||
def var_has_position(self):
|
|
||||||
return True
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def obs_tag(self):
|
def obs_tag(self):
|
||||||
return self.name
|
return self.name
|
||||||
@ -48,10 +32,6 @@ class Agent(Entity):
|
|||||||
def observations(self):
|
def observations(self):
|
||||||
return self._observations
|
return self._observations
|
||||||
|
|
||||||
@property
|
|
||||||
def var_can_collide(self):
|
|
||||||
return True
|
|
||||||
|
|
||||||
def step_result(self):
|
def step_result(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -60,16 +40,21 @@ class Agent(Entity):
|
|||||||
return self._collection
|
return self._collection
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def state(self):
|
def var_is_blocking_pos(self):
|
||||||
return self._state or ActionResult(entity=self, identifier=c.NOOP, validity=c.VALID, reward=0)
|
return self._is_blocking_pos
|
||||||
|
|
||||||
def __init__(self, actions: List[Action], observations: List[str], *args, **kwargs):
|
@property
|
||||||
|
def state(self):
|
||||||
|
return self._state or ActionResult(entity=self, identifier=c.NOOP, validity=c.VALID)
|
||||||
|
|
||||||
|
def __init__(self, actions: List[Action], observations: List[str], *args, is_blocking_pos=False, **kwargs):
|
||||||
super(Agent, self).__init__(*args, **kwargs)
|
super(Agent, self).__init__(*args, **kwargs)
|
||||||
self._paralyzed = set()
|
self._paralyzed = set()
|
||||||
self.step_result = dict()
|
self.step_result = dict()
|
||||||
self._actions = actions
|
self._actions = actions
|
||||||
self._observations = observations
|
self._observations = observations
|
||||||
self._state: Union[Result, None] = None
|
self._state: Union[Result, None] = None
|
||||||
|
self._is_blocking_pos = is_blocking_pos
|
||||||
|
|
||||||
# noinspection PyAttributeOutsideInit
|
# noinspection PyAttributeOutsideInit
|
||||||
def clear_temp_state(self):
|
def clear_temp_state(self):
|
||||||
|
@ -14,7 +14,7 @@ class Entity(_Object, abc.ABC):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def state(self):
|
def state(self):
|
||||||
return self._status or ActionResult(entity=self, identifier=c.NOOP, validity=c.VALID, reward=0)
|
return self._status or ActionResult(entity=self, identifier=c.NOOP, validity=c.VALID)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def var_has_position(self):
|
def var_has_position(self):
|
||||||
@ -60,6 +60,10 @@ class Entity(_Object, abc.ABC):
|
|||||||
def pos(self):
|
def pos(self):
|
||||||
return self._pos
|
return self._pos
|
||||||
|
|
||||||
|
def set_pos(self, pos):
|
||||||
|
assert isinstance(pos, tuple) and len(pos) == 2
|
||||||
|
self._pos = pos
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def last_pos(self):
|
def last_pos(self):
|
||||||
try:
|
try:
|
||||||
@ -84,7 +88,7 @@ class Entity(_Object, abc.ABC):
|
|||||||
for observer in self.observers:
|
for observer in self.observers:
|
||||||
observer.notify_del_entity(self)
|
observer.notify_del_entity(self)
|
||||||
self._view_directory = curr_pos[0] - next_pos[0], curr_pos[1] - next_pos[1]
|
self._view_directory = curr_pos[0] - next_pos[0], curr_pos[1] - next_pos[1]
|
||||||
self._pos = next_pos
|
self.set_pos(next_pos)
|
||||||
for observer in self.observers:
|
for observer in self.observers:
|
||||||
observer.notify_add_entity(self)
|
observer.notify_add_entity(self)
|
||||||
return valid
|
return valid
|
||||||
@ -93,7 +97,7 @@ class Entity(_Object, abc.ABC):
|
|||||||
def __init__(self, pos, bind_to=None, **kwargs):
|
def __init__(self, pos, bind_to=None, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self._status = None
|
self._status = None
|
||||||
self._pos = pos
|
self.set_pos(pos)
|
||||||
self._last_pos = pos
|
self._last_pos = pos
|
||||||
if bind_to:
|
if bind_to:
|
||||||
try:
|
try:
|
||||||
@ -109,8 +113,9 @@ class Entity(_Object, abc.ABC):
|
|||||||
def render(self):
|
def render(self):
|
||||||
return RenderEntity(self.__class__.__name__.lower(), self.pos)
|
return RenderEntity(self.__class__.__name__.lower(), self.pos)
|
||||||
|
|
||||||
def __repr__(self):
|
@abc.abstractmethod
|
||||||
return super(Entity, self).__repr__() + f'(@{self.pos})'
|
def render(self):
|
||||||
|
return RenderEntity(self.__class__.__name__.lower(), self.pos)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def obs_tag(self):
|
def obs_tag(self):
|
||||||
@ -149,4 +154,4 @@ class Entity(_Object, abc.ABC):
|
|||||||
except StopIteration:
|
except StopIteration:
|
||||||
pass
|
pass
|
||||||
except ValueError:
|
except ValueError:
|
||||||
print()
|
pass
|
||||||
|
@ -1,24 +0,0 @@
|
|||||||
|
|
||||||
|
|
||||||
# noinspection PyAttributeOutsideInit
|
|
||||||
class BoundEntityMixin:
|
|
||||||
|
|
||||||
@property
|
|
||||||
def bound_entity(self):
|
|
||||||
return self._bound_entity
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self):
|
|
||||||
if self.bound_entity:
|
|
||||||
return f'{self.__class__.__name__}({self.bound_entity.name})'
|
|
||||||
else:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def belongs_to_entity(self, entity):
|
|
||||||
return entity == self.bound_entity
|
|
||||||
|
|
||||||
def bind_to(self, entity):
|
|
||||||
self._bound_entity = entity
|
|
||||||
|
|
||||||
def unbind(self):
|
|
||||||
self._bound_entity = None
|
|
@ -13,10 +13,6 @@ class _Object:
|
|||||||
def __bool__(self):
|
def __bool__(self):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@property
|
|
||||||
def var_has_position(self):
|
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def var_can_be_bound(self):
|
def var_can_be_bound(self):
|
||||||
try:
|
try:
|
||||||
@ -30,22 +26,14 @@ class _Object:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self):
|
||||||
if self._str_ident is not None:
|
return f'{self.__class__.__name__}[{self.identifier}]'
|
||||||
name = f'{self.__class__.__name__}[{self._str_ident}]'
|
|
||||||
else:
|
|
||||||
name = f'{self.__class__.__name__}#{self.u_int}'
|
|
||||||
if self.bound_entity:
|
|
||||||
name = h.add_bound_name(name, self.bound_entity)
|
|
||||||
if self.var_has_position:
|
|
||||||
name = h.add_pos_name(name, self)
|
|
||||||
return name
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def identifier(self):
|
def identifier(self):
|
||||||
if self._str_ident is not None:
|
if self._str_ident is not None:
|
||||||
return self._str_ident
|
return self._str_ident
|
||||||
else:
|
else:
|
||||||
return self.name
|
return self.u_int
|
||||||
|
|
||||||
def reset_uid(self):
|
def reset_uid(self):
|
||||||
self._u_idx = defaultdict(lambda: 0)
|
self._u_idx = defaultdict(lambda: 0)
|
||||||
@ -62,7 +50,15 @@ class _Object:
|
|||||||
print(f'Following kwargs were passed, but ignored: {kwargs}')
|
print(f'Following kwargs were passed, but ignored: {kwargs}')
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f'{self.name}'
|
name = self.name
|
||||||
|
if self.bound_entity:
|
||||||
|
name = h.add_bound_name(name, self.bound_entity)
|
||||||
|
try:
|
||||||
|
if self.var_has_position:
|
||||||
|
name = h.add_pos_name(name, self)
|
||||||
|
except (AttributeError):
|
||||||
|
pass
|
||||||
|
return name
|
||||||
|
|
||||||
def __eq__(self, other) -> bool:
|
def __eq__(self, other) -> bool:
|
||||||
return other == self.identifier
|
return other == self.identifier
|
||||||
@ -88,7 +84,7 @@ class _Object:
|
|||||||
def summarize_state(self):
|
def summarize_state(self):
|
||||||
return dict()
|
return dict()
|
||||||
|
|
||||||
def bind(self, entity):
|
def bind_to(self, entity):
|
||||||
# noinspection PyAttributeOutsideInit
|
# noinspection PyAttributeOutsideInit
|
||||||
self._bound_entity = entity
|
self._bound_entity = entity
|
||||||
return c.VALID
|
return c.VALID
|
||||||
@ -100,9 +96,6 @@ class _Object:
|
|||||||
def bound_entity(self):
|
def bound_entity(self):
|
||||||
return self._bound_entity
|
return self._bound_entity
|
||||||
|
|
||||||
def bind_to(self, entity):
|
|
||||||
self._bound_entity = entity
|
|
||||||
|
|
||||||
def unbind(self):
|
def unbind(self):
|
||||||
self._bound_entity = None
|
self._bound_entity = None
|
||||||
|
|
||||||
|
@ -24,7 +24,7 @@ class PlaceHolder(_Object):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self):
|
||||||
return "PlaceHolder"
|
return self.__class__.__name__
|
||||||
|
|
||||||
|
|
||||||
class GlobalPosition(_Object):
|
class GlobalPosition(_Object):
|
||||||
@ -36,7 +36,8 @@ class GlobalPosition(_Object):
|
|||||||
else:
|
else:
|
||||||
return self.bound_entity.pos
|
return self.bound_entity.pos
|
||||||
|
|
||||||
def __init__(self, level_shape, *args, normalized: bool = True, **kwargs):
|
def __init__(self, agent, level_shape, *args, normalized: bool = True, **kwargs):
|
||||||
super(GlobalPosition, self).__init__(*args, **kwargs)
|
super(GlobalPosition, self).__init__(*args, **kwargs)
|
||||||
|
self.bind_to(agent)
|
||||||
self._normalized = normalized
|
self._normalized = normalized
|
||||||
self._shape = level_shape
|
self._shape = level_shape
|
||||||
|
@ -5,13 +5,8 @@ from marl_factory_grid.utils.utility_classes import RenderEntity
|
|||||||
|
|
||||||
class Wall(Entity):
|
class Wall(Entity):
|
||||||
|
|
||||||
@property
|
def __init__(self, *args, **kwargs):
|
||||||
def var_has_position(self):
|
super().__init__(*args, **kwargs)
|
||||||
return True
|
|
||||||
|
|
||||||
@property
|
|
||||||
def var_can_collide(self):
|
|
||||||
return True
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def encoding(self):
|
def encoding(self):
|
||||||
@ -19,11 +14,3 @@ class Wall(Entity):
|
|||||||
|
|
||||||
def render(self):
|
def render(self):
|
||||||
return RenderEntity(c.WALL, self.pos)
|
return RenderEntity(c.WALL, self.pos)
|
||||||
|
|
||||||
@property
|
|
||||||
def var_is_blocking_pos(self):
|
|
||||||
return True
|
|
||||||
|
|
||||||
@property
|
|
||||||
def var_is_blocking_light(self):
|
|
||||||
return True
|
|
||||||
|
@ -87,11 +87,14 @@ class Factory(gym.Env):
|
|||||||
entities = self.map.do_init()
|
entities = self.map.do_init()
|
||||||
|
|
||||||
# Init rules
|
# Init rules
|
||||||
rules = self.conf.load_rules()
|
env_rules = self.conf.load_env_rules()
|
||||||
|
entity_rules = self.conf.load_entity_spawn_rules(entities)
|
||||||
|
env_rules.extend(entity_rules)
|
||||||
|
|
||||||
# Parse the agent conf
|
# Parse the agent conf
|
||||||
parsed_agents_conf = self.conf.parse_agents_conf()
|
parsed_agents_conf = self.conf.parse_agents_conf()
|
||||||
self.state = Gamestate(entities, parsed_agents_conf, rules, self.conf.env_seed, self.conf.verbose)
|
self.state = Gamestate(entities, parsed_agents_conf, env_rules, self.map.level_shape,
|
||||||
|
self.conf.env_seed, self.conf.verbose)
|
||||||
|
|
||||||
# All is set up, trigger entity init with variable pos
|
# All is set up, trigger entity init with variable pos
|
||||||
# All is set up, trigger additional init (after agent entity spawn etc)
|
# All is set up, trigger additional init (after agent entity spawn etc)
|
||||||
|
@ -1,10 +1,15 @@
|
|||||||
from marl_factory_grid.environment.entity.agent import Agent
|
from marl_factory_grid.environment.entity.agent import Agent
|
||||||
from marl_factory_grid.environment.groups.collection import Collection
|
from marl_factory_grid.environment.groups.collection import Collection
|
||||||
|
from marl_factory_grid.environment.rules import SpawnAgents
|
||||||
|
|
||||||
|
|
||||||
class Agents(Collection):
|
class Agents(Collection):
|
||||||
_entity = Agent
|
_entity = Agent
|
||||||
|
|
||||||
|
@property
|
||||||
|
def spawn_rule(self):
|
||||||
|
return {SpawnAgents.__name__: {}}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def var_is_blocking_light(self):
|
def var_is_blocking_light(self):
|
||||||
return False
|
return False
|
||||||
|
@ -1,18 +1,25 @@
|
|||||||
from typing import List, Tuple, Union
|
from typing import List, Tuple, Union, Dict
|
||||||
|
|
||||||
from marl_factory_grid.environment.entity.entity import Entity
|
from marl_factory_grid.environment.entity.entity import Entity
|
||||||
from marl_factory_grid.environment.groups.objects import _Objects
|
from marl_factory_grid.environment.groups.objects import _Objects
|
||||||
|
# noinspection PyProtectedMember
|
||||||
from marl_factory_grid.environment.entity.object import _Object
|
from marl_factory_grid.environment.entity.object import _Object
|
||||||
import marl_factory_grid.environment.constants as c
|
import marl_factory_grid.environment.constants as c
|
||||||
|
from marl_factory_grid.utils.results import Result
|
||||||
|
|
||||||
|
|
||||||
class Collection(_Objects):
|
class Collection(_Objects):
|
||||||
_entity = _Object # entity?
|
_entity = _Object # entity?
|
||||||
|
symbol = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def var_is_blocking_light(self):
|
def var_is_blocking_light(self):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def var_is_blocking_pos(self):
|
||||||
|
return False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def var_can_collide(self):
|
def var_can_collide(self):
|
||||||
return False
|
return False
|
||||||
@ -23,29 +30,61 @@ class Collection(_Objects):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def var_has_position(self):
|
def var_has_position(self):
|
||||||
return False
|
return True
|
||||||
|
|
||||||
# @property
|
|
||||||
# def var_has_bound(self):
|
|
||||||
# return False # batteries, globalpos, inventories true
|
|
||||||
|
|
||||||
@property
|
|
||||||
def var_can_be_bound(self):
|
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def encodings(self):
|
def encodings(self):
|
||||||
return [x.encoding for x in self]
|
return [x.encoding for x in self]
|
||||||
|
|
||||||
def __init__(self, size, *args, **kwargs):
|
@property
|
||||||
super(Collection, self).__init__(*args, **kwargs)
|
def spawn_rule(self):
|
||||||
self.size = size
|
"""Prevent SpawnRule creation if Objects are spawned by map, Doors e.g."""
|
||||||
|
if self.symbol:
|
||||||
def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args): # woihn mit den args
|
return None
|
||||||
if isinstance(coords_or_quantity, int):
|
elif self._spawnrule:
|
||||||
self.add_items([self._entity() for _ in range(coords_or_quantity)])
|
return self._spawnrule
|
||||||
else:
|
else:
|
||||||
self.add_items([self._entity(pos) for pos in coords_or_quantity])
|
return {c.SPAWN_ENTITY_RULE: dict(collection=self, coords_or_quantity=self._coords_or_quantity)}
|
||||||
|
|
||||||
|
def __init__(self, size, *args, coords_or_quantity: int = None, ignore_blocking=False,
|
||||||
|
spawnrule: Union[None, Dict[str, dict]] = None,
|
||||||
|
**kwargs):
|
||||||
|
super(Collection, self).__init__(*args, **kwargs)
|
||||||
|
self._coords_or_quantity = coords_or_quantity
|
||||||
|
self.size = size
|
||||||
|
self._spawnrule = spawnrule
|
||||||
|
self._ignore_blocking = ignore_blocking
|
||||||
|
|
||||||
|
def trigger_spawn(self, state, *entity_args, coords_or_quantity=None, ignore_blocking=False, **entity_kwargs):
|
||||||
|
coords_or_quantity = coords_or_quantity if coords_or_quantity else self._coords_or_quantity
|
||||||
|
if self.var_has_position:
|
||||||
|
if isinstance(coords_or_quantity, int):
|
||||||
|
if ignore_blocking or self._ignore_blocking:
|
||||||
|
coords_or_quantity = state.entities.floorlist[:coords_or_quantity]
|
||||||
|
else:
|
||||||
|
coords_or_quantity = state.get_n_random_free_positions(coords_or_quantity)
|
||||||
|
self.spawn(coords_or_quantity, *entity_args, **entity_kwargs)
|
||||||
|
state.print(f'{len(coords_or_quantity)} new {self.name} have been spawned at {coords_or_quantity}')
|
||||||
|
return Result(identifier=f'{self.name}_spawn', validity=c.VALID, value=len(coords_or_quantity))
|
||||||
|
else:
|
||||||
|
if isinstance(coords_or_quantity, int):
|
||||||
|
self.spawn(coords_or_quantity, *entity_args, **entity_kwargs)
|
||||||
|
state.print(f'{coords_or_quantity} new {self.name} have been spawned randomly.')
|
||||||
|
return Result(identifier=f'{self.name}_spawn', validity=c.VALID, value=coords_or_quantity)
|
||||||
|
else:
|
||||||
|
raise ValueError(f'{self._entity.__name__} has no position!')
|
||||||
|
|
||||||
|
def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args, **entity_kwargs):
|
||||||
|
if self.var_has_position:
|
||||||
|
if isinstance(coords_or_quantity, int):
|
||||||
|
raise ValueError(f'{self._entity.__name__} should have a position!')
|
||||||
|
else:
|
||||||
|
self.add_items([self._entity(pos, *entity_args, **entity_kwargs) for pos in coords_or_quantity])
|
||||||
|
else:
|
||||||
|
if isinstance(coords_or_quantity, int):
|
||||||
|
self.add_items([self._entity(*entity_args, **entity_kwargs) for _ in range(coords_or_quantity)])
|
||||||
|
else:
|
||||||
|
raise ValueError(f'{self._entity.__name__} has no position!')
|
||||||
return c.VALID
|
return c.VALID
|
||||||
|
|
||||||
def despawn(self, items: List[_Object]):
|
def despawn(self, items: List[_Object]):
|
||||||
@ -115,7 +154,7 @@ class Collection(_Objects):
|
|||||||
except StopIteration:
|
except StopIteration:
|
||||||
pass
|
pass
|
||||||
except ValueError:
|
except ValueError:
|
||||||
print()
|
pass
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def positions(self):
|
def positions(self):
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
from random import shuffle, random
|
from random import shuffle
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
from marl_factory_grid.environment.groups.objects import _Objects
|
from marl_factory_grid.environment.groups.objects import _Objects
|
||||||
@ -12,10 +12,10 @@ class Entities(_Objects):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def neighboring_positions(pos):
|
def neighboring_positions(pos):
|
||||||
return (POS_MASK + pos).reshape(-1, 2)
|
return [tuple(x) for x in (POS_MASK + pos).reshape(-1, 2)]
|
||||||
|
|
||||||
def get_entities_near_pos(self, pos):
|
def get_entities_near_pos(self, pos):
|
||||||
return [y for x in itemgetter(*(tuple(x) for x in self.neighboring_positions(pos)))(self.pos_dict) for y in x]
|
return [y for x in itemgetter(*self.neighboring_positions(pos))(self.pos_dict) for y in x]
|
||||||
|
|
||||||
def render(self):
|
def render(self):
|
||||||
return [y for x in self for y in x.render() if x is not None]
|
return [y for x in self for y in x.render() if x is not None]
|
||||||
@ -35,8 +35,9 @@ class Entities(_Objects):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def guests_that_can_collide(self, pos):
|
def guests_that_can_collide(self, pos):
|
||||||
return[x for val in self.pos_dict[pos] for x in val if x.var_can_collide]
|
return [x for val in self.pos_dict[pos] for x in val if x.var_can_collide]
|
||||||
|
|
||||||
|
@property
|
||||||
def empty_positions(self):
|
def empty_positions(self):
|
||||||
empty_positions = [key for key in self.floorlist if not self.pos_dict[key]]
|
empty_positions = [key for key in self.floorlist if not self.pos_dict[key]]
|
||||||
shuffle(empty_positions)
|
shuffle(empty_positions)
|
||||||
@ -48,11 +49,23 @@ class Entities(_Objects):
|
|||||||
shuffle(empty_positions)
|
shuffle(empty_positions)
|
||||||
return empty_positions
|
return empty_positions
|
||||||
|
|
||||||
def is_blocked(self):
|
@property
|
||||||
return[key for key, val in self.pos_dict.items() if any([x.var_is_blocking_pos for x in val])]
|
def blocked_positions(self):
|
||||||
|
blocked_positions = [key for key, val in self.pos_dict.items() if any([x.var_is_blocking_pos for x in val])]
|
||||||
|
shuffle(blocked_positions)
|
||||||
|
return blocked_positions
|
||||||
|
|
||||||
def is_not_blocked(self):
|
@property
|
||||||
return[key for key, val in self.pos_dict.items() if not all([x.var_is_blocking_pos for x in val])]
|
def free_positions_generator(self):
|
||||||
|
generator = (
|
||||||
|
key for key in self.floorlist if all(not x.var_can_collide and not x.var_is_blocking_pos
|
||||||
|
for x in self.pos_dict[key])
|
||||||
|
)
|
||||||
|
return generator
|
||||||
|
|
||||||
|
@property
|
||||||
|
def free_positions_list(self):
|
||||||
|
return [x for x in self.free_positions_generator]
|
||||||
|
|
||||||
def iter_entities(self):
|
def iter_entities(self):
|
||||||
return iter((x for sublist in self.values() for x in sublist))
|
return iter((x for sublist in self.values() for x in sublist))
|
||||||
@ -92,3 +105,6 @@ class Entities(_Objects):
|
|||||||
@property
|
@property
|
||||||
def positions(self):
|
def positions(self):
|
||||||
return [k for k, v in self.pos_dict.items() for _ in v]
|
return [k for k, v in self.pos_dict.items() for _ in v]
|
||||||
|
|
||||||
|
def is_occupied(self, pos):
|
||||||
|
return len([x for x in self.pos_dict[pos] if x.var_can_collide or x.var_is_blocking_pos]) >= 1
|
||||||
|
@ -4,10 +4,6 @@ from marl_factory_grid.environment import constants as c
|
|||||||
# noinspection PyUnresolvedReferences,PyTypeChecker
|
# noinspection PyUnresolvedReferences,PyTypeChecker
|
||||||
class IsBoundMixin:
|
class IsBoundMixin:
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self):
|
|
||||||
return f'{self.__class__.__name__}({self._bound_entity.name})'
|
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f'{self.__class__.__name__}#{self._bound_entity.name}({self._data})'
|
return f'{self.__class__.__name__}#{self._bound_entity.name}({self._data})'
|
||||||
|
|
||||||
|
@ -5,11 +5,16 @@ import numpy as np
|
|||||||
|
|
||||||
from marl_factory_grid.environment.entity.object import _Object
|
from marl_factory_grid.environment.entity.object import _Object
|
||||||
import marl_factory_grid.environment.constants as c
|
import marl_factory_grid.environment.constants as c
|
||||||
|
from marl_factory_grid.utils import helpers as h
|
||||||
|
|
||||||
|
|
||||||
class _Objects:
|
class _Objects:
|
||||||
_entity = _Object
|
_entity = _Object
|
||||||
|
|
||||||
|
@property
|
||||||
|
def var_can_be_bound(self):
|
||||||
|
return False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def observers(self):
|
def observers(self):
|
||||||
return self._observers
|
return self._observers
|
||||||
@ -148,12 +153,12 @@ class _Objects:
|
|||||||
|
|
||||||
def by_entity(self, entity):
|
def by_entity(self, entity):
|
||||||
try:
|
try:
|
||||||
return next((x for x in self if x.belongs_to_entity(entity)))
|
return h.get_first(self, filter_by=lambda x: x.belongs_to_entity(entity))
|
||||||
except (StopIteration, AttributeError):
|
except (StopIteration, AttributeError):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def idx_by_entity(self, entity):
|
def idx_by_entity(self, entity):
|
||||||
try:
|
try:
|
||||||
return next((idx for idx, x in enumerate(self) if x.belongs_to_entity(entity)))
|
return h.get_first_index(self, filter_by=lambda x: x.belongs_to_entity(entity))
|
||||||
except (StopIteration, AttributeError):
|
except (StopIteration, AttributeError):
|
||||||
return None
|
return None
|
||||||
|
@ -1,7 +1,10 @@
|
|||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
|
from marl_factory_grid.environment import constants as c
|
||||||
from marl_factory_grid.environment.entity.util import GlobalPosition
|
from marl_factory_grid.environment.entity.util import GlobalPosition
|
||||||
from marl_factory_grid.environment.groups.collection import Collection
|
from marl_factory_grid.environment.groups.collection import Collection
|
||||||
|
from marl_factory_grid.utils.results import Result
|
||||||
|
from marl_factory_grid.utils.states import Gamestate
|
||||||
|
|
||||||
|
|
||||||
class Combined(Collection):
|
class Combined(Collection):
|
||||||
@ -36,17 +39,17 @@ class GlobalPositions(Collection):
|
|||||||
|
|
||||||
_entity = GlobalPosition
|
_entity = GlobalPosition
|
||||||
|
|
||||||
@property
|
var_is_blocking_light = False
|
||||||
def var_is_blocking_light(self):
|
var_can_be_bound = True
|
||||||
return False
|
var_can_collide = False
|
||||||
|
var_has_position = False
|
||||||
@property
|
|
||||||
def var_can_collide(self):
|
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
|
||||||
def var_can_be_bound(self):
|
|
||||||
return True
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(GlobalPositions, self).__init__(*args, **kwargs)
|
super(GlobalPositions, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def spawn(self, agents, level_shape, *args, **kwargs):
|
||||||
|
self.add_items([self._entity(agent, level_shape, *args, **kwargs) for agent in agents])
|
||||||
|
return [Result(identifier=f'{self.name}_spawn', validity=c.VALID, value=len(self))]
|
||||||
|
|
||||||
|
def trigger_spawn(self, state: Gamestate, *args, **kwargs) -> [Result]:
|
||||||
|
return self.spawn(state[c.AGENT], state.lvl_shape, *args, **kwargs)
|
||||||
|
@ -7,9 +7,12 @@ class Walls(Collection):
|
|||||||
_entity = Wall
|
_entity = Wall
|
||||||
symbol = c.SYMBOL_WALL
|
symbol = c.SYMBOL_WALL
|
||||||
|
|
||||||
@property
|
var_can_collide = True
|
||||||
def var_has_position(self):
|
var_is_blocking_light = True
|
||||||
return True
|
var_can_move = False
|
||||||
|
var_has_position = True
|
||||||
|
var_can_be_bound = False
|
||||||
|
var_is_blocking_pos = True
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(Walls, self).__init__(*args, **kwargs)
|
super(Walls, self).__init__(*args, **kwargs)
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import abc
|
import abc
|
||||||
from random import shuffle
|
from random import shuffle
|
||||||
from typing import List
|
from typing import List, Collection, Union
|
||||||
|
|
||||||
from marl_factory_grid.environment.entity.agent import Agent
|
from marl_factory_grid.environment.entity.agent import Agent
|
||||||
from marl_factory_grid.utils import helpers as h
|
from marl_factory_grid.utils import helpers as h
|
||||||
@ -39,6 +39,29 @@ class Rule(abc.ABC):
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class SpawnEntity(Rule):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _collection(self) -> Collection:
|
||||||
|
return Collection()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self):
|
||||||
|
return f'{self.__class__.__name__}({self.collection.name})'
|
||||||
|
|
||||||
|
def __init__(self, collection, coords_or_quantity, ignore_blocking=False):
|
||||||
|
super().__init__()
|
||||||
|
self.coords_or_quantity = coords_or_quantity
|
||||||
|
self.collection = collection
|
||||||
|
self.ignore_blocking = ignore_blocking
|
||||||
|
|
||||||
|
def on_init(self, state, lvl_map) -> [TickResult]:
|
||||||
|
results = self.collection.trigger_spawn(state, ignore_blocking=self.ignore_blocking)
|
||||||
|
pos_str = f' on: {[x.pos for x in self.collection]}' if self.collection.var_has_position else ''
|
||||||
|
state.print(f'Initial {self.collection.__class__.__name__} were spawned{pos_str}')
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
class SpawnAgents(Rule):
|
class SpawnAgents(Rule):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@ -46,14 +69,14 @@ class SpawnAgents(Rule):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def on_init(self, state, lvl_map):
|
def on_init(self, state, lvl_map):
|
||||||
agent_conf = state.agents_conf
|
|
||||||
# agents = Agents(lvl_map.size)
|
# agents = Agents(lvl_map.size)
|
||||||
agents = state[c.AGENT]
|
agents = state[c.AGENT]
|
||||||
empty_positions = state.entities.empty_positions()[:len(agent_conf)]
|
empty_positions = state.entities.empty_positions[:len(state.agents_conf)]
|
||||||
for agent_name in agent_conf:
|
for agent_name, agent_conf in state.agents_conf.items():
|
||||||
actions = agent_conf[agent_name]['actions'].copy()
|
actions = agent_conf['actions'].copy()
|
||||||
observations = agent_conf[agent_name]['observations'].copy()
|
observations = agent_conf['observations'].copy()
|
||||||
positions = agent_conf[agent_name]['positions'].copy()
|
positions = agent_conf['positions'].copy()
|
||||||
|
other = agent_conf['other'].copy()
|
||||||
if positions:
|
if positions:
|
||||||
shuffle(positions)
|
shuffle(positions)
|
||||||
while True:
|
while True:
|
||||||
@ -61,18 +84,18 @@ class SpawnAgents(Rule):
|
|||||||
pos = positions.pop()
|
pos = positions.pop()
|
||||||
except IndexError:
|
except IndexError:
|
||||||
raise ValueError(f'It was not possible to spawn an Agent on the available position: '
|
raise ValueError(f'It was not possible to spawn an Agent on the available position: '
|
||||||
f'\n{agent_name[agent_name]["positions"].copy()}')
|
f'\n{agent_conf["positions"].copy()}')
|
||||||
if agents.by_pos(pos) and state.check_pos_validity(pos):
|
if bool(agents.by_pos(pos)) or not state.check_pos_validity(pos):
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
agents.add_item(Agent(actions, observations, pos, str_ident=agent_name))
|
agents.add_item(Agent(actions, observations, pos, str_ident=agent_name, **other))
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
agents.add_item(Agent(actions, observations, empty_positions.pop(), str_ident=agent_name))
|
agents.add_item(Agent(actions, observations, empty_positions.pop(), str_ident=agent_name, **other))
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class MaxStepsReached(Rule):
|
class DoneAtMaxStepsReached(Rule):
|
||||||
|
|
||||||
def __init__(self, max_steps: int = 500):
|
def __init__(self, max_steps: int = 500):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -83,8 +106,8 @@ class MaxStepsReached(Rule):
|
|||||||
|
|
||||||
def on_check_done(self, state):
|
def on_check_done(self, state):
|
||||||
if self.max_steps <= state.curr_step:
|
if self.max_steps <= state.curr_step:
|
||||||
return [DoneResult(validity=c.VALID, identifier=self.name, reward=0)]
|
return [DoneResult(validity=c.VALID, identifier=self.name)]
|
||||||
return [DoneResult(validity=c.NOT_VALID, identifier=self.name, reward=0)]
|
return [DoneResult(validity=c.NOT_VALID, identifier=self.name)]
|
||||||
|
|
||||||
|
|
||||||
class AssignGlobalPositions(Rule):
|
class AssignGlobalPositions(Rule):
|
||||||
@ -101,7 +124,7 @@ class AssignGlobalPositions(Rule):
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
class Collision(Rule):
|
class WatchCollisions(Rule):
|
||||||
|
|
||||||
def __init__(self, done_at_collisions: bool = False):
|
def __init__(self, done_at_collisions: bool = False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -132,4 +155,4 @@ class Collision(Rule):
|
|||||||
move_failed = any(h.is_move(x.state.identifier) and not x.state.validity for x in state[c.AGENT])
|
move_failed = any(h.is_move(x.state.identifier) and not x.state.validity for x in state[c.AGENT])
|
||||||
if inter_entity_collision_detected or move_failed:
|
if inter_entity_collision_detected or move_failed:
|
||||||
return [DoneResult(validity=c.VALID, identifier=c.COLLISION, reward=r.COLLISION)]
|
return [DoneResult(validity=c.VALID, identifier=c.COLLISION, reward=r.COLLISION)]
|
||||||
return [DoneResult(validity=c.NOT_VALID, identifier=self.name, reward=0)]
|
return [DoneResult(validity=c.NOT_VALID, identifier=self.name)]
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from .actions import BtryCharge
|
from .actions import BtryCharge
|
||||||
from .entitites import Pod, Battery
|
from .entitites import ChargePod, Battery
|
||||||
from .groups import ChargePods, Batteries
|
from .groups import ChargePods, Batteries
|
||||||
from .rules import DoneAtBatteryDischarge, BatteryDecharge
|
from .rules import DoneAtBatteryDischarge, BatteryDecharge
|
||||||
|
@ -6,6 +6,7 @@ from marl_factory_grid.utils.results import ActionResult
|
|||||||
|
|
||||||
from marl_factory_grid.modules.batteries import constants as b
|
from marl_factory_grid.modules.batteries import constants as b
|
||||||
from marl_factory_grid.environment import constants as c
|
from marl_factory_grid.environment import constants as c
|
||||||
|
from marl_factory_grid.utils import helpers as h
|
||||||
|
|
||||||
|
|
||||||
class BtryCharge(Action):
|
class BtryCharge(Action):
|
||||||
@ -14,8 +15,8 @@ class BtryCharge(Action):
|
|||||||
super().__init__(b.ACTION_CHARGE)
|
super().__init__(b.ACTION_CHARGE)
|
||||||
|
|
||||||
def do(self, entity, state) -> Union[None, ActionResult]:
|
def do(self, entity, state) -> Union[None, ActionResult]:
|
||||||
if charge_pod := state[b.CHARGE_PODS].by_pos(entity.pos):
|
if charge_pod := h.get_first(state[b.CHARGE_PODS].by_pos(entity.pos)):
|
||||||
valid = charge_pod.charge_battery(state[b.BATTERIES].by_entity(entity))
|
valid = h.get_first(charge_pod.charge_battery(state[b.BATTERIES].by_entity(entity)))
|
||||||
if valid:
|
if valid:
|
||||||
state.print(f'{entity.name} just charged batteries at {charge_pod.name}.')
|
state.print(f'{entity.name} just charged batteries at {charge_pod.name}.')
|
||||||
else:
|
else:
|
||||||
|
BIN
marl_factory_grid/modules/batteries/chargepods.png
Normal file
BIN
marl_factory_grid/modules/batteries/chargepods.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 7.9 KiB |
@ -50,7 +50,7 @@ class Battery(_Object):
|
|||||||
return summary
|
return summary
|
||||||
|
|
||||||
|
|
||||||
class Pod(Entity):
|
class ChargePod(Entity):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def encoding(self):
|
def encoding(self):
|
||||||
@ -58,7 +58,7 @@ class Pod(Entity):
|
|||||||
|
|
||||||
def __init__(self, *args, charge_rate: float = 0.4,
|
def __init__(self, *args, charge_rate: float = 0.4,
|
||||||
multi_charge: bool = False, **kwargs):
|
multi_charge: bool = False, **kwargs):
|
||||||
super(Pod, self).__init__(*args, **kwargs)
|
super(ChargePod, self).__init__(*args, **kwargs)
|
||||||
self.charge_rate = charge_rate
|
self.charge_rate = charge_rate
|
||||||
self.multi_charge = multi_charge
|
self.multi_charge = multi_charge
|
||||||
|
|
||||||
|
@ -1,52 +1,36 @@
|
|||||||
from typing import Union, List, Tuple
|
from typing import Union, List, Tuple
|
||||||
|
|
||||||
|
from marl_factory_grid.environment import constants as c
|
||||||
from marl_factory_grid.environment.groups.collection import Collection
|
from marl_factory_grid.environment.groups.collection import Collection
|
||||||
from marl_factory_grid.modules.batteries.entitites import Pod, Battery
|
from marl_factory_grid.modules.batteries.entitites import ChargePod, Battery
|
||||||
|
from marl_factory_grid.utils.results import Result
|
||||||
|
|
||||||
|
|
||||||
class Batteries(Collection):
|
class Batteries(Collection):
|
||||||
_entity = Battery
|
_entity = Battery
|
||||||
|
|
||||||
@property
|
var_has_position = False
|
||||||
def var_is_blocking_light(self):
|
var_can_be_bound = True
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
|
||||||
def var_can_collide(self):
|
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
|
||||||
def var_can_move(self):
|
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
|
||||||
def var_has_position(self):
|
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
|
||||||
def var_can_be_bound(self):
|
|
||||||
return True
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def obs_tag(self):
|
def obs_tag(self):
|
||||||
return self.__class__.__name__
|
return self.__class__.__name__
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, size, initial_charge_level: float=1.0, *args, **kwargs):
|
||||||
super(Batteries, self).__init__(*args, **kwargs)
|
super(Batteries, self).__init__(size, *args, **kwargs)
|
||||||
|
self.initial_charge_level = initial_charge_level
|
||||||
|
|
||||||
def spawn(self, agents, initial_charge_level):
|
def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], agents, *entity_args, **entity_kwargs):
|
||||||
batteries = [self._entity(initial_charge_level, agent) for _, agent in enumerate(agents)]
|
batteries = [self._entity(self.initial_charge_level, agent) for _, agent in enumerate(agents)]
|
||||||
self.add_items(batteries)
|
self.add_items(batteries)
|
||||||
|
|
||||||
# def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args): hat keine pos
|
def trigger_spawn(self, state, *entity_args, coords_or_quantity=None, **entity_kwargs):
|
||||||
# agents = entity_args[0]
|
self.spawn(0, state[c.AGENT])
|
||||||
# initial_charge_level = entity_args[1]
|
return Result(identifier=f'{self.name}_spawn', validity=c.VALID, value=len(self))
|
||||||
# batteries = [self._entity(initial_charge_level, agent) for _, agent in enumerate(agents)]
|
|
||||||
# self.add_items(batteries)
|
|
||||||
|
|
||||||
|
|
||||||
class ChargePods(Collection):
|
class ChargePods(Collection):
|
||||||
_entity = Pod
|
_entity = ChargePod
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(ChargePods, self).__init__(*args, **kwargs)
|
super(ChargePods, self).__init__(*args, **kwargs)
|
||||||
|
@ -49,10 +49,6 @@ class BatteryDecharge(Rule):
|
|||||||
self.per_action_costs = per_action_costs
|
self.per_action_costs = per_action_costs
|
||||||
self.initial_charge = initial_charge
|
self.initial_charge = initial_charge
|
||||||
|
|
||||||
def on_init(self, state, lvl_map): # on reset?
|
|
||||||
assert len(state[c.AGENT]), "There are no agents, did you already spawn them?"
|
|
||||||
state[b.BATTERIES].spawn(state[c.AGENT], self.initial_charge)
|
|
||||||
|
|
||||||
def tick_step(self, state) -> List[TickResult]:
|
def tick_step(self, state) -> List[TickResult]:
|
||||||
# Decharge
|
# Decharge
|
||||||
batteries = state[b.BATTERIES]
|
batteries = state[b.BATTERIES]
|
||||||
@ -66,7 +62,7 @@ class BatteryDecharge(Rule):
|
|||||||
|
|
||||||
batteries.by_entity(agent).decharge(energy_consumption)
|
batteries.by_entity(agent).decharge(energy_consumption)
|
||||||
|
|
||||||
results.append(TickResult(self.name, reward=0, entity=agent, validity=c.VALID))
|
results.append(TickResult(self.name, entity=agent, validity=c.VALID))
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@ -82,13 +78,13 @@ class BatteryDecharge(Rule):
|
|||||||
if self.paralyze_agents_on_discharge:
|
if self.paralyze_agents_on_discharge:
|
||||||
btry.bound_entity.paralyze(self.name)
|
btry.bound_entity.paralyze(self.name)
|
||||||
results.append(
|
results.append(
|
||||||
TickResult("Paralyzed", entity=btry.bound_entity, reward=0, validity=c.VALID)
|
TickResult("Paralyzed", entity=btry.bound_entity, validity=c.VALID)
|
||||||
)
|
)
|
||||||
state.print(f'{btry.bound_entity.name} has just been paralyzed!')
|
state.print(f'{btry.bound_entity.name} has just been paralyzed!')
|
||||||
if btry.bound_entity.var_is_paralyzed and not btry.is_discharged:
|
if btry.bound_entity.var_is_paralyzed and not btry.is_discharged:
|
||||||
btry.bound_entity.de_paralyze(self.name)
|
btry.bound_entity.de_paralyze(self.name)
|
||||||
results.append(
|
results.append(
|
||||||
TickResult("De-Paralyzed", entity=btry.bound_entity, reward=0, validity=c.VALID)
|
TickResult("De-Paralyzed", entity=btry.bound_entity, validity=c.VALID)
|
||||||
)
|
)
|
||||||
state.print(f'{btry.bound_entity.name} has just been de-paralyzed!')
|
state.print(f'{btry.bound_entity.name} has just been de-paralyzed!')
|
||||||
return results
|
return results
|
||||||
@ -132,7 +128,7 @@ class DoneAtBatteryDischarge(BatteryDecharge):
|
|||||||
if any_discharged or all_discharged:
|
if any_discharged or all_discharged:
|
||||||
return [DoneResult(self.name, validity=c.VALID, reward=self.reward_discharge_done)]
|
return [DoneResult(self.name, validity=c.VALID, reward=self.reward_discharge_done)]
|
||||||
else:
|
else:
|
||||||
return [DoneResult(self.name, validity=c.NOT_VALID, reward=0)]
|
return [DoneResult(self.name, validity=c.NOT_VALID)]
|
||||||
|
|
||||||
|
|
||||||
class SpawnChargePods(Rule):
|
class SpawnChargePods(Rule):
|
||||||
@ -155,7 +151,7 @@ class SpawnChargePods(Rule):
|
|||||||
|
|
||||||
def on_init(self, state, lvl_map):
|
def on_init(self, state, lvl_map):
|
||||||
pod_collection = state[b.CHARGE_PODS]
|
pod_collection = state[b.CHARGE_PODS]
|
||||||
empty_positions = state.entities.empty_positions()
|
empty_positions = state.entities.empty_positions
|
||||||
pods = pod_collection.from_coordinates(empty_positions, entity_kwargs=dict(
|
pods = pod_collection.from_coordinates(empty_positions, entity_kwargs=dict(
|
||||||
multi_charge=self.multi_charge, charge_rate=self.charge_rate)
|
multi_charge=self.multi_charge, charge_rate=self.charge_rate)
|
||||||
)
|
)
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from .actions import CleanUp
|
from .actions import CleanUp
|
||||||
from .entitites import DirtPile
|
from .entitites import DirtPile
|
||||||
from .groups import DirtPiles
|
from .groups import DirtPiles
|
||||||
from .rules import SpawnDirt, EntitiesSmearDirtOnMove, DoneOnAllDirtCleaned
|
from .rules import EntitiesSmearDirtOnMove, DoneOnAllDirtCleaned
|
||||||
|
@ -7,22 +7,6 @@ from marl_factory_grid.modules.clean_up import constants as d
|
|||||||
|
|
||||||
class DirtPile(Entity):
|
class DirtPile(Entity):
|
||||||
|
|
||||||
@property
|
|
||||||
def var_can_collide(self):
|
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
|
||||||
def var_can_move(self):
|
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
|
||||||
def var_is_blocking_light(self):
|
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
|
||||||
def var_has_position(self):
|
|
||||||
return True
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def amount(self):
|
def amount(self):
|
||||||
return self._amount
|
return self._amount
|
||||||
|
@ -9,68 +9,55 @@ from marl_factory_grid.modules.clean_up.entitites import DirtPile
|
|||||||
class DirtPiles(Collection):
|
class DirtPiles(Collection):
|
||||||
_entity = DirtPile
|
_entity = DirtPile
|
||||||
|
|
||||||
@property
|
var_is_blocking_light = False
|
||||||
def var_is_blocking_light(self):
|
var_can_collide = False
|
||||||
return False
|
var_can_move = False
|
||||||
|
var_has_position = True
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def var_can_collide(self):
|
def global_amount(self):
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
|
||||||
def var_can_move(self):
|
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
|
||||||
def var_has_position(self):
|
|
||||||
return True
|
|
||||||
|
|
||||||
@property
|
|
||||||
def amount(self):
|
|
||||||
return sum([dirt.amount for dirt in self])
|
return sum([dirt.amount for dirt in self])
|
||||||
|
|
||||||
def __init__(self, *args,
|
def __init__(self, *args,
|
||||||
max_local_amount=5,
|
max_local_amount=5,
|
||||||
clean_amount=1,
|
clean_amount=1,
|
||||||
max_global_amount: int = 20, **kwargs):
|
max_global_amount: int = 20,
|
||||||
|
coords_or_quantity=10,
|
||||||
|
initial_amount=2,
|
||||||
|
amount_var=0.2,
|
||||||
|
n_var=0.2,
|
||||||
|
**kwargs):
|
||||||
super(DirtPiles, self).__init__(*args, **kwargs)
|
super(DirtPiles, self).__init__(*args, **kwargs)
|
||||||
|
self.amount_var = amount_var
|
||||||
|
self.n_var = n_var
|
||||||
self.clean_amount = clean_amount
|
self.clean_amount = clean_amount
|
||||||
self.max_global_amount = max_global_amount
|
self.max_global_amount = max_global_amount
|
||||||
self.max_local_amount = max_local_amount
|
self.max_local_amount = max_local_amount
|
||||||
|
self.coords_or_quantity = coords_or_quantity
|
||||||
|
self.initial_amount = initial_amount
|
||||||
|
|
||||||
def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args):
|
def trigger_spawn(self, state, coords_or_quantity=0, amount=0) -> [Result]:
|
||||||
amount_s = entity_args[0]
|
coords_or_quantity = coords_or_quantity if coords_or_quantity else self.coords_or_quantity
|
||||||
|
n_new = int(abs(coords_or_quantity + (state.rng.uniform(-self.n_var, self.n_var))))
|
||||||
|
n_new = state.get_n_random_free_positions(n_new)
|
||||||
|
|
||||||
|
amounts = [amount if amount else (self.initial_amount + state.rng.uniform(-self.amount_var, self.amount_var))
|
||||||
|
for _ in range(coords_or_quantity)]
|
||||||
spawn_counter = 0
|
spawn_counter = 0
|
||||||
for idx, pos in enumerate(coords_or_quantity):
|
for idx, (pos, a) in enumerate(zip(n_new, amounts)):
|
||||||
if not self.amount > self.max_global_amount:
|
if not self.global_amount > self.max_global_amount:
|
||||||
amount = amount_s[idx] if isinstance(amount_s, list) else amount_s
|
|
||||||
if dirt := self.by_pos(pos):
|
if dirt := self.by_pos(pos):
|
||||||
dirt = next(dirt.iter())
|
dirt = next(dirt.iter())
|
||||||
new_value = dirt.amount + amount
|
new_value = dirt.amount + a
|
||||||
dirt.set_new_amount(new_value)
|
dirt.set_new_amount(new_value)
|
||||||
else:
|
else:
|
||||||
dirt = DirtPile(pos, amount=amount)
|
super().spawn([pos], amount=a)
|
||||||
self.add_item(dirt)
|
|
||||||
spawn_counter += 1
|
spawn_counter += 1
|
||||||
else:
|
else:
|
||||||
return Result(identifier=f'{self.name}_spawn', validity=c.NOT_VALID, reward=0,
|
return Result(identifier=f'{self.name}_spawn', validity=c.NOT_VALID, value=spawn_counter)
|
||||||
value=spawn_counter)
|
|
||||||
return Result(identifier=f'{self.name}_spawn', validity=c.VALID, reward=0, value=spawn_counter)
|
|
||||||
|
|
||||||
def trigger_dirt_spawn(self, n, amount, state, n_var=0.2, amount_var=0.2) -> Result:
|
return Result(identifier=f'{self.name}_spawn', validity=c.VALID, value=spawn_counter)
|
||||||
free_for_dirt = [x for x in state.entities.floorlist if len(state.entities.pos_dict[x]) == 0 or (
|
|
||||||
len(state.entities.pos_dict[x]) >= 1 and isinstance(next(y for y in x), DirtPile))]
|
|
||||||
# free_for_dirt = [x for x in state[c.FLOOR]
|
|
||||||
# if len(x.guests) == 0 or (
|
|
||||||
# len(x.guests) == 1 and
|
|
||||||
# isinstance(next(y for y in x.guests), DirtPile))]
|
|
||||||
state.rng.shuffle(free_for_dirt)
|
|
||||||
|
|
||||||
new_spawn = int(abs(n + (state.rng.uniform(-n_var, n_var))))
|
|
||||||
new_amount_s = [abs(amount + (amount*state.rng.uniform(-amount_var, amount_var))) for _ in range(new_spawn)]
|
|
||||||
n_dirty_positions = free_for_dirt[:new_spawn]
|
|
||||||
return self.spawn(n_dirty_positions, new_amount_s)
|
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
s = super(DirtPiles, self).__repr__()
|
s = super(DirtPiles, self).__repr__()
|
||||||
return f'{s[:-1]}, {self.amount})'
|
return f'{s[:-1]}, {self.global_amount}]'
|
||||||
|
@ -22,58 +22,37 @@ class DoneOnAllDirtCleaned(Rule):
|
|||||||
def on_check_done(self, state) -> [DoneResult]:
|
def on_check_done(self, state) -> [DoneResult]:
|
||||||
if len(state[d.DIRT]) == 0 and state.curr_step:
|
if len(state[d.DIRT]) == 0 and state.curr_step:
|
||||||
return [DoneResult(validity=c.VALID, identifier=self.name, reward=self.reward)]
|
return [DoneResult(validity=c.VALID, identifier=self.name, reward=self.reward)]
|
||||||
return [DoneResult(validity=c.NOT_VALID, identifier=self.name, reward=0)]
|
return [DoneResult(validity=c.NOT_VALID, identifier=self.name)]
|
||||||
|
|
||||||
|
|
||||||
class SpawnDirt(Rule):
|
class RespawnDirt(Rule):
|
||||||
|
|
||||||
def __init__(self, initial_n: int = 5, initial_amount: float = 1.3,
|
def __init__(self, respawn_freq: int = 15, respawn_n: int = 5, respawn_amount: float = 1.0):
|
||||||
respawn_n: int = 3, respawn_amount: float = 0.8,
|
|
||||||
n_var: float = 0.2, amount_var: float = 0.2, spawn_freq: int = 15):
|
|
||||||
"""
|
"""
|
||||||
Defines the spawn pattern of intial and additional 'Dirt'-entitites.
|
Defines the spawn pattern of intial and additional 'Dirt'-entitites.
|
||||||
First chooses positions, then trys to spawn dirt until 'respawn_n' or the maximal global amount is reached.
|
First chooses positions, then trys to spawn dirt until 'respawn_n' or the maximal global amount is reached.
|
||||||
If there is allready some, it is topped up to min(max_local_amount, amount).
|
If there is allready some, it is topped up to min(max_local_amount, amount).
|
||||||
|
|
||||||
:type spawn_freq: int
|
:type respawn_freq: int
|
||||||
:parameter spawn_freq: In which frequency should this Rule try to spawn new 'Dirt'?
|
:parameter respawn_freq: In which frequency should this Rule try to spawn new 'Dirt'?
|
||||||
:type respawn_n: int
|
:type respawn_n: int
|
||||||
:parameter respawn_n: How many respawn positions are considered.
|
:parameter respawn_n: How many respawn positions are considered.
|
||||||
:type initial_n: int
|
|
||||||
:parameter initial_n: How much initial positions are considered.
|
|
||||||
:type amount_var: float
|
|
||||||
:parameter amount_var: Variance of amount to spawn.
|
|
||||||
:type n_var: float
|
|
||||||
:parameter n_var: Variance of n to spawn.
|
|
||||||
:type respawn_amount: float
|
:type respawn_amount: float
|
||||||
:parameter respawn_amount: Defines how much dirt 'amount' is placed every 'spawn_freq' ticks.
|
:parameter respawn_amount: Defines how much dirt 'amount' is placed every 'spawn_freq' ticks.
|
||||||
:type initial_amount: float
|
|
||||||
:parameter initial_amount: Defines how much dirt 'amount' is initially placed.
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.amount_var = amount_var
|
|
||||||
self.n_var = n_var
|
|
||||||
self.respawn_amount = respawn_amount
|
|
||||||
self.respawn_n = respawn_n
|
self.respawn_n = respawn_n
|
||||||
self.initial_amount = initial_amount
|
self.respawn_amount = respawn_amount
|
||||||
self.initial_n = initial_n
|
self.respawn_freq = respawn_freq
|
||||||
self.spawn_freq = spawn_freq
|
self._next_dirt_spawn = respawn_freq
|
||||||
self._next_dirt_spawn = spawn_freq
|
|
||||||
|
|
||||||
def on_init(self, state, lvl_map) -> str:
|
|
||||||
result = state[d.DIRT].trigger_dirt_spawn(self.initial_n, self.initial_amount, state,
|
|
||||||
n_var=self.n_var, amount_var=self.amount_var)
|
|
||||||
state.print(f'Initial Dirt was spawned on: {[x.pos for x in state[d.DIRT]]}')
|
|
||||||
return result
|
|
||||||
|
|
||||||
def tick_step(self, state):
|
def tick_step(self, state):
|
||||||
|
collection = state[d.DIRT]
|
||||||
if self._next_dirt_spawn < 0:
|
if self._next_dirt_spawn < 0:
|
||||||
pass # No DirtPile Spawn
|
pass # No DirtPile Spawn
|
||||||
elif not self._next_dirt_spawn:
|
elif not self._next_dirt_spawn:
|
||||||
result = [state[d.DIRT].trigger_dirt_spawn(self.respawn_n, self.respawn_amount, state,
|
result = [collection.trigger_spawn(state, coords_or_quantity=self.respawn_n, amount=self.respawn_amount)]
|
||||||
n_var=self.n_var, amount_var=self.amount_var)]
|
self._next_dirt_spawn = self.respawn_freq
|
||||||
self._next_dirt_spawn = self.spawn_freq
|
|
||||||
else:
|
else:
|
||||||
self._next_dirt_spawn -= 1
|
self._next_dirt_spawn -= 1
|
||||||
result = []
|
result = []
|
||||||
@ -99,8 +78,8 @@ class EntitiesSmearDirtOnMove(Rule):
|
|||||||
for entity in state.moving_entites:
|
for entity in state.moving_entites:
|
||||||
if is_move(entity.state.identifier) and entity.state.validity == c.VALID:
|
if is_move(entity.state.identifier) and entity.state.validity == c.VALID:
|
||||||
if old_pos_dirt := state[d.DIRT].by_pos(entity.last_pos):
|
if old_pos_dirt := state[d.DIRT].by_pos(entity.last_pos):
|
||||||
|
old_pos_dirt = next(iter(old_pos_dirt))
|
||||||
if smeared_dirt := round(old_pos_dirt.amount * self.smear_ratio, 2):
|
if smeared_dirt := round(old_pos_dirt.amount * self.smear_ratio, 2):
|
||||||
if state[d.DIRT].spawn(entity.pos, amount=smeared_dirt):
|
if state[d.DIRT].spawn(entity.pos, amount=smeared_dirt):
|
||||||
results.append(TickResult(identifier=self.name, entity=entity,
|
results.append(TickResult(identifier=self.name, entity=entity, validity=c.VALID))
|
||||||
reward=0, validity=c.VALID))
|
|
||||||
return results
|
return results
|
||||||
|
@ -1,4 +1,7 @@
|
|||||||
from .actions import DestAction
|
from .actions import DestAction
|
||||||
from .entitites import Destination
|
from .entitites import Destination
|
||||||
from .groups import Destinations
|
from .groups import Destinations
|
||||||
from .rules import DoneAtDestinationReachAll, SpawnDestinations
|
from .rules import (DoneAtDestinationReachAll,
|
||||||
|
DoneAtDestinationReachAny,
|
||||||
|
SpawnDestinationsPerAgent,
|
||||||
|
DestinationReachReward)
|
||||||
|
@ -9,30 +9,6 @@ from marl_factory_grid.utils.utility_classes import RenderEntity
|
|||||||
|
|
||||||
class Destination(Entity):
|
class Destination(Entity):
|
||||||
|
|
||||||
@property
|
|
||||||
def var_can_move(self):
|
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
|
||||||
def var_can_collide(self):
|
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
|
||||||
def var_has_position(self):
|
|
||||||
return True
|
|
||||||
|
|
||||||
@property
|
|
||||||
def var_is_blocking_pos(self):
|
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
|
||||||
def var_is_blocking_light(self):
|
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
|
||||||
def var_can_be_bound(self):
|
|
||||||
return True
|
|
||||||
|
|
||||||
def was_reached(self):
|
def was_reached(self):
|
||||||
return self._was_reached
|
return self._was_reached
|
||||||
|
|
||||||
|
@ -7,37 +7,14 @@ from marl_factory_grid.modules.destinations import constants as d
|
|||||||
class Destinations(Collection):
|
class Destinations(Collection):
|
||||||
_entity = Destination
|
_entity = Destination
|
||||||
|
|
||||||
@property
|
var_is_blocking_light = False
|
||||||
def var_is_blocking_light(self):
|
var_can_collide = False
|
||||||
return False
|
var_can_move = False
|
||||||
|
var_has_position = True
|
||||||
@property
|
var_can_be_bound = True
|
||||||
def var_can_collide(self):
|
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
|
||||||
def var_can_move(self):
|
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
|
||||||
def var_has_position(self):
|
|
||||||
return True
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return super(Destinations, self).__repr__()
|
return super(Destinations, self).__repr__()
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def trigger_destination_spawn(n_dests, state):
|
|
||||||
coordinates = state.entities.floorlist[:n_dests]
|
|
||||||
if destinations := [Destination(pos) for pos in coordinates]:
|
|
||||||
state[d.DESTINATION].add_items(destinations)
|
|
||||||
state.print(f'{n_dests} new destinations have been spawned')
|
|
||||||
return c.VALID
|
|
||||||
else:
|
|
||||||
state.print('No Destiantions are spawning, limit is reached.')
|
|
||||||
return c.NOT_VALID
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -2,8 +2,8 @@ import ast
|
|||||||
from random import shuffle
|
from random import shuffle
|
||||||
from typing import List, Dict, Tuple
|
from typing import List, Dict, Tuple
|
||||||
|
|
||||||
import marl_factory_grid.modules.destinations.constants
|
|
||||||
from marl_factory_grid.environment.rules import Rule
|
from marl_factory_grid.environment.rules import Rule
|
||||||
|
from marl_factory_grid.utils import helpers as h
|
||||||
from marl_factory_grid.utils.results import TickResult, DoneResult
|
from marl_factory_grid.utils.results import TickResult, DoneResult
|
||||||
from marl_factory_grid.environment import constants as c
|
from marl_factory_grid.environment import constants as c
|
||||||
|
|
||||||
@ -54,7 +54,7 @@ class DoneAtDestinationReachAll(DestinationReachReward):
|
|||||||
"""
|
"""
|
||||||
This rule triggers and sets the done flag if ALL Destinations have been reached.
|
This rule triggers and sets the done flag if ALL Destinations have been reached.
|
||||||
|
|
||||||
:type reward_at_done: object
|
:type reward_at_done: float
|
||||||
:param reward_at_done: Specifies the reward, agent get, whenn all destinations are reached.
|
:param reward_at_done: Specifies the reward, agent get, whenn all destinations are reached.
|
||||||
:type dest_reach_reward: float
|
:type dest_reach_reward: float
|
||||||
:param dest_reach_reward: Specify the reward, agents get when reaching a single destination.
|
:param dest_reach_reward: Specify the reward, agents get when reaching a single destination.
|
||||||
@ -65,7 +65,7 @@ class DoneAtDestinationReachAll(DestinationReachReward):
|
|||||||
def on_check_done(self, state) -> List[DoneResult]:
|
def on_check_done(self, state) -> List[DoneResult]:
|
||||||
if all(x.was_reached() for x in state[d.DESTINATION]):
|
if all(x.was_reached() for x in state[d.DESTINATION]):
|
||||||
return [DoneResult(self.name, validity=c.VALID, reward=self.reward)]
|
return [DoneResult(self.name, validity=c.VALID, reward=self.reward)]
|
||||||
return [DoneResult(self.name, validity=c.NOT_VALID, reward=0)]
|
return [DoneResult(self.name, validity=c.NOT_VALID)]
|
||||||
|
|
||||||
|
|
||||||
class DoneAtDestinationReachAny(DestinationReachReward):
|
class DoneAtDestinationReachAny(DestinationReachReward):
|
||||||
@ -75,7 +75,7 @@ class DoneAtDestinationReachAny(DestinationReachReward):
|
|||||||
This rule triggers and sets the done flag if ANY Destinations has been reached.
|
This rule triggers and sets the done flag if ANY Destinations has been reached.
|
||||||
!!! IMPORTANT: 'reward_at_done' is shared between the agents; 'dest_reach_reward' is bound to a specific one.
|
!!! IMPORTANT: 'reward_at_done' is shared between the agents; 'dest_reach_reward' is bound to a specific one.
|
||||||
|
|
||||||
:type reward_at_done: object
|
:type reward_at_done: float
|
||||||
:param reward_at_done: Specifies the reward, all agent get, when any destinations has been reached.
|
:param reward_at_done: Specifies the reward, all agent get, when any destinations has been reached.
|
||||||
Default {d.REWARD_DEST_DONE}
|
Default {d.REWARD_DEST_DONE}
|
||||||
:type dest_reach_reward: float
|
:type dest_reach_reward: float
|
||||||
@ -87,67 +87,29 @@ class DoneAtDestinationReachAny(DestinationReachReward):
|
|||||||
|
|
||||||
def on_check_done(self, state) -> List[DoneResult]:
|
def on_check_done(self, state) -> List[DoneResult]:
|
||||||
if any(x.was_reached() for x in state[d.DESTINATION]):
|
if any(x.was_reached() for x in state[d.DESTINATION]):
|
||||||
return [DoneResult(self.name, validity=c.VALID, reward=marl_factory_grid.modules.destinations.constants.REWARD_DEST_REACHED)]
|
return [DoneResult(self.name, validity=c.VALID, reward=d.REWARD_DEST_REACHED)]
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
class SpawnDestinations(Rule):
|
|
||||||
|
|
||||||
def __init__(self, n_dests: int = 1, spawn_mode: str = d.MODE_GROUPED):
|
|
||||||
f"""
|
|
||||||
Defines how destinations are initially spawned and respawned in addition.
|
|
||||||
!!! This rule introduces no kind of reward or Env.-Done condition!
|
|
||||||
|
|
||||||
:type n_dests: int
|
|
||||||
:param n_dests: How many destiantions should be maintained (and initally spawnewd) on the map?
|
|
||||||
:type spawn_mode: str
|
|
||||||
:param spawn_mode: One of {d.SPAWN_MODES}. {d.MODE_GROUPED}: Always wait for all Dstiantions do be gone,
|
|
||||||
then respawn after the given time. {d.MODE_SINGLE}: Just spawn every destination,
|
|
||||||
that has been reached, after the given time
|
|
||||||
|
|
||||||
"""
|
|
||||||
super(SpawnDestinations, self).__init__()
|
|
||||||
self.n_dests = n_dests
|
|
||||||
self.spawn_mode = spawn_mode
|
|
||||||
|
|
||||||
def on_init(self, state, lvl_map):
|
|
||||||
# noinspection PyAttributeOutsideInit
|
|
||||||
state[d.DESTINATION].trigger_destination_spawn(self.n_dests, state)
|
|
||||||
pass
|
|
||||||
|
|
||||||
def tick_pre_step(self, state) -> List[TickResult]:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def tick_step(self, state) -> List[TickResult]:
|
|
||||||
if n_dest_spawn := max(0, self.n_dests - len(state[d.DESTINATION])):
|
|
||||||
if self.spawn_mode == d.MODE_GROUPED and n_dest_spawn == self.n_dests:
|
|
||||||
validity = state[d.DESTINATION].trigger_destination_spawn(n_dest_spawn, state)
|
|
||||||
return [TickResult(self.name, validity=validity, entity=None, value=n_dest_spawn)]
|
|
||||||
elif self.spawn_mode == d.MODE_SINGLE and n_dest_spawn:
|
|
||||||
validity = state[d.DESTINATION].trigger_destination_spawn(n_dest_spawn, state)
|
|
||||||
return [TickResult(self.name, validity=validity, entity=None, value=n_dest_spawn)]
|
|
||||||
else:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class SpawnDestinationsPerAgent(Rule):
|
class SpawnDestinationsPerAgent(Rule):
|
||||||
def __init__(self, per_agent_positions: Dict[str, List[Tuple[int, int]]]):
|
def __init__(self, coords_or_quantity: Dict[str, List[Tuple[int, int]]]):
|
||||||
"""
|
"""
|
||||||
Special rule, that spawn distinations, that are bound to a single agent a fixed set of positions.
|
Special rule, that spawn distinations, that are bound to a single agent a fixed set of positions.
|
||||||
Usefull for introducing specialists, etc. ..
|
Usefull for introducing specialists, etc. ..
|
||||||
|
|
||||||
!!! This rule does not introduce any reward or done condition.
|
!!! This rule does not introduce any reward or done condition.
|
||||||
|
|
||||||
:type per_agent_positions: Dict[str, List[Tuple[int, int]]
|
:type coords_or_quantity: Dict[str, List[Tuple[int, int]]
|
||||||
:param per_agent_positions: Please provide a dictionary with agent names as keys; and a list of possible
|
:param coords_or_quantity: Please provide a dictionary with agent names as keys; and a list of possible
|
||||||
destiantion coords as value. Example: {Wolfgang: [(0, 0), (1, 1), ...]}
|
destiantion coords as value. Example: {Wolfgang: [(0, 0), (1, 1), ...]}
|
||||||
"""
|
"""
|
||||||
super(Rule, self).__init__()
|
super(Rule, self).__init__()
|
||||||
self.per_agent_positions = {key: [ast.literal_eval(x) for x in val] for key, val in per_agent_positions.items()}
|
self.per_agent_positions = {key: [ast.literal_eval(x) for x in val] for key, val in coords_or_quantity.items()}
|
||||||
|
|
||||||
def on_init(self, state, lvl_map):
|
def on_init(self, state, lvl_map):
|
||||||
for (agent_name, position_list) in self.per_agent_positions.items():
|
for (agent_name, position_list) in self.per_agent_positions.items():
|
||||||
agent = next(x for x in state[c.AGENT] if agent_name in x.name) # Fixme: Ugly AF
|
agent = h.get_first(state[c.AGENT], lambda x: agent_name in x.name)
|
||||||
|
assert agent
|
||||||
position_list = position_list.copy()
|
position_list = position_list.copy()
|
||||||
shuffle(position_list)
|
shuffle(position_list)
|
||||||
while True:
|
while True:
|
||||||
@ -155,7 +117,7 @@ class SpawnDestinationsPerAgent(Rule):
|
|||||||
pos = position_list.pop()
|
pos = position_list.pop()
|
||||||
except IndexError:
|
except IndexError:
|
||||||
print(f"Could not spawn Destinations at: {self.per_agent_positions[agent_name]}")
|
print(f"Could not spawn Destinations at: {self.per_agent_positions[agent_name]}")
|
||||||
print(f'Check your agent palcement: {state[c.AGENT]} ... Exit ...')
|
print(f'Check your agent placement: {state[c.AGENT]} ... Exit ...')
|
||||||
exit(9999)
|
exit(9999)
|
||||||
if (not pos == agent.pos) and (not state[d.DESTINATION].by_pos(pos)):
|
if (not pos == agent.pos) and (not state[d.DESTINATION].by_pos(pos)):
|
||||||
destination = Destination(pos, bind_to=agent)
|
destination = Destination(pos, bind_to=agent)
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from marl_factory_grid.environment.entity.entity import Entity
|
from marl_factory_grid.environment.entity.entity import Entity
|
||||||
|
from marl_factory_grid.utils import Result
|
||||||
from marl_factory_grid.utils.utility_classes import RenderEntity
|
from marl_factory_grid.utils.utility_classes import RenderEntity
|
||||||
from marl_factory_grid.environment import constants as c
|
from marl_factory_grid.environment import constants as c
|
||||||
|
|
||||||
@ -41,21 +42,6 @@ class Door(Entity):
|
|||||||
def str_state(self):
|
def str_state(self):
|
||||||
return 'open' if self.is_open else 'closed'
|
return 'open' if self.is_open else 'closed'
|
||||||
|
|
||||||
def __init__(self, *args, closed_on_init=True, auto_close_interval=10, **kwargs):
|
|
||||||
self._status = d.STATE_CLOSED
|
|
||||||
super(Door, self).__init__(*args, **kwargs)
|
|
||||||
self.auto_close_interval = auto_close_interval
|
|
||||||
self.time_to_close = 0
|
|
||||||
if not closed_on_init:
|
|
||||||
self._open()
|
|
||||||
else:
|
|
||||||
self._close()
|
|
||||||
|
|
||||||
def summarize_state(self):
|
|
||||||
state_dict = super().summarize_state()
|
|
||||||
state_dict.update(state=str(self.str_state), time_to_close=int(self.time_to_close))
|
|
||||||
return state_dict
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_closed(self):
|
def is_closed(self):
|
||||||
return self._status == d.STATE_CLOSED
|
return self._status == d.STATE_CLOSED
|
||||||
@ -68,6 +54,25 @@ class Door(Entity):
|
|||||||
def status(self):
|
def status(self):
|
||||||
return self._status
|
return self._status
|
||||||
|
|
||||||
|
@property
|
||||||
|
def time_to_close(self):
|
||||||
|
return self._time_to_close
|
||||||
|
|
||||||
|
def __init__(self, *args, closed_on_init=True, auto_close_interval=10, **kwargs):
|
||||||
|
self._status = d.STATE_CLOSED
|
||||||
|
super(Door, self).__init__(*args, **kwargs)
|
||||||
|
self._auto_close_interval = auto_close_interval
|
||||||
|
self._time_to_close = 0
|
||||||
|
if not closed_on_init:
|
||||||
|
self._open()
|
||||||
|
else:
|
||||||
|
self._close()
|
||||||
|
|
||||||
|
def summarize_state(self):
|
||||||
|
state_dict = super().summarize_state()
|
||||||
|
state_dict.update(state=str(self.str_state), time_to_close=self.time_to_close)
|
||||||
|
return state_dict
|
||||||
|
|
||||||
def render(self):
|
def render(self):
|
||||||
name, state = 'door_open' if self.is_open else 'door_closed', 'blank'
|
name, state = 'door_open' if self.is_open else 'door_closed', 'blank'
|
||||||
return RenderEntity(name, self.pos, 1, 'none', state, self.u_int + 1)
|
return RenderEntity(name, self.pos, 1, 'none', state, self.u_int + 1)
|
||||||
@ -80,18 +85,35 @@ class Door(Entity):
|
|||||||
return c.VALID
|
return c.VALID
|
||||||
|
|
||||||
def tick(self, state):
|
def tick(self, state):
|
||||||
if self.is_open and len(state.entities.pos_dict[self.pos]) == 2 and self.time_to_close:
|
# Check if no entity is standing in the door
|
||||||
self.time_to_close -= 1
|
if len(state.entities.pos_dict[self.pos]) <= 2:
|
||||||
return c.NOT_VALID
|
if self.is_open and self.time_to_close:
|
||||||
elif self.is_open and not self.time_to_close and len(state.entities.pos_dict[self.pos]) == 2:
|
self._decrement_timer()
|
||||||
self.use()
|
return Result(f"{d.DOOR}_tick", c.VALID, entity=self)
|
||||||
return c.VALID
|
elif self.is_open and not self.time_to_close:
|
||||||
|
self.use()
|
||||||
|
return Result(f"{d.DOOR}_closed", c.VALID, entity=self)
|
||||||
|
else:
|
||||||
|
# No one is in door, but it is closed... Nothing to do....
|
||||||
|
return None
|
||||||
else:
|
else:
|
||||||
return c.NOT_VALID
|
# Entity is standing in the door, reset timer
|
||||||
|
self._reset_timer()
|
||||||
|
return Result(f"{d.DOOR}_reset", c.VALID, entity=self)
|
||||||
|
|
||||||
def _open(self):
|
def _open(self):
|
||||||
self._status = d.STATE_OPEN
|
self._status = d.STATE_OPEN
|
||||||
self.time_to_close = self.auto_close_interval
|
self._reset_timer()
|
||||||
|
return True
|
||||||
|
|
||||||
def _close(self):
|
def _close(self):
|
||||||
self._status = d.STATE_CLOSED
|
self._status = d.STATE_CLOSED
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _decrement_timer(self):
|
||||||
|
self._time_to_close -= 1
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _reset_timer(self):
|
||||||
|
self._time_to_close = self._auto_close_interval
|
||||||
|
return True
|
||||||
|
@ -18,8 +18,10 @@ class Doors(Collection):
|
|||||||
super(Doors, self).__init__(*args, can_collide=True, **kwargs)
|
super(Doors, self).__init__(*args, can_collide=True, **kwargs)
|
||||||
|
|
||||||
def tick_doors(self, state):
|
def tick_doors(self, state):
|
||||||
result_dict = dict()
|
results = list()
|
||||||
for door in self:
|
for door in self:
|
||||||
did_tick = door.tick(state)
|
tick_result = door.tick(state)
|
||||||
result_dict.update({door.name: did_tick})
|
if tick_result is not None:
|
||||||
return result_dict
|
results.append(tick_result)
|
||||||
|
# TODO: Should return a Result object, not a random dict.
|
||||||
|
return results
|
||||||
|
@ -19,10 +19,10 @@ class DoorAutoClose(Rule):
|
|||||||
|
|
||||||
def tick_step(self, state):
|
def tick_step(self, state):
|
||||||
if doors := state[d.DOORS]:
|
if doors := state[d.DOORS]:
|
||||||
doors_tick_result = doors.tick_doors(state)
|
doors_tick_results = doors.tick_doors(state)
|
||||||
doors_that_ticked = [key for key, val in doors_tick_result.items() if val]
|
doors_that_closed = [x.entity.name for x in doors_tick_results if 'closed' in x.identifier]
|
||||||
state.print(f'{doors_that_ticked} were auto-closed'
|
door_str = doors_that_closed if doors_that_closed else "No Doors"
|
||||||
if doors_that_ticked else 'No Doors were auto-closed')
|
state.print(f'{door_str} were auto-closed')
|
||||||
return [TickResult(self.name, validity=c.VALID, value=1)]
|
return [TickResult(self.name, validity=c.VALID, value=1)]
|
||||||
state.print('There are no doors, but you loaded the corresponding Module')
|
state.print('There are no doors, but you loaded the corresponding Module')
|
||||||
return []
|
return []
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
from .actions import ItemAction
|
from .actions import ItemAction
|
||||||
from .entitites import Item, DropOffLocation
|
from .entitites import Item, DropOffLocation
|
||||||
from .groups import DropOffLocations, Items, Inventory, Inventories
|
from .groups import DropOffLocations, Items, Inventory, Inventories
|
||||||
from .rules import ItemRules
|
|
||||||
|
@ -29,7 +29,7 @@ class ItemAction(Action):
|
|||||||
elif items := state[i.ITEM].by_pos(entity.pos):
|
elif items := state[i.ITEM].by_pos(entity.pos):
|
||||||
item = items[0]
|
item = items[0]
|
||||||
item.change_parent_collection(inventory)
|
item.change_parent_collection(inventory)
|
||||||
item.set_pos_to(c.VALUE_NO_POS)
|
item.set_pos(c.VALUE_NO_POS)
|
||||||
state.print(f'{entity.name} just picked up an item at {entity.pos}')
|
state.print(f'{entity.name} just picked up an item at {entity.pos}')
|
||||||
return ActionResult(entity=entity, identifier=self._identifier, validity=c.VALID, reward=r.PICK_UP_VALID)
|
return ActionResult(entity=entity, identifier=self._identifier, validity=c.VALID, reward=r.PICK_UP_VALID)
|
||||||
|
|
||||||
|
@ -8,16 +8,11 @@ from marl_factory_grid.modules.items import constants as i
|
|||||||
|
|
||||||
class Item(Entity):
|
class Item(Entity):
|
||||||
|
|
||||||
@property
|
|
||||||
def var_can_collide(self):
|
|
||||||
return False
|
|
||||||
|
|
||||||
def render(self):
|
def render(self):
|
||||||
return RenderEntity(i.ITEM, self.pos) if self.pos != c.VALUE_NO_POS else None
|
return RenderEntity(i.ITEM, self.pos) if self.pos != c.VALUE_NO_POS else None
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self._auto_despawn = -1
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def auto_despawn(self):
|
def auto_despawn(self):
|
||||||
@ -31,9 +26,6 @@ class Item(Entity):
|
|||||||
def set_auto_despawn(self, auto_despawn):
|
def set_auto_despawn(self, auto_despawn):
|
||||||
self._auto_despawn = auto_despawn
|
self._auto_despawn = auto_despawn
|
||||||
|
|
||||||
def set_pos_to(self, no_pos):
|
|
||||||
self._pos = no_pos
|
|
||||||
|
|
||||||
def summarize_state(self) -> dict:
|
def summarize_state(self) -> dict:
|
||||||
super_summarization = super(Item, self).summarize_state()
|
super_summarization = super(Item, self).summarize_state()
|
||||||
super_summarization.update(dict(auto_despawn=self.auto_despawn))
|
super_summarization.update(dict(auto_despawn=self.auto_despawn))
|
||||||
@ -42,21 +34,6 @@ class Item(Entity):
|
|||||||
|
|
||||||
class DropOffLocation(Entity):
|
class DropOffLocation(Entity):
|
||||||
|
|
||||||
@property
|
|
||||||
def var_can_collide(self):
|
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
|
||||||
def var_can_move(self):
|
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
|
||||||
def var_is_blocking_light(self):
|
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
|
||||||
def var_has_position(self):
|
|
||||||
return True
|
|
||||||
|
|
||||||
def render(self):
|
def render(self):
|
||||||
return RenderEntity(i.DROP_OFF, self.pos)
|
return RenderEntity(i.DROP_OFF, self.pos)
|
||||||
|
@ -8,6 +8,7 @@ from marl_factory_grid.environment.groups.objects import _Objects
|
|||||||
from marl_factory_grid.environment.groups.mixins import IsBoundMixin
|
from marl_factory_grid.environment.groups.mixins import IsBoundMixin
|
||||||
from marl_factory_grid.environment.entity.agent import Agent
|
from marl_factory_grid.environment.entity.agent import Agent
|
||||||
from marl_factory_grid.modules.items.entitites import Item, DropOffLocation
|
from marl_factory_grid.modules.items.entitites import Item, DropOffLocation
|
||||||
|
from marl_factory_grid.utils.results import Result
|
||||||
|
|
||||||
|
|
||||||
class Items(Collection):
|
class Items(Collection):
|
||||||
@ -15,7 +16,7 @@ class Items(Collection):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def var_has_position(self):
|
def var_has_position(self):
|
||||||
return False
|
return True
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_blocking_light(self):
|
def is_blocking_light(self):
|
||||||
@ -28,18 +29,18 @@ class Items(Collection):
|
|||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
@staticmethod
|
def trigger_spawn(self, state, *entity_args, coords_or_quantity=None, **entity_kwargs) -> [Result]:
|
||||||
def trigger_item_spawn(state, n_items, spawn_frequency):
|
coords_or_quantity = coords_or_quantity if coords_or_quantity else self._coords_or_quantity
|
||||||
if item_to_spawns := max(0, (n_items - len(state[i.ITEM]))):
|
assert coords_or_quantity
|
||||||
position_list = [x for x in state.entities.floorlist]
|
|
||||||
shuffle(position_list)
|
if item_to_spawns := max(0, (coords_or_quantity - len(self))):
|
||||||
position_list = state.entities.floorlist[:item_to_spawns]
|
return super().trigger_spawn(state,
|
||||||
state[i.ITEM].spawn(position_list)
|
*entity_args,
|
||||||
state.print(f'{item_to_spawns} new items have been spawned; next spawn in {spawn_frequency}')
|
coords_or_quantity=item_to_spawns,
|
||||||
return len(position_list)
|
**entity_kwargs)
|
||||||
else:
|
else:
|
||||||
state.print('No Items are spawning, limit is reached.')
|
state.print('No Items are spawning, limit is reached.')
|
||||||
return 0
|
return Result(identifier=f'{self.name}_spawn', validity=c.NOT_VALID, value=coords_or_quantity)
|
||||||
|
|
||||||
|
|
||||||
class Inventory(IsBoundMixin, Collection):
|
class Inventory(IsBoundMixin, Collection):
|
||||||
@ -76,9 +77,15 @@ class Inventory(IsBoundMixin, Collection):
|
|||||||
class Inventories(_Objects):
|
class Inventories(_Objects):
|
||||||
_entity = Inventory
|
_entity = Inventory
|
||||||
|
|
||||||
|
var_can_move = False
|
||||||
|
var_has_position = False
|
||||||
|
|
||||||
|
|
||||||
|
symbol = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def var_can_move(self):
|
def spawn_rule(self):
|
||||||
return False
|
return {c.SPAWN_ENTITY_RULE: dict(collection=self, coords_or_quantity=None)}
|
||||||
|
|
||||||
def __init__(self, size: int, *args, **kwargs):
|
def __init__(self, size: int, *args, **kwargs):
|
||||||
super(Inventories, self).__init__(*args, **kwargs)
|
super(Inventories, self).__init__(*args, **kwargs)
|
||||||
@ -86,10 +93,12 @@ class Inventories(_Objects):
|
|||||||
self._obs = None
|
self._obs = None
|
||||||
self._lazy_eval_transforms = []
|
self._lazy_eval_transforms = []
|
||||||
|
|
||||||
def spawn(self, agents):
|
def spawn(self, agents, *args, **kwargs):
|
||||||
inventories = [self._entity(agent, self.size, )
|
self.add_items([self._entity(agent, self.size, *args, **kwargs) for _, agent in enumerate(agents)])
|
||||||
for _, agent in enumerate(agents)]
|
return [Result(identifier=f'{self.name}_spawn', validity=c.VALID, value=len(self))]
|
||||||
self.add_items(inventories)
|
|
||||||
|
def trigger_spawn(self, state, *args, **kwargs) -> [Result]:
|
||||||
|
return self.spawn(state[c.AGENT], *args, **kwargs)
|
||||||
|
|
||||||
def idx_by_entity(self, entity):
|
def idx_by_entity(self, entity):
|
||||||
try:
|
try:
|
||||||
@ -106,9 +115,6 @@ class Inventories(_Objects):
|
|||||||
def summarize_states(self, **kwargs):
|
def summarize_states(self, **kwargs):
|
||||||
return [val.summarize_states(**kwargs) for key, val in self.items()]
|
return [val.summarize_states(**kwargs) for key, val in self.items()]
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def trigger_inventory_spawn(state):
|
|
||||||
state[i.INVENTORY].spawn(state[c.AGENT])
|
|
||||||
|
|
||||||
|
|
||||||
class DropOffLocations(Collection):
|
class DropOffLocations(Collection):
|
||||||
@ -135,7 +141,7 @@ class DropOffLocations(Collection):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def trigger_drop_off_location_spawn(state, n_locations):
|
def trigger_drop_off_location_spawn(state, n_locations):
|
||||||
empty_positions = state.entities.empty_positions()[:n_locations]
|
empty_positions = state.entities.empty_positions[:n_locations]
|
||||||
do_entites = state[i.DROP_OFF]
|
do_entites = state[i.DROP_OFF]
|
||||||
drop_offs = [DropOffLocation(pos) for pos in empty_positions]
|
drop_offs = [DropOffLocation(pos) for pos in empty_positions]
|
||||||
do_entites.add_items(drop_offs)
|
do_entites.add_items(drop_offs)
|
||||||
|
@ -6,52 +6,28 @@ from marl_factory_grid.utils.results import TickResult
|
|||||||
from marl_factory_grid.modules.items import constants as i
|
from marl_factory_grid.modules.items import constants as i
|
||||||
|
|
||||||
|
|
||||||
class ItemRules(Rule):
|
class RespawnItems(Rule):
|
||||||
|
|
||||||
def __init__(self, n_items: int = 5, spawn_frequency: int = 15,
|
def __init__(self, n_items: int = 5, respawn_freq: int = 15, n_locations: int = 5):
|
||||||
n_locations: int = 5, max_dropoff_storage_size: int = 0):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.spawn_frequency = spawn_frequency
|
self.spawn_frequency = respawn_freq
|
||||||
self._next_item_spawn = spawn_frequency
|
self._next_item_spawn = respawn_freq
|
||||||
self.n_items = n_items
|
self.n_items = n_items
|
||||||
self.max_dropoff_storage_size = max_dropoff_storage_size
|
|
||||||
self.n_locations = n_locations
|
self.n_locations = n_locations
|
||||||
|
|
||||||
def on_init(self, state, lvl_map):
|
|
||||||
state[i.DROP_OFF].trigger_drop_off_location_spawn(state, self.n_locations)
|
|
||||||
self._next_item_spawn = self.spawn_frequency
|
|
||||||
state[i.INVENTORY].trigger_inventory_spawn(state)
|
|
||||||
state[i.ITEM].trigger_item_spawn(state, self.n_items, self.spawn_frequency)
|
|
||||||
|
|
||||||
def tick_step(self, state):
|
def tick_step(self, state):
|
||||||
for item in list(state[i.ITEM].values()):
|
|
||||||
if item.auto_despawn >= 1:
|
|
||||||
item.set_auto_despawn(item.auto_despawn - 1)
|
|
||||||
elif not item.auto_despawn:
|
|
||||||
state[i.ITEM].delete_env_object(item)
|
|
||||||
else:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if not self._next_item_spawn:
|
if not self._next_item_spawn:
|
||||||
state[i.ITEM].trigger_item_spawn(state, self.n_items, self.spawn_frequency)
|
state[i.ITEM].trigger_spawn(state, self.n_items, self.spawn_frequency)
|
||||||
else:
|
else:
|
||||||
self._next_item_spawn = max(0, self._next_item_spawn - 1)
|
self._next_item_spawn = max(0, self._next_item_spawn - 1)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def tick_post_step(self, state) -> List[TickResult]:
|
def tick_post_step(self, state) -> List[TickResult]:
|
||||||
for item in list(state[i.ITEM].values()):
|
|
||||||
if item.auto_despawn >= 1:
|
|
||||||
item.set_auto_despawn(item.auto_despawn-1)
|
|
||||||
elif not item.auto_despawn:
|
|
||||||
state[i.ITEM].delete_env_object(item)
|
|
||||||
else:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if not self._next_item_spawn:
|
if not self._next_item_spawn:
|
||||||
if spawned_items := state[i.ITEM].trigger_item_spawn(state, self.n_items, self.spawn_frequency):
|
if spawned_items := state[i.ITEM].trigger_spawn(state, self.n_items, self.spawn_frequency):
|
||||||
return [TickResult(self.name, validity=c.VALID, value=spawned_items, entity=None)]
|
return [TickResult(self.name, validity=c.VALID, value=spawned_items.value)]
|
||||||
else:
|
else:
|
||||||
return [TickResult(self.name, validity=c.NOT_VALID, value=0, entity=None)]
|
return [TickResult(self.name, validity=c.NOT_VALID, value=0)]
|
||||||
else:
|
else:
|
||||||
self._next_item_spawn = max(0, self._next_item_spawn-1)
|
self._next_item_spawn = max(0, self._next_item_spawn-1)
|
||||||
return []
|
return []
|
||||||
|
@ -1,3 +1,2 @@
|
|||||||
from .entitites import Machine
|
from .entitites import Machine
|
||||||
from .groups import Machines
|
from .groups import Machines
|
||||||
from .rules import MachineRule
|
|
||||||
|
@ -5,6 +5,7 @@ from marl_factory_grid.utils.results import ActionResult
|
|||||||
|
|
||||||
from marl_factory_grid.modules.machines import constants as m, rewards as r
|
from marl_factory_grid.modules.machines import constants as m, rewards as r
|
||||||
from marl_factory_grid.environment import constants as c
|
from marl_factory_grid.environment import constants as c
|
||||||
|
from marl_factory_grid.utils import helpers as h
|
||||||
|
|
||||||
|
|
||||||
class MachineAction(Action):
|
class MachineAction(Action):
|
||||||
@ -13,13 +14,10 @@ class MachineAction(Action):
|
|||||||
super().__init__(m.MACHINE_ACTION)
|
super().__init__(m.MACHINE_ACTION)
|
||||||
|
|
||||||
def do(self, entity, state) -> Union[None, ActionResult]:
|
def do(self, entity, state) -> Union[None, ActionResult]:
|
||||||
if machine := state[m.MACHINES].by_pos(entity.pos):
|
if machine := h.get_first(state[m.MACHINES].by_pos(entity.pos)):
|
||||||
if valid := machine.maintain():
|
if valid := machine.maintain():
|
||||||
return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=r.MAINTAIN_VALID)
|
return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=r.MAINTAIN_VALID)
|
||||||
else:
|
else:
|
||||||
return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=r.MAINTAIN_FAIL)
|
return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=r.MAINTAIN_FAIL)
|
||||||
else:
|
else:
|
||||||
return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=r.MAINTAIN_FAIL)
|
return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=r.MAINTAIN_FAIL)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -8,22 +8,6 @@ from . import constants as m
|
|||||||
|
|
||||||
class Machine(Entity):
|
class Machine(Entity):
|
||||||
|
|
||||||
@property
|
|
||||||
def var_can_collide(self):
|
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
|
||||||
def var_can_move(self):
|
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
|
||||||
def var_is_blocking_light(self):
|
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
|
||||||
def var_has_position(self):
|
|
||||||
return True
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def encoding(self):
|
def encoding(self):
|
||||||
return self._encodings[self.status]
|
return self._encodings[self.status]
|
||||||
@ -46,12 +30,12 @@ class Machine(Entity):
|
|||||||
else:
|
else:
|
||||||
return c.NOT_VALID
|
return c.NOT_VALID
|
||||||
|
|
||||||
def tick(self):
|
def tick(self, state):
|
||||||
# if self.status == m.STATE_MAINTAIN and any([c.AGENT in x.name for x in self.tile.guests]):
|
# if self.status == m.STATE_MAINTAIN and any([c.AGENT in x.name for x in self.tile.guests]):
|
||||||
if self.status == m.STATE_MAINTAIN and any([c.AGENT in x.name for x in self.state.entities.pos_dict[self.pos]]):
|
if self.status == m.STATE_MAINTAIN and any([c.AGENT in x.name for x in state.entities.pos_dict[self.pos]]):
|
||||||
return TickResult(identifier=self.name, validity=c.VALID, reward=0, entity=self)
|
return TickResult(identifier=self.name, validity=c.VALID, entity=self)
|
||||||
# elif self.status == m.STATE_MAINTAIN and not any([c.AGENT in x.name for x in self.tile.guests]):
|
# elif self.status == m.STATE_MAINTAIN and not any([c.AGENT in x.name for x in self.tile.guests]):
|
||||||
elif self.status == m.STATE_MAINTAIN and not any([c.AGENT in x.name for x in self.state.entities.pos_dict[self.pos]]):
|
elif self.status == m.STATE_MAINTAIN and not any([c.AGENT in x.name for x in state.entities.pos_dict[self.pos]]):
|
||||||
self.status = m.STATE_WORK
|
self.status = m.STATE_WORK
|
||||||
self.reset_counter()
|
self.reset_counter()
|
||||||
return None
|
return None
|
||||||
|
@ -1,28 +0,0 @@
|
|||||||
from typing import List
|
|
||||||
from marl_factory_grid.environment.rules import Rule
|
|
||||||
from marl_factory_grid.utils.results import TickResult, DoneResult
|
|
||||||
from marl_factory_grid.environment import constants as c
|
|
||||||
from marl_factory_grid.modules.machines import constants as m
|
|
||||||
from marl_factory_grid.modules.machines.entitites import Machine
|
|
||||||
|
|
||||||
|
|
||||||
class MachineRule(Rule):
|
|
||||||
|
|
||||||
def __init__(self, n_machines: int = 2):
|
|
||||||
super(MachineRule, self).__init__()
|
|
||||||
self.n_machines = n_machines
|
|
||||||
|
|
||||||
def on_init(self, state, lvl_map):
|
|
||||||
state[m.MACHINES].spawn(state.entities.empty_positions())
|
|
||||||
|
|
||||||
def tick_pre_step(self, state) -> List[TickResult]:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def tick_step(self, state) -> List[TickResult]:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def tick_post_step(self, state) -> List[TickResult]:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_check_done(self, state) -> List[DoneResult]:
|
|
||||||
pass
|
|
@ -1,48 +1,35 @@
|
|||||||
|
from random import shuffle
|
||||||
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
from ...algorithms.static.utils import points_to_graph
|
from ...algorithms.static.utils import points_to_graph
|
||||||
from ...environment import constants as c
|
from ...environment import constants as c
|
||||||
from ...environment.actions import Action, ALL_BASEACTIONS
|
from ...environment.actions import Action, ALL_BASEACTIONS
|
||||||
from ...environment.entity.entity import Entity
|
from ...environment.entity.entity import Entity
|
||||||
from ..doors import constants as do
|
from ..doors import constants as do
|
||||||
from ..maintenance import constants as mi
|
from ..maintenance import constants as mi
|
||||||
from ...utils.helpers import MOVEMAP
|
from ...utils import helpers as h
|
||||||
from ...utils.utility_classes import RenderEntity
|
from ...utils.utility_classes import RenderEntity, Floor
|
||||||
from ...utils.states import Gamestate
|
from ..doors import DoorUse
|
||||||
|
|
||||||
|
|
||||||
class Maintainer(Entity):
|
class Maintainer(Entity):
|
||||||
|
|
||||||
@property
|
def __init__(self, objective: str, action: Action, *args, **kwargs):
|
||||||
def var_can_collide(self):
|
|
||||||
return True
|
|
||||||
|
|
||||||
@property
|
|
||||||
def var_can_move(self):
|
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
|
||||||
def var_is_blocking_light(self):
|
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
|
||||||
def var_has_position(self):
|
|
||||||
return True
|
|
||||||
|
|
||||||
def __init__(self, state: Gamestate, objective: str, action: Action, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.action = action
|
self.action = action
|
||||||
self.actions = [x() for x in ALL_BASEACTIONS]
|
self.actions = [x() for x in ALL_BASEACTIONS] + [DoorUse()]
|
||||||
self.objective = objective
|
self.objective = objective
|
||||||
self._path = None
|
self._path = None
|
||||||
self._next = []
|
self._next = []
|
||||||
self._last = []
|
self._last = []
|
||||||
self._last_serviced = 'None'
|
self._last_serviced = 'None'
|
||||||
self._floortile_graph = points_to_graph(state.entities.floorlist)
|
self._floortile_graph = None
|
||||||
|
|
||||||
def tick(self, state):
|
def tick(self, state):
|
||||||
if found_objective := state[self.objective].by_pos(self.pos):
|
if found_objective := h.get_first(state[self.objective].by_pos(self.pos)):
|
||||||
if found_objective.name != self._last_serviced:
|
if found_objective.name != self._last_serviced:
|
||||||
self.action.do(self, state)
|
self.action.do(self, state)
|
||||||
self._last_serviced = found_objective.name
|
self._last_serviced = found_objective.name
|
||||||
@ -54,24 +41,27 @@ class Maintainer(Entity):
|
|||||||
return action.do(self, state)
|
return action.do(self, state)
|
||||||
|
|
||||||
def get_move_action(self, state) -> Action:
|
def get_move_action(self, state) -> Action:
|
||||||
|
if not self._floortile_graph:
|
||||||
|
state.print("Generating Floorgraph....")
|
||||||
|
self._floortile_graph = points_to_graph(state.entities.floorlist)
|
||||||
if self._path is None or not self._path:
|
if self._path is None or not self._path:
|
||||||
if not self._next:
|
if not self._next:
|
||||||
self._next = list(state[self.objective].values())
|
self._next = list(state[self.objective].values()) + [Floor(*state.random_free_position)]
|
||||||
|
shuffle(self._next)
|
||||||
self._last = []
|
self._last = []
|
||||||
self._last.append(self._next.pop())
|
self._last.append(self._next.pop())
|
||||||
|
state.print("Calculating shortest path....")
|
||||||
self._path = self.calculate_route(self._last[-1])
|
self._path = self.calculate_route(self._last[-1])
|
||||||
|
|
||||||
if door := self._door_is_close(state):
|
if door := self._closed_door_in_path(state):
|
||||||
if door.is_closed:
|
state.print(f"{self} found {door} that is closed. Attempt to open.")
|
||||||
# Translate the action_object to an integer to have the same output as any other model
|
# Translate the action_object to an integer to have the same output as any other model
|
||||||
action = do.ACTION_DOOR_USE
|
action = do.ACTION_DOOR_USE
|
||||||
else:
|
|
||||||
action = self._predict_move(state)
|
|
||||||
else:
|
else:
|
||||||
action = self._predict_move(state)
|
action = self._predict_move(state)
|
||||||
# Translate the action_object to an integer to have the same output as any other model
|
# Translate the action_object to an integer to have the same output as any other model
|
||||||
try:
|
try:
|
||||||
action_obj = next(x for x in self.actions if x.name == action)
|
action_obj = h.get_first(self.actions, lambda x: x.name == action)
|
||||||
except (StopIteration, UnboundLocalError):
|
except (StopIteration, UnboundLocalError):
|
||||||
print('Will not happen')
|
print('Will not happen')
|
||||||
raise EnvironmentError
|
raise EnvironmentError
|
||||||
@ -81,11 +71,10 @@ class Maintainer(Entity):
|
|||||||
route = nx.shortest_path(self._floortile_graph, self.pos, entity.pos)
|
route = nx.shortest_path(self._floortile_graph, self.pos, entity.pos)
|
||||||
return route[1:]
|
return route[1:]
|
||||||
|
|
||||||
def _door_is_close(self, state):
|
def _closed_door_in_path(self, state):
|
||||||
state.print("Found a door that is close.")
|
if self._path:
|
||||||
try:
|
return h.get_first(state[do.DOORS].by_pos(self._path[0]), lambda x: x.is_closed)
|
||||||
return next(y for x in state.entities.neighboring_positions(self.state.pos) for y in state.entities.pos_dict[x] if do.DOOR in y.name)
|
else:
|
||||||
except StopIteration:
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _predict_move(self, state):
|
def _predict_move(self, state):
|
||||||
@ -96,7 +85,7 @@ class Maintainer(Entity):
|
|||||||
next_pos = self._path.pop(0)
|
next_pos = self._path.pop(0)
|
||||||
diff = np.subtract(next_pos, self.pos)
|
diff = np.subtract(next_pos, self.pos)
|
||||||
# Retrieve action based on the pos dif (like in: What do I have to do to get there?)
|
# Retrieve action based on the pos dif (like in: What do I have to do to get there?)
|
||||||
action = next(action for action, pos_diff in MOVEMAP.items() if np.all(diff == pos_diff))
|
action = next(action for action, pos_diff in h.MOVEMAP.items() if np.all(diff == pos_diff))
|
||||||
return action
|
return action
|
||||||
|
|
||||||
def render(self):
|
def render(self):
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Union, List, Tuple
|
from typing import Union, List, Tuple, Dict
|
||||||
|
|
||||||
from marl_factory_grid.environment.groups.collection import Collection
|
from marl_factory_grid.environment.groups.collection import Collection
|
||||||
from .entities import Maintainer
|
from .entities import Maintainer
|
||||||
@ -10,25 +10,21 @@ from ...utils.states import Gamestate
|
|||||||
class Maintainers(Collection):
|
class Maintainers(Collection):
|
||||||
_entity = Maintainer
|
_entity = Maintainer
|
||||||
|
|
||||||
@property
|
var_can_collide = True
|
||||||
def var_can_collide(self):
|
var_can_move = True
|
||||||
return True
|
var_is_blocking_light = False
|
||||||
|
var_has_position = True
|
||||||
|
|
||||||
@property
|
def __init__(self, size, *args, coords_or_quantity: int = None,
|
||||||
def var_can_move(self):
|
spawnrule: Union[None, Dict[str, dict]] = None,
|
||||||
return True
|
**kwargs):
|
||||||
|
super(Collection, self).__init__(*args, **kwargs)
|
||||||
@property
|
self._coords_or_quantity = coords_or_quantity
|
||||||
def var_is_blocking_light(self):
|
self.size = size
|
||||||
return False
|
self._spawnrule = spawnrule
|
||||||
|
|
||||||
@property
|
|
||||||
def var_has_position(self):
|
|
||||||
return True
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args):
|
def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args):
|
||||||
state = entity_args[0]
|
self.add_items([self._entity(mc.MACHINES, MachineAction(), pos) for pos in coords_or_quantity])
|
||||||
self.add_items([self._entity(state, mc.MACHINES, MachineAction(), pos) for pos in coords_or_quantity])
|
|
||||||
|
@ -4,29 +4,24 @@ from marl_factory_grid.utils.results import TickResult, DoneResult
|
|||||||
from marl_factory_grid.environment import constants as c
|
from marl_factory_grid.environment import constants as c
|
||||||
from . import rewards as r
|
from . import rewards as r
|
||||||
from . import constants as M
|
from . import constants as M
|
||||||
from marl_factory_grid.utils.states import Gamestate
|
|
||||||
|
|
||||||
|
|
||||||
class MaintenanceRule(Rule):
|
class MoveMaintainers(Rule):
|
||||||
|
|
||||||
def __init__(self, n_maintainer: int = 1, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(MaintenanceRule, self).__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.n_maintainer = n_maintainer
|
|
||||||
|
|
||||||
def on_init(self, state: Gamestate, lvl_map):
|
|
||||||
state[M.MAINTAINERS].spawn(state.entities.empty_positions[:self.n_maintainer], state)
|
|
||||||
pass
|
|
||||||
|
|
||||||
def tick_pre_step(self, state) -> List[TickResult]:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def tick_step(self, state) -> List[TickResult]:
|
def tick_step(self, state) -> List[TickResult]:
|
||||||
for maintainer in state[M.MAINTAINERS]:
|
for maintainer in state[M.MAINTAINERS]:
|
||||||
maintainer.tick(state)
|
maintainer.tick(state)
|
||||||
|
# Todo: Return a Result Object.
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def tick_post_step(self, state) -> List[TickResult]:
|
|
||||||
pass
|
class DoneAtMaintainerCollision(Rule):
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
def on_check_done(self, state) -> List[DoneResult]:
|
def on_check_done(self, state) -> List[DoneResult]:
|
||||||
agents = list(state[c.AGENT].values())
|
agents = list(state[c.AGENT].values())
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
from random import choices, choice
|
from random import choices, choice
|
||||||
|
|
||||||
from . import constants as z, Zone
|
from . import constants as z, Zone
|
||||||
|
from .. import Destination
|
||||||
from ..destinations import constants as d
|
from ..destinations import constants as d
|
||||||
from ... import Destination
|
|
||||||
from ...environment.rules import Rule
|
from ...environment.rules import Rule
|
||||||
from ...environment import constants as c
|
from ...environment import constants as c
|
||||||
|
|
||||||
|
@ -0,0 +1,3 @@
|
|||||||
|
from . import helpers as h
|
||||||
|
from . import helpers
|
||||||
|
from .results import Result, DoneResult, ActionResult, TickResult
|
@ -1,28 +1,24 @@
|
|||||||
import ast
|
import ast
|
||||||
from collections import defaultdict
|
|
||||||
from os import PathLike
|
from os import PathLike
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Union, List
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from marl_factory_grid.environment.groups.agents import Agents
|
|
||||||
from marl_factory_grid.environment.entity.agent import Agent
|
|
||||||
from marl_factory_grid.environment.rules import Rule
|
from marl_factory_grid.environment.rules import Rule
|
||||||
from marl_factory_grid.utils.helpers import locate_and_import_class
|
from marl_factory_grid.utils.helpers import locate_and_import_class
|
||||||
|
from marl_factory_grid.environment.constants import DEFAULT_PATH, MODULE_PATH
|
||||||
from marl_factory_grid.environment import constants as c
|
from marl_factory_grid.environment import constants as c
|
||||||
|
|
||||||
DEFAULT_PATH = 'environment'
|
|
||||||
MODULE_PATH = 'modules'
|
|
||||||
|
|
||||||
|
|
||||||
class FactoryConfigParser(object):
|
class FactoryConfigParser(object):
|
||||||
default_entites = []
|
default_entites = []
|
||||||
default_rules = ['MaxStepsReached', 'Collision']
|
default_rules = ['DoneAtMaxStepsReached', 'WatchCollision']
|
||||||
default_actions = [c.MOVE8, c.NOOP]
|
default_actions = [c.MOVE8, c.NOOP]
|
||||||
default_observations = [c.WALLS, c.AGENT]
|
default_observations = [c.WALLS, c.AGENT]
|
||||||
|
|
||||||
def __init__(self, config_path, custom_modules_path: Union[None, PathLike] = None):
|
def __init__(self, config_path, custom_modules_path: Union[PathLike] = None):
|
||||||
self.config_path = Path(config_path)
|
self.config_path = Path(config_path)
|
||||||
self.custom_modules_path = Path(custom_modules_path) if custom_modules_path is not None else custom_modules_path
|
self.custom_modules_path = Path(custom_modules_path) if custom_modules_path is not None else custom_modules_path
|
||||||
self.config = yaml.safe_load(self.config_path.open())
|
self.config = yaml.safe_load(self.config_path.open())
|
||||||
@ -46,6 +42,10 @@ class FactoryConfigParser(object):
|
|||||||
def rules(self):
|
def rules(self):
|
||||||
return self.config['Rules']
|
return self.config['Rules']
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tests(self):
|
||||||
|
return self.config.get('Tests', [])
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def agents(self):
|
def agents(self):
|
||||||
return self.config['Agents']
|
return self.config['Agents']
|
||||||
@ -61,7 +61,6 @@ class FactoryConfigParser(object):
|
|||||||
return self.config[item]
|
return self.config[item]
|
||||||
|
|
||||||
def load_entities(self):
|
def load_entities(self):
|
||||||
# entites = Entities()
|
|
||||||
entity_classes = dict()
|
entity_classes = dict()
|
||||||
entities = []
|
entities = []
|
||||||
if c.DEFAULTS in self.entities:
|
if c.DEFAULTS in self.entities:
|
||||||
@ -69,28 +68,40 @@ class FactoryConfigParser(object):
|
|||||||
entities.extend(x for x in self.entities if x != c.DEFAULTS)
|
entities.extend(x for x in self.entities if x != c.DEFAULTS)
|
||||||
|
|
||||||
for entity in entities:
|
for entity in entities:
|
||||||
|
e1 = e2 = e3 = None
|
||||||
try:
|
try:
|
||||||
folder_path = Path(__file__).parent.parent / DEFAULT_PATH
|
folder_path = Path(__file__).parent.parent / DEFAULT_PATH
|
||||||
entity_class = locate_and_import_class(entity, folder_path)
|
entity_class = locate_and_import_class(entity, folder_path)
|
||||||
except AttributeError as e1:
|
except AttributeError as e:
|
||||||
|
e1 = e
|
||||||
try:
|
try:
|
||||||
folder_path = Path(__file__).parent.parent / MODULE_PATH
|
module_path = Path(__file__).parent.parent / MODULE_PATH
|
||||||
entity_class = locate_and_import_class(entity, folder_path)
|
entity_class = locate_and_import_class(entity, module_path)
|
||||||
except AttributeError as e2:
|
except AttributeError as e:
|
||||||
try:
|
e2 = e
|
||||||
folder_path = self.custom_modules_path
|
if self.custom_modules_path:
|
||||||
entity_class = locate_and_import_class(entity, folder_path)
|
try:
|
||||||
except AttributeError as e3:
|
entity_class = locate_and_import_class(entity, self.custom_modules_path)
|
||||||
ents = [y for x in [e1.argss[1], e2.argss[1], e3.argss[1]] for y in x]
|
except AttributeError as e:
|
||||||
print('### Error ### Error ### Error ### Error ### Error ###')
|
e3 = e
|
||||||
print()
|
pass
|
||||||
print(f'Class "{entity}" was not found in "{folder_path.name}"')
|
if (e1 and e2) or e3:
|
||||||
print('Possible Entitys are:', str(ents))
|
ents = [y for x in [e1, e2, e3] if x is not None for y in x.args[1]]
|
||||||
print()
|
print('##############################################################')
|
||||||
print('Goodbye')
|
print('### Error ### Error ### Error ### Error ### Error ###')
|
||||||
print()
|
print('##############################################################')
|
||||||
exit()
|
print(f'Class "{entity}" was not found in "{module_path.name}"')
|
||||||
# raise AttributeError(e1.argss[0], e2.argss[0], e3.argss[0], 'Possible Entitys are:', str(ents))
|
print(f'Class "{entity}" was not found in "{folder_path.name}"')
|
||||||
|
print('##############################################################')
|
||||||
|
if self.custom_modules_path:
|
||||||
|
print(f'Class "{entity}" was not found in "{self.custom_modules_path}"')
|
||||||
|
print('Possible Entitys are:', str(ents))
|
||||||
|
print('##############################################################')
|
||||||
|
print('Goodbye')
|
||||||
|
print('##############################################################')
|
||||||
|
print('### Error ### Error ### Error ### Error ### Error ###')
|
||||||
|
print('##############################################################')
|
||||||
|
exit(-99999)
|
||||||
|
|
||||||
entity_kwargs = self.entities.get(entity, {})
|
entity_kwargs = self.entities.get(entity, {})
|
||||||
entity_symbol = entity_class.symbol if hasattr(entity_class, 'symbol') else None
|
entity_symbol = entity_class.symbol if hasattr(entity_class, 'symbol') else None
|
||||||
@ -128,31 +139,86 @@ class FactoryConfigParser(object):
|
|||||||
observations.extend(self.default_observations)
|
observations.extend(self.default_observations)
|
||||||
observations.extend(x for x in self.agents[name]['Observations'] if x != c.DEFAULTS)
|
observations.extend(x for x in self.agents[name]['Observations'] if x != c.DEFAULTS)
|
||||||
positions = [ast.literal_eval(x) for x in self.agents[name].get('Positions', [])]
|
positions = [ast.literal_eval(x) for x in self.agents[name].get('Positions', [])]
|
||||||
parsed_agents_conf[name] = dict(actions=parsed_actions, observations=observations, positions=positions)
|
other_kwargs = {k: v for k, v in self.agents[name].items() if k not in
|
||||||
|
['Actions', 'Observations', 'Positions']}
|
||||||
|
parsed_agents_conf[name] = dict(
|
||||||
|
actions=parsed_actions, observations=observations, positions=positions, other=other_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
return parsed_agents_conf
|
return parsed_agents_conf
|
||||||
|
|
||||||
def load_rules(self):
|
def load_env_rules(self) -> List[Rule]:
|
||||||
# entites = Entities()
|
rules = self.rules.copy()
|
||||||
rules_classes = dict()
|
|
||||||
rules = []
|
|
||||||
if c.DEFAULTS in self.rules:
|
if c.DEFAULTS in self.rules:
|
||||||
for rule in self.default_rules:
|
for rule in self.default_rules:
|
||||||
if rule not in rules:
|
if rule not in rules:
|
||||||
rules.append(rule)
|
rules.append({rule: {}})
|
||||||
rules.extend(x for x in self.rules if x != c.DEFAULTS)
|
|
||||||
|
|
||||||
for rule in rules:
|
return self._load_smth(rules, Rule)
|
||||||
|
|
||||||
|
def load_env_tests(self) -> List[Rule]:
|
||||||
|
return self._load_smth(self.tests, None) # Test
|
||||||
|
|
||||||
|
def _load_smth(self, config, class_obj):
|
||||||
|
rules = list()
|
||||||
|
rules_names = list()
|
||||||
|
for rule in config:
|
||||||
|
e1 = e2 = e3 = None
|
||||||
try:
|
try:
|
||||||
folder_path = (Path(__file__).parent.parent / DEFAULT_PATH)
|
folder_path = (Path(__file__).parent.parent / DEFAULT_PATH)
|
||||||
rule_class = locate_and_import_class(rule, folder_path)
|
rule_class = locate_and_import_class(rule, folder_path)
|
||||||
except AttributeError:
|
except AttributeError as e:
|
||||||
|
e1 = e
|
||||||
try:
|
try:
|
||||||
folder_path = (Path(__file__).parent.parent / MODULE_PATH)
|
module_path = (Path(__file__).parent.parent / MODULE_PATH)
|
||||||
rule_class = locate_and_import_class(rule, folder_path)
|
rule_class = locate_and_import_class(rule, module_path)
|
||||||
|
except AttributeError as e:
|
||||||
|
e2 = e
|
||||||
|
if self.custom_modules_path:
|
||||||
|
try:
|
||||||
|
rule_class = locate_and_import_class(rule, self.custom_modules_path)
|
||||||
|
except AttributeError as e:
|
||||||
|
e3 = e
|
||||||
|
pass
|
||||||
|
if (e1 and e2) or e3:
|
||||||
|
ents = [y for x in [e1, e2, e3] if x is not None for y in x.args[1]]
|
||||||
|
print('### Error ### Error ### Error ### Error ### Error ###')
|
||||||
|
print('')
|
||||||
|
print(f'Class "{rule}" was not found in "{module_path.name}"')
|
||||||
|
print(f'Class "{rule}" was not found in "{folder_path.name}"')
|
||||||
|
if self.custom_modules_path:
|
||||||
|
print(f'Class "{rule}" was not found in "{self.custom_modules_path}"')
|
||||||
|
print('Possible Entitys are:', str(ents))
|
||||||
|
print('')
|
||||||
|
print('Goodbye')
|
||||||
|
print('')
|
||||||
|
exit(-99999)
|
||||||
|
|
||||||
|
if issubclass(rule_class, class_obj):
|
||||||
|
rule_kwargs = config.get(rule, {})
|
||||||
|
rules.append(rule_class(**(rule_kwargs or {})))
|
||||||
|
return rules
|
||||||
|
|
||||||
|
def load_entity_spawn_rules(self, entities) -> List[Rule]:
|
||||||
|
rules = list()
|
||||||
|
rules_dicts = list()
|
||||||
|
for e in entities:
|
||||||
|
try:
|
||||||
|
if spawn_rule := e.spawn_rule:
|
||||||
|
rules_dicts.append(spawn_rule)
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
for rule_dict in rules_dicts:
|
||||||
|
for rule_name, rule_kwargs in rule_dict.items():
|
||||||
|
try:
|
||||||
|
folder_path = (Path(__file__).parent.parent / DEFAULT_PATH)
|
||||||
|
rule_class = locate_and_import_class(rule_name, folder_path)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
rule_class = locate_and_import_class(rule, self.custom_modules_path)
|
try:
|
||||||
# Fixme This check does not work!
|
folder_path = (Path(__file__).parent.parent / MODULE_PATH)
|
||||||
# assert isinstance(rule_class, Rule), f'{rule_class.__name__} is no valid "Rule".'
|
rule_class = locate_and_import_class(rule_name, folder_path)
|
||||||
rule_kwargs = self.rules.get(rule, {})
|
except AttributeError:
|
||||||
rules_classes.update({rule: {'class': rule_class, 'kwargs': rule_kwargs}})
|
rule_class = locate_and_import_class(rule_name, self.custom_modules_path)
|
||||||
return rules_classes
|
rules.append(rule_class(**rule_kwargs))
|
||||||
|
return rules
|
||||||
|
@ -2,7 +2,7 @@ import importlib
|
|||||||
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import PurePath, Path
|
from pathlib import PurePath, Path
|
||||||
from typing import Union, Dict, List
|
from typing import Union, Dict, List, Iterable, Callable
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from numpy.typing import ArrayLike
|
from numpy.typing import ArrayLike
|
||||||
@ -222,7 +222,7 @@ def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''):
|
|||||||
mod = importlib.import_module('.'.join(module_parts))
|
mod = importlib.import_module('.'.join(module_parts))
|
||||||
all_found_modules.extend([x for x in dir(mod) if (not(x.startswith('__') or len(x) <= 2) and x.istitle())
|
all_found_modules.extend([x for x in dir(mod) if (not(x.startswith('__') or len(x) <= 2) and x.istitle())
|
||||||
and x not in ['Entity', 'NamedTuple', 'List', 'Rule', 'Union',
|
and x not in ['Entity', 'NamedTuple', 'List', 'Rule', 'Union',
|
||||||
'TickResult', 'ActionResult', 'Action', 'Agent', 'BoundEntityMixin',
|
'TickResult', 'ActionResult', 'Action', 'Agent',
|
||||||
'RenderEntity', 'TemplateRule', 'Objects', 'PositionMixin',
|
'RenderEntity', 'TemplateRule', 'Objects', 'PositionMixin',
|
||||||
'IsBoundMixin', 'EnvObject', 'EnvObjects', 'Dict', 'Any'
|
'IsBoundMixin', 'EnvObject', 'EnvObjects', 'Dict', 'Any'
|
||||||
]])
|
]])
|
||||||
@ -240,7 +240,13 @@ def add_bound_name(name_str, bound_e):
|
|||||||
|
|
||||||
def add_pos_name(name_str, bound_e):
|
def add_pos_name(name_str, bound_e):
|
||||||
if bound_e.var_has_position:
|
if bound_e.var_has_position:
|
||||||
return f'{name_str}({bound_e.pos})'
|
return f'{name_str}@{bound_e.pos}'
|
||||||
return name_str
|
return name_str
|
||||||
|
|
||||||
|
|
||||||
|
def get_first(iterable: Iterable, filter_by: Callable[[any], bool] = lambda _: True):
|
||||||
|
return next((x for x in iterable if filter_by(x)), None)
|
||||||
|
|
||||||
|
|
||||||
|
def get_first_index(iterable: Iterable, filter_by: Callable[[any], bool] = lambda _: True):
|
||||||
|
return next((idx for idx, x in enumerate(iterable) if filter_by(x)), None)
|
||||||
|
@ -47,6 +47,7 @@ class LevelParser(object):
|
|||||||
# All other
|
# All other
|
||||||
for es_name in self.e_p_dict:
|
for es_name in self.e_p_dict:
|
||||||
e_class, e_kwargs = self.e_p_dict[es_name]['class'], self.e_p_dict[es_name]['kwargs']
|
e_class, e_kwargs = self.e_p_dict[es_name]['class'], self.e_p_dict[es_name]['kwargs']
|
||||||
|
e_kwargs = e_kwargs if e_kwargs else {}
|
||||||
|
|
||||||
if hasattr(e_class, 'symbol') and e_class.symbol is not None:
|
if hasattr(e_class, 'symbol') and e_class.symbol is not None:
|
||||||
symbols = e_class.symbol
|
symbols = e_class.symbol
|
||||||
|
@ -1,17 +1,17 @@
|
|||||||
import math
|
|
||||||
import re
|
import re
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from itertools import product
|
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from numba import njit
|
|
||||||
|
|
||||||
from marl_factory_grid.environment import constants as c
|
from marl_factory_grid.environment import constants as c
|
||||||
|
from marl_factory_grid.environment.entity.object import _Object
|
||||||
from marl_factory_grid.environment.groups.utils import Combined
|
from marl_factory_grid.environment.groups.utils import Combined
|
||||||
import marl_factory_grid.utils.helpers as h
|
|
||||||
from marl_factory_grid.utils.states import Gamestate
|
|
||||||
from marl_factory_grid.utils.utility_classes import Floor
|
from marl_factory_grid.utils.utility_classes import Floor
|
||||||
|
from marl_factory_grid.utils.ray_caster import RayCaster
|
||||||
|
from marl_factory_grid.utils.states import Gamestate
|
||||||
|
from marl_factory_grid.utils import helpers as h
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class OBSBuilder(object):
|
class OBSBuilder(object):
|
||||||
@ -77,11 +77,13 @@ class OBSBuilder(object):
|
|||||||
|
|
||||||
def place_entity_in_observation(self, obs_array, agent, e):
|
def place_entity_in_observation(self, obs_array, agent, e):
|
||||||
x, y = (e.x - agent.x) + self.pomdp_r, (e.y - agent.y) + self.pomdp_r
|
x, y = (e.x - agent.x) + self.pomdp_r, (e.y - agent.y) + self.pomdp_r
|
||||||
try:
|
if not min([y, x]) < 0:
|
||||||
obs_array[x, y] += e.encoding
|
try:
|
||||||
except IndexError:
|
obs_array[x, y] += e.encoding
|
||||||
# Seemded to be visible but is out of range
|
except IndexError:
|
||||||
pass
|
# Seemded to be visible but is out of range
|
||||||
|
pass
|
||||||
|
pass
|
||||||
|
|
||||||
def build_for_agent(self, agent, state) -> (List[str], np.ndarray):
|
def build_for_agent(self, agent, state) -> (List[str], np.ndarray):
|
||||||
assert self._curr_env_step == state.curr_step, (
|
assert self._curr_env_step == state.curr_step, (
|
||||||
@ -121,18 +123,24 @@ class OBSBuilder(object):
|
|||||||
e = self.all_obs[l_name]
|
e = self.all_obs[l_name]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
try:
|
try:
|
||||||
# Look for bound entity names!
|
# Look for bound entity REPRs!
|
||||||
pattern = re.compile(f'{re.escape(l_name)}(.*){re.escape(agent.name)}')
|
pattern = re.compile(f'{re.escape(l_name)}'
|
||||||
name = next((x for x in self.all_obs if pattern.search(x)), None)
|
f'{re.escape("[")}(.*){re.escape("]")}'
|
||||||
|
f'{re.escape("(")}{re.escape(agent.name)}{re.escape(")")}')
|
||||||
|
name = next((key for key, val in self.all_obs.items()
|
||||||
|
if pattern.search(str(val)) and isinstance(val, _Object)), None)
|
||||||
e = self.all_obs[name]
|
e = self.all_obs[name]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
try:
|
try:
|
||||||
e = next(v for k, v in self.all_obs.items() if l_name in k and agent.name in k)
|
e = next(v for k, v in self.all_obs.items() if l_name in k and agent.name in k)
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
raise KeyError(
|
print(f'# Check for spelling errors!')
|
||||||
f'Check for spelling errors! \n '
|
print(f'# No combination of "{l_name}" and "{agent.name}" could not be found in:')
|
||||||
f'No combination of "{l_name} and {agent.name}" could not be found in:\n '
|
print(f'# {list(dict(self.all_obs).keys())}')
|
||||||
f'{list(dict(self.all_obs).keys())}')
|
print('#')
|
||||||
|
print('# exiting...')
|
||||||
|
print('#')
|
||||||
|
exit(-99999)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
positional = e.var_has_position
|
positional = e.var_has_position
|
||||||
@ -161,15 +169,14 @@ class OBSBuilder(object):
|
|||||||
try:
|
try:
|
||||||
light_map = np.zeros(self.obs_shape)
|
light_map = np.zeros(self.obs_shape)
|
||||||
visible_floor = self.ray_caster[agent.name].visible_entities(self._floortiles, reset_cache=False)
|
visible_floor = self.ray_caster[agent.name].visible_entities(self._floortiles, reset_cache=False)
|
||||||
if self.pomdp_r:
|
|
||||||
for f in set(visible_floor):
|
for f in set(visible_floor):
|
||||||
self.place_entity_in_observation(light_map, agent, f)
|
self.place_entity_in_observation(light_map, agent, f)
|
||||||
else:
|
# else:
|
||||||
for f in set(visible_floor):
|
# for f in set(visible_floor):
|
||||||
light_map[f.x, f.y] += f.encoding
|
# light_map[f.x, f.y] += f.encoding
|
||||||
self.curr_lightmaps[agent.name] = light_map
|
self.curr_lightmaps[agent.name] = light_map
|
||||||
except (KeyError, ValueError):
|
except (KeyError, ValueError):
|
||||||
print()
|
|
||||||
pass
|
pass
|
||||||
return obs, self.obs_layers[agent.name]
|
return obs, self.obs_layers[agent.name]
|
||||||
|
|
||||||
@ -185,7 +192,7 @@ class OBSBuilder(object):
|
|||||||
|
|
||||||
for obs_str in agent.observations:
|
for obs_str in agent.observations:
|
||||||
if isinstance(obs_str, dict):
|
if isinstance(obs_str, dict):
|
||||||
obs_str, vals = next(obs_str.items().__iter__())
|
obs_str, vals = h.get_first(obs_str.items())
|
||||||
else:
|
else:
|
||||||
vals = None
|
vals = None
|
||||||
if obs_str == c.SELF:
|
if obs_str == c.SELF:
|
||||||
@ -214,129 +221,3 @@ class OBSBuilder(object):
|
|||||||
obs_layers.append(obs_str)
|
obs_layers.append(obs_str)
|
||||||
self.obs_layers[agent.name] = obs_layers
|
self.obs_layers[agent.name] = obs_layers
|
||||||
self.curr_lightmaps[agent.name] = np.zeros(self.obs_shape)
|
self.curr_lightmaps[agent.name] = np.zeros(self.obs_shape)
|
||||||
|
|
||||||
|
|
||||||
class RayCaster:
|
|
||||||
def __init__(self, agent, pomdp_r, degs=360):
|
|
||||||
self.agent = agent
|
|
||||||
self.pomdp_r = pomdp_r
|
|
||||||
self.n_rays = (self.pomdp_r + 1) * 8
|
|
||||||
self.degs = degs
|
|
||||||
self.ray_targets = self.build_ray_targets()
|
|
||||||
self.obs_shape_cube = np.array([self.pomdp_r, self.pomdp_r])
|
|
||||||
self._cache_dict = {}
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f'{self.__class__.__name__}({self.agent.name})'
|
|
||||||
|
|
||||||
def build_ray_targets(self):
|
|
||||||
north = np.array([0, -1]) * self.pomdp_r
|
|
||||||
thetas = [np.deg2rad(deg) for deg in np.linspace(-self.degs // 2, self.degs // 2, self.n_rays)[::-1]]
|
|
||||||
rot_M = [
|
|
||||||
[[math.cos(theta), -math.sin(theta)],
|
|
||||||
[math.sin(theta), math.cos(theta)]] for theta in thetas
|
|
||||||
]
|
|
||||||
rot_M = np.stack(rot_M, 0)
|
|
||||||
rot_M = np.unique(np.round(rot_M @ north), axis=0)
|
|
||||||
return rot_M.astype(int)
|
|
||||||
|
|
||||||
def ray_block_cache(self, key, callback):
|
|
||||||
if key not in self._cache_dict:
|
|
||||||
self._cache_dict[key] = callback()
|
|
||||||
return self._cache_dict[key]
|
|
||||||
|
|
||||||
def visible_entities(self, pos_dict, reset_cache=True):
|
|
||||||
visible = list()
|
|
||||||
if reset_cache:
|
|
||||||
self._cache_dict = {}
|
|
||||||
|
|
||||||
for ray in self.get_rays():
|
|
||||||
rx, ry = ray[0]
|
|
||||||
for x, y in ray:
|
|
||||||
cx, cy = x - rx, y - ry
|
|
||||||
|
|
||||||
entities_hit = pos_dict[(x, y)]
|
|
||||||
hits = self.ray_block_cache((x, y),
|
|
||||||
lambda: any(True for e in entities_hit if e.var_is_blocking_light)
|
|
||||||
)
|
|
||||||
|
|
||||||
diag_hits = all([
|
|
||||||
self.ray_block_cache(
|
|
||||||
key,
|
|
||||||
lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light) and bool(
|
|
||||||
pos_dict[key]))
|
|
||||||
for key in ((x, y - cy), (x - cx, y))
|
|
||||||
]) if (cx != 0 and cy != 0) else False
|
|
||||||
|
|
||||||
visible += entities_hit if not diag_hits else []
|
|
||||||
if hits or diag_hits:
|
|
||||||
break
|
|
||||||
rx, ry = x, y
|
|
||||||
return visible
|
|
||||||
|
|
||||||
def get_rays(self):
|
|
||||||
a_pos = self.agent.pos
|
|
||||||
outline = self.ray_targets + a_pos
|
|
||||||
return self.bresenham_loop(a_pos, outline)
|
|
||||||
|
|
||||||
# todo do this once and cache the points!
|
|
||||||
def get_fov_outline(self) -> np.ndarray:
|
|
||||||
return self.ray_targets + self.agent.pos
|
|
||||||
|
|
||||||
def get_square_outline(self):
|
|
||||||
agent = self.agent
|
|
||||||
x_coords = range(agent.x - self.pomdp_r, agent.x + self.pomdp_r + 1)
|
|
||||||
y_coords = range(agent.y - self.pomdp_r, agent.y + self.pomdp_r + 1)
|
|
||||||
outline = list(product(x_coords, [agent.y - self.pomdp_r, agent.y + self.pomdp_r])) \
|
|
||||||
+ list(product([agent.x - self.pomdp_r, agent.x + self.pomdp_r], y_coords))
|
|
||||||
return outline
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@njit
|
|
||||||
def bresenham_loop(a_pos, points):
|
|
||||||
results = []
|
|
||||||
for end in points:
|
|
||||||
x1, y1 = a_pos
|
|
||||||
x2, y2 = end
|
|
||||||
dx = x2 - x1
|
|
||||||
dy = y2 - y1
|
|
||||||
|
|
||||||
# Determine how steep the line is
|
|
||||||
is_steep = abs(dy) > abs(dx)
|
|
||||||
|
|
||||||
# Rotate line
|
|
||||||
if is_steep:
|
|
||||||
x1, y1 = y1, x1
|
|
||||||
x2, y2 = y2, x2
|
|
||||||
|
|
||||||
# Swap start and end points if necessary and store swap state
|
|
||||||
swapped = False
|
|
||||||
if x1 > x2:
|
|
||||||
x1, x2 = x2, x1
|
|
||||||
y1, y2 = y2, y1
|
|
||||||
swapped = True
|
|
||||||
|
|
||||||
# Recalculate differentials
|
|
||||||
dx = x2 - x1
|
|
||||||
dy = y2 - y1
|
|
||||||
|
|
||||||
# Calculate error
|
|
||||||
error = int(dx / 2.0)
|
|
||||||
ystep = 1 if y1 < y2 else -1
|
|
||||||
|
|
||||||
# Iterate over bounding box generating points between start and end
|
|
||||||
y = y1
|
|
||||||
points = []
|
|
||||||
for x in range(int(x1), int(x2) + 1):
|
|
||||||
coord = [y, x] if is_steep else [x, y]
|
|
||||||
points.append(coord)
|
|
||||||
error -= abs(dy)
|
|
||||||
if error < 0:
|
|
||||||
y += ystep
|
|
||||||
error += dx
|
|
||||||
|
|
||||||
# Reverse the list if the coordinates were swapped
|
|
||||||
if swapped:
|
|
||||||
points.reverse()
|
|
||||||
results.append(points)
|
|
||||||
return results
|
|
||||||
|
@ -39,8 +39,9 @@ class RayCaster:
|
|||||||
if reset_cache:
|
if reset_cache:
|
||||||
self._cache_dict = dict()
|
self._cache_dict = dict()
|
||||||
|
|
||||||
for ray in self.get_rays():
|
for ray in self.get_rays(): # Do not check, just trust.
|
||||||
rx, ry = ray[0]
|
rx, ry = ray[0]
|
||||||
|
# self.ray_block_cache(ray[0], lambda: False) We do not do that, because of doors etc...
|
||||||
for x, y in ray:
|
for x, y in ray:
|
||||||
cx, cy = x - rx, y - ry
|
cx, cy = x - rx, y - ry
|
||||||
|
|
||||||
@ -52,7 +53,8 @@ class RayCaster:
|
|||||||
diag_hits = all([
|
diag_hits = all([
|
||||||
self.ray_block_cache(
|
self.ray_block_cache(
|
||||||
key,
|
key,
|
||||||
lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light))
|
lambda: any(True for e in pos_dict[key] if e.var_is_blocking_light))
|
||||||
|
# lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light))
|
||||||
for key in ((x, y-cy), (x-cx, y))
|
for key in ((x, y-cy), (x-cx, y))
|
||||||
]) if (cx != 0 and cy != 0) else False
|
]) if (cx != 0 and cy != 0) else False
|
||||||
|
|
||||||
|
@ -31,7 +31,7 @@ class Renderer:
|
|||||||
|
|
||||||
def __init__(self, lvl_shape: Tuple[int, int] = (16, 16),
|
def __init__(self, lvl_shape: Tuple[int, int] = (16, 16),
|
||||||
lvl_padded_shape: Union[Tuple[int, int], None] = None,
|
lvl_padded_shape: Union[Tuple[int, int], None] = None,
|
||||||
cell_size: int = 40, fps: int = 7,
|
cell_size: int = 40, fps: int = 7, factor: float = 0.9,
|
||||||
grid_lines: bool = True, view_radius: int = 2):
|
grid_lines: bool = True, view_radius: int = 2):
|
||||||
# TODO: Customn_assets paths
|
# TODO: Customn_assets paths
|
||||||
self.grid_h, self.grid_w = lvl_shape
|
self.grid_h, self.grid_w = lvl_shape
|
||||||
@ -45,7 +45,7 @@ class Renderer:
|
|||||||
self.screen = pygame.display.set_mode(self.screen_size)
|
self.screen = pygame.display.set_mode(self.screen_size)
|
||||||
self.clock = pygame.time.Clock()
|
self.clock = pygame.time.Clock()
|
||||||
assets = list(self.ASSETS.rglob('*.png'))
|
assets = list(self.ASSETS.rglob('*.png'))
|
||||||
self.assets = {path.stem: self.load_asset(str(path), 1) for path in assets}
|
self.assets = {path.stem: self.load_asset(str(path), factor) for path in assets}
|
||||||
self.fill_bg()
|
self.fill_bg()
|
||||||
|
|
||||||
now = time.time()
|
now = time.time()
|
||||||
@ -110,22 +110,22 @@ class Renderer:
|
|||||||
pygame.quit()
|
pygame.quit()
|
||||||
sys.exit()
|
sys.exit()
|
||||||
self.fill_bg()
|
self.fill_bg()
|
||||||
blits = deque()
|
# First all others
|
||||||
for entity in [x for x in entities]:
|
blits = deque(self.blit_params(x) for x in entities if not x.name.lower() == AGENT)
|
||||||
bp = self.blit_params(entity)
|
# Then Agents, so that agents are rendered on top.
|
||||||
blits.append(bp)
|
for agent in (x for x in entities if x.name.lower() == AGENT):
|
||||||
if entity.name.lower() == AGENT:
|
agent_blit = self.blit_params(agent)
|
||||||
if self.view_radius > 0:
|
if self.view_radius > 0:
|
||||||
vis_rects = self.visibility_rects(bp, entity.aux)
|
vis_rects = self.visibility_rects(agent_blit, agent.aux)
|
||||||
blits.extendleft(vis_rects)
|
blits.extendleft(vis_rects)
|
||||||
if entity.state != BLANK:
|
if agent.state != BLANK:
|
||||||
agent_state_blits = self.blit_params(
|
state_blit = self.blit_params(
|
||||||
RenderEntity(entity.state, (entity.pos[0] + 0.12, entity.pos[1]), 0.48, SCALE)
|
RenderEntity(agent.state, (agent.pos[0] + 0.12, agent.pos[1]), 0.48, SCALE)
|
||||||
)
|
)
|
||||||
textsurface = self.font.render(str(entity.id), False, (0, 0, 0))
|
textsurface = self.font.render(str(agent.id), False, (0, 0, 0))
|
||||||
text_blit = dict(source=textsurface, dest=(bp['dest'].center[0]-.07*self.cell_size,
|
text_blit = dict(source=textsurface, dest=(agent_blit['dest'].center[0]-.07*self.cell_size,
|
||||||
bp['dest'].center[1]))
|
agent_blit['dest'].center[1]))
|
||||||
blits += [agent_state_blits, text_blit]
|
blits += [agent_blit, state_blit, text_blit]
|
||||||
|
|
||||||
for blit in blits:
|
for blit in blits:
|
||||||
self.screen.blit(**blit)
|
self.screen.blit(**blit)
|
||||||
|
@ -28,7 +28,10 @@ class Result:
|
|||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
valid = "not " if not self.validity else ""
|
valid = "not " if not self.validity else ""
|
||||||
return f'{self.__class__.__name__}({self.identifier.capitalize()} {valid}valid: {self.reward})'
|
reward = f" | Reward: {self.reward}" if self.reward is not None else ""
|
||||||
|
value = f" | Value: {self.value}" if self.value is not None else ""
|
||||||
|
entity = f" | by: {self.entity.name}" if self.entity is not None else ""
|
||||||
|
return f'{self.__class__.__name__}({self.identifier.capitalize()} {valid}valid{reward}{value})'
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from itertools import islice
|
||||||
from typing import List, Dict, Tuple
|
from typing import List, Dict, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -59,14 +60,15 @@ class Gamestate(object):
|
|||||||
def moving_entites(self):
|
def moving_entites(self):
|
||||||
return [y for x in self.entities for y in x if x.var_can_move]
|
return [y for x in self.entities for y in x if x.var_can_move]
|
||||||
|
|
||||||
def __init__(self, entities, agents_conf, rules: Dict[str, dict], env_seed=69, verbose=False):
|
def __init__(self, entities, agents_conf, rules: List[Rule], lvl_shape, env_seed=69, verbose=False):
|
||||||
|
self.lvl_shape = lvl_shape
|
||||||
self.entities = entities
|
self.entities = entities
|
||||||
self.curr_step = 0
|
self.curr_step = 0
|
||||||
self.curr_actions = None
|
self.curr_actions = None
|
||||||
self.agents_conf = agents_conf
|
self.agents_conf = agents_conf
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
self.rng = np.random.default_rng(env_seed)
|
self.rng = np.random.default_rng(env_seed)
|
||||||
self.rules = StepRules(*(v['class'](**v['kwargs']) for v in rules.values()))
|
self.rules = StepRules(*rules)
|
||||||
|
|
||||||
def __getitem__(self, item):
|
def __getitem__(self, item):
|
||||||
return self.entities[item]
|
return self.entities[item]
|
||||||
@ -80,6 +82,13 @@ class Gamestate(object):
|
|||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f'{self.__class__.__name__}({len(self.entities)} Entitites @ Step {self.curr_step})'
|
return f'{self.__class__.__name__}({len(self.entities)} Entitites @ Step {self.curr_step})'
|
||||||
|
|
||||||
|
@property
|
||||||
|
def random_free_position(self):
|
||||||
|
return self.get_n_random_free_positions(1)[0]
|
||||||
|
|
||||||
|
def get_n_random_free_positions(self, n):
|
||||||
|
return list(islice(self.entities.free_positions_generator, n))
|
||||||
|
|
||||||
def tick(self, actions) -> List[Result]:
|
def tick(self, actions) -> List[Result]:
|
||||||
results = list()
|
results = list()
|
||||||
self.curr_step += 1
|
self.curr_step += 1
|
||||||
@ -115,8 +124,7 @@ class Gamestate(object):
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
def get_all_pos_with_collisions(self) -> List[Tuple[(int, int)]]:
|
def get_all_pos_with_collisions(self) -> List[Tuple[(int, int)]]:
|
||||||
positions = [pos for pos, entity_list_for_position in self.entities.pos_dict.items()
|
positions = [pos for pos, entities in self.entities.pos_dict.items() if len(entities) >= 2 and (len([e for e in entities if e.var_can_collide]) >= 2)]
|
||||||
if any([e.var_can_collide for e in entity_list_for_position])]
|
|
||||||
return positions
|
return positions
|
||||||
|
|
||||||
def check_move_validity(self, moving_entity, position):
|
def check_move_validity(self, moving_entity, position):
|
||||||
|
@ -135,4 +135,3 @@ if __name__ == '__main__':
|
|||||||
ce.get_observations()
|
ce.get_observations()
|
||||||
ce.get_assets()
|
ce.get_assets()
|
||||||
all_conf = ce.get_all()
|
all_conf = ce.get_all()
|
||||||
print()
|
|
||||||
|
@ -52,3 +52,6 @@ class Floor:
|
|||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
return hash(self.name)
|
return hash(self.name)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"Floor{self.pos}"
|
||||||
|
@ -6,6 +6,7 @@ import yaml
|
|||||||
from marl_factory_grid.environment.factory import Factory
|
from marl_factory_grid.environment.factory import Factory
|
||||||
from marl_factory_grid.utils.logging.envmonitor import EnvMonitor
|
from marl_factory_grid.utils.logging.envmonitor import EnvMonitor
|
||||||
from marl_factory_grid.utils.logging.recorder import EnvRecorder
|
from marl_factory_grid.utils.logging.recorder import EnvRecorder
|
||||||
|
from marl_factory_grid.utils import helpers as h
|
||||||
|
|
||||||
from marl_factory_grid.modules.doors import constants as d
|
from marl_factory_grid.modules.doors import constants as d
|
||||||
|
|
||||||
@ -61,7 +62,7 @@ if __name__ == '__main__':
|
|||||||
if render:
|
if render:
|
||||||
env.render()
|
env.render()
|
||||||
try:
|
try:
|
||||||
door = next(x for x in env.unwrapped.unwrapped[d.DOORS] if x.is_open)
|
door = h.get_first([x for x in env.unwrapped.unwrapped[d.DOORS] if x.is_open])
|
||||||
print('openDoor found')
|
print('openDoor found')
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
pass
|
pass
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
from algorithms.utils import Checkpointer
|
from algorithms.utils import Checkpointer
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from algorithms.utils import load_yaml_file, add_env_props, instantiate_class, load_class
|
from algorithms.utils import load_yaml_file, add_env_props, instantiate_class, load_class
|
||||||
#from algorithms.marl import LoopSNAC, LoopIAC, LoopSEAC
|
|
||||||
|
|
||||||
|
# from algorithms.marl import LoopSNAC, LoopIAC, LoopSEAC
|
||||||
|
|
||||||
|
|
||||||
for i in range(0, 5):
|
for i in range(0, 5):
|
||||||
|
43
transform_wg_to_json_no_priv.py
Normal file
43
transform_wg_to_json_no_priv.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
import configparser
|
||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
|
||||||
|
conf_path = Path('wg0')
|
||||||
|
wg0_conf = configparser.ConfigParser()
|
||||||
|
wg0_conf.read(conf_path/'wg0.conf')
|
||||||
|
interface = wg0_conf['Interface']
|
||||||
|
# Iterate all pears
|
||||||
|
for client_name in wg0_conf.sections():
|
||||||
|
if client_name == 'Interface':
|
||||||
|
continue
|
||||||
|
# Delete any old conf.json for the current peer
|
||||||
|
(conf_path / f'{client_name}.json').unlink(missing_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
peer = wg0_conf[client_name]
|
||||||
|
|
||||||
|
date_time = datetime.now().strftime('%Y-%m-%dT%H:%M:%S.%f000Z')
|
||||||
|
|
||||||
|
jdict = dict(
|
||||||
|
id=client_name,
|
||||||
|
private_key=peer['PublicKey'],
|
||||||
|
public_key=peer['PublicKey'],
|
||||||
|
# preshared_key=wg0_conf[client_name_wg0]['PresharedKey'],
|
||||||
|
name=client_name,
|
||||||
|
email=f"sysadmin@mobile.ifi.lmu.de",
|
||||||
|
allocated_ips=[interface['Address'].replace('/24', '')],
|
||||||
|
allowed_ips=['10.4.0.0/24', '10.153.199.0/24'],
|
||||||
|
extra_allowed_ips=[],
|
||||||
|
use_server_dns=True,
|
||||||
|
enabled=True,
|
||||||
|
created_at=date_time,
|
||||||
|
updated_at=date_time
|
||||||
|
)
|
||||||
|
|
||||||
|
with (conf_path / f'{client_name}.json').open('w+') as f:
|
||||||
|
json.dump(jdict, f, indent='\t', separators=(',', ': '))
|
||||||
|
print(client_name, ' written...')
|
Loading…
x
Reference in New Issue
Block a user