Merge branch 'main' into unit_testing

# Conflicts:
#	marl_factory_grid/environment/factory.py
#	marl_factory_grid/utils/config_parser.py
#	marl_factory_grid/utils/states.py
This commit is contained in:
Chanumask 2023-11-10 10:54:00 +01:00
commit 209b317105
97 changed files with 1088 additions and 1239 deletions

5
.idea/.gitignore generated vendored Normal file
View File

@ -0,0 +1,5 @@
# Default ignored files
/shelf/
/workspace.xml
# Editor-based HTTP Client requests
/httpRequests/

View File

@ -56,7 +56,7 @@ Just define what your environment needs in a *yaml*-configfile like:
- Items
Rules:
Defaults: {}
Collision:
WatchCollisions:
done_at_collisions: !!bool True
ItemRespawn:
spawn_freq: 5
@ -94,7 +94,7 @@ All [Entites](marl_factory_grid/environment/entity/global_entities.py) are avail
#### Rules
[Rules](marl_factory_grid/environment/entity/object.py) define how the environment behaves on micro-scale.
[Rules](marl_factory_grid/environment/entity/object.py) define how the environment behaves on microscale.
Each of the hookes (`on_init`, `pre_step`, `on_step`, '`post_step`', `on_done`)
provide env-access to implement customn logic, calculate rewards, or gather information.
@ -107,6 +107,7 @@ Make sure to bring your own assets for each Entity living in the Gridworld as th
PNG-files (transparent background) of square aspect-ratio should do the job, in general.
<img src="/marl_factory_grid/environment/assets/wall.png" width="5%">
<!--suppress HtmlUnknownAttribute -->
<html &nbsp&nbsp&nbsp&nbsp html>
<img src="/marl_factory_grid/environment/assets/agent/agent.png" width="5%">

View File

@ -1,6 +1 @@
from .environment import *
from .modules import *
from .utils import *
from .quickstart import init

View File

@ -1 +1,4 @@
import os, sys; sys.path.append(os.path.dirname(os.path.realpath(__file__)))
import os
import sys
sys.path.append(os.path.dirname(os.path.realpath(__file__)))

View File

@ -1 +1 @@
from marl_factory_grid.algorithms.marl.memory import MARLActorCriticMemory
from marl_factory_grid.algorithms.marl.memory import MARLActorCriticMemory

View File

@ -28,6 +28,7 @@ class Names:
BATCH_SIZE = 'bnatch_size'
N_ACTIONS = 'n_actions'
nms = Names
ListOrTensor = Union[List, torch.Tensor]
@ -112,10 +113,9 @@ class BaseActorCritic:
next_obs, reward, done, info = env.step(action)
done = [done] * self.n_agents if isinstance(done, bool) else done
last_hiddens = dict(hidden_actor =out[nms.HIDDEN_ACTOR],
last_hiddens = dict(hidden_actor=out[nms.HIDDEN_ACTOR],
hidden_critic=out[nms.HIDDEN_CRITIC])
tm.add(observation=obs, action=action, reward=reward, done=done,
logits=out.get(nms.LOGITS, None), values=out.get(nms.CRITIC, None),
**last_hiddens)
@ -142,7 +142,9 @@ class BaseActorCritic:
print(f'reward at episode: {episode} = {rew_log}')
episode += 1
df_results.append([episode, rew_log, *reward])
df_results = pd.DataFrame(df_results, columns=['steps', 'reward', *[f'agent#{i}' for i in range(self.n_agents)]])
df_results = pd.DataFrame(df_results,
columns=['steps', 'reward', *[f'agent#{i}' for i in range(self.n_agents)]]
)
if checkpointer is not None:
df_results.to_csv(checkpointer.path / 'results.csv', index=False)
return df_results
@ -157,24 +159,27 @@ class BaseActorCritic:
last_action, reward = [-1] * self.n_agents, [0.] * self.n_agents
done, rew_log, eps_rew = [False] * self.n_agents, 0, torch.zeros(self.n_agents)
while not all(done):
if render: env.render()
if render:
env.render()
out = self.forward(obs, last_action, **last_hiddens)
action = self.get_actions(out)
next_obs, reward, done, info = env.step(action)
if isinstance(done, bool): done = [done] * obs.shape[0]
if isinstance(done, bool):
done = [done] * obs.shape[0]
obs = next_obs
last_action = action
last_hiddens = dict(hidden_actor=out.get(nms.HIDDEN_ACTOR, None),
hidden_critic=out.get(nms.HIDDEN_CRITIC, None)
)
eps_rew += torch.tensor(reward)
results.append(eps_rew.tolist() + [sum(eps_rew).item()] + [episode])
results.append(eps_rew.tolist() + [np.sum(eps_rew).item()] + [episode])
episode += 1
agent_columns = [f'agent#{i}' for i in range(self.cfg['environment']['n_agents'])]
results = pd.DataFrame(results, columns=agent_columns + ['sum', 'episode'])
results = pd.melt(results, id_vars=['episode'], value_vars=agent_columns + ['sum'], value_name='reward', var_name='agent')
results = pd.melt(results, id_vars=['episode'], value_vars=agent_columns + ['sum'],
value_name='reward', var_name='agent')
return results
@staticmethod

View File

@ -36,7 +36,7 @@ class LoopMAPPO(LoopSNAC):
rewards_ = torch.stack(rewards_, dim=1)
return rewards_
def mappo(self, batch, network, gamma, entropy_coef, vf_coef, clip_range, **kwargs):
def mappo(self, batch, network, gamma, entropy_coef, vf_coef, clip_range, **__):
out = network(batch[nms.OBSERVATION], batch[nms.ACTION], batch[nms.HIDDEN_ACTOR], batch[nms.HIDDEN_CRITIC])
logits = out[nms.LOGITS][:, :-1] # last one only needed for v_{t+1}
@ -45,7 +45,7 @@ class LoopMAPPO(LoopSNAC):
# monte carlo returns
mc_returns = self.monte_carlo_returns(batch[nms.REWARD], batch[nms.DONE], gamma)
mc_returns = (mc_returns - mc_returns.mean()) / (mc_returns.std() + 1e-8) #todo: norm across agent ok?
mc_returns = (mc_returns - mc_returns.mean()) / (mc_returns.std() + 1e-8) # todo: norm across agent ok?
advantages = mc_returns - out[nms.CRITIC][:, :-1]
# policy loss

View File

@ -1,8 +1,7 @@
import numpy as np
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from torch.nn.utils import spectral_norm
class RecurrentAC(nn.Module):
@ -88,8 +87,8 @@ class NormalizedLinear(nn.Linear):
self.trainable_magnitude = trainable_magnitude
self.scale = nn.Parameter(torch.tensor([1.]), requires_grad=trainable_magnitude)
def forward(self, input):
normalized_input = F.normalize(input, dim=-1, p=2, eps=1e-5)
def forward(self, in_array):
normalized_input = F.normalize(in_array, dim=-1, p=2, eps=1e-5)
normalized_weight = F.normalize(self.weight, dim=-1, p=2, eps=1e-5)
return F.linear(normalized_input, normalized_weight) * self.d_sqrt * self.scale

View File

@ -16,7 +16,7 @@ class LoopSEAC(LoopIAC):
with torch.inference_mode(True):
true_action_logp = torch.stack([
torch.log_softmax(out[nms.LOGITS][ag_i, :-1], -1)
.gather(index=actions[ag_i, 1:, None], dim=-1)
.gather(index=actions[ag_i, 1:, None], dim=-1)
for ag_i, out in enumerate(outputs)
], 0).squeeze()
@ -38,7 +38,6 @@ class LoopSEAC(LoopIAC):
a2c_loss = (-iw*log_ap * advantages.detach()).mean(-1)
value_loss = (iw*advantages.pow(2)).mean(-1) # n_agent
# weighted loss
@ -53,4 +52,4 @@ class LoopSEAC(LoopIAC):
self.optimizer[ag_i].zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.net[ag_i].parameters(), 0.5)
self.optimizer[ag_i].step()
self.optimizer[ag_i].step()

View File

@ -30,4 +30,4 @@ class LoopSNAC(BaseActorCritic):
self._as_torch(actions).unsqueeze(1),
hidden_actor, hidden_critic
)
return out
return out

View File

@ -56,8 +56,8 @@ class TSPBaseAgent(ABC):
def _door_is_close(self, state):
try:
# return next(y for x in self.state.tile.neighboring_floor for y in x.guests if do.DOOR in y.name)
return next(y for x in state.entities.neighboring_positions(self.state.pos) for y in state.entities.pos_dict[x] if do.DOOR in y.name)
return next(y for x in state.entities.neighboring_positions(self.state.pos)
for y in state.entities.pos_dict[x] if do.DOOR in y.name)
except StopIteration:
return None

View File

@ -14,8 +14,8 @@ class TSPTargetAgent(TSPBaseAgent):
def _handle_doors(self, state):
try:
# return next(y for x in self.state.tile.neighboring_floor for y in x.guests if do.DOOR in y.name)
return next(y for x in state.entities.neighboring_positions(self.state.pos) for y in state.entities.pos_dict[x] if do.DOOR in y.name)
return next(y for x in state.entities.neighboring_positions(self.state.pos)
for y in state.entities.pos_dict[x] if do.DOOR in y.name)
except StopIteration:
return None
@ -30,4 +30,3 @@ class TSPTargetAgent(TSPBaseAgent):
except (StopIteration, UnboundLocalError):
print('Will not happen')
return action_obj

View File

@ -26,12 +26,16 @@ def points_to_graph(coordiniates, allow_euclidean_connections=True, allow_manhat
assert allow_euclidean_connections or allow_manhattan_connections
possible_connections = itertools.combinations(coordiniates, 2)
graph = nx.Graph()
for a, b in possible_connections:
diff = np.linalg.norm(np.asarray(a)-np.asarray(b))
if allow_manhattan_connections and allow_euclidean_connections and diff <= np.sqrt(2):
graph.add_edge(a, b)
elif not allow_manhattan_connections and allow_euclidean_connections and diff == np.sqrt(2):
graph.add_edge(a, b)
elif allow_manhattan_connections and not allow_euclidean_connections and diff == 1:
graph.add_edge(a, b)
if allow_manhattan_connections and allow_euclidean_connections:
graph.add_edges_from(
(a, b) for a, b in possible_connections if np.linalg.norm(np.asarray(a) - np.asarray(b)) <= np.sqrt(2)
)
elif not allow_manhattan_connections and allow_euclidean_connections:
graph.add_edges_from(
(a, b) for a, b in possible_connections if np.linalg.norm(np.asarray(a) - np.asarray(b)) == np.sqrt(2)
)
elif allow_manhattan_connections and not allow_euclidean_connections:
graph.add_edges_from(
(a, b) for a, b in possible_connections if np.linalg.norm(np.asarray(a) - np.asarray(b)) == 1
)
return graph

View File

@ -1,8 +1,9 @@
import torch
import numpy as np
import yaml
from pathlib import Path
import numpy as np
import torch
import yaml
def load_class(classname):
from importlib import import_module
@ -42,7 +43,6 @@ def get_class(arguments):
def get_arguments(arguments):
from importlib import import_module
d = dict(arguments)
if "classname" in d:
del d["classname"]
@ -82,4 +82,4 @@ class Checkpointer(object):
for name, model in to_save:
self.save_experiment(name, model)
self.__current_checkpoint += 1
self.__current_step += 1
self.__current_step += 1

View File

@ -22,26 +22,41 @@ Agents:
- Inventory
- DropOffLocations
- Maintainers
# This is special for agents, as each one is differten and can act as an adversary e.g.
Positions:
- (16, 7)
- (16, 6)
- (16, 3)
- (16, 4)
- (16, 5)
Entities:
Batteries:
initial_charge: 0.8
per_action_costs: 0.02
ChargePods: {}
Destinations: {}
ChargePods:
coords_or_quantity: 2
Destinations:
coords_or_quantity: 1
spawn_mode: GROUPED
DirtPiles:
coords_or_quantity: 10
initial_amount: 2
clean_amount: 1
dirt_spawn_r_var: 0.1
initial_amount: 2
initial_dirt_ratio: 0.05
max_global_amount: 20
max_local_amount: 5
Doors: {}
DropOffLocations: {}
Doors:
DropOffLocations:
coords_or_quantity: 1
max_dropoff_storage_size: 0
GlobalPositions: {}
Inventories: {}
Items: {}
Machines: {}
Maintainers: {}
Items:
coords_or_quantity: 5
Machines:
coords_or_quantity: 2
Maintainers:
coords_or_quantity: 1
Zones: {}
General:
@ -49,32 +64,31 @@ General:
individual_rewards: true
level_name: large
pomdp_r: 3
verbose: false
verbose: True
tests: false
Rules:
SpawnAgents: {}
DoneAtBatteryDischarge: {}
Collision:
done_at_collisions: false
AssignGlobalPositions: {}
DoneAtDestinationReachAny: {}
DestinationReachReward: {}
SpawnDestinations:
n_dests: 1
spawn_mode: GROUPED
DoneOnAllDirtCleaned: {}
SpawnDirt:
spawn_freq: 15
# Environment Dynamics
EntitiesSmearDirtOnMove:
smear_ratio: 0.2
DoorAutoClose:
close_frequency: 10
ItemRules:
max_dropoff_storage_size: 0
n_items: 5
n_locations: 5
spawn_frequency: 15
MaxStepsReached:
MoveMaintainers:
# Respawn Stuff
RespawnDirt:
respawn_freq: 15
RespawnItems:
respawn_freq: 15
# Utilities
WatchCollisions:
done_at_collisions: false
# Done Conditions
DoneAtDestinationReachAny:
DoneOnAllDirtCleaned:
DoneAtBatteryDischarge:
DoneAtMaintainerCollision:
DoneAtMaxStepsReached:
max_steps: 500
# AgentSingleZonePlacement:
# n_zones: 4

View File

@ -1,15 +1,41 @@
General:
# Your Seed
env_seed: 69
# Individual or global rewards?
individual_rewards: true
# The level.txt file to load
level_name: narrow_corridor
# View Radius; 0 = full observatbility
pomdp_r: 0
# print all messages and events
verbose: true
Agents:
# Agents are identified by their name
Wolfgang:
# The available actions for this particular agent
Actions:
# Able to do nothing
- Noop
# Able to move in all 8 directions
- Move8
# Stuff the agent can observe (per 2d slice)
# use "Combined" if you want to merge multiple slices into one
Observations:
# He sees walls
- Walls
# he sees other agent, "karl-Heinz" in this setting would be fine, too
- Other
# He can see Destinations, that are assigned to him (hence the singular)
- Destination
# Avaiable Spawn Positions as list
Positions:
- (2, 1)
- (2, 5)
# It is okay to collide with other agents, so that
# they end up on the same position
is_blocking_pos: true
# See Above....
Karl-Heinz:
Actions:
- Noop
@ -21,26 +47,43 @@ Agents:
Positions:
- (2, 1)
- (2, 5)
is_blocking_pos: true
# Other noteworthy Entitites
Entities:
Destinations: {}
General:
env_seed: 69
individual_rewards: true
level_name: narrow_corridor
pomdp_r: 0
verbose: true
# The destiantions or positional targets to reach
Destinations:
# Let them spawn on closed doors and agent positions
ignore_blocking: true
# We need a special spawn rule...
spawnrule:
# ...which assigns the destinations per agent
SpawnDestinationsPerAgent:
# we use this parameter
coords_or_quantity:
# to enable and assign special positions per agent
Wolfgang:
- (2, 1)
- (2, 5)
Karl-Heinz:
- (2, 1)
- (2, 5)
# Whether you want to provide a numeric Position observation.
# GlobalPositions:
# normalized: false
# Define the env. dynamics
Rules:
SpawnAgents: {}
Collision:
# Utilities
# This rule Checks for Collision, also it assigns the (negative) reward
WatchCollisions:
reward: -0.1
reward_at_done: -1
done_at_collisions: false
FixedDestinationSpawn:
per_agent_positions:
Wolfgang:
- (2, 1)
- (2, 5)
Karl-Heinz:
- (2, 1)
- (2, 5)
DestinationReachAll: {}
# Done Conditions
# Load any of the rules, to check for done conditions.
# DoneAtDestinationReachAny:
DoneAtDestinationReachAll:
# reward_at_done: 1
DoneAtMaxStepsReached:
max_steps: 200

View File

@ -48,9 +48,9 @@ class Move(Action, abc.ABC):
reward = r.MOVEMENTS_VALID if move_validity else r.MOVEMENTS_FAIL
return ActionResult(entity=entity, identifier=self._identifier, validity=move_validity, reward=reward)
else: # There is no place to go, propably collision
# This is currently handeld by the Collision rule, so that it can be switched on and off by conf.yml
# This is currently handeld by the WatchCollisions rule, so that it can be switched on and off by conf.yml
# return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=r.COLLISION)
return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=0)
return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID)
def _calc_new_pos(self, pos):
x_diff, y_diff = MOVEMAP[self._identifier]

View File

@ -10,6 +10,7 @@ AGENT = 'Agent' # Identifier of Agent-objects an
OTHERS = 'Other'
COMBINED = 'Combined'
GLOBALPOSITIONS = 'GlobalPositions' # Identifier of the global position slice
SPAWN_ENTITY_RULE = 'SpawnEntity'
# Attributes
IS_BLOCKING_LIGHT = 'var_is_blocking_light'
@ -29,7 +30,7 @@ VALUE_NO_POS = (-9999, -9999) # Invalid Position value used in the e
ACTION = 'action' # Identifier of Action-objects and groups (groups).
COLLISION = 'Collision' # Identifier to use in the context of collitions.
COLLISION = 'Collisions' # Identifier to use in the context of collitions.
# LAST_POS = 'LAST_POS' # Identifiert for retrieving an enitites last pos.
VALIDITY = 'VALIDITY' # Identifiert for retrieving the Validity of Action, Tick, etc. ...
@ -54,3 +55,5 @@ NOOP = 'Noop'
# Result Identifier
MOVEMENTS_VALID = 'motion_valid'
MOVEMENTS_FAIL = 'motion_not_valid'
DEFAULT_PATH = 'environment'
MODULE_PATH = 'modules'

View File

@ -12,14 +12,6 @@ from marl_factory_grid.environment import constants as c
class Agent(Entity):
@property
def var_is_blocking_light(self):
return False
@property
def var_can_move(self):
return True
@property
def var_is_paralyzed(self):
return len(self._paralyzed)
@ -28,14 +20,6 @@ class Agent(Entity):
def paralyze_reasons(self):
return [x for x in self._paralyzed]
@property
def var_is_blocking_pos(self):
return False
@property
def var_has_position(self):
return True
@property
def obs_tag(self):
return self.name
@ -48,10 +32,6 @@ class Agent(Entity):
def observations(self):
return self._observations
@property
def var_can_collide(self):
return True
def step_result(self):
pass
@ -60,16 +40,21 @@ class Agent(Entity):
return self._collection
@property
def state(self):
return self._state or ActionResult(entity=self, identifier=c.NOOP, validity=c.VALID, reward=0)
def var_is_blocking_pos(self):
return self._is_blocking_pos
def __init__(self, actions: List[Action], observations: List[str], *args, **kwargs):
@property
def state(self):
return self._state or ActionResult(entity=self, identifier=c.NOOP, validity=c.VALID)
def __init__(self, actions: List[Action], observations: List[str], *args, is_blocking_pos=False, **kwargs):
super(Agent, self).__init__(*args, **kwargs)
self._paralyzed = set()
self.step_result = dict()
self._actions = actions
self._observations = observations
self._state: Union[Result, None] = None
self._is_blocking_pos = is_blocking_pos
# noinspection PyAttributeOutsideInit
def clear_temp_state(self):

View File

@ -1,20 +1,19 @@
import abc
from collections import defaultdict
import numpy as np
from .object import _Object
from .object import Object
from .. import constants as c
from ...utils.results import ActionResult
from ...utils.utility_classes import RenderEntity
class Entity(_Object, abc.ABC):
class Entity(Object, abc.ABC):
"""Full Env Entity that lives on the environment Grid. Doors, Items, DirtPile etc..."""
@property
def state(self):
return self._status or ActionResult(entity=self, identifier=c.NOOP, validity=c.VALID, reward=0)
return self._status or ActionResult(entity=self, identifier=c.NOOP, validity=c.VALID)
@property
def var_has_position(self):
@ -60,6 +59,10 @@ class Entity(_Object, abc.ABC):
def pos(self):
return self._pos
def set_pos(self, pos):
assert isinstance(pos, tuple) and len(pos) == 2
self._pos = pos
@property
def last_pos(self):
try:
@ -84,7 +87,7 @@ class Entity(_Object, abc.ABC):
for observer in self.observers:
observer.notify_del_entity(self)
self._view_directory = curr_pos[0] - next_pos[0], curr_pos[1] - next_pos[1]
self._pos = next_pos
self.set_pos(next_pos)
for observer in self.observers:
observer.notify_add_entity(self)
return valid
@ -92,6 +95,7 @@ class Entity(_Object, abc.ABC):
def __init__(self, pos, bind_to=None, **kwargs):
super().__init__(**kwargs)
self._view_directory = c.VALUE_NO_POS
self._status = None
self._pos = pos
self._last_pos = pos
@ -109,9 +113,6 @@ class Entity(_Object, abc.ABC):
def render(self):
return RenderEntity(self.__class__.__name__.lower(), self.pos)
def __repr__(self):
return super(Entity, self).__repr__() + f'(@{self.pos})'
@property
def obs_tag(self):
try:
@ -128,25 +129,3 @@ class Entity(_Object, abc.ABC):
self._collection.delete_env_object(self)
self._collection = other_collection
return self._collection == other_collection
@classmethod
def from_coordinates(cls, positions: [(int, int)], *args, entity_kwargs=None, **kwargs, ):
collection = cls(*args, **kwargs)
collection.add_items(
[cls._entity(tuple(pos), **entity_kwargs if entity_kwargs is not None else {}) for pos in positions])
return collection
def notify_del_entity(self, entity):
try:
self.pos_dict[entity.pos].remove(entity)
except (ValueError, AttributeError):
pass
def by_pos(self, pos: (int, int)):
pos = tuple(pos)
try:
return self.state.entities.pos_dict[pos]
except StopIteration:
pass
except ValueError:
print()

View File

@ -1,24 +0,0 @@
# noinspection PyAttributeOutsideInit
class BoundEntityMixin:
@property
def bound_entity(self):
return self._bound_entity
@property
def name(self):
if self.bound_entity:
return f'{self.__class__.__name__}({self.bound_entity.name})'
else:
pass
def belongs_to_entity(self, entity):
return entity == self.bound_entity
def bind_to(self, entity):
self._bound_entity = entity
def unbind(self):
self._bound_entity = None

View File

@ -5,7 +5,7 @@ from marl_factory_grid.environment import constants as c
import marl_factory_grid.utils.helpers as h
class _Object:
class Object:
"""Generell Objects for Organisation and Maintanance such as Actions etc..."""
_u_idx = defaultdict(lambda: 0)
@ -13,10 +13,6 @@ class _Object:
def __bool__(self):
return True
@property
def var_has_position(self):
return False
@property
def var_can_be_bound(self):
try:
@ -30,22 +26,14 @@ class _Object:
@property
def name(self):
if self._str_ident is not None:
name = f'{self.__class__.__name__}[{self._str_ident}]'
else:
name = f'{self.__class__.__name__}#{self.u_int}'
if self.bound_entity:
name = h.add_bound_name(name, self.bound_entity)
if self.var_has_position:
name = h.add_pos_name(name, self)
return name
return f'{self.__class__.__name__}[{self.identifier}]'
@property
def identifier(self):
if self._str_ident is not None:
return self._str_ident
else:
return self.name
return self.u_int
def reset_uid(self):
self._u_idx = defaultdict(lambda: 0)
@ -62,7 +50,15 @@ class _Object:
print(f'Following kwargs were passed, but ignored: {kwargs}')
def __repr__(self):
return f'{self.name}'
name = self.name
if self.bound_entity:
name = h.add_bound_name(name, self.bound_entity)
try:
if self.var_has_position:
name = h.add_pos_name(name, self)
except AttributeError:
pass
return name
def __eq__(self, other) -> bool:
return other == self.identifier
@ -71,8 +67,8 @@ class _Object:
return hash(self.identifier)
def _identify_and_count_up(self):
idx = _Object._u_idx[self.__class__.__name__]
_Object._u_idx[self.__class__.__name__] += 1
idx = Object._u_idx[self.__class__.__name__]
Object._u_idx[self.__class__.__name__] += 1
return idx
def set_collection(self, collection):
@ -88,7 +84,7 @@ class _Object:
def summarize_state(self):
return dict()
def bind(self, entity):
def bind_to(self, entity):
# noinspection PyAttributeOutsideInit
self._bound_entity = entity
return c.VALID
@ -100,84 +96,5 @@ class _Object:
def bound_entity(self):
return self._bound_entity
def bind_to(self, entity):
self._bound_entity = entity
def unbind(self):
self._bound_entity = None
# class EnvObject(_Object):
# """Objects that hold Information that are observable, but have no position on the environment grid. Inventories etc..."""
#
# _u_idx = defaultdict(lambda: 0)
#
# @property
# def obs_tag(self):
# try:
# return self._collection.name or self.name
# except AttributeError:
# return self.name
#
# @property
# def var_is_blocking_light(self):
# try:
# return self._collection.var_is_blocking_light or False
# except AttributeError:
# return False
#
# @property
# def var_can_be_bound(self):
# try:
# return self._collection.var_can_be_bound or False
# except AttributeError:
# return False
#
# @property
# def var_can_move(self):
# try:
# return self._collection.var_can_move or False
# except AttributeError:
# return False
#
# @property
# def var_is_blocking_pos(self):
# try:
# return self._collection.var_is_blocking_pos or False
# except AttributeError:
# return False
#
# @property
# def var_has_position(self):
# try:
# return self._collection.var_has_position or False
# except AttributeError:
# return False
#
# @property
# def var_can_collide(self):
# try:
# return self._collection.var_can_collide or False
# except AttributeError:
# return False
#
#
# @property
# def encoding(self):
# return c.VALUE_OCCUPIED_CELL
#
#
# def __init__(self, **kwargs):
# self._bound_entity = None
# super(EnvObject, self).__init__(**kwargs)
#
#
# def change_parent_collection(self, other_collection):
# other_collection.add_item(self)
# self._collection.delete_env_object(self)
# self._collection = other_collection
# return self._collection == other_collection
#
#
# def summarize_state(self):
# return dict(name=str(self.name))

View File

@ -1,6 +1,6 @@
import numpy as np
from marl_factory_grid.environment.entity.object import _Object
from marl_factory_grid.environment.entity.object import Object
##########################################################################
@ -8,7 +8,7 @@ from marl_factory_grid.environment.entity.object import _Object
##########################################################################
class PlaceHolder(_Object):
class PlaceHolder(Object):
def __init__(self, *args, fill_value=0, **kwargs):
super().__init__(*args, **kwargs)
@ -24,10 +24,10 @@ class PlaceHolder(_Object):
@property
def name(self):
return "PlaceHolder"
return self.__class__.__name__
class GlobalPosition(_Object):
class GlobalPosition(Object):
@property
def encoding(self):
@ -36,7 +36,8 @@ class GlobalPosition(_Object):
else:
return self.bound_entity.pos
def __init__(self, level_shape, *args, normalized: bool = True, **kwargs):
def __init__(self, agent, level_shape, *args, normalized: bool = True, **kwargs):
super(GlobalPosition, self).__init__(*args, **kwargs)
self.bind_to(agent)
self._normalized = normalized
self._shape = level_shape

View File

@ -5,13 +5,8 @@ from marl_factory_grid.utils.utility_classes import RenderEntity
class Wall(Entity):
@property
def var_has_position(self):
return True
@property
def var_can_collide(self):
return True
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@property
def encoding(self):
@ -19,11 +14,3 @@ class Wall(Entity):
def render(self):
return RenderEntity(c.WALL, self.pos)
@property
def var_is_blocking_pos(self):
return True
@property
def var_is_blocking_light(self):
return True

View File

@ -56,15 +56,18 @@ class Factory(gym.Env):
self.level_filepath = Path(custom_level_path)
else:
self.level_filepath = Path(__file__).parent.parent / h.LEVELS_DIR / f'{self.conf.level_name}.txt'
self._renderer = None # expensive - don't use; unless required !
parsed_entities = self.conf.load_entities()
self.map = LevelParser(self.level_filepath, parsed_entities, self.conf.pomdp_r)
# Init for later usage:
self.state: Gamestate
self.map: LevelParser
self.obs_builder: OBSBuilder
# noinspection PyTypeChecker
self.state: Gamestate = None
# noinspection PyTypeChecker
self.obs_builder: OBSBuilder = None
# expensive - don't use; unless required !
self._renderer = None
# reset env to initial state, preparing env for new episode.
# returns tuple where the first dict contains initial observation for each agent in the env
@ -74,7 +77,7 @@ class Factory(gym.Env):
return self.state.entities[item]
def reset(self) -> (dict, dict):
if hasattr(self, 'state'):
if self.state is not None:
for entity_group in self.state.entities:
try:
entity_group[0].reset_uid()
@ -87,12 +90,16 @@ class Factory(gym.Env):
entities = self.map.do_init()
# Init rules
rules = self.conf.load_env_rules()
env_rules = self.conf.load_env_rules()
entity_rules = self.conf.load_entity_spawn_rules(entities)
env_rules.extend(entity_rules)
env_tests = self.conf.load_env_tests() if self.conf.tests else []
# Parse the agent conf
parsed_agents_conf = self.conf.parse_agents_conf()
self.state = Gamestate(entities, parsed_agents_conf, rules, env_tests, self.conf.env_seed, self.conf.verbose)
self.state = Gamestate(entities, parsed_agents_conf, env_rules, env_tests, self.map.level_shape,
self.conf.env_seed, self.conf.verbose)
# All is set up, trigger entity init with variable pos
# All is set up, trigger additional init (after agent entity spawn etc)
@ -160,7 +167,7 @@ class Factory(gym.Env):
# Finalize
reward, reward_info, done = self.summarize_step_results(tick_result, done_results)
info = reward_info
info = dict(reward_info)
info.update(step_reward=sum(reward), step=self.state.curr_step)

View File

@ -1,10 +1,15 @@
from marl_factory_grid.environment.entity.agent import Agent
from marl_factory_grid.environment.groups.collection import Collection
from marl_factory_grid.environment.rules import SpawnAgents
class Agents(Collection):
_entity = Agent
@property
def spawn_rule(self):
return {SpawnAgents.__name__: {}}
@property
def var_is_blocking_light(self):
return False

View File

@ -1,18 +1,25 @@
from typing import List, Tuple, Union
from typing import List, Tuple, Union, Dict
from marl_factory_grid.environment.entity.entity import Entity
from marl_factory_grid.environment.groups.objects import _Objects
from marl_factory_grid.environment.entity.object import _Object
from marl_factory_grid.environment.groups.objects import Objects
# noinspection PyProtectedMember
from marl_factory_grid.environment.entity.object import Object
import marl_factory_grid.environment.constants as c
from marl_factory_grid.utils.results import Result
class Collection(_Objects):
_entity = _Object # entity?
class Collection(Objects):
_entity = Object # entity?
symbol = None
@property
def var_is_blocking_light(self):
return False
@property
def var_is_blocking_pos(self):
return False
@property
def var_can_collide(self):
return False
@ -23,33 +30,65 @@ class Collection(_Objects):
@property
def var_has_position(self):
return False
# @property
# def var_has_bound(self):
# return False # batteries, globalpos, inventories true
@property
def var_can_be_bound(self):
return False
return True
@property
def encodings(self):
return [x.encoding for x in self]
def __init__(self, size, *args, **kwargs):
super(Collection, self).__init__(*args, **kwargs)
self.size = size
def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args): # woihn mit den args
if isinstance(coords_or_quantity, int):
self.add_items([self._entity() for _ in range(coords_or_quantity)])
@property
def spawn_rule(self):
"""Prevent SpawnRule creation if Objects are spawned by map, Doors e.g."""
if self.symbol:
return None
elif self._spawnrule:
return self._spawnrule
else:
self.add_items([self._entity(pos) for pos in coords_or_quantity])
return {c.SPAWN_ENTITY_RULE: dict(collection=self, coords_or_quantity=self._coords_or_quantity)}
def __init__(self, size, *args, coords_or_quantity: int = None, ignore_blocking=False,
spawnrule: Union[None, Dict[str, dict]] = None,
**kwargs):
super(Collection, self).__init__(*args, **kwargs)
self._coords_or_quantity = coords_or_quantity
self.size = size
self._spawnrule = spawnrule
self._ignore_blocking = ignore_blocking
def trigger_spawn(self, state, *entity_args, coords_or_quantity=None, ignore_blocking=False, **entity_kwargs):
coords_or_quantity = coords_or_quantity if coords_or_quantity else self._coords_or_quantity
if self.var_has_position:
if self.var_has_position and isinstance(coords_or_quantity, int):
if ignore_blocking or self._ignore_blocking:
coords_or_quantity = state.entities.floorlist[:coords_or_quantity]
else:
coords_or_quantity = state.get_n_random_free_positions(coords_or_quantity)
self.spawn(coords_or_quantity, *entity_args, **entity_kwargs)
state.print(f'{len(coords_or_quantity)} new {self.name} have been spawned at {coords_or_quantity}')
return Result(identifier=f'{self.name}_spawn', validity=c.VALID, value=len(coords_or_quantity))
else:
if isinstance(coords_or_quantity, int):
self.spawn(coords_or_quantity, *entity_args, **entity_kwargs)
state.print(f'{coords_or_quantity} new {self.name} have been spawned randomly.')
return Result(identifier=f'{self.name}_spawn', validity=c.VALID, value=coords_or_quantity)
else:
raise ValueError(f'{self._entity.__name__} has no position!')
def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args, **entity_kwargs):
if self.var_has_position:
if isinstance(coords_or_quantity, int):
raise ValueError(f'{self._entity.__name__} should have a position!')
else:
self.add_items([self._entity(pos, *entity_args, **entity_kwargs) for pos in coords_or_quantity])
else:
if isinstance(coords_or_quantity, int):
self.add_items([self._entity(*entity_args, **entity_kwargs) for _ in range(coords_or_quantity)])
else:
raise ValueError(f'{self._entity.__name__} has no position!')
return c.VALID
def despawn(self, items: List[_Object]):
items = [items] if isinstance(items, _Object) else items
def despawn(self, items: List[Object]):
items = [items] if isinstance(items, Object) else items
for item in items:
del self[item]
@ -115,7 +154,7 @@ class Collection(_Objects):
except StopIteration:
pass
except ValueError:
print()
pass
@property
def positions(self):

View File

@ -1,21 +1,21 @@
from collections import defaultdict
from operator import itemgetter
from random import shuffle, random
from random import shuffle
from typing import Dict
from marl_factory_grid.environment.groups.objects import _Objects
from marl_factory_grid.environment.groups.objects import Objects
from marl_factory_grid.utils.helpers import POS_MASK
class Entities(_Objects):
_entity = _Objects
class Entities(Objects):
_entity = Objects
@staticmethod
def neighboring_positions(pos):
return (POS_MASK + pos).reshape(-1, 2)
return [tuple(x) for x in (POS_MASK + pos).reshape(-1, 2)]
def get_entities_near_pos(self, pos):
return [y for x in itemgetter(*(tuple(x) for x in self.neighboring_positions(pos)))(self.pos_dict) for y in x]
return [y for x in itemgetter(*self.neighboring_positions(pos))(self.pos_dict) for y in x]
def render(self):
return [y for x in self for y in x.render() if x is not None]
@ -35,8 +35,9 @@ class Entities(_Objects):
super().__init__()
def guests_that_can_collide(self, pos):
return[x for val in self.pos_dict[pos] for x in val if x.var_can_collide]
return [x for val in self.pos_dict[pos] for x in val if x.var_can_collide]
@property
def empty_positions(self):
empty_positions = [key for key in self.floorlist if not self.pos_dict[key]]
shuffle(empty_positions)
@ -48,11 +49,23 @@ class Entities(_Objects):
shuffle(empty_positions)
return empty_positions
def is_blocked(self):
return[key for key, val in self.pos_dict.items() if any([x.var_is_blocking_pos for x in val])]
@property
def blocked_positions(self):
blocked_positions = [key for key, val in self.pos_dict.items() if any([x.var_is_blocking_pos for x in val])]
shuffle(blocked_positions)
return blocked_positions
def is_not_blocked(self):
return[key for key, val in self.pos_dict.items() if not all([x.var_is_blocking_pos for x in val])]
@property
def free_positions_generator(self):
generator = (
key for key in self.floorlist if all(not x.var_can_collide and not x.var_is_blocking_pos
for x in self.pos_dict[key])
)
return generator
@property
def free_positions_list(self):
return [x for x in self.free_positions_generator]
def iter_entities(self):
return iter((x for sublist in self.values() for x in sublist))
@ -74,7 +87,7 @@ class Entities(_Objects):
def __delitem__(self, name):
assert_str = 'This group of entity does not exist in this collection!'
assert any([key for key in name.keys() if key in self.keys()]), assert_str
self[name]._observers.delete(self)
self[name].del_observer(self)
for entity in self[name]:
entity.del_observer(self)
return super(Entities, self).__delitem__(name)
@ -92,3 +105,6 @@ class Entities(_Objects):
@property
def positions(self):
return [k for k, v in self.pos_dict.items() for _ in v]
def is_occupied(self, pos):
return len([x for x in self.pos_dict[pos] if x.var_can_collide or x.var_is_blocking_pos]) >= 1

View File

@ -4,10 +4,6 @@ from marl_factory_grid.environment import constants as c
# noinspection PyUnresolvedReferences,PyTypeChecker
class IsBoundMixin:
@property
def name(self):
return f'{self.__class__.__name__}({self._bound_entity.name})'
def __repr__(self):
return f'{self.__class__.__name__}#{self._bound_entity.name}({self._data})'

View File

@ -1,14 +1,19 @@
from collections import defaultdict
from typing import List
from typing import List, Iterator, Union
import numpy as np
from marl_factory_grid.environment.entity.object import _Object
from marl_factory_grid.environment.entity.object import Object
import marl_factory_grid.environment.constants as c
from marl_factory_grid.utils import helpers as h
class _Objects:
_entity = _Object
class Objects:
_entity = Object
@property
def var_can_be_bound(self):
return False
@property
def observers(self):
@ -45,7 +50,7 @@ class _Objects:
def __len__(self):
return len(self._data)
def __iter__(self):
def __iter__(self) -> Iterator[Union[Object, None]]:
return iter(self.values())
def add_item(self, item: _entity):
@ -125,13 +130,14 @@ class _Objects:
repr_dict = {key: val for key, val in self._data.items() if key not in [c.WALLS]}
return f'{self.__class__.__name__}[{repr_dict}]'
def notify_del_entity(self, entity: _Object):
def notify_del_entity(self, entity: Object):
try:
# noinspection PyUnresolvedReferences
self.pos_dict[entity.pos].remove(entity)
except (AttributeError, ValueError, IndexError):
pass
def notify_add_entity(self, entity: _Object):
def notify_add_entity(self, entity: Object):
try:
if self not in entity.observers:
entity.add_observer(self)
@ -148,12 +154,12 @@ class _Objects:
def by_entity(self, entity):
try:
return next((x for x in self if x.belongs_to_entity(entity)))
return h.get_first(self, filter_by=lambda x: x.belongs_to_entity(entity))
except (StopIteration, AttributeError):
return None
def idx_by_entity(self, entity):
try:
return next((idx for idx, x in enumerate(self) if x.belongs_to_entity(entity)))
return h.get_first_index(self, filter_by=lambda x: x.belongs_to_entity(entity))
except (StopIteration, AttributeError):
return None

View File

@ -1,7 +1,10 @@
from typing import List, Union
from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.entity.util import GlobalPosition
from marl_factory_grid.environment.groups.collection import Collection
from marl_factory_grid.utils.results import Result
from marl_factory_grid.utils.states import Gamestate
class Combined(Collection):
@ -36,17 +39,17 @@ class GlobalPositions(Collection):
_entity = GlobalPosition
@property
def var_is_blocking_light(self):
return False
@property
def var_can_collide(self):
return False
@property
def var_can_be_bound(self):
return True
var_is_blocking_light = False
var_can_be_bound = True
var_can_collide = False
var_has_position = False
def __init__(self, *args, **kwargs):
super(GlobalPositions, self).__init__(*args, **kwargs)
def spawn(self, agents, level_shape, *args, **kwargs):
self.add_items([self._entity(agent, level_shape, *args, **kwargs) for agent in agents])
return [Result(identifier=f'{self.name}_spawn', validity=c.VALID, value=len(self))]
def trigger_spawn(self, state: Gamestate, *args, **kwargs) -> [Result]:
return self.spawn(state[c.AGENT], state.lvl_shape, *args, **kwargs)

View File

@ -7,9 +7,12 @@ class Walls(Collection):
_entity = Wall
symbol = c.SYMBOL_WALL
@property
def var_has_position(self):
return True
var_can_collide = True
var_is_blocking_light = True
var_can_move = False
var_has_position = True
var_can_be_bound = False
var_is_blocking_pos = True
def __init__(self, *args, **kwargs):
super(Walls, self).__init__(*args, **kwargs)

View File

@ -2,3 +2,4 @@ MOVEMENTS_VALID: float = -0.001
MOVEMENTS_FAIL: float = -0.05
NOOP: float = -0.01
COLLISION: float = -0.5
COLLISION_DONE: float = -1

View File

@ -1,11 +1,11 @@
import abc
from random import shuffle
from typing import List
from typing import List, Collection
from marl_factory_grid.environment import rewards as r, constants as c
from marl_factory_grid.environment.entity.agent import Agent
from marl_factory_grid.utils import helpers as h
from marl_factory_grid.utils.results import TickResult, DoneResult
from marl_factory_grid.environment import rewards as r, constants as c
class Rule(abc.ABC):
@ -39,6 +39,29 @@ class Rule(abc.ABC):
return []
class SpawnEntity(Rule):
@property
def _collection(self) -> Collection:
return Collection()
@property
def name(self):
return f'{self.__class__.__name__}({self.collection.name})'
def __init__(self, collection, coords_or_quantity, ignore_blocking=False):
super().__init__()
self.coords_or_quantity = coords_or_quantity
self.collection = collection
self.ignore_blocking = ignore_blocking
def on_init(self, state, lvl_map) -> [TickResult]:
results = self.collection.trigger_spawn(state, ignore_blocking=self.ignore_blocking)
pos_str = f' on: {[x.pos for x in self.collection]}' if self.collection.var_has_position else ''
state.print(f'Initial {self.collection.__class__.__name__} were spawned{pos_str}')
return results
class SpawnAgents(Rule):
def __init__(self):
@ -46,14 +69,14 @@ class SpawnAgents(Rule):
pass
def on_init(self, state, lvl_map):
agent_conf = state.agents_conf
# agents = Agents(lvl_map.size)
agents = state[c.AGENT]
empty_positions = state.entities.empty_positions()[:len(agent_conf)]
for agent_name in agent_conf:
actions = agent_conf[agent_name]['actions'].copy()
observations = agent_conf[agent_name]['observations'].copy()
positions = agent_conf[agent_name]['positions'].copy()
empty_positions = state.entities.empty_positions[:len(state.agents_conf)]
for agent_name, agent_conf in state.agents_conf.items():
actions = agent_conf['actions'].copy()
observations = agent_conf['observations'].copy()
positions = agent_conf['positions'].copy()
other = agent_conf['other'].copy()
if positions:
shuffle(positions)
while True:
@ -61,18 +84,18 @@ class SpawnAgents(Rule):
pos = positions.pop()
except IndexError:
raise ValueError(f'It was not possible to spawn an Agent on the available position: '
f'\n{agent_name[agent_name]["positions"].copy()}')
if agents.by_pos(pos) and state.check_pos_validity(pos):
f'\n{agent_conf["positions"].copy()}')
if bool(agents.by_pos(pos)) or not state.check_pos_validity(pos):
continue
else:
agents.add_item(Agent(actions, observations, pos, str_ident=agent_name))
agents.add_item(Agent(actions, observations, pos, str_ident=agent_name, **other))
break
else:
agents.add_item(Agent(actions, observations, empty_positions.pop(), str_ident=agent_name))
agents.add_item(Agent(actions, observations, empty_positions.pop(), str_ident=agent_name, **other))
pass
class MaxStepsReached(Rule):
class DoneAtMaxStepsReached(Rule):
def __init__(self, max_steps: int = 500):
super().__init__()
@ -83,8 +106,8 @@ class MaxStepsReached(Rule):
def on_check_done(self, state):
if self.max_steps <= state.curr_step:
return [DoneResult(validity=c.VALID, identifier=self.name, reward=0)]
return [DoneResult(validity=c.NOT_VALID, identifier=self.name, reward=0)]
return [DoneResult(validity=c.VALID, identifier=self.name)]
return [DoneResult(validity=c.NOT_VALID, identifier=self.name)]
class AssignGlobalPositions(Rule):
@ -95,16 +118,17 @@ class AssignGlobalPositions(Rule):
def on_init(self, state, lvl_map):
from marl_factory_grid.environment.entity.util import GlobalPosition
for agent in state[c.AGENT]:
gp = GlobalPosition(lvl_map.level_shape)
gp.bind_to(agent)
gp = GlobalPosition(agent, lvl_map.level_shape)
state[c.GLOBALPOSITIONS].add_item(gp)
return []
class Collision(Rule):
class WatchCollisions(Rule):
def __init__(self, done_at_collisions: bool = False):
def __init__(self, reward=r.COLLISION, done_at_collisions: bool = False, reward_at_done=r.COLLISION_DONE):
super().__init__()
self.reward_at_done = reward_at_done
self.reward = reward
self.done_at_collisions = done_at_collisions
self.curr_done = False
@ -117,12 +141,12 @@ class Collision(Rule):
if len(guests) >= 2:
for i, guest in enumerate(guests):
try:
guest.set_state(TickResult(identifier=c.COLLISION, reward=r.COLLISION,
guest.set_state(TickResult(identifier=c.COLLISION, reward=self.reward,
validity=c.NOT_VALID, entity=self))
except AttributeError:
pass
results.append(TickResult(entity=guest, identifier=c.COLLISION,
reward=r.COLLISION, validity=c.VALID))
reward=self.reward, validity=c.VALID))
self.curr_done = True if self.done_at_collisions else False
return results
@ -131,5 +155,5 @@ class Collision(Rule):
inter_entity_collision_detected = self.curr_done
move_failed = any(h.is_move(x.state.identifier) and not x.state.validity for x in state[c.AGENT])
if inter_entity_collision_detected or move_failed:
return [DoneResult(validity=c.VALID, identifier=c.COLLISION, reward=r.COLLISION)]
return [DoneResult(validity=c.NOT_VALID, identifier=self.name, reward=0)]
return [DoneResult(validity=c.VALID, identifier=c.COLLISION, reward=self.reward_at_done)]
return []

View File

@ -6,7 +6,9 @@ from marl_factory_grid.utils.results import TickResult, DoneResult
class TemplateRule(Rule):
def __init__(self, *args, **kwargs):
super(TemplateRule, self).__init__(*args, **kwargs)
super(TemplateRule, self).__init__()
self.args = args
self.kwargs = kwargs
def on_init(self, state, lvl_map):
pass

View File

@ -1,4 +1,4 @@
from .actions import BtryCharge
from .entitites import Pod, Battery
from .entitites import ChargePod, Battery
from .groups import ChargePods, Batteries
from .rules import DoneAtBatteryDischarge, BatteryDecharge

View File

@ -1,11 +1,11 @@
from typing import Union
import marl_factory_grid.modules.batteries.constants
from marl_factory_grid.environment.actions import Action
from marl_factory_grid.utils.results import ActionResult
from marl_factory_grid.modules.batteries import constants as b
from marl_factory_grid.environment import constants as c
from marl_factory_grid.utils import helpers as h
class BtryCharge(Action):
@ -14,8 +14,8 @@ class BtryCharge(Action):
super().__init__(b.ACTION_CHARGE)
def do(self, entity, state) -> Union[None, ActionResult]:
if charge_pod := state[b.CHARGE_PODS].by_pos(entity.pos):
valid = charge_pod.charge_battery(state[b.BATTERIES].by_entity(entity))
if charge_pod := h.get_first(state[b.CHARGE_PODS].by_pos(entity.pos)):
valid = h.get_first(charge_pod.charge_battery(state[b.BATTERIES].by_entity(entity)))
if valid:
state.print(f'{entity.name} just charged batteries at {charge_pod.name}.')
else:
@ -23,5 +23,6 @@ class BtryCharge(Action):
else:
valid = c.NOT_VALID
state.print(f'{entity.name} failed to charged batteries at {entity.pos}.')
return ActionResult(entity=entity, identifier=self._identifier, validity=valid,
reward=marl_factory_grid.modules.batteries.constants.REWARD_CHARGE_VALID if valid else marl_factory_grid.modules.batteries.constants.Reward_CHARGE_FAIL)
reward=b.REWARD_CHARGE_VALID if valid else b.Reward_CHARGE_FAIL)

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.9 KiB

View File

@ -1,11 +1,11 @@
from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.entity.entity import Entity
from marl_factory_grid.environment.entity.object import _Object
from marl_factory_grid.environment.entity.object import Object
from marl_factory_grid.modules.batteries import constants as b
from marl_factory_grid.utils.utility_classes import RenderEntity
class Battery(_Object):
class Battery(Object):
@property
def var_can_be_bound(self):
@ -50,7 +50,7 @@ class Battery(_Object):
return summary
class Pod(Entity):
class ChargePod(Entity):
@property
def encoding(self):
@ -58,7 +58,7 @@ class Pod(Entity):
def __init__(self, *args, charge_rate: float = 0.4,
multi_charge: bool = False, **kwargs):
super(Pod, self).__init__(*args, **kwargs)
super(ChargePod, self).__init__(*args, **kwargs)
self.charge_rate = charge_rate
self.multi_charge = multi_charge

View File

@ -1,52 +1,36 @@
from typing import Union, List, Tuple
from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.groups.collection import Collection
from marl_factory_grid.modules.batteries.entitites import Pod, Battery
from marl_factory_grid.modules.batteries.entitites import ChargePod, Battery
from marl_factory_grid.utils.results import Result
class Batteries(Collection):
_entity = Battery
@property
def var_is_blocking_light(self):
return False
@property
def var_can_collide(self):
return False
@property
def var_can_move(self):
return False
@property
def var_has_position(self):
return False
@property
def var_can_be_bound(self):
return True
var_has_position = False
var_can_be_bound = True
@property
def obs_tag(self):
return self.__class__.__name__
def __init__(self, *args, **kwargs):
super(Batteries, self).__init__(*args, **kwargs)
def __init__(self, size, initial_charge_level: float=1.0, *args, **kwargs):
super(Batteries, self).__init__(size, *args, **kwargs)
self.initial_charge_level = initial_charge_level
def spawn(self, agents, initial_charge_level):
batteries = [self._entity(initial_charge_level, agent) for _, agent in enumerate(agents)]
def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], agents, *entity_args, **entity_kwargs):
batteries = [self._entity(self.initial_charge_level, agent) for _, agent in enumerate(agents)]
self.add_items(batteries)
# def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args): hat keine pos
# agents = entity_args[0]
# initial_charge_level = entity_args[1]
# batteries = [self._entity(initial_charge_level, agent) for _, agent in enumerate(agents)]
# self.add_items(batteries)
def trigger_spawn(self, state, *entity_args, coords_or_quantity=None, **entity_kwargs):
self.spawn(0, state[c.AGENT])
return Result(identifier=f'{self.name}_spawn', validity=c.VALID, value=len(self))
class ChargePods(Collection):
_entity = Pod
_entity = ChargePod
def __init__(self, *args, **kwargs):
super(ChargePods, self).__init__(*args, **kwargs)

View File

@ -1,11 +1,9 @@
from typing import List, Union
import marl_factory_grid.modules.batteries.constants
from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.utils.results import TickResult, DoneResult
from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.modules.batteries import constants as b
from marl_factory_grid.utils.results import TickResult, DoneResult
class BatteryDecharge(Rule):
@ -49,10 +47,6 @@ class BatteryDecharge(Rule):
self.per_action_costs = per_action_costs
self.initial_charge = initial_charge
def on_init(self, state, lvl_map): # on reset?
assert len(state[c.AGENT]), "There are no agents, did you already spawn them?"
state[b.BATTERIES].spawn(state[c.AGENT], self.initial_charge)
def tick_step(self, state) -> List[TickResult]:
# Decharge
batteries = state[b.BATTERIES]
@ -66,7 +60,7 @@ class BatteryDecharge(Rule):
batteries.by_entity(agent).decharge(energy_consumption)
results.append(TickResult(self.name, reward=0, entity=agent, validity=c.VALID))
results.append(TickResult(self.name, entity=agent, validity=c.VALID))
return results
@ -82,13 +76,13 @@ class BatteryDecharge(Rule):
if self.paralyze_agents_on_discharge:
btry.bound_entity.paralyze(self.name)
results.append(
TickResult("Paralyzed", entity=btry.bound_entity, reward=0, validity=c.VALID)
TickResult("Paralyzed", entity=btry.bound_entity, validity=c.VALID)
)
state.print(f'{btry.bound_entity.name} has just been paralyzed!')
if btry.bound_entity.var_is_paralyzed and not btry.is_discharged:
btry.bound_entity.de_paralyze(self.name)
results.append(
TickResult("De-Paralyzed", entity=btry.bound_entity, reward=0, validity=c.VALID)
TickResult("De-Paralyzed", entity=btry.bound_entity, validity=c.VALID)
)
state.print(f'{btry.bound_entity.name} has just been de-paralyzed!')
return results
@ -132,7 +126,7 @@ class DoneAtBatteryDischarge(BatteryDecharge):
if any_discharged or all_discharged:
return [DoneResult(self.name, validity=c.VALID, reward=self.reward_discharge_done)]
else:
return [DoneResult(self.name, validity=c.NOT_VALID, reward=0)]
return [DoneResult(self.name, validity=c.NOT_VALID)]
class SpawnChargePods(Rule):
@ -155,7 +149,7 @@ class SpawnChargePods(Rule):
def on_init(self, state, lvl_map):
pod_collection = state[b.CHARGE_PODS]
empty_positions = state.entities.empty_positions()
empty_positions = state.entities.empty_positions
pods = pod_collection.from_coordinates(empty_positions, entity_kwargs=dict(
multi_charge=self.multi_charge, charge_rate=self.charge_rate)
)

View File

@ -1,4 +1,4 @@
from .actions import CleanUp
from .entitites import DirtPile
from .groups import DirtPiles
from .rules import SpawnDirt, EntitiesSmearDirtOnMove, DoneOnAllDirtCleaned
from .rules import EntitiesSmearDirtOnMove, DoneOnAllDirtCleaned

View File

@ -1,5 +1,3 @@
from numpy import random
from marl_factory_grid.environment.entity.entity import Entity
from marl_factory_grid.utils.utility_classes import RenderEntity
from marl_factory_grid.modules.clean_up import constants as d
@ -7,22 +5,6 @@ from marl_factory_grid.modules.clean_up import constants as d
class DirtPile(Entity):
@property
def var_can_collide(self):
return False
@property
def var_can_move(self):
return False
@property
def var_is_blocking_light(self):
return False
@property
def var_has_position(self):
return True
@property
def amount(self):
return self._amount

View File

@ -1,76 +1,61 @@
from typing import Union, List, Tuple
from marl_factory_grid.environment import constants as c
from marl_factory_grid.utils.results import Result
from marl_factory_grid.environment.groups.collection import Collection
from marl_factory_grid.modules.clean_up.entitites import DirtPile
from marl_factory_grid.utils.results import Result
class DirtPiles(Collection):
_entity = DirtPile
@property
def var_is_blocking_light(self):
return False
var_is_blocking_light = False
var_can_collide = False
var_can_move = False
var_has_position = True
@property
def var_can_collide(self):
return False
@property
def var_can_move(self):
return False
@property
def var_has_position(self):
return True
@property
def amount(self):
def global_amount(self):
return sum([dirt.amount for dirt in self])
def __init__(self, *args,
max_local_amount=5,
clean_amount=1,
max_global_amount: int = 20, **kwargs):
max_global_amount: int = 20,
coords_or_quantity=10,
initial_amount=2,
amount_var=0.2,
n_var=0.2,
**kwargs):
super(DirtPiles, self).__init__(*args, **kwargs)
self.amount_var = amount_var
self.n_var = n_var
self.clean_amount = clean_amount
self.max_global_amount = max_global_amount
self.max_local_amount = max_local_amount
self.coords_or_quantity = coords_or_quantity
self.initial_amount = initial_amount
def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args):
amount_s = entity_args[0]
def trigger_spawn(self, state, coords_or_quantity=0, amount=0) -> [Result]:
coords_or_quantity = coords_or_quantity if coords_or_quantity else self.coords_or_quantity
n_new = int(abs(coords_or_quantity + (state.rng.uniform(-self.n_var, self.n_var))))
n_new = state.get_n_random_free_positions(n_new)
amounts = [amount if amount else (self.initial_amount + state.rng.uniform(-self.amount_var, self.amount_var))
for _ in range(coords_or_quantity)]
spawn_counter = 0
for idx, pos in enumerate(coords_or_quantity):
if not self.amount > self.max_global_amount:
amount = amount_s[idx] if isinstance(amount_s, list) else amount_s
for idx, (pos, a) in enumerate(zip(n_new, amounts)):
if not self.global_amount > self.max_global_amount:
if dirt := self.by_pos(pos):
dirt = next(dirt.iter())
new_value = dirt.amount + amount
new_value = dirt.amount + a
dirt.set_new_amount(new_value)
else:
dirt = DirtPile(pos, amount=amount)
self.add_item(dirt)
super().spawn([pos], amount=a)
spawn_counter += 1
else:
return Result(identifier=f'{self.name}_spawn', validity=c.NOT_VALID, reward=0,
value=spawn_counter)
return Result(identifier=f'{self.name}_spawn', validity=c.VALID, reward=0, value=spawn_counter)
return Result(identifier=f'{self.name}_spawn', validity=c.NOT_VALID, value=spawn_counter)
def trigger_dirt_spawn(self, n, amount, state, n_var=0.2, amount_var=0.2) -> Result:
free_for_dirt = [x for x in state.entities.floorlist if len(state.entities.pos_dict[x]) == 0 or (
len(state.entities.pos_dict[x]) >= 1 and isinstance(next(y for y in x), DirtPile))]
# free_for_dirt = [x for x in state[c.FLOOR]
# if len(x.guests) == 0 or (
# len(x.guests) == 1 and
# isinstance(next(y for y in x.guests), DirtPile))]
state.rng.shuffle(free_for_dirt)
new_spawn = int(abs(n + (state.rng.uniform(-n_var, n_var))))
new_amount_s = [abs(amount + (amount*state.rng.uniform(-amount_var, amount_var))) for _ in range(new_spawn)]
n_dirty_positions = free_for_dirt[:new_spawn]
return self.spawn(n_dirty_positions, new_amount_s)
return Result(identifier=f'{self.name}_spawn', validity=c.VALID, value=spawn_counter)
def __repr__(self):
s = super(DirtPiles, self).__repr__()
return f'{s[:-1]}, {self.amount})'
return f'{s[:-1]}, {self.global_amount}]'

View File

@ -22,58 +22,37 @@ class DoneOnAllDirtCleaned(Rule):
def on_check_done(self, state) -> [DoneResult]:
if len(state[d.DIRT]) == 0 and state.curr_step:
return [DoneResult(validity=c.VALID, identifier=self.name, reward=self.reward)]
return [DoneResult(validity=c.NOT_VALID, identifier=self.name, reward=0)]
return [DoneResult(validity=c.NOT_VALID, identifier=self.name)]
class SpawnDirt(Rule):
class RespawnDirt(Rule):
def __init__(self, initial_n: int = 5, initial_amount: float = 1.3,
respawn_n: int = 3, respawn_amount: float = 0.8,
n_var: float = 0.2, amount_var: float = 0.2, spawn_freq: int = 15):
def __init__(self, respawn_freq: int = 15, respawn_n: int = 5, respawn_amount: float = 1.0):
"""
Defines the spawn pattern of intial and additional 'Dirt'-entitites.
First chooses positions, then trys to spawn dirt until 'respawn_n' or the maximal global amount is reached.
If there is allready some, it is topped up to min(max_local_amount, amount).
:type spawn_freq: int
:parameter spawn_freq: In which frequency should this Rule try to spawn new 'Dirt'?
:type respawn_freq: int
:parameter respawn_freq: In which frequency should this Rule try to spawn new 'Dirt'?
:type respawn_n: int
:parameter respawn_n: How many respawn positions are considered.
:type initial_n: int
:parameter initial_n: How much initial positions are considered.
:type amount_var: float
:parameter amount_var: Variance of amount to spawn.
:type n_var: float
:parameter n_var: Variance of n to spawn.
:type respawn_amount: float
:parameter respawn_amount: Defines how much dirt 'amount' is placed every 'spawn_freq' ticks.
:type initial_amount: float
:parameter initial_amount: Defines how much dirt 'amount' is initially placed.
"""
super().__init__()
self.amount_var = amount_var
self.n_var = n_var
self.respawn_amount = respawn_amount
self.respawn_n = respawn_n
self.initial_amount = initial_amount
self.initial_n = initial_n
self.spawn_freq = spawn_freq
self._next_dirt_spawn = spawn_freq
def on_init(self, state, lvl_map) -> str:
result = state[d.DIRT].trigger_dirt_spawn(self.initial_n, self.initial_amount, state,
n_var=self.n_var, amount_var=self.amount_var)
state.print(f'Initial Dirt was spawned on: {[x.pos for x in state[d.DIRT]]}')
return result
self.respawn_amount = respawn_amount
self.respawn_freq = respawn_freq
self._next_dirt_spawn = respawn_freq
def tick_step(self, state):
collection = state[d.DIRT]
if self._next_dirt_spawn < 0:
pass # No DirtPile Spawn
result = [] # No DirtPile Spawn
elif not self._next_dirt_spawn:
result = [state[d.DIRT].trigger_dirt_spawn(self.respawn_n, self.respawn_amount, state,
n_var=self.n_var, amount_var=self.amount_var)]
self._next_dirt_spawn = self.spawn_freq
result = [collection.trigger_spawn(state, coords_or_quantity=self.respawn_n, amount=self.respawn_amount)]
self._next_dirt_spawn = self.respawn_freq
else:
self._next_dirt_spawn -= 1
result = []
@ -99,8 +78,8 @@ class EntitiesSmearDirtOnMove(Rule):
for entity in state.moving_entites:
if is_move(entity.state.identifier) and entity.state.validity == c.VALID:
if old_pos_dirt := state[d.DIRT].by_pos(entity.last_pos):
old_pos_dirt = next(iter(old_pos_dirt))
if smeared_dirt := round(old_pos_dirt.amount * self.smear_ratio, 2):
if state[d.DIRT].spawn(entity.pos, amount=smeared_dirt):
results.append(TickResult(identifier=self.name, entity=entity,
reward=0, validity=c.VALID))
results.append(TickResult(identifier=self.name, entity=entity, validity=c.VALID))
return results

View File

@ -1,4 +1,7 @@
from .actions import DestAction
from .entitites import Destination
from .groups import Destinations
from .rules import DoneAtDestinationReachAll, SpawnDestinations
from .rules import (DoneAtDestinationReachAll,
DoneAtDestinationReachAny,
SpawnDestinationsPerAgent,
DestinationReachReward)

View File

@ -21,4 +21,4 @@ class DestAction(Action):
valid = c.NOT_VALID
state.print(f'{entity.name} just tried to do_wait_action do_wait_action at {entity.pos} but failed')
return ActionResult(entity=entity, identifier=self._identifier, validity=valid,
reward=marl_factory_grid.modules.destinations.constants.REWARD_WAIT_VALID if valid else marl_factory_grid.modules.destinations.constants.REWARD_WAIT_FAIL)
reward=d.REWARD_WAIT_VALID if valid else d.REWARD_WAIT_FAIL)

View File

@ -9,30 +9,6 @@ from marl_factory_grid.utils.utility_classes import RenderEntity
class Destination(Entity):
@property
def var_can_move(self):
return False
@property
def var_can_collide(self):
return False
@property
def var_has_position(self):
return True
@property
def var_is_blocking_pos(self):
return False
@property
def var_is_blocking_light(self):
return False
@property
def var_can_be_bound(self):
return True
def was_reached(self):
return self._was_reached

View File

@ -1,43 +1,18 @@
from marl_factory_grid.environment.groups.collection import Collection
from marl_factory_grid.modules.destinations.entitites import Destination
from marl_factory_grid.environment import constants as c
from marl_factory_grid.modules.destinations import constants as d
class Destinations(Collection):
_entity = Destination
@property
def var_is_blocking_light(self):
return False
@property
def var_can_collide(self):
return False
@property
def var_can_move(self):
return False
@property
def var_has_position(self):
return True
var_is_blocking_light = False
var_can_collide = False
var_can_move = False
var_has_position = True
var_can_be_bound = True
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __repr__(self):
return super(Destinations, self).__repr__()
@staticmethod
def trigger_destination_spawn(n_dests, state):
coordinates = state.entities.floorlist[:n_dests]
if destinations := [Destination(pos) for pos in coordinates]:
state[d.DESTINATION].add_items(destinations)
state.print(f'{n_dests} new destinations have been spawned')
return c.VALID
else:
state.print('No Destiantions are spawning, limit is reached.')
return c.NOT_VALID

View File

@ -2,8 +2,8 @@ import ast
from random import shuffle
from typing import List, Dict, Tuple
import marl_factory_grid.modules.destinations.constants
from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.utils import helpers as h
from marl_factory_grid.utils.results import TickResult, DoneResult
from marl_factory_grid.environment import constants as c
@ -54,7 +54,7 @@ class DoneAtDestinationReachAll(DestinationReachReward):
"""
This rule triggers and sets the done flag if ALL Destinations have been reached.
:type reward_at_done: object
:type reward_at_done: float
:param reward_at_done: Specifies the reward, agent get, whenn all destinations are reached.
:type dest_reach_reward: float
:param dest_reach_reward: Specify the reward, agents get when reaching a single destination.
@ -65,7 +65,7 @@ class DoneAtDestinationReachAll(DestinationReachReward):
def on_check_done(self, state) -> List[DoneResult]:
if all(x.was_reached() for x in state[d.DESTINATION]):
return [DoneResult(self.name, validity=c.VALID, reward=self.reward)]
return [DoneResult(self.name, validity=c.NOT_VALID, reward=0)]
return [DoneResult(self.name, validity=c.NOT_VALID)]
class DoneAtDestinationReachAny(DestinationReachReward):
@ -75,7 +75,7 @@ class DoneAtDestinationReachAny(DestinationReachReward):
This rule triggers and sets the done flag if ANY Destinations has been reached.
!!! IMPORTANT: 'reward_at_done' is shared between the agents; 'dest_reach_reward' is bound to a specific one.
:type reward_at_done: object
:type reward_at_done: float
:param reward_at_done: Specifies the reward, all agent get, when any destinations has been reached.
Default {d.REWARD_DEST_DONE}
:type dest_reach_reward: float
@ -87,67 +87,29 @@ class DoneAtDestinationReachAny(DestinationReachReward):
def on_check_done(self, state) -> List[DoneResult]:
if any(x.was_reached() for x in state[d.DESTINATION]):
return [DoneResult(self.name, validity=c.VALID, reward=marl_factory_grid.modules.destinations.constants.REWARD_DEST_REACHED)]
return [DoneResult(self.name, validity=c.VALID, reward=d.REWARD_DEST_REACHED)]
return []
class SpawnDestinations(Rule):
def __init__(self, n_dests: int = 1, spawn_mode: str = d.MODE_GROUPED):
f"""
Defines how destinations are initially spawned and respawned in addition.
!!! This rule introduces no kind of reward or Env.-Done condition!
:type n_dests: int
:param n_dests: How many destiantions should be maintained (and initally spawnewd) on the map?
:type spawn_mode: str
:param spawn_mode: One of {d.SPAWN_MODES}. {d.MODE_GROUPED}: Always wait for all Dstiantions do be gone,
then respawn after the given time. {d.MODE_SINGLE}: Just spawn every destination,
that has been reached, after the given time
"""
super(SpawnDestinations, self).__init__()
self.n_dests = n_dests
self.spawn_mode = spawn_mode
def on_init(self, state, lvl_map):
# noinspection PyAttributeOutsideInit
state[d.DESTINATION].trigger_destination_spawn(self.n_dests, state)
pass
def tick_pre_step(self, state) -> List[TickResult]:
pass
def tick_step(self, state) -> List[TickResult]:
if n_dest_spawn := max(0, self.n_dests - len(state[d.DESTINATION])):
if self.spawn_mode == d.MODE_GROUPED and n_dest_spawn == self.n_dests:
validity = state[d.DESTINATION].trigger_destination_spawn(n_dest_spawn, state)
return [TickResult(self.name, validity=validity, entity=None, value=n_dest_spawn)]
elif self.spawn_mode == d.MODE_SINGLE and n_dest_spawn:
validity = state[d.DESTINATION].trigger_destination_spawn(n_dest_spawn, state)
return [TickResult(self.name, validity=validity, entity=None, value=n_dest_spawn)]
else:
pass
class SpawnDestinationsPerAgent(Rule):
def __init__(self, per_agent_positions: Dict[str, List[Tuple[int, int]]]):
def __init__(self, coords_or_quantity: Dict[str, List[Tuple[int, int]]]):
"""
Special rule, that spawn distinations, that are bound to a single agent a fixed set of positions.
Usefull for introducing specialists, etc. ..
!!! This rule does not introduce any reward or done condition.
:type per_agent_positions: Dict[str, List[Tuple[int, int]]
:param per_agent_positions: Please provide a dictionary with agent names as keys; and a list of possible
:type coords_or_quantity: Dict[str, List[Tuple[int, int]]
:param coords_or_quantity: Please provide a dictionary with agent names as keys; and a list of possible
destiantion coords as value. Example: {Wolfgang: [(0, 0), (1, 1), ...]}
"""
super(Rule, self).__init__()
self.per_agent_positions = {key: [ast.literal_eval(x) for x in val] for key, val in per_agent_positions.items()}
self.per_agent_positions = {key: [ast.literal_eval(x) for x in val] for key, val in coords_or_quantity.items()}
def on_init(self, state, lvl_map):
for (agent_name, position_list) in self.per_agent_positions.items():
agent = next(x for x in state[c.AGENT] if agent_name in x.name) # Fixme: Ugly AF
agent = h.get_first(state[c.AGENT], lambda x: agent_name in x.name)
assert agent
position_list = position_list.copy()
shuffle(position_list)
while True:
@ -155,7 +117,7 @@ class SpawnDestinationsPerAgent(Rule):
pos = position_list.pop()
except IndexError:
print(f"Could not spawn Destinations at: {self.per_agent_positions[agent_name]}")
print(f'Check your agent palcement: {state[c.AGENT]} ... Exit ...')
print(f'Check your agent placement: {state[c.AGENT]} ... Exit ...')
exit(9999)
if (not pos == agent.pos) and (not state[d.DESTINATION].by_pos(pos)):
destination = Destination(pos, bind_to=agent)

View File

@ -1,4 +1,5 @@
from marl_factory_grid.environment.entity.entity import Entity
from marl_factory_grid.utils import Result
from marl_factory_grid.utils.utility_classes import RenderEntity
from marl_factory_grid.environment import constants as c
@ -41,21 +42,6 @@ class Door(Entity):
def str_state(self):
return 'open' if self.is_open else 'closed'
def __init__(self, *args, closed_on_init=True, auto_close_interval=10, **kwargs):
self._status = d.STATE_CLOSED
super(Door, self).__init__(*args, **kwargs)
self.auto_close_interval = auto_close_interval
self.time_to_close = 0
if not closed_on_init:
self._open()
else:
self._close()
def summarize_state(self):
state_dict = super().summarize_state()
state_dict.update(state=str(self.str_state), time_to_close=int(self.time_to_close))
return state_dict
@property
def is_closed(self):
return self._status == d.STATE_CLOSED
@ -68,6 +54,25 @@ class Door(Entity):
def status(self):
return self._status
@property
def time_to_close(self):
return self._time_to_close
def __init__(self, *args, closed_on_init=True, auto_close_interval=10, **kwargs):
self._status = d.STATE_CLOSED
super(Door, self).__init__(*args, **kwargs)
self._auto_close_interval = auto_close_interval
self._time_to_close = 0
if not closed_on_init:
self._open()
else:
self._close()
def summarize_state(self):
state_dict = super().summarize_state()
state_dict.update(state=str(self.str_state), time_to_close=self.time_to_close)
return state_dict
def render(self):
name, state = 'door_open' if self.is_open else 'door_closed', 'blank'
return RenderEntity(name, self.pos, 1, 'none', state, self.u_int + 1)
@ -80,18 +85,35 @@ class Door(Entity):
return c.VALID
def tick(self, state):
if self.is_open and len(state.entities.pos_dict[self.pos]) == 2 and self.time_to_close:
self.time_to_close -= 1
return c.NOT_VALID
elif self.is_open and not self.time_to_close and len(state.entities.pos_dict[self.pos]) == 2:
self.use()
return c.VALID
# Check if no entity is standing in the door
if len(state.entities.pos_dict[self.pos]) <= 2:
if self.is_open and self.time_to_close:
self._decrement_timer()
return Result(f"{d.DOOR}_tick", c.VALID, entity=self)
elif self.is_open and not self.time_to_close:
self.use()
return Result(f"{d.DOOR}_closed", c.VALID, entity=self)
else:
# No one is in door, but it is closed... Nothing to do....
return None
else:
return c.NOT_VALID
# Entity is standing in the door, reset timer
self._reset_timer()
return Result(f"{d.DOOR}_reset", c.VALID, entity=self)
def _open(self):
self._status = d.STATE_OPEN
self.time_to_close = self.auto_close_interval
self._reset_timer()
return True
def _close(self):
self._status = d.STATE_CLOSED
return True
def _decrement_timer(self):
self._time_to_close -= 1
return True
def _reset_timer(self):
self._time_to_close = self._auto_close_interval
return True

View File

@ -1,5 +1,3 @@
from typing import Union
from marl_factory_grid.environment.groups.collection import Collection
from marl_factory_grid.modules.doors import constants as d
from marl_factory_grid.modules.doors.entitites import Door
@ -18,8 +16,10 @@ class Doors(Collection):
super(Doors, self).__init__(*args, can_collide=True, **kwargs)
def tick_doors(self, state):
result_dict = dict()
results = list()
for door in self:
did_tick = door.tick(state)
result_dict.update({door.name: did_tick})
return result_dict
tick_result = door.tick(state)
if tick_result is not None:
results.append(tick_result)
# TODO: Should return a Result object, not a random dict.
return results

View File

@ -1,2 +1,2 @@
USE_DOOR_VALID: float = -0.00
USE_DOOR_FAIL: float = -0.01
USE_DOOR_FAIL: float = -0.01

View File

@ -19,10 +19,10 @@ class DoorAutoClose(Rule):
def tick_step(self, state):
if doors := state[d.DOORS]:
doors_tick_result = doors.tick_doors(state)
doors_that_ticked = [key for key, val in doors_tick_result.items() if val]
state.print(f'{doors_that_ticked} were auto-closed'
if doors_that_ticked else 'No Doors were auto-closed')
doors_tick_results = doors.tick_doors(state)
doors_that_closed = [x.entity.name for x in doors_tick_results if 'closed' in x.identifier]
door_str = doors_that_closed if doors_that_closed else "No Doors"
state.print(f'{door_str} were auto-closed')
return [TickResult(self.name, validity=c.VALID, value=1)]
state.print('There are no doors, but you loaded the corresponding Module')
return []

View File

@ -1,8 +1,8 @@
import random
from typing import List, Union
from typing import List
from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.utils.results import TickResult
@ -14,8 +14,6 @@ class AgentSingleZonePlacementBeta(Rule):
super().__init__()
def on_init(self, state, lvl_map):
zones = state[c.ZONES]
n_zones = state[c.ZONES]
agents = state[c.AGENT]
if len(self.coordinates) == len(agents):
coordinates = self.coordinates
@ -31,4 +29,4 @@ class AgentSingleZonePlacementBeta(Rule):
return []
def tick_post_step(self, state) -> List[TickResult]:
return []
return []

View File

@ -1,4 +1,3 @@
from .actions import ItemAction
from .entitites import Item, DropOffLocation
from .groups import DropOffLocations, Items, Inventory, Inventories
from .rules import ItemRules

View File

@ -29,7 +29,7 @@ class ItemAction(Action):
elif items := state[i.ITEM].by_pos(entity.pos):
item = items[0]
item.change_parent_collection(inventory)
item.set_pos_to(c.VALUE_NO_POS)
item.set_pos(c.VALUE_NO_POS)
state.print(f'{entity.name} just picked up an item at {entity.pos}')
return ActionResult(entity=entity, identifier=self._identifier, validity=c.VALID, reward=r.PICK_UP_VALID)

View File

@ -1,6 +1,3 @@
from typing import NamedTuple
SYMBOL_NO_ITEM = 0
SYMBOL_DROP_OFF = 1
# Item Env

View File

@ -8,56 +8,20 @@ from marl_factory_grid.modules.items import constants as i
class Item(Entity):
@property
def var_can_collide(self):
return False
def render(self):
return RenderEntity(i.ITEM, self.pos) if self.pos != c.VALUE_NO_POS else None
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._auto_despawn = -1
@property
def auto_despawn(self):
return self._auto_despawn
@property
def encoding(self):
# Edit this if you want items to be drawn in the ops differently
return 1
def set_auto_despawn(self, auto_despawn):
self._auto_despawn = auto_despawn
def set_pos_to(self, no_pos):
self._pos = no_pos
def summarize_state(self) -> dict:
super_summarization = super(Item, self).summarize_state()
super_summarization.update(dict(auto_despawn=self.auto_despawn))
return super_summarization
class DropOffLocation(Entity):
@property
def var_can_collide(self):
return False
@property
def var_can_move(self):
return False
@property
def var_is_blocking_light(self):
return False
@property
def var_has_position(self):
return True
def render(self):
return RenderEntity(i.DROP_OFF, self.pos)
@ -65,18 +29,16 @@ class DropOffLocation(Entity):
def encoding(self):
return i.SYMBOL_DROP_OFF
def __init__(self, *args, storage_size_until_full: int = 5, auto_item_despawn_interval: int = 5, **kwargs):
def __init__(self, *args, storage_size_until_full: int = 5, **kwargs):
super(DropOffLocation, self).__init__(*args, **kwargs)
self.auto_item_despawn_interval = auto_item_despawn_interval
self.storage = deque(maxlen=storage_size_until_full or None)
def place_item(self, item: Item):
if self.is_full:
raise RuntimeWarning("There is currently no way to clear the storage or make it unfull.")
return bc.NOT_VALID # in Zeile 81 verschieben?
return bc.NOT_VALID
else:
self.storage.append(item)
item.set_auto_despawn(self.auto_item_despawn_interval)
return c.VALID
@property

View File

@ -1,13 +1,11 @@
from random import shuffle
from marl_factory_grid.modules.items import constants as i
from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.groups.collection import Collection
from marl_factory_grid.environment.groups.objects import _Objects
from marl_factory_grid.environment.groups.mixins import IsBoundMixin
from marl_factory_grid.environment.entity.agent import Agent
from marl_factory_grid.environment.groups.collection import Collection
from marl_factory_grid.environment.groups.mixins import IsBoundMixin
from marl_factory_grid.environment.groups.objects import Objects
from marl_factory_grid.modules.items import constants as i
from marl_factory_grid.modules.items.entitites import Item, DropOffLocation
from marl_factory_grid.utils.results import Result
class Items(Collection):
@ -15,7 +13,7 @@ class Items(Collection):
@property
def var_has_position(self):
return False
return True
@property
def is_blocking_light(self):
@ -28,18 +26,18 @@ class Items(Collection):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@staticmethod
def trigger_item_spawn(state, n_items, spawn_frequency):
if item_to_spawns := max(0, (n_items - len(state[i.ITEM]))):
position_list = [x for x in state.entities.floorlist]
shuffle(position_list)
position_list = state.entities.floorlist[:item_to_spawns]
state[i.ITEM].spawn(position_list)
state.print(f'{item_to_spawns} new items have been spawned; next spawn in {spawn_frequency}')
return len(position_list)
def trigger_spawn(self, state, *entity_args, coords_or_quantity=None, **entity_kwargs) -> [Result]:
coords_or_quantity = coords_or_quantity if coords_or_quantity else self._coords_or_quantity
assert coords_or_quantity
if item_to_spawns := max(0, (coords_or_quantity - len(self))):
return super().trigger_spawn(state,
*entity_args,
coords_or_quantity=item_to_spawns,
**entity_kwargs)
else:
state.print('No Items are spawning, limit is reached.')
return 0
return Result(identifier=f'{self.name}_spawn', validity=c.NOT_VALID, value=coords_or_quantity)
class Inventory(IsBoundMixin, Collection):
@ -73,12 +71,17 @@ class Inventory(IsBoundMixin, Collection):
self._collection = collection
class Inventories(_Objects):
class Inventories(Objects):
_entity = Inventory
var_can_move = False
var_has_position = False
symbol = None
@property
def var_can_move(self):
return False
def spawn_rule(self):
return {c.SPAWN_ENTITY_RULE: dict(collection=self, coords_or_quantity=None)}
def __init__(self, size: int, *args, **kwargs):
super(Inventories, self).__init__(*args, **kwargs)
@ -86,10 +89,12 @@ class Inventories(_Objects):
self._obs = None
self._lazy_eval_transforms = []
def spawn(self, agents):
inventories = [self._entity(agent, self.size, )
for _, agent in enumerate(agents)]
self.add_items(inventories)
def spawn(self, agents, *args, **kwargs):
self.add_items([self._entity(agent, self.size, *args, **kwargs) for _, agent in enumerate(agents)])
return [Result(identifier=f'{self.name}_spawn', validity=c.VALID, value=len(self))]
def trigger_spawn(self, state, *args, **kwargs) -> [Result]:
return self.spawn(state[c.AGENT], *args, **kwargs)
def idx_by_entity(self, entity):
try:
@ -106,10 +111,6 @@ class Inventories(_Objects):
def summarize_states(self, **kwargs):
return [val.summarize_states(**kwargs) for key, val in self.items()]
@staticmethod
def trigger_inventory_spawn(state):
state[i.INVENTORY].spawn(state[c.AGENT])
class DropOffLocations(Collection):
_entity = DropOffLocation
@ -135,7 +136,7 @@ class DropOffLocations(Collection):
@staticmethod
def trigger_drop_off_location_spawn(state, n_locations):
empty_positions = state.entities.empty_positions()[:n_locations]
empty_positions = state.entities.empty_positions[:n_locations]
do_entites = state[i.DROP_OFF]
drop_offs = [DropOffLocation(pos) for pos in empty_positions]
do_entites.add_items(drop_offs)

View File

@ -1,4 +1,4 @@
DROP_OFF_VALID: float = 0.1
DROP_OFF_FAIL: float = -0.1
PICK_UP_FAIL: float = -0.1
PICK_UP_VALID: float = 0.1
PICK_UP_VALID: float = 0.1

View File

@ -6,52 +6,28 @@ from marl_factory_grid.utils.results import TickResult
from marl_factory_grid.modules.items import constants as i
class ItemRules(Rule):
class RespawnItems(Rule):
def __init__(self, n_items: int = 5, spawn_frequency: int = 15,
n_locations: int = 5, max_dropoff_storage_size: int = 0):
def __init__(self, n_items: int = 5, respawn_freq: int = 15, n_locations: int = 5):
super().__init__()
self.spawn_frequency = spawn_frequency
self._next_item_spawn = spawn_frequency
self.spawn_frequency = respawn_freq
self._next_item_spawn = respawn_freq
self.n_items = n_items
self.max_dropoff_storage_size = max_dropoff_storage_size
self.n_locations = n_locations
def on_init(self, state, lvl_map):
state[i.DROP_OFF].trigger_drop_off_location_spawn(state, self.n_locations)
self._next_item_spawn = self.spawn_frequency
state[i.INVENTORY].trigger_inventory_spawn(state)
state[i.ITEM].trigger_item_spawn(state, self.n_items, self.spawn_frequency)
def tick_step(self, state):
for item in list(state[i.ITEM].values()):
if item.auto_despawn >= 1:
item.set_auto_despawn(item.auto_despawn - 1)
elif not item.auto_despawn:
state[i.ITEM].delete_env_object(item)
else:
pass
if not self._next_item_spawn:
state[i.ITEM].trigger_item_spawn(state, self.n_items, self.spawn_frequency)
state[i.ITEM].trigger_spawn(state, self.n_items, self.spawn_frequency)
else:
self._next_item_spawn = max(0, self._next_item_spawn - 1)
return []
def tick_post_step(self, state) -> List[TickResult]:
for item in list(state[i.ITEM].values()):
if item.auto_despawn >= 1:
item.set_auto_despawn(item.auto_despawn-1)
elif not item.auto_despawn:
state[i.ITEM].delete_env_object(item)
else:
pass
if not self._next_item_spawn:
if spawned_items := state[i.ITEM].trigger_item_spawn(state, self.n_items, self.spawn_frequency):
return [TickResult(self.name, validity=c.VALID, value=spawned_items, entity=None)]
if spawned_items := state[i.ITEM].trigger_spawn(state, self.n_items, self.spawn_frequency):
return [TickResult(self.name, validity=c.VALID, value=spawned_items.value)]
else:
return [TickResult(self.name, validity=c.NOT_VALID, value=0, entity=None)]
return [TickResult(self.name, validity=c.NOT_VALID, value=0)]
else:
self._next_item_spawn = max(0, self._next_item_spawn-1)
return []

View File

@ -1,3 +1,2 @@
from .entitites import Machine
from .groups import Machines
from .rules import MachineRule

View File

@ -1,10 +1,12 @@
from typing import Union
import marl_factory_grid.modules.machines.constants
from marl_factory_grid.environment.actions import Action
from marl_factory_grid.utils.results import ActionResult
from marl_factory_grid.modules.machines import constants as m, rewards as r
from marl_factory_grid.modules.machines import constants as m
from marl_factory_grid.environment import constants as c
from marl_factory_grid.utils import helpers as h
class MachineAction(Action):
@ -13,13 +15,12 @@ class MachineAction(Action):
super().__init__(m.MACHINE_ACTION)
def do(self, entity, state) -> Union[None, ActionResult]:
if machine := state[m.MACHINES].by_pos(entity.pos):
if machine := h.get_first(state[m.MACHINES].by_pos(entity.pos)):
if valid := machine.maintain():
return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=r.MAINTAIN_VALID)
return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=marl_factory_grid.modules.machines.constants.MAINTAIN_VALID)
else:
return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=r.MAINTAIN_FAIL)
return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=marl_factory_grid.modules.machines.constants.MAINTAIN_FAIL)
else:
return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=r.MAINTAIN_FAIL)
return ActionResult(entity=entity, identifier=self._identifier,
validity=c.NOT_VALID, reward=marl_factory_grid.modules.machines.constants.MAINTAIN_FAIL
)

View File

@ -11,3 +11,7 @@ STATE_MAINTAIN = 'maintenance'
SYMBOL_WORK = 1
SYMBOL_IDLE = 0.6
SYMBOL_MAINTAIN = 0.3
MAINTAIN_VALID: float = 0.5
MAINTAIN_FAIL: float = -0.1
FAIL_MISSING_MAINTENANCE: float = -0.5
NONE: float = 0

View File

@ -8,22 +8,6 @@ from . import constants as m
class Machine(Entity):
@property
def var_can_collide(self):
return False
@property
def var_can_move(self):
return False
@property
def var_is_blocking_light(self):
return False
@property
def var_has_position(self):
return True
@property
def encoding(self):
return self._encodings[self.status]
@ -46,12 +30,11 @@ class Machine(Entity):
else:
return c.NOT_VALID
def tick(self):
# if self.status == m.STATE_MAINTAIN and any([c.AGENT in x.name for x in self.tile.guests]):
if self.status == m.STATE_MAINTAIN and any([c.AGENT in x.name for x in self.state.entities.pos_dict[self.pos]]):
return TickResult(identifier=self.name, validity=c.VALID, reward=0, entity=self)
# elif self.status == m.STATE_MAINTAIN and not any([c.AGENT in x.name for x in self.tile.guests]):
elif self.status == m.STATE_MAINTAIN and not any([c.AGENT in x.name for x in self.state.entities.pos_dict[self.pos]]):
def tick(self, state):
others = state.entities.pos_dict[self.pos]
if self.status == m.STATE_MAINTAIN and any([c.AGENT in x.name for x in others]):
return TickResult(identifier=self.name, validity=c.VALID, entity=self)
elif self.status == m.STATE_MAINTAIN and not any([c.AGENT in x.name for x in others]):
self.status = m.STATE_WORK
self.reset_counter()
return None

View File

@ -1,5 +1,3 @@
from typing import Union, List, Tuple
from marl_factory_grid.environment.groups.collection import Collection
from .entitites import Machine

View File

@ -1,5 +0,0 @@
MAINTAIN_VALID: float = 0.5
MAINTAIN_FAIL: float = -0.1
FAIL_MISSING_MAINTENANCE: float = -0.5
NONE: float = 0

View File

@ -1,28 +0,0 @@
from typing import List
from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.utils.results import TickResult, DoneResult
from marl_factory_grid.environment import constants as c
from marl_factory_grid.modules.machines import constants as m
from marl_factory_grid.modules.machines.entitites import Machine
class MachineRule(Rule):
def __init__(self, n_machines: int = 2):
super(MachineRule, self).__init__()
self.n_machines = n_machines
def on_init(self, state, lvl_map):
state[m.MACHINES].spawn(state.entities.empty_positions())
def tick_pre_step(self, state) -> List[TickResult]:
pass
def tick_step(self, state) -> List[TickResult]:
pass
def tick_post_step(self, state) -> List[TickResult]:
pass
def on_check_done(self, state) -> List[DoneResult]:
pass

View File

@ -1,3 +1,4 @@
MAINTAINER = 'Maintainer' # TEMPLATE _identifier. Define your own!
MAINTAINERS = 'Maintainers' # TEMPLATE _identifier. Define your own!
MAINTAINER_COLLISION_REWARD = -5

View File

@ -1,48 +1,35 @@
from random import shuffle
import networkx as nx
import numpy as np
from ...algorithms.static.utils import points_to_graph
from ...environment import constants as c
from ...environment.actions import Action, ALL_BASEACTIONS
from ...environment.entity.entity import Entity
from ..doors import constants as do
from ..maintenance import constants as mi
from ...utils.helpers import MOVEMAP
from ...utils.utility_classes import RenderEntity
from ...utils.states import Gamestate
from ...utils import helpers as h
from ...utils.utility_classes import RenderEntity, Floor
from ..doors import DoorUse
class Maintainer(Entity):
@property
def var_can_collide(self):
return True
@property
def var_can_move(self):
return False
@property
def var_is_blocking_light(self):
return False
@property
def var_has_position(self):
return True
def __init__(self, state: Gamestate, objective: str, action: Action, *args, **kwargs):
def __init__(self, objective: str, action: Action, *args, **kwargs):
super().__init__(*args, **kwargs)
self.action = action
self.actions = [x() for x in ALL_BASEACTIONS]
self.actions = [x() for x in ALL_BASEACTIONS] + [DoorUse()]
self.objective = objective
self._path = None
self._next = []
self._last = []
self._last_serviced = 'None'
self._floortile_graph = points_to_graph(state.entities.floorlist)
self._floortile_graph = None
def tick(self, state):
if found_objective := state[self.objective].by_pos(self.pos):
if found_objective := h.get_first(state[self.objective].by_pos(self.pos)):
if found_objective.name != self._last_serviced:
self.action.do(self, state)
self._last_serviced = found_objective.name
@ -54,24 +41,27 @@ class Maintainer(Entity):
return action.do(self, state)
def get_move_action(self, state) -> Action:
if not self._floortile_graph:
state.print("Generating Floorgraph....")
self._floortile_graph = points_to_graph(state.entities.floorlist)
if self._path is None or not self._path:
if not self._next:
self._next = list(state[self.objective].values())
self._next = list(state[self.objective].values()) + [Floor(*state.random_free_position)]
shuffle(self._next)
self._last = []
self._last.append(self._next.pop())
state.print("Calculating shortest path....")
self._path = self.calculate_route(self._last[-1])
if door := self._door_is_close(state):
if door.is_closed:
# Translate the action_object to an integer to have the same output as any other model
action = do.ACTION_DOOR_USE
else:
action = self._predict_move(state)
if door := self._closed_door_in_path(state):
state.print(f"{self} found {door} that is closed. Attempt to open.")
# Translate the action_object to an integer to have the same output as any other model
action = do.ACTION_DOOR_USE
else:
action = self._predict_move(state)
# Translate the action_object to an integer to have the same output as any other model
try:
action_obj = next(x for x in self.actions if x.name == action)
action_obj = h.get_first(self.actions, lambda x: x.name == action)
except (StopIteration, UnboundLocalError):
print('Will not happen')
raise EnvironmentError
@ -81,11 +71,10 @@ class Maintainer(Entity):
route = nx.shortest_path(self._floortile_graph, self.pos, entity.pos)
return route[1:]
def _door_is_close(self, state):
state.print("Found a door that is close.")
try:
return next(y for x in state.entities.neighboring_positions(self.state.pos) for y in state.entities.pos_dict[x] if do.DOOR in y.name)
except StopIteration:
def _closed_door_in_path(self, state):
if self._path:
return h.get_first(state[do.DOORS].by_pos(self._path[0]), lambda x: x.is_closed)
else:
return None
def _predict_move(self, state):
@ -96,7 +85,7 @@ class Maintainer(Entity):
next_pos = self._path.pop(0)
diff = np.subtract(next_pos, self.pos)
# Retrieve action based on the pos dif (like in: What do I have to do to get there?)
action = next(action for action, pos_diff in MOVEMAP.items() if np.all(diff == pos_diff))
action = next(action for action, pos_diff in h.MOVEMAP.items() if np.all(diff == pos_diff))
return action
def render(self):

View File

@ -1,34 +1,27 @@
from typing import Union, List, Tuple
from typing import Union, List, Tuple, Dict
from marl_factory_grid.environment.groups.collection import Collection
from .entities import Maintainer
from ..machines import constants as mc
from ..machines.actions import MachineAction
from ...utils.states import Gamestate
class Maintainers(Collection):
_entity = Maintainer
@property
def var_can_collide(self):
return True
var_can_collide = True
var_can_move = True
var_is_blocking_light = False
var_has_position = True
@property
def var_can_move(self):
return True
def __init__(self, size, *args, coords_or_quantity: int = None,
spawnrule: Union[None, Dict[str, dict]] = None,
**kwargs):
super(Collection, self).__init__(*args, **kwargs)
self._coords_or_quantity = coords_or_quantity
self.size = size
self._spawnrule = spawnrule
@property
def var_is_blocking_light(self):
return False
@property
def var_has_position(self):
return True
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args):
state = entity_args[0]
self.add_items([self._entity(state, mc.MACHINES, MachineAction(), pos) for pos in coords_or_quantity])
self.add_items([self._entity(mc.MACHINES, MachineAction(), pos) for pos in coords_or_quantity])

View File

@ -1 +0,0 @@
MAINTAINER_COLLISION_REWARD = -5

View File

@ -1,32 +1,28 @@
from typing import List
import marl_factory_grid.modules.maintenance.constants
from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.utils.results import TickResult, DoneResult
from marl_factory_grid.environment import constants as c
from . import rewards as r
from . import constants as M
from marl_factory_grid.utils.states import Gamestate
class MaintenanceRule(Rule):
class MoveMaintainers(Rule):
def __init__(self, n_maintainer: int = 1, *args, **kwargs):
super(MaintenanceRule, self).__init__(*args, **kwargs)
self.n_maintainer = n_maintainer
def on_init(self, state: Gamestate, lvl_map):
state[M.MAINTAINERS].spawn(state.entities.empty_positions[:self.n_maintainer], state)
pass
def tick_pre_step(self, state) -> List[TickResult]:
pass
def __init__(self):
super().__init__()
def tick_step(self, state) -> List[TickResult]:
for maintainer in state[M.MAINTAINERS]:
maintainer.tick(state)
# Todo: Return a Result Object.
return []
def tick_post_step(self, state) -> List[TickResult]:
pass
class DoneAtMaintainerCollision(Rule):
def __init__(self):
super().__init__()
def on_check_done(self, state) -> List[DoneResult]:
agents = list(state[c.AGENT].values())
@ -35,5 +31,5 @@ class MaintenanceRule(Rule):
for agent in agents:
if agent.pos in m_pos:
done_results.append(DoneResult(entity=agent, validity=c.VALID, identifier=self.name,
reward=r.MAINTAINER_COLLISION_REWARD))
reward=marl_factory_grid.modules.maintenance.constants.MAINTAINER_COLLISION_REWARD))
return done_results

View File

@ -1,10 +1,10 @@
import random
from typing import List, Tuple
from marl_factory_grid.environment.entity.object import _Object
from marl_factory_grid.environment.entity.object import Object
class Zone(_Object):
class Zone(Object):
@property
def positions(self):

View File

@ -1,8 +1,8 @@
from marl_factory_grid.environment.groups.objects import _Objects
from marl_factory_grid.environment.groups.objects import Objects
from marl_factory_grid.modules.zones import Zone
class Zones(_Objects):
class Zones(Objects):
symbol = None
_entity = Zone

View File

@ -1,8 +1,8 @@
from random import choices, choice
from . import constants as z, Zone
from .. import Destination
from ..destinations import constants as d
from ... import Destination
from ...environment.rules import Rule
from ...environment import constants as c

View File

@ -0,0 +1,3 @@
from . import helpers as h
from . import helpers
from .results import Result, DoneResult, ActionResult, TickResult

View File

@ -1,4 +1,5 @@
import ast
from os import PathLike
from pathlib import Path
from typing import Union, List
@ -9,18 +10,17 @@ from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.environment.tests import Test
from marl_factory_grid.utils.helpers import locate_and_import_class
DEFAULT_PATH = 'environment'
MODULE_PATH = 'modules'
from marl_factory_grid.environment.constants import DEFAULT_PATH, MODULE_PATH
from marl_factory_grid.environment import constants as c
class FactoryConfigParser(object):
default_entites = []
default_rules = ['MaxStepsReached', 'Collision']
default_rules = ['DoneAtMaxStepsReached', 'WatchCollision']
default_actions = [c.MOVE8, c.NOOP]
default_observations = [c.WALLS, c.AGENT]
def __init__(self, config_path, custom_modules_path: Union[None, PathLike] = None):
def __init__(self, config_path, custom_modules_path: Union[PathLike] = None):
self.config_path = Path(config_path)
self.custom_modules_path = Path(custom_modules_path) if custom_modules_path is not None else custom_modules_path
self.config = yaml.safe_load(self.config_path.open())
@ -44,6 +44,10 @@ class FactoryConfigParser(object):
def rules(self):
return self.config['Rules']
@property
def tests(self):
return self.config.get('Tests', [])
@property
def agents(self):
return self.config['Agents']
@ -56,10 +60,12 @@ class FactoryConfigParser(object):
return str(self.config)
def __getitem__(self, item):
return self.config[item]
try:
return self.config[item]
except KeyError:
print(f'The mandatory {item} section could not be found in your .config gile. Check Spelling!')
def load_entities(self):
# entites = Entities()
entity_classes = dict()
entities = []
if c.DEFAULTS in self.entities:
@ -67,28 +73,40 @@ class FactoryConfigParser(object):
entities.extend(x for x in self.entities if x != c.DEFAULTS)
for entity in entities:
e1 = e2 = e3 = None
try:
folder_path = Path(__file__).parent.parent / DEFAULT_PATH
entity_class = locate_and_import_class(entity, folder_path)
except AttributeError as e1:
except AttributeError as e:
e1 = e
try:
folder_path = Path(__file__).parent.parent / MODULE_PATH
entity_class = locate_and_import_class(entity, folder_path)
except AttributeError as e2:
try:
folder_path = self.custom_modules_path
entity_class = locate_and_import_class(entity, folder_path)
except AttributeError as e3:
ents = [y for x in [e1.argss[1], e2.argss[1], e3.argss[1]] for y in x]
print('### Error ### Error ### Error ### Error ### Error ###')
print()
print(f'Class "{entity}" was not found in "{folder_path.name}"')
print('Possible Entitys are:', str(ents))
print()
print('Goodbye')
print()
exit()
# raise AttributeError(e1.argss[0], e2.argss[0], e3.argss[0], 'Possible Entitys are:', str(ents))
module_path = Path(__file__).parent.parent / MODULE_PATH
entity_class = locate_and_import_class(entity, module_path)
except AttributeError as e:
e2 = e
if self.custom_modules_path:
try:
entity_class = locate_and_import_class(entity, self.custom_modules_path)
except AttributeError as e:
e3 = e
pass
if (e1 and e2) or e3:
ents = [y for x in [e1, e2, e3] if x is not None for y in x.args[1]]
print('##############################################################')
print('### Error ### Error ### Error ### Error ### Error ###')
print('##############################################################')
print(f'Class "{entity}" was not found in "{module_path.name}"')
print(f'Class "{entity}" was not found in "{folder_path.name}"')
print('##############################################################')
if self.custom_modules_path:
print(f'Class "{entity}" was not found in "{self.custom_modules_path}"')
print('Possible Entitys are:', str(ents))
print('##############################################################')
print('Goodbye')
print('##############################################################')
print('### Error ### Error ### Error ### Error ### Error ###')
print('##############################################################')
exit(-99999)
entity_kwargs = self.entities.get(entity, {})
entity_symbol = entity_class.symbol if hasattr(entity_class, 'symbol') else None
@ -126,7 +144,12 @@ class FactoryConfigParser(object):
observations.extend(self.default_observations)
observations.extend(x for x in self.agents[name]['Observations'] if x != c.DEFAULTS)
positions = [ast.literal_eval(x) for x in self.agents[name].get('Positions', [])]
parsed_agents_conf[name] = dict(actions=parsed_actions, observations=observations, positions=positions)
other_kwargs = {k: v for k, v in self.agents[name].items() if k not in
['Actions', 'Observations', 'Positions']}
parsed_agents_conf[name] = dict(
actions=parsed_actions, observations=observations, positions=positions, other=other_kwargs
)
return parsed_agents_conf
def load_env_rules(self) -> List[Rule]:
@ -137,28 +160,69 @@ class FactoryConfigParser(object):
rules.append({rule: {}})
return self._load_smth(rules, Rule)
pass
def load_env_tests(self) -> List[Test]:
def load_env_tests(self) -> List[Rule]:
return self._load_smth(self.tests, None) # Test
pass
def _load_smth(self, config, class_obj):
rules = list()
rules_names = list()
for rule in rules_names:
for rule in config:
e1 = e2 = e3 = None
try:
folder_path = (Path(__file__).parent.parent / DEFAULT_PATH)
rule_class = locate_and_import_class(rule, folder_path)
except AttributeError:
except AttributeError as e:
e1 = e
try:
folder_path = (Path(__file__).parent.parent / MODULE_PATH)
rule_class = locate_and_import_class(rule, folder_path)
except AttributeError:
rule_class = locate_and_import_class(rule, self.custom_modules_path)
# Fixme This check does not work!
# assert isinstance(rule_class, class_obj), f'{rule_class.__name__} is no valid "class_obj.__name__".'
rule_kwargs = config.get(rule, {})
rules.append(rule_class(**rule_kwargs))
module_path = (Path(__file__).parent.parent / MODULE_PATH)
rule_class = locate_and_import_class(rule, module_path)
except AttributeError as e:
e2 = e
if self.custom_modules_path:
try:
rule_class = locate_and_import_class(rule, self.custom_modules_path)
except AttributeError as e:
e3 = e
pass
if (e1 and e2) or e3:
ents = [y for x in [e1, e2, e3] if x is not None for y in x.args[1]]
print('### Error ### Error ### Error ### Error ### Error ###')
print('')
print(f'Class "{rule}" was not found in "{module_path.name}"')
print(f'Class "{rule}" was not found in "{folder_path.name}"')
if self.custom_modules_path:
print(f'Class "{rule}" was not found in "{self.custom_modules_path}"')
print('Possible Entitys are:', str(ents))
print('')
print('Goodbye')
print('')
exit(-99999)
if issubclass(rule_class, class_obj):
rule_kwargs = config.get(rule, {})
rules.append(rule_class(**(rule_kwargs or {})))
return rules
def load_entity_spawn_rules(self, entities) -> List[Rule]:
rules = list()
rules_dicts = list()
for e in entities:
try:
if spawn_rule := e.spawn_rule:
rules_dicts.append(spawn_rule)
except AttributeError:
pass
for rule_dict in rules_dicts:
for rule_name, rule_kwargs in rule_dict.items():
try:
folder_path = (Path(__file__).parent.parent / DEFAULT_PATH)
rule_class = locate_and_import_class(rule_name, folder_path)
except AttributeError:
try:
folder_path = (Path(__file__).parent.parent / MODULE_PATH)
rule_class = locate_and_import_class(rule_name, folder_path)
except AttributeError:
rule_class = locate_and_import_class(rule_name, self.custom_modules_path)
rules.append(rule_class(**rule_kwargs))
return rules

View File

@ -2,7 +2,7 @@ import importlib
from collections import defaultdict
from pathlib import PurePath, Path
from typing import Union, Dict, List
from typing import Union, Dict, List, Iterable, Callable
import numpy as np
from numpy.typing import ArrayLike
@ -61,8 +61,8 @@ class ObservationTranslator:
:param per_agent_named_obs_spaces: `Named observation space` one for each agent. Overloaded.
type per_agent_named_obs_spaces: Dict[str, dict]
:param placeholder_fill_value: Currently not fully implemented!!!
:type placeholder_fill_value: Union[int, str] = 'N')
:param placeholder_fill_value: Currently, not fully implemented!!!
:type placeholder_fill_value: Union[int, str] = 'N'
"""
if isinstance(placeholder_fill_value, str):
@ -222,7 +222,7 @@ def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''):
mod = importlib.import_module('.'.join(module_parts))
all_found_modules.extend([x for x in dir(mod) if (not(x.startswith('__') or len(x) <= 2) and x.istitle())
and x not in ['Entity', 'NamedTuple', 'List', 'Rule', 'Union',
'TickResult', 'ActionResult', 'Action', 'Agent', 'BoundEntityMixin',
'TickResult', 'ActionResult', 'Action', 'Agent',
'RenderEntity', 'TemplateRule', 'Objects', 'PositionMixin',
'IsBoundMixin', 'EnvObject', 'EnvObjects', 'Dict', 'Any'
]])
@ -240,7 +240,13 @@ def add_bound_name(name_str, bound_e):
def add_pos_name(name_str, bound_e):
if bound_e.var_has_position:
return f'{name_str}({bound_e.pos})'
return f'{name_str}@{bound_e.pos}'
return name_str
def get_first(iterable: Iterable, filter_by: Callable[[any], bool] = lambda _: True):
return next((x for x in iterable if filter_by(x)), None)
def get_first_index(iterable: Iterable, filter_by: Callable[[any], bool] = lambda _: True):
return next((idx for idx, x in enumerate(iterable) if filter_by(x)), None)

View File

@ -47,6 +47,7 @@ class LevelParser(object):
# All other
for es_name in self.e_p_dict:
e_class, e_kwargs = self.e_p_dict[es_name]['class'], self.e_p_dict[es_name]['kwargs']
e_kwargs = e_kwargs if e_kwargs else {}
if hasattr(e_class, 'symbol') and e_class.symbol is not None:
symbols = e_class.symbol

View File

@ -9,7 +9,7 @@ from marl_factory_grid.utils.helpers import IGNORED_DF_COLUMNS
import pandas as pd
from marl_factory_grid.utils.plotting.compare_runs import plot_single_run
from marl_factory_grid.utils.plotting.plot_single_runs import plot_single_run
class EnvMonitor(Wrapper):
@ -22,7 +22,6 @@ class EnvMonitor(Wrapper):
self._monitor_df = pd.DataFrame()
self._monitor_dict = dict()
def step(self, action):
obs_type, obs, reward, done, info = self.env.step(action)
self._read_info(info)

View File

@ -2,11 +2,9 @@ from os import PathLike
from pathlib import Path
from typing import Union, List
import yaml
from gymnasium import Wrapper
import numpy as np
import pandas as pd
from gymnasium import Wrapper
class EnvRecorder(Wrapper):
@ -106,7 +104,7 @@ class EnvRecorder(Wrapper):
out_dict = {'episodes': self._recorder_out_list}
out_dict.update(
{'n_episodes': self._curr_episode,
'metadata':dict(
'metadata': dict(
level_name=self.env.params['General']['level_name'],
verbose=False,
n_agents=len(self.env.params['Agents']),

View File

@ -1,17 +1,16 @@
import math
import re
from collections import defaultdict
from itertools import product
from typing import Dict, List
import numpy as np
from numba import njit
from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.entity.object import Object
from marl_factory_grid.environment.groups.utils import Combined
import marl_factory_grid.utils.helpers as h
from marl_factory_grid.utils.states import Gamestate
from marl_factory_grid.utils.utility_classes import Floor
from marl_factory_grid.utils.ray_caster import RayCaster
from marl_factory_grid.utils.states import Gamestate
from marl_factory_grid.utils import helpers as h
class OBSBuilder(object):
@ -77,11 +76,13 @@ class OBSBuilder(object):
def place_entity_in_observation(self, obs_array, agent, e):
x, y = (e.x - agent.x) + self.pomdp_r, (e.y - agent.y) + self.pomdp_r
try:
obs_array[x, y] += e.encoding
except IndexError:
# Seemded to be visible but is out of range
pass
if not min([y, x]) < 0:
try:
obs_array[x, y] += e.encoding
except IndexError:
# Seemded to be visible but is out of range
pass
pass
def build_for_agent(self, agent, state) -> (List[str], np.ndarray):
assert self._curr_env_step == state.curr_step, (
@ -121,18 +122,24 @@ class OBSBuilder(object):
e = self.all_obs[l_name]
except KeyError:
try:
# Look for bound entity names!
pattern = re.compile(f'{re.escape(l_name)}(.*){re.escape(agent.name)}')
name = next((x for x in self.all_obs if pattern.search(x)), None)
# Look for bound entity REPRs!
pattern = re.compile(f'{re.escape(l_name)}'
f'{re.escape("[")}(.*){re.escape("]")}'
f'{re.escape("(")}{re.escape(agent.name)}{re.escape(")")}')
name = next((key for key, val in self.all_obs.items()
if pattern.search(str(val)) and isinstance(val, Object)), None)
e = self.all_obs[name]
except KeyError:
try:
e = next(v for k, v in self.all_obs.items() if l_name in k and agent.name in k)
except StopIteration:
raise KeyError(
f'Check for spelling errors! \n '
f'No combination of "{l_name} and {agent.name}" could not be found in:\n '
f'{list(dict(self.all_obs).keys())}')
print(f'# Check for spelling errors!')
print(f'# No combination of "{l_name}" and "{agent.name}" could not be found in:')
print(f'# {list(dict(self.all_obs).keys())}')
print('#')
print('# exiting...')
print('#')
exit(-99999)
try:
positional = e.var_has_position
@ -161,31 +168,30 @@ class OBSBuilder(object):
try:
light_map = np.zeros(self.obs_shape)
visible_floor = self.ray_caster[agent.name].visible_entities(self._floortiles, reset_cache=False)
if self.pomdp_r:
for f in set(visible_floor):
self.place_entity_in_observation(light_map, agent, f)
else:
for f in set(visible_floor):
light_map[f.x, f.y] += f.encoding
for f in set(visible_floor):
self.place_entity_in_observation(light_map, agent, f)
# else:
# for f in set(visible_floor):
# light_map[f.x, f.y] += f.encoding
self.curr_lightmaps[agent.name] = light_map
except (KeyError, ValueError):
print()
pass
return obs, self.obs_layers[agent.name]
def _sort_and_name_observation_conf(self, agent):
'''
"""
Builds the useable observation scheme per agent from conf.yaml.
:param agent:
:return:
'''
"""
# Fixme: no asymetric shapes possible.
self.ray_caster[agent.name] = RayCaster(agent, min(self.obs_shape))
obs_layers = []
for obs_str in agent.observations:
if isinstance(obs_str, dict):
obs_str, vals = next(obs_str.items().__iter__())
obs_str, vals = h.get_first(obs_str.items())
else:
vals = None
if obs_str == c.SELF:
@ -214,129 +220,3 @@ class OBSBuilder(object):
obs_layers.append(obs_str)
self.obs_layers[agent.name] = obs_layers
self.curr_lightmaps[agent.name] = np.zeros(self.obs_shape)
class RayCaster:
def __init__(self, agent, pomdp_r, degs=360):
self.agent = agent
self.pomdp_r = pomdp_r
self.n_rays = (self.pomdp_r + 1) * 8
self.degs = degs
self.ray_targets = self.build_ray_targets()
self.obs_shape_cube = np.array([self.pomdp_r, self.pomdp_r])
self._cache_dict = {}
def __repr__(self):
return f'{self.__class__.__name__}({self.agent.name})'
def build_ray_targets(self):
north = np.array([0, -1]) * self.pomdp_r
thetas = [np.deg2rad(deg) for deg in np.linspace(-self.degs // 2, self.degs // 2, self.n_rays)[::-1]]
rot_M = [
[[math.cos(theta), -math.sin(theta)],
[math.sin(theta), math.cos(theta)]] for theta in thetas
]
rot_M = np.stack(rot_M, 0)
rot_M = np.unique(np.round(rot_M @ north), axis=0)
return rot_M.astype(int)
def ray_block_cache(self, key, callback):
if key not in self._cache_dict:
self._cache_dict[key] = callback()
return self._cache_dict[key]
def visible_entities(self, pos_dict, reset_cache=True):
visible = list()
if reset_cache:
self._cache_dict = {}
for ray in self.get_rays():
rx, ry = ray[0]
for x, y in ray:
cx, cy = x - rx, y - ry
entities_hit = pos_dict[(x, y)]
hits = self.ray_block_cache((x, y),
lambda: any(True for e in entities_hit if e.var_is_blocking_light)
)
diag_hits = all([
self.ray_block_cache(
key,
lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light) and bool(
pos_dict[key]))
for key in ((x, y - cy), (x - cx, y))
]) if (cx != 0 and cy != 0) else False
visible += entities_hit if not diag_hits else []
if hits or diag_hits:
break
rx, ry = x, y
return visible
def get_rays(self):
a_pos = self.agent.pos
outline = self.ray_targets + a_pos
return self.bresenham_loop(a_pos, outline)
# todo do this once and cache the points!
def get_fov_outline(self) -> np.ndarray:
return self.ray_targets + self.agent.pos
def get_square_outline(self):
agent = self.agent
x_coords = range(agent.x - self.pomdp_r, agent.x + self.pomdp_r + 1)
y_coords = range(agent.y - self.pomdp_r, agent.y + self.pomdp_r + 1)
outline = list(product(x_coords, [agent.y - self.pomdp_r, agent.y + self.pomdp_r])) \
+ list(product([agent.x - self.pomdp_r, agent.x + self.pomdp_r], y_coords))
return outline
@staticmethod
@njit
def bresenham_loop(a_pos, points):
results = []
for end in points:
x1, y1 = a_pos
x2, y2 = end
dx = x2 - x1
dy = y2 - y1
# Determine how steep the line is
is_steep = abs(dy) > abs(dx)
# Rotate line
if is_steep:
x1, y1 = y1, x1
x2, y2 = y2, x2
# Swap start and end points if necessary and store swap state
swapped = False
if x1 > x2:
x1, x2 = x2, x1
y1, y2 = y2, y1
swapped = True
# Recalculate differentials
dx = x2 - x1
dy = y2 - y1
# Calculate error
error = int(dx / 2.0)
ystep = 1 if y1 < y2 else -1
# Iterate over bounding box generating points between start and end
y = y1
points = []
for x in range(int(x1), int(x2) + 1):
coord = [y, x] if is_steep else [x, y]
points.append(coord)
error -= abs(dy)
if error < 0:
y += ystep
error += dx
# Reverse the list if the coordinates were swapped
if swapped:
points.reverse()
results.append(points)
return results

View File

@ -7,50 +7,11 @@ from typing import Union, List
import pandas as pd
from marl_factory_grid.utils.helpers import IGNORED_DF_COLUMNS
from marl_factory_grid.utils.plotting.plotting import prepare_plot
from marl_factory_grid.utils.plotting.plotting_utils import prepare_plot
MODEL_MAP = None
def plot_single_run(run_path: Union[str, PathLike], use_tex: bool = False, column_keys=None):
run_path = Path(run_path)
df_list = list()
if run_path.is_dir():
monitor_file = next(run_path.glob('*monitor*.pick'))
elif run_path.exists() and run_path.is_file():
monitor_file = run_path
else:
raise ValueError
with monitor_file.open('rb') as f:
monitor_df = pickle.load(f)
monitor_df = monitor_df.fillna(0)
df_list.append(monitor_df)
df = pd.concat(df_list, ignore_index=True)
df = df.fillna(0).rename(columns={'episode': 'Episode'}).sort_values(['Episode'])
if column_keys is not None:
columns = [col for col in column_keys if col in df.columns]
else:
columns = [col for col in df.columns if col not in IGNORED_DF_COLUMNS]
roll_n = 50
non_overlapp_window = df.groupby(['Episode']).rolling(roll_n, min_periods=1).mean()
df_melted = df[columns + ['Episode']].reset_index().melt(
id_vars=['Episode'], value_vars=columns, var_name="Measurement", value_name="Score"
)
if df_melted['Episode'].max() > 800:
skip_n = round(df_melted['Episode'].max() * 0.02)
df_melted = df_melted[df_melted['Episode'] % skip_n == 0]
prepare_plot(run_path.parent / f'{run_path.parent.name}_monitor_lineplot.png', df_melted, use_tex=use_tex)
print('Plotting done.')
def compare_seed_runs(run_path: Union[str, PathLike], use_tex: bool = False):
run_path = Path(run_path)
df_list = list()

View File

@ -0,0 +1,48 @@
import pickle
from os import PathLike
from pathlib import Path
from typing import Union
import pandas as pd
from marl_factory_grid.utils.helpers import IGNORED_DF_COLUMNS
from marl_factory_grid.utils.plotting.plotting_utils import prepare_plot
def plot_single_run(run_path: Union[str, PathLike], use_tex: bool = False, column_keys=None,
file_key: str ='monitor', file_ext: str ='pkl'):
run_path = Path(run_path)
df_list = list()
if run_path.is_dir():
monitor_file = next(run_path.glob(f'*{file_key}*.{file_ext}'))
elif run_path.exists() and run_path.is_file():
monitor_file = run_path
else:
raise ValueError
with monitor_file.open('rb') as f:
monitor_df = pickle.load(f)
monitor_df = monitor_df.fillna(0)
df_list.append(monitor_df)
df = pd.concat(df_list, ignore_index=True)
df = df.fillna(0).rename(columns={'episode': 'Episode'}).sort_values(['Episode'])
if column_keys is not None:
columns = [col for col in column_keys if col in df.columns]
else:
columns = [col for col in df.columns if col not in IGNORED_DF_COLUMNS]
# roll_n = 50
# non_overlapp_window = df.groupby(['Episode']).rolling(roll_n, min_periods=1).mean()
df_melted = df[columns + ['Episode']].reset_index().melt(
id_vars=['Episode'], value_vars=columns, var_name="Measurement", value_name="Score"
)
if df_melted['Episode'].max() > 800:
skip_n = round(df_melted['Episode'].max() * 0.02)
df_melted = df_melted[df_melted['Episode'] % skip_n == 0]
prepare_plot(run_path.parent / f'{run_path.parent.name}_monitor_lineplot.png', df_melted, use_tex=use_tex)
print('Plotting done.')

View File

@ -60,7 +60,7 @@ def prepare_center_double_column_legend(df, hue, style, hue_order):
print('Struggling to plot Figure using LaTeX - going back to normal.')
plt.close('all')
sns.set(rc={'text.usetex': False}, style='whitegrid')
fig = plt.figure(figsize=(10, 11))
_ = plt.figure(figsize=(10, 11))
lineplot = sns.lineplot(data=df, x='Episode', y='Score', hue=hue, style=style,
ci=95, palette=PALETTE, hue_order=hue_order, legend=False)
# plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0)

View File

@ -19,7 +19,7 @@ class RayCaster:
return f'{self.__class__.__name__}({self.agent.name})'
def build_ray_targets(self):
north = np.array([0, -1])*self.pomdp_r
north = np.array([0, -1]) * self.pomdp_r
thetas = [np.deg2rad(deg) for deg in np.linspace(-self.degs // 2, self.degs // 2, self.n_rays)[::-1]]
rot_M = [
[[math.cos(theta), -math.sin(theta)],
@ -39,8 +39,9 @@ class RayCaster:
if reset_cache:
self._cache_dict = dict()
for ray in self.get_rays():
for ray in self.get_rays(): # Do not check, just trust.
rx, ry = ray[0]
# self.ray_block_cache(ray[0], lambda: False) We do not do that, because of doors etc...
for x, y in ray:
cx, cy = x - rx, y - ry
@ -52,8 +53,9 @@ class RayCaster:
diag_hits = all([
self.ray_block_cache(
key,
lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light))
for key in ((x, y-cy), (x-cx, y))
# lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light)
lambda: any(True for e in pos_dict[key] if e.var_is_blocking_light))
for key in ((x, y - cy), (x - cx, y))
]) if (cx != 0 and cy != 0) else False
visible += entities_hit if not diag_hits else []
@ -75,8 +77,8 @@ class RayCaster:
agent = self.agent
x_coords = range(agent.x - self.pomdp_r, agent.x + self.pomdp_r + 1)
y_coords = range(agent.y - self.pomdp_r, agent.y + self.pomdp_r + 1)
outline = list(product(x_coords, [agent.y - self.pomdp_r, agent.y + self.pomdp_r])) \
+ list(product([agent.x - self.pomdp_r, agent.x + self.pomdp_r], y_coords))
outline = list(product(x_coords, [agent.y - self.pomdp_r, agent.y + self.pomdp_r]))
outline += list(product([agent.x - self.pomdp_r, agent.x + self.pomdp_r], y_coords))
return outline
@staticmethod

View File

@ -31,7 +31,7 @@ class Renderer:
def __init__(self, lvl_shape: Tuple[int, int] = (16, 16),
lvl_padded_shape: Union[Tuple[int, int], None] = None,
cell_size: int = 40, fps: int = 7,
cell_size: int = 40, fps: int = 7, factor: float = 0.9,
grid_lines: bool = True, view_radius: int = 2):
# TODO: Customn_assets paths
self.grid_h, self.grid_w = lvl_shape
@ -45,7 +45,7 @@ class Renderer:
self.screen = pygame.display.set_mode(self.screen_size)
self.clock = pygame.time.Clock()
assets = list(self.ASSETS.rglob('*.png'))
self.assets = {path.stem: self.load_asset(str(path), 1) for path in assets}
self.assets = {path.stem: self.load_asset(str(path), factor) for path in assets}
self.fill_bg()
now = time.time()
@ -110,22 +110,22 @@ class Renderer:
pygame.quit()
sys.exit()
self.fill_bg()
blits = deque()
for entity in [x for x in entities]:
bp = self.blit_params(entity)
blits.append(bp)
if entity.name.lower() == AGENT:
if self.view_radius > 0:
vis_rects = self.visibility_rects(bp, entity.aux)
blits.extendleft(vis_rects)
if entity.state != BLANK:
agent_state_blits = self.blit_params(
RenderEntity(entity.state, (entity.pos[0] + 0.12, entity.pos[1]), 0.48, SCALE)
)
textsurface = self.font.render(str(entity.id), False, (0, 0, 0))
text_blit = dict(source=textsurface, dest=(bp['dest'].center[0]-.07*self.cell_size,
bp['dest'].center[1]))
blits += [agent_state_blits, text_blit]
# First all others
blits = deque(self.blit_params(x) for x in entities if not x.name.lower() == AGENT)
# Then Agents, so that agents are rendered on top.
for agent in (x for x in entities if x.name.lower() == AGENT):
agent_blit = self.blit_params(agent)
if self.view_radius > 0:
vis_rects = self.visibility_rects(agent_blit, agent.aux)
blits.extendleft(vis_rects)
if agent.state != BLANK:
state_blit = self.blit_params(
RenderEntity(agent.state, (agent.pos[0] + 0.12, agent.pos[1]), 0.48, SCALE)
)
textsurface = self.font.render(str(agent.id), False, (0, 0, 0))
text_blit = dict(source=textsurface, dest=(agent_blit['dest'].center[0]-.07*self.cell_size,
agent_blit['dest'].center[1]))
blits += [agent_blit, state_blit, text_blit]
for blit in blits:
self.screen.blit(**blit)

View File

@ -1,9 +1,12 @@
from typing import Union
from dataclasses import dataclass
from marl_factory_grid.environment.entity.object import Object
TYPE_VALUE = 'value'
TYPE_REWARD = 'reward'
types = [TYPE_VALUE, TYPE_REWARD]
TYPES = [TYPE_VALUE, TYPE_REWARD]
@dataclass
class InfoObject:
@ -18,17 +21,21 @@ class Result:
validity: bool
reward: Union[float, None] = None
value: Union[float, None] = None
entity: None = None
entity: Object = None
def get_infos(self):
n = self.entity.name if self.entity is not None else "Global"
return [InfoObject(identifier=f'{n}_{self.identifier}_{t}',
val_type=t, value=self.__getattribute__(t)) for t in types
# Return multiple Info Dicts
return [InfoObject(identifier=f'{n}_{self.identifier}',
val_type=t, value=self.__getattribute__(t)) for t in TYPES
if self.__getattribute__(t) is not None]
def __repr__(self):
valid = "not " if not self.validity else ""
return f'{self.__class__.__name__}({self.identifier.capitalize()} {valid}valid: {self.reward})'
reward = f" | Reward: {self.reward}" if self.reward is not None else ""
value = f" | Value: {self.value}" if self.value is not None else ""
entity = f" | by: {self.entity.name}" if self.entity is not None else ""
return f'{self.__class__.__name__}({self.identifier.capitalize()} {valid}valid{reward}{value}{entity})'
@dataclass

View File

@ -1,9 +1,12 @@
from typing import List, Dict, Tuple
from itertools import islice
from typing import List, Tuple
import numpy as np
from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.entity.entity import Entity
from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.utils.results import Result, DoneResult
from marl_factory_grid.environment.tests import Test
from marl_factory_grid.utils.results import Result
@ -60,7 +63,8 @@ class Gamestate(object):
def moving_entites(self):
return [y for x in self.entities for y in x if x.var_can_move]
def __init__(self, entities, agents_conf, rules: [Rule], tests: [Test], env_seed=69, verbose=False):
def __init__(self, entities, agents_conf, rules: List[Rule], tests: [Test], lvl_shape, env_seed=69, verbose=False):
self.lvl_shape = lvl_shape
self.entities = entities
self.curr_step = 0
self.curr_actions = None
@ -82,7 +86,52 @@ class Gamestate(object):
def __repr__(self):
return f'{self.__class__.__name__}({len(self.entities)} Entitites @ Step {self.curr_step})'
def tick(self, actions) -> List[Result]:
@property
def random_free_position(self) -> (int, int):
"""
Returns a single **free** position (x, y), which is **free** for spawning or walking.
No Entity at this position posses *var_is_blocking_pos* or *var_can_collide*.
:return: Single **free** position.
"""
return self.get_n_random_free_positions(1)[0]
def get_n_random_free_positions(self, n) -> list[tuple[int, int]]:
"""
Returns a list of *n* **free** positions [(x, y), ... ], which are **free** for spawning or walking.
No Entity at this position posses *var_is_blocking_pos* or *var_can_collide*.
:return: List of n **free** position.
"""
return list(islice(self.entities.free_positions_generator, n))
@property
def random_position(self) -> (int, int):
"""
Returns a single available position (x, y), ignores all entity attributes.
:return: Single random position.
"""
return self.get_n_random_positions(1)[0]
def get_n_random_positions(self, n) -> list[tuple[int, int]]:
"""
Returns a list of *n* available positions [(x, y), ... ], ignores all entity attributes.
:return: List of n random positions.
"""
return list(islice(self.entities.floorlist, n))
def tick(self, actions) -> list[Result]:
"""
Performs a single **Gamestate Tick**by calling the inner rule hooks in sequential order.
- tick_pre_step_all: Things to do before the agents do their actions. Statechange, Moving, Spawning etc...
- agent tick: Agents do their actions.
- tick_step_all: Things to do after the agents did their actions. Statechange, Moving, Spawning etc...
- tick_post_step_all: Things to do at the very end of each step. Counting, Reward calculations etc...
:return: List of *Result*-objects.
"""
results = list()
test_results = list()
self.curr_step += 1
@ -112,11 +161,23 @@ class Gamestate(object):
return results
def print(self, string):
def print(self, string) -> None:
"""
When *verbose* is active, print stuff.
:param string: *String* to print.
:type string: str
:return: Nothing
"""
if self.verbose:
print(string)
def check_done(self):
def check_done(self) -> List[DoneResult]:
"""
Iterate all **Rules** that override tehe *on_ckeck_done* hook.
:return: List of Results
"""
results = list()
for rule in self.rules:
if on_check_done_result := rule.on_check_done(self):
@ -124,24 +185,47 @@ class Gamestate(object):
return results
def get_all_pos_with_collisions(self) -> List[Tuple[(int, int)]]:
positions = [pos for pos, entity_list_for_position in self.entities.pos_dict.items()
if any([e.var_can_collide for e in entity_list_for_position])]
"""
Returns a list positions [(x, y), ... ] on which collisions occur. This does not include agents,
that were unable to move because their target direction was blocked, also a form of collision.
:return: List of positions.
"""
positions = [pos for pos, entities in self.entities.pos_dict.items() if
len(entities) >= 2 and (len([e for e in entities if e.var_can_collide]) >= 2)
]
return positions
def check_move_validity(self, moving_entity, position):
if moving_entity.pos != position and not any(
entity.var_is_blocking_pos for entity in self.entities.pos_dict[position]) and not (
moving_entity.var_is_blocking_pos and self.entities.is_occupied(position)):
return True
else:
return False
def check_move_validity(self, moving_entity: Entity, target_position: (int, int)) -> bool:
"""
Whether it is safe to move to the target positions and moving entity does not introduce a blocking attribute,
when position is allready occupied.
def check_pos_validity(self, position):
if not any(entity.var_is_blocking_pos for entity in self.entities.pos_dict[position]):
return True
else:
return False
:param moving_entity: Entity
:param target_position: pos
:return: Safe to move to
"""
is_not_blocked = self.check_pos_validity(target_position)
will_not_block_others = moving_entity.var_is_blocking_pos and self.entities.is_occupied(target_position)
if moving_entity.pos != target_position and is_not_blocked and not will_not_block_others:
return c.VALID
else:
return c.NOT_VALID
def check_pos_validity(self, pos: (int, int)) -> bool:
"""
Check if *pos* is a valid position to move or spawn to.
:param pos: position to check
:return: Wheter pos is a valid target.
"""
if not any(e.var_is_blocking_pos for e in self.entities.pos_dict[pos]) and pos in self.entities.floorlist:
return c.VALID
else:
return c.NOT_VALID
class StepTests:
def __init__(self, *args):

View File

@ -28,7 +28,9 @@ class ConfigExplainer:
def explain_module(self, class_to_explain):
parameters = inspect.signature(class_to_explain).parameters
explained = {class_to_explain.__name__: {key: val.default for key, val in parameters.items() if key not in EXCLUDED}}
explained = {class_to_explain.__name__:
{key: val.default for key, val in parameters.items() if key not in EXCLUDED}
}
return explained
def _load_and_compare(self, compare_class, paths):
@ -135,4 +137,3 @@ if __name__ == '__main__':
ce.get_observations()
ce.get_assets()
all_conf = ce.get_all()
print()

View File

@ -52,3 +52,6 @@ class Floor:
def __hash__(self):
return hash(self.name)
def __repr__(self):
return f"Floor{self.pos}"

View File

@ -6,18 +6,21 @@ from marl_factory_grid.environment.factory import Factory
from marl_factory_grid.utils.logging.envmonitor import EnvMonitor
from marl_factory_grid.utils.logging.recorder import EnvRecorder
from marl_factory_grid.utils.plotting.plot_single_runs import plot_single_run
from marl_factory_grid.utils.tools import ConfigExplainer
if __name__ == '__main__':
# Render at each step?
render = True
render = False
# Reveal all possible Modules (Entities, Rules, Agents[Actions, Observations], etc.)
explain_config = False
# Collect statistics?
monitor = False
monitor = True
# Record as Protobuf?
record = False
# Plot Results?
plotting = True
run_path = Path('study_out')
@ -38,7 +41,7 @@ if __name__ == '__main__':
factory = EnvRecorder(factory)
# RL learn Loop
for episode in trange(500):
for episode in trange(10):
_ = factory.reset()
done = False
if render:
@ -54,7 +57,10 @@ if __name__ == '__main__':
break
if monitor:
factory.save_run(run_path / 'test.pkl')
factory.save_run(run_path / 'test_monitor.pkl')
if record:
factory.save_records(run_path / 'test.pb')
if plotting:
plot_single_run(run_path)
print('Done!!! Goodbye....')

View File

@ -6,6 +6,7 @@ import yaml
from marl_factory_grid.environment.factory import Factory
from marl_factory_grid.utils.logging.envmonitor import EnvMonitor
from marl_factory_grid.utils.logging.recorder import EnvRecorder
from marl_factory_grid.utils import helpers as h
from marl_factory_grid.modules.doors import constants as d
@ -55,13 +56,14 @@ if __name__ == '__main__':
for model_idx, model in enumerate(models)]
else:
actions = models[0].predict(env_state, deterministic=determin)[0]
# noinspection PyTupleAssignmentBalance
env_state, step_r, done_bool, info_obj = env.step(actions)
rew += step_r
if render:
env.render()
try:
door = next(x for x in env.unwrapped.unwrapped[d.DOORS] if x.is_open)
door = h.get_first([x for x in env.unwrapped.unwrapped[d.DOORS] if x.is_open])
print('openDoor found')
except StopIteration:
pass

View File

@ -1,8 +1,8 @@
from algorithms.utils import Checkpointer
from pathlib import Path
from algorithms.utils import load_yaml_file, add_env_props, instantiate_class, load_class
#from algorithms.marl import LoopSNAC, LoopIAC, LoopSEAC
# from algorithms.marl import LoopSNAC, LoopIAC, LoopSEAC
for i in range(0, 5):

View File

@ -0,0 +1,41 @@
import configparser
import json
from datetime import datetime
from pathlib import Path
if __name__ == '__main__':
conf_path = Path('wg0')
wg0_conf = configparser.ConfigParser()
wg0_conf.read(conf_path/'wg0.conf')
interface = wg0_conf['Interface']
# Iterate all pears
for client_name in wg0_conf.sections():
if client_name == 'Interface':
continue
# Delete any old conf.json for the current peer
(conf_path / f'{client_name}.json').unlink(missing_ok=True)
peer = wg0_conf[client_name]
date_time = datetime.now().strftime('%Y-%m-%dT%H:%M:%S.%f000Z')
jdict = dict(
id=client_name,
private_key=peer['PublicKey'],
public_key=peer['PublicKey'],
# preshared_key=wg0_conf[client_name_wg0]['PresharedKey'],
name=client_name,
email=f"sysadmin@mobile.ifi.lmu.de",
allocated_ips=[interface['Address'].replace('/24', '')],
allowed_ips=['10.4.0.0/24', '10.153.199.0/24'],
extra_allowed_ips=[],
use_server_dns=True,
enabled=True,
created_at=date_time,
updated_at=date_time
)
with (conf_path / f'{client_name}.json').open('w+') as f:
json.dump(jdict, f, indent='\t', separators=(',', ': '))
print(client_name, ' written...')