diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..b58b603
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,5 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Editor-based HTTP Client requests
+/httpRequests/
diff --git a/README.md b/README.md
index a1d2740..d0c0a19 100644
--- a/README.md
+++ b/README.md
@@ -56,7 +56,7 @@ Just define what your environment needs in a *yaml*-configfile like:
- Items
Rules:
Defaults: {}
- Collision:
+ WatchCollisions:
done_at_collisions: !!bool True
ItemRespawn:
spawn_freq: 5
@@ -94,7 +94,7 @@ All [Entites](marl_factory_grid/environment/entity/global_entities.py) are avail
#### Rules
-[Rules](marl_factory_grid/environment/entity/object.py) define how the environment behaves on micro-scale.
+[Rules](marl_factory_grid/environment/entity/object.py) define how the environment behaves on microscale.
Each of the hookes (`on_init`, `pre_step`, `on_step`, '`post_step`', `on_done`)
provide env-access to implement customn logic, calculate rewards, or gather information.
@@ -107,6 +107,7 @@ Make sure to bring your own assets for each Entity living in the Gridworld as th
PNG-files (transparent background) of square aspect-ratio should do the job, in general.
+
diff --git a/marl_factory_grid/__init__.py b/marl_factory_grid/__init__.py
index b2bbfa3..259e3cf 100644
--- a/marl_factory_grid/__init__.py
+++ b/marl_factory_grid/__init__.py
@@ -1,6 +1 @@
-from .environment import *
-from .modules import *
-from .utils import *
-
from .quickstart import init
-
diff --git a/marl_factory_grid/algorithms/__init__.py b/marl_factory_grid/algorithms/__init__.py
index 0980070..cc2c489 100644
--- a/marl_factory_grid/algorithms/__init__.py
+++ b/marl_factory_grid/algorithms/__init__.py
@@ -1 +1,4 @@
-import os, sys; sys.path.append(os.path.dirname(os.path.realpath(__file__)))
+import os
+import sys
+
+sys.path.append(os.path.dirname(os.path.realpath(__file__)))
diff --git a/marl_factory_grid/algorithms/marl/__init__.py b/marl_factory_grid/algorithms/marl/__init__.py
index 984588c..a4c30ef 100644
--- a/marl_factory_grid/algorithms/marl/__init__.py
+++ b/marl_factory_grid/algorithms/marl/__init__.py
@@ -1 +1 @@
-from marl_factory_grid.algorithms.marl.memory import MARLActorCriticMemory
\ No newline at end of file
+from marl_factory_grid.algorithms.marl.memory import MARLActorCriticMemory
diff --git a/marl_factory_grid/algorithms/marl/base_ac.py b/marl_factory_grid/algorithms/marl/base_ac.py
index 3bb0318..ef195b7 100644
--- a/marl_factory_grid/algorithms/marl/base_ac.py
+++ b/marl_factory_grid/algorithms/marl/base_ac.py
@@ -28,6 +28,7 @@ class Names:
BATCH_SIZE = 'bnatch_size'
N_ACTIONS = 'n_actions'
+
nms = Names
ListOrTensor = Union[List, torch.Tensor]
@@ -112,10 +113,9 @@ class BaseActorCritic:
next_obs, reward, done, info = env.step(action)
done = [done] * self.n_agents if isinstance(done, bool) else done
- last_hiddens = dict(hidden_actor =out[nms.HIDDEN_ACTOR],
+ last_hiddens = dict(hidden_actor=out[nms.HIDDEN_ACTOR],
hidden_critic=out[nms.HIDDEN_CRITIC])
-
tm.add(observation=obs, action=action, reward=reward, done=done,
logits=out.get(nms.LOGITS, None), values=out.get(nms.CRITIC, None),
**last_hiddens)
@@ -142,7 +142,9 @@ class BaseActorCritic:
print(f'reward at episode: {episode} = {rew_log}')
episode += 1
df_results.append([episode, rew_log, *reward])
- df_results = pd.DataFrame(df_results, columns=['steps', 'reward', *[f'agent#{i}' for i in range(self.n_agents)]])
+ df_results = pd.DataFrame(df_results,
+ columns=['steps', 'reward', *[f'agent#{i}' for i in range(self.n_agents)]]
+ )
if checkpointer is not None:
df_results.to_csv(checkpointer.path / 'results.csv', index=False)
return df_results
@@ -157,24 +159,27 @@ class BaseActorCritic:
last_action, reward = [-1] * self.n_agents, [0.] * self.n_agents
done, rew_log, eps_rew = [False] * self.n_agents, 0, torch.zeros(self.n_agents)
while not all(done):
- if render: env.render()
+ if render:
+ env.render()
out = self.forward(obs, last_action, **last_hiddens)
action = self.get_actions(out)
next_obs, reward, done, info = env.step(action)
- if isinstance(done, bool): done = [done] * obs.shape[0]
+ if isinstance(done, bool):
+ done = [done] * obs.shape[0]
obs = next_obs
last_action = action
last_hiddens = dict(hidden_actor=out.get(nms.HIDDEN_ACTOR, None),
hidden_critic=out.get(nms.HIDDEN_CRITIC, None)
)
eps_rew += torch.tensor(reward)
- results.append(eps_rew.tolist() + [sum(eps_rew).item()] + [episode])
+ results.append(eps_rew.tolist() + [np.sum(eps_rew).item()] + [episode])
episode += 1
agent_columns = [f'agent#{i}' for i in range(self.cfg['environment']['n_agents'])]
results = pd.DataFrame(results, columns=agent_columns + ['sum', 'episode'])
- results = pd.melt(results, id_vars=['episode'], value_vars=agent_columns + ['sum'], value_name='reward', var_name='agent')
+ results = pd.melt(results, id_vars=['episode'], value_vars=agent_columns + ['sum'],
+ value_name='reward', var_name='agent')
return results
@staticmethod
diff --git a/marl_factory_grid/algorithms/marl/mappo.py b/marl_factory_grid/algorithms/marl/mappo.py
index d22fa08..faf3b0d 100644
--- a/marl_factory_grid/algorithms/marl/mappo.py
+++ b/marl_factory_grid/algorithms/marl/mappo.py
@@ -36,7 +36,7 @@ class LoopMAPPO(LoopSNAC):
rewards_ = torch.stack(rewards_, dim=1)
return rewards_
- def mappo(self, batch, network, gamma, entropy_coef, vf_coef, clip_range, **kwargs):
+ def mappo(self, batch, network, gamma, entropy_coef, vf_coef, clip_range, **__):
out = network(batch[nms.OBSERVATION], batch[nms.ACTION], batch[nms.HIDDEN_ACTOR], batch[nms.HIDDEN_CRITIC])
logits = out[nms.LOGITS][:, :-1] # last one only needed for v_{t+1}
@@ -45,7 +45,7 @@ class LoopMAPPO(LoopSNAC):
# monte carlo returns
mc_returns = self.monte_carlo_returns(batch[nms.REWARD], batch[nms.DONE], gamma)
- mc_returns = (mc_returns - mc_returns.mean()) / (mc_returns.std() + 1e-8) #todo: norm across agent ok?
+ mc_returns = (mc_returns - mc_returns.mean()) / (mc_returns.std() + 1e-8) # todo: norm across agent ok?
advantages = mc_returns - out[nms.CRITIC][:, :-1]
# policy loss
diff --git a/marl_factory_grid/algorithms/marl/networks.py b/marl_factory_grid/algorithms/marl/networks.py
index c4fdb72..796c03f 100644
--- a/marl_factory_grid/algorithms/marl/networks.py
+++ b/marl_factory_grid/algorithms/marl/networks.py
@@ -1,8 +1,7 @@
+import numpy as np
import torch
import torch.nn as nn
-import numpy as np
import torch.nn.functional as F
-from torch.nn.utils import spectral_norm
class RecurrentAC(nn.Module):
@@ -88,8 +87,8 @@ class NormalizedLinear(nn.Linear):
self.trainable_magnitude = trainable_magnitude
self.scale = nn.Parameter(torch.tensor([1.]), requires_grad=trainable_magnitude)
- def forward(self, input):
- normalized_input = F.normalize(input, dim=-1, p=2, eps=1e-5)
+ def forward(self, in_array):
+ normalized_input = F.normalize(in_array, dim=-1, p=2, eps=1e-5)
normalized_weight = F.normalize(self.weight, dim=-1, p=2, eps=1e-5)
return F.linear(normalized_input, normalized_weight) * self.d_sqrt * self.scale
diff --git a/marl_factory_grid/algorithms/marl/seac.py b/marl_factory_grid/algorithms/marl/seac.py
index 9c458c7..07e8267 100644
--- a/marl_factory_grid/algorithms/marl/seac.py
+++ b/marl_factory_grid/algorithms/marl/seac.py
@@ -16,7 +16,7 @@ class LoopSEAC(LoopIAC):
with torch.inference_mode(True):
true_action_logp = torch.stack([
torch.log_softmax(out[nms.LOGITS][ag_i, :-1], -1)
- .gather(index=actions[ag_i, 1:, None], dim=-1)
+ .gather(index=actions[ag_i, 1:, None], dim=-1)
for ag_i, out in enumerate(outputs)
], 0).squeeze()
@@ -38,7 +38,6 @@ class LoopSEAC(LoopIAC):
a2c_loss = (-iw*log_ap * advantages.detach()).mean(-1)
-
value_loss = (iw*advantages.pow(2)).mean(-1) # n_agent
# weighted loss
@@ -53,4 +52,4 @@ class LoopSEAC(LoopIAC):
self.optimizer[ag_i].zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.net[ag_i].parameters(), 0.5)
- self.optimizer[ag_i].step()
\ No newline at end of file
+ self.optimizer[ag_i].step()
diff --git a/marl_factory_grid/algorithms/marl/snac.py b/marl_factory_grid/algorithms/marl/snac.py
index b249754..11be902 100644
--- a/marl_factory_grid/algorithms/marl/snac.py
+++ b/marl_factory_grid/algorithms/marl/snac.py
@@ -30,4 +30,4 @@ class LoopSNAC(BaseActorCritic):
self._as_torch(actions).unsqueeze(1),
hidden_actor, hidden_critic
)
- return out
\ No newline at end of file
+ return out
diff --git a/marl_factory_grid/algorithms/static/TSP_base_agent.py b/marl_factory_grid/algorithms/static/TSP_base_agent.py
index bc48f7c..7d25f63 100644
--- a/marl_factory_grid/algorithms/static/TSP_base_agent.py
+++ b/marl_factory_grid/algorithms/static/TSP_base_agent.py
@@ -56,8 +56,8 @@ class TSPBaseAgent(ABC):
def _door_is_close(self, state):
try:
- # return next(y for x in self.state.tile.neighboring_floor for y in x.guests if do.DOOR in y.name)
- return next(y for x in state.entities.neighboring_positions(self.state.pos) for y in state.entities.pos_dict[x] if do.DOOR in y.name)
+ return next(y for x in state.entities.neighboring_positions(self.state.pos)
+ for y in state.entities.pos_dict[x] if do.DOOR in y.name)
except StopIteration:
return None
diff --git a/marl_factory_grid/algorithms/static/TSP_target_agent.py b/marl_factory_grid/algorithms/static/TSP_target_agent.py
index 0c5de3a..b0d8b29 100644
--- a/marl_factory_grid/algorithms/static/TSP_target_agent.py
+++ b/marl_factory_grid/algorithms/static/TSP_target_agent.py
@@ -14,8 +14,8 @@ class TSPTargetAgent(TSPBaseAgent):
def _handle_doors(self, state):
try:
- # return next(y for x in self.state.tile.neighboring_floor for y in x.guests if do.DOOR in y.name)
- return next(y for x in state.entities.neighboring_positions(self.state.pos) for y in state.entities.pos_dict[x] if do.DOOR in y.name)
+ return next(y for x in state.entities.neighboring_positions(self.state.pos)
+ for y in state.entities.pos_dict[x] if do.DOOR in y.name)
except StopIteration:
return None
@@ -30,4 +30,3 @@ class TSPTargetAgent(TSPBaseAgent):
except (StopIteration, UnboundLocalError):
print('Will not happen')
return action_obj
-
diff --git a/marl_factory_grid/algorithms/static/utils.py b/marl_factory_grid/algorithms/static/utils.py
index d5119db..60cba30 100644
--- a/marl_factory_grid/algorithms/static/utils.py
+++ b/marl_factory_grid/algorithms/static/utils.py
@@ -26,12 +26,16 @@ def points_to_graph(coordiniates, allow_euclidean_connections=True, allow_manhat
assert allow_euclidean_connections or allow_manhattan_connections
possible_connections = itertools.combinations(coordiniates, 2)
graph = nx.Graph()
- for a, b in possible_connections:
- diff = np.linalg.norm(np.asarray(a)-np.asarray(b))
- if allow_manhattan_connections and allow_euclidean_connections and diff <= np.sqrt(2):
- graph.add_edge(a, b)
- elif not allow_manhattan_connections and allow_euclidean_connections and diff == np.sqrt(2):
- graph.add_edge(a, b)
- elif allow_manhattan_connections and not allow_euclidean_connections and diff == 1:
- graph.add_edge(a, b)
+ if allow_manhattan_connections and allow_euclidean_connections:
+ graph.add_edges_from(
+ (a, b) for a, b in possible_connections if np.linalg.norm(np.asarray(a) - np.asarray(b)) <= np.sqrt(2)
+ )
+ elif not allow_manhattan_connections and allow_euclidean_connections:
+ graph.add_edges_from(
+ (a, b) for a, b in possible_connections if np.linalg.norm(np.asarray(a) - np.asarray(b)) == np.sqrt(2)
+ )
+ elif allow_manhattan_connections and not allow_euclidean_connections:
+ graph.add_edges_from(
+ (a, b) for a, b in possible_connections if np.linalg.norm(np.asarray(a) - np.asarray(b)) == 1
+ )
return graph
diff --git a/marl_factory_grid/algorithms/utils.py b/marl_factory_grid/algorithms/utils.py
index 59e78bd..8c60386 100644
--- a/marl_factory_grid/algorithms/utils.py
+++ b/marl_factory_grid/algorithms/utils.py
@@ -1,8 +1,9 @@
-import torch
-import numpy as np
-import yaml
from pathlib import Path
+import numpy as np
+import torch
+import yaml
+
def load_class(classname):
from importlib import import_module
@@ -42,7 +43,6 @@ def get_class(arguments):
def get_arguments(arguments):
- from importlib import import_module
d = dict(arguments)
if "classname" in d:
del d["classname"]
@@ -82,4 +82,4 @@ class Checkpointer(object):
for name, model in to_save:
self.save_experiment(name, model)
self.__current_checkpoint += 1
- self.__current_step += 1
\ No newline at end of file
+ self.__current_step += 1
diff --git a/marl_factory_grid/configs/default_config.yaml b/marl_factory_grid/configs/default_config.yaml
index 44a0977..d3015c9 100644
--- a/marl_factory_grid/configs/default_config.yaml
+++ b/marl_factory_grid/configs/default_config.yaml
@@ -22,26 +22,41 @@ Agents:
- Inventory
- DropOffLocations
- Maintainers
+ # This is special for agents, as each one is differten and can act as an adversary e.g.
+ Positions:
+ - (16, 7)
+ - (16, 6)
+ - (16, 3)
+ - (16, 4)
+ - (16, 5)
Entities:
Batteries:
initial_charge: 0.8
per_action_costs: 0.02
- ChargePods: {}
- Destinations: {}
+ ChargePods:
+ coords_or_quantity: 2
+ Destinations:
+ coords_or_quantity: 1
+ spawn_mode: GROUPED
DirtPiles:
+ coords_or_quantity: 10
+ initial_amount: 2
clean_amount: 1
dirt_spawn_r_var: 0.1
- initial_amount: 2
- initial_dirt_ratio: 0.05
max_global_amount: 20
max_local_amount: 5
- Doors: {}
- DropOffLocations: {}
+ Doors:
+ DropOffLocations:
+ coords_or_quantity: 1
+ max_dropoff_storage_size: 0
GlobalPositions: {}
Inventories: {}
- Items: {}
- Machines: {}
- Maintainers: {}
+ Items:
+ coords_or_quantity: 5
+ Machines:
+ coords_or_quantity: 2
+ Maintainers:
+ coords_or_quantity: 1
Zones: {}
General:
@@ -49,32 +64,31 @@ General:
individual_rewards: true
level_name: large
pomdp_r: 3
- verbose: false
+ verbose: True
+ tests: false
Rules:
- SpawnAgents: {}
- DoneAtBatteryDischarge: {}
- Collision:
- done_at_collisions: false
- AssignGlobalPositions: {}
- DoneAtDestinationReachAny: {}
- DestinationReachReward: {}
- SpawnDestinations:
- n_dests: 1
- spawn_mode: GROUPED
- DoneOnAllDirtCleaned: {}
- SpawnDirt:
- spawn_freq: 15
+ # Environment Dynamics
EntitiesSmearDirtOnMove:
smear_ratio: 0.2
DoorAutoClose:
close_frequency: 10
- ItemRules:
- max_dropoff_storage_size: 0
- n_items: 5
- n_locations: 5
- spawn_frequency: 15
- MaxStepsReached:
+ MoveMaintainers:
+
+ # Respawn Stuff
+ RespawnDirt:
+ respawn_freq: 15
+ RespawnItems:
+ respawn_freq: 15
+
+ # Utilities
+ WatchCollisions:
+ done_at_collisions: false
+
+ # Done Conditions
+ DoneAtDestinationReachAny:
+ DoneOnAllDirtCleaned:
+ DoneAtBatteryDischarge:
+ DoneAtMaintainerCollision:
+ DoneAtMaxStepsReached:
max_steps: 500
-# AgentSingleZonePlacement:
-# n_zones: 4
diff --git a/marl_factory_grid/configs/narrow_corridor.yaml b/marl_factory_grid/configs/narrow_corridor.yaml
index 0006513..f53b972 100644
--- a/marl_factory_grid/configs/narrow_corridor.yaml
+++ b/marl_factory_grid/configs/narrow_corridor.yaml
@@ -1,15 +1,41 @@
+General:
+ # Your Seed
+ env_seed: 69
+ # Individual or global rewards?
+ individual_rewards: true
+ # The level.txt file to load
+ level_name: narrow_corridor
+ # View Radius; 0 = full observatbility
+ pomdp_r: 0
+ # print all messages and events
+ verbose: true
+
Agents:
+ # Agents are identified by their name
Wolfgang:
+ # The available actions for this particular agent
Actions:
+ # Able to do nothing
- Noop
+ # Able to move in all 8 directions
- Move8
+ # Stuff the agent can observe (per 2d slice)
+ # use "Combined" if you want to merge multiple slices into one
Observations:
+ # He sees walls
- Walls
+ # he sees other agent, "karl-Heinz" in this setting would be fine, too
- Other
+ # He can see Destinations, that are assigned to him (hence the singular)
- Destination
+ # Avaiable Spawn Positions as list
Positions:
- (2, 1)
- (2, 5)
+ # It is okay to collide with other agents, so that
+ # they end up on the same position
+ is_blocking_pos: true
+ # See Above....
Karl-Heinz:
Actions:
- Noop
@@ -21,26 +47,43 @@ Agents:
Positions:
- (2, 1)
- (2, 5)
+ is_blocking_pos: true
+
+# Other noteworthy Entitites
Entities:
- Destinations: {}
-
-General:
- env_seed: 69
- individual_rewards: true
- level_name: narrow_corridor
- pomdp_r: 0
- verbose: true
+ # The destiantions or positional targets to reach
+ Destinations:
+ # Let them spawn on closed doors and agent positions
+ ignore_blocking: true
+ # We need a special spawn rule...
+ spawnrule:
+ # ...which assigns the destinations per agent
+ SpawnDestinationsPerAgent:
+ # we use this parameter
+ coords_or_quantity:
+ # to enable and assign special positions per agent
+ Wolfgang:
+ - (2, 1)
+ - (2, 5)
+ Karl-Heinz:
+ - (2, 1)
+ - (2, 5)
+ # Whether you want to provide a numeric Position observation.
+ # GlobalPositions:
+ # normalized: false
+# Define the env. dynamics
Rules:
- SpawnAgents: {}
- Collision:
+ # Utilities
+ # This rule Checks for Collision, also it assigns the (negative) reward
+ WatchCollisions:
+ reward: -0.1
+ reward_at_done: -1
done_at_collisions: false
- FixedDestinationSpawn:
- per_agent_positions:
- Wolfgang:
- - (2, 1)
- - (2, 5)
- Karl-Heinz:
- - (2, 1)
- - (2, 5)
- DestinationReachAll: {}
+ # Done Conditions
+ # Load any of the rules, to check for done conditions.
+ # DoneAtDestinationReachAny:
+ DoneAtDestinationReachAll:
+ # reward_at_done: 1
+ DoneAtMaxStepsReached:
+ max_steps: 200
diff --git a/marl_factory_grid/environment/actions.py b/marl_factory_grid/environment/actions.py
index 4edfe24..606832c 100644
--- a/marl_factory_grid/environment/actions.py
+++ b/marl_factory_grid/environment/actions.py
@@ -48,9 +48,9 @@ class Move(Action, abc.ABC):
reward = r.MOVEMENTS_VALID if move_validity else r.MOVEMENTS_FAIL
return ActionResult(entity=entity, identifier=self._identifier, validity=move_validity, reward=reward)
else: # There is no place to go, propably collision
- # This is currently handeld by the Collision rule, so that it can be switched on and off by conf.yml
+ # This is currently handeld by the WatchCollisions rule, so that it can be switched on and off by conf.yml
# return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=r.COLLISION)
- return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=0)
+ return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID)
def _calc_new_pos(self, pos):
x_diff, y_diff = MOVEMAP[self._identifier]
diff --git a/marl_factory_grid/environment/constants.py b/marl_factory_grid/environment/constants.py
index 1fdf639..6ddb19a 100644
--- a/marl_factory_grid/environment/constants.py
+++ b/marl_factory_grid/environment/constants.py
@@ -10,6 +10,7 @@ AGENT = 'Agent' # Identifier of Agent-objects an
OTHERS = 'Other'
COMBINED = 'Combined'
GLOBALPOSITIONS = 'GlobalPositions' # Identifier of the global position slice
+SPAWN_ENTITY_RULE = 'SpawnEntity'
# Attributes
IS_BLOCKING_LIGHT = 'var_is_blocking_light'
@@ -29,7 +30,7 @@ VALUE_NO_POS = (-9999, -9999) # Invalid Position value used in the e
ACTION = 'action' # Identifier of Action-objects and groups (groups).
-COLLISION = 'Collision' # Identifier to use in the context of collitions.
+COLLISION = 'Collisions' # Identifier to use in the context of collitions.
# LAST_POS = 'LAST_POS' # Identifiert for retrieving an enitites last pos.
VALIDITY = 'VALIDITY' # Identifiert for retrieving the Validity of Action, Tick, etc. ...
@@ -54,3 +55,5 @@ NOOP = 'Noop'
# Result Identifier
MOVEMENTS_VALID = 'motion_valid'
MOVEMENTS_FAIL = 'motion_not_valid'
+DEFAULT_PATH = 'environment'
+MODULE_PATH = 'modules'
diff --git a/marl_factory_grid/environment/entity/agent.py b/marl_factory_grid/environment/entity/agent.py
index 285c8d2..0920604 100644
--- a/marl_factory_grid/environment/entity/agent.py
+++ b/marl_factory_grid/environment/entity/agent.py
@@ -12,14 +12,6 @@ from marl_factory_grid.environment import constants as c
class Agent(Entity):
- @property
- def var_is_blocking_light(self):
- return False
-
- @property
- def var_can_move(self):
- return True
-
@property
def var_is_paralyzed(self):
return len(self._paralyzed)
@@ -28,14 +20,6 @@ class Agent(Entity):
def paralyze_reasons(self):
return [x for x in self._paralyzed]
- @property
- def var_is_blocking_pos(self):
- return False
-
- @property
- def var_has_position(self):
- return True
-
@property
def obs_tag(self):
return self.name
@@ -48,10 +32,6 @@ class Agent(Entity):
def observations(self):
return self._observations
- @property
- def var_can_collide(self):
- return True
-
def step_result(self):
pass
@@ -60,16 +40,21 @@ class Agent(Entity):
return self._collection
@property
- def state(self):
- return self._state or ActionResult(entity=self, identifier=c.NOOP, validity=c.VALID, reward=0)
+ def var_is_blocking_pos(self):
+ return self._is_blocking_pos
- def __init__(self, actions: List[Action], observations: List[str], *args, **kwargs):
+ @property
+ def state(self):
+ return self._state or ActionResult(entity=self, identifier=c.NOOP, validity=c.VALID)
+
+ def __init__(self, actions: List[Action], observations: List[str], *args, is_blocking_pos=False, **kwargs):
super(Agent, self).__init__(*args, **kwargs)
self._paralyzed = set()
self.step_result = dict()
self._actions = actions
self._observations = observations
self._state: Union[Result, None] = None
+ self._is_blocking_pos = is_blocking_pos
# noinspection PyAttributeOutsideInit
def clear_temp_state(self):
diff --git a/marl_factory_grid/environment/entity/entity.py b/marl_factory_grid/environment/entity/entity.py
index 637827f..999787b 100644
--- a/marl_factory_grid/environment/entity/entity.py
+++ b/marl_factory_grid/environment/entity/entity.py
@@ -1,20 +1,19 @@
import abc
-from collections import defaultdict
import numpy as np
-from .object import _Object
+from .object import Object
from .. import constants as c
from ...utils.results import ActionResult
from ...utils.utility_classes import RenderEntity
-class Entity(_Object, abc.ABC):
+class Entity(Object, abc.ABC):
"""Full Env Entity that lives on the environment Grid. Doors, Items, DirtPile etc..."""
@property
def state(self):
- return self._status or ActionResult(entity=self, identifier=c.NOOP, validity=c.VALID, reward=0)
+ return self._status or ActionResult(entity=self, identifier=c.NOOP, validity=c.VALID)
@property
def var_has_position(self):
@@ -60,6 +59,10 @@ class Entity(_Object, abc.ABC):
def pos(self):
return self._pos
+ def set_pos(self, pos):
+ assert isinstance(pos, tuple) and len(pos) == 2
+ self._pos = pos
+
@property
def last_pos(self):
try:
@@ -84,7 +87,7 @@ class Entity(_Object, abc.ABC):
for observer in self.observers:
observer.notify_del_entity(self)
self._view_directory = curr_pos[0] - next_pos[0], curr_pos[1] - next_pos[1]
- self._pos = next_pos
+ self.set_pos(next_pos)
for observer in self.observers:
observer.notify_add_entity(self)
return valid
@@ -92,6 +95,7 @@ class Entity(_Object, abc.ABC):
def __init__(self, pos, bind_to=None, **kwargs):
super().__init__(**kwargs)
+ self._view_directory = c.VALUE_NO_POS
self._status = None
self._pos = pos
self._last_pos = pos
@@ -109,9 +113,6 @@ class Entity(_Object, abc.ABC):
def render(self):
return RenderEntity(self.__class__.__name__.lower(), self.pos)
- def __repr__(self):
- return super(Entity, self).__repr__() + f'(@{self.pos})'
-
@property
def obs_tag(self):
try:
@@ -128,25 +129,3 @@ class Entity(_Object, abc.ABC):
self._collection.delete_env_object(self)
self._collection = other_collection
return self._collection == other_collection
-
- @classmethod
- def from_coordinates(cls, positions: [(int, int)], *args, entity_kwargs=None, **kwargs, ):
- collection = cls(*args, **kwargs)
- collection.add_items(
- [cls._entity(tuple(pos), **entity_kwargs if entity_kwargs is not None else {}) for pos in positions])
- return collection
-
- def notify_del_entity(self, entity):
- try:
- self.pos_dict[entity.pos].remove(entity)
- except (ValueError, AttributeError):
- pass
-
- def by_pos(self, pos: (int, int)):
- pos = tuple(pos)
- try:
- return self.state.entities.pos_dict[pos]
- except StopIteration:
- pass
- except ValueError:
- print()
diff --git a/marl_factory_grid/environment/entity/mixin.py b/marl_factory_grid/environment/entity/mixin.py
deleted file mode 100644
index bab6343..0000000
--- a/marl_factory_grid/environment/entity/mixin.py
+++ /dev/null
@@ -1,24 +0,0 @@
-
-
-# noinspection PyAttributeOutsideInit
-class BoundEntityMixin:
-
- @property
- def bound_entity(self):
- return self._bound_entity
-
- @property
- def name(self):
- if self.bound_entity:
- return f'{self.__class__.__name__}({self.bound_entity.name})'
- else:
- pass
-
- def belongs_to_entity(self, entity):
- return entity == self.bound_entity
-
- def bind_to(self, entity):
- self._bound_entity = entity
-
- def unbind(self):
- self._bound_entity = None
diff --git a/marl_factory_grid/environment/entity/object.py b/marl_factory_grid/environment/entity/object.py
index 8810baf..e8c69da 100644
--- a/marl_factory_grid/environment/entity/object.py
+++ b/marl_factory_grid/environment/entity/object.py
@@ -5,7 +5,7 @@ from marl_factory_grid.environment import constants as c
import marl_factory_grid.utils.helpers as h
-class _Object:
+class Object:
"""Generell Objects for Organisation and Maintanance such as Actions etc..."""
_u_idx = defaultdict(lambda: 0)
@@ -13,10 +13,6 @@ class _Object:
def __bool__(self):
return True
- @property
- def var_has_position(self):
- return False
-
@property
def var_can_be_bound(self):
try:
@@ -30,22 +26,14 @@ class _Object:
@property
def name(self):
- if self._str_ident is not None:
- name = f'{self.__class__.__name__}[{self._str_ident}]'
- else:
- name = f'{self.__class__.__name__}#{self.u_int}'
- if self.bound_entity:
- name = h.add_bound_name(name, self.bound_entity)
- if self.var_has_position:
- name = h.add_pos_name(name, self)
- return name
+ return f'{self.__class__.__name__}[{self.identifier}]'
@property
def identifier(self):
if self._str_ident is not None:
return self._str_ident
else:
- return self.name
+ return self.u_int
def reset_uid(self):
self._u_idx = defaultdict(lambda: 0)
@@ -62,7 +50,15 @@ class _Object:
print(f'Following kwargs were passed, but ignored: {kwargs}')
def __repr__(self):
- return f'{self.name}'
+ name = self.name
+ if self.bound_entity:
+ name = h.add_bound_name(name, self.bound_entity)
+ try:
+ if self.var_has_position:
+ name = h.add_pos_name(name, self)
+ except AttributeError:
+ pass
+ return name
def __eq__(self, other) -> bool:
return other == self.identifier
@@ -71,8 +67,8 @@ class _Object:
return hash(self.identifier)
def _identify_and_count_up(self):
- idx = _Object._u_idx[self.__class__.__name__]
- _Object._u_idx[self.__class__.__name__] += 1
+ idx = Object._u_idx[self.__class__.__name__]
+ Object._u_idx[self.__class__.__name__] += 1
return idx
def set_collection(self, collection):
@@ -88,7 +84,7 @@ class _Object:
def summarize_state(self):
return dict()
- def bind(self, entity):
+ def bind_to(self, entity):
# noinspection PyAttributeOutsideInit
self._bound_entity = entity
return c.VALID
@@ -100,84 +96,5 @@ class _Object:
def bound_entity(self):
return self._bound_entity
- def bind_to(self, entity):
- self._bound_entity = entity
-
def unbind(self):
self._bound_entity = None
-
-
-# class EnvObject(_Object):
-# """Objects that hold Information that are observable, but have no position on the environment grid. Inventories etc..."""
-#
- # _u_idx = defaultdict(lambda: 0)
-#
-# @property
-# def obs_tag(self):
-# try:
-# return self._collection.name or self.name
-# except AttributeError:
-# return self.name
-#
-# @property
-# def var_is_blocking_light(self):
-# try:
-# return self._collection.var_is_blocking_light or False
-# except AttributeError:
-# return False
-#
-# @property
-# def var_can_be_bound(self):
-# try:
-# return self._collection.var_can_be_bound or False
-# except AttributeError:
-# return False
-#
-# @property
-# def var_can_move(self):
-# try:
-# return self._collection.var_can_move or False
-# except AttributeError:
-# return False
-#
-# @property
-# def var_is_blocking_pos(self):
-# try:
-# return self._collection.var_is_blocking_pos or False
-# except AttributeError:
-# return False
-#
-# @property
-# def var_has_position(self):
-# try:
-# return self._collection.var_has_position or False
-# except AttributeError:
-# return False
-#
-# @property
-# def var_can_collide(self):
-# try:
-# return self._collection.var_can_collide or False
-# except AttributeError:
-# return False
-#
-#
-# @property
-# def encoding(self):
-# return c.VALUE_OCCUPIED_CELL
-#
-#
-# def __init__(self, **kwargs):
-# self._bound_entity = None
-# super(EnvObject, self).__init__(**kwargs)
-#
-#
-# def change_parent_collection(self, other_collection):
-# other_collection.add_item(self)
-# self._collection.delete_env_object(self)
-# self._collection = other_collection
-# return self._collection == other_collection
-#
-#
-# def summarize_state(self):
-# return dict(name=str(self.name))
diff --git a/marl_factory_grid/environment/entity/util.py b/marl_factory_grid/environment/entity/util.py
index 1a5cbe3..2a15c41 100644
--- a/marl_factory_grid/environment/entity/util.py
+++ b/marl_factory_grid/environment/entity/util.py
@@ -1,6 +1,6 @@
import numpy as np
-from marl_factory_grid.environment.entity.object import _Object
+from marl_factory_grid.environment.entity.object import Object
##########################################################################
@@ -8,7 +8,7 @@ from marl_factory_grid.environment.entity.object import _Object
##########################################################################
-class PlaceHolder(_Object):
+class PlaceHolder(Object):
def __init__(self, *args, fill_value=0, **kwargs):
super().__init__(*args, **kwargs)
@@ -24,10 +24,10 @@ class PlaceHolder(_Object):
@property
def name(self):
- return "PlaceHolder"
+ return self.__class__.__name__
-class GlobalPosition(_Object):
+class GlobalPosition(Object):
@property
def encoding(self):
@@ -36,7 +36,8 @@ class GlobalPosition(_Object):
else:
return self.bound_entity.pos
- def __init__(self, level_shape, *args, normalized: bool = True, **kwargs):
+ def __init__(self, agent, level_shape, *args, normalized: bool = True, **kwargs):
super(GlobalPosition, self).__init__(*args, **kwargs)
+ self.bind_to(agent)
self._normalized = normalized
self._shape = level_shape
diff --git a/marl_factory_grid/environment/entity/wall.py b/marl_factory_grid/environment/entity/wall.py
index 3f0fb7c..83044cd 100644
--- a/marl_factory_grid/environment/entity/wall.py
+++ b/marl_factory_grid/environment/entity/wall.py
@@ -5,13 +5,8 @@ from marl_factory_grid.utils.utility_classes import RenderEntity
class Wall(Entity):
- @property
- def var_has_position(self):
- return True
-
- @property
- def var_can_collide(self):
- return True
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
@property
def encoding(self):
@@ -19,11 +14,3 @@ class Wall(Entity):
def render(self):
return RenderEntity(c.WALL, self.pos)
-
- @property
- def var_is_blocking_pos(self):
- return True
-
- @property
- def var_is_blocking_light(self):
- return True
diff --git a/marl_factory_grid/environment/factory.py b/marl_factory_grid/environment/factory.py
index 651444e..6662581 100644
--- a/marl_factory_grid/environment/factory.py
+++ b/marl_factory_grid/environment/factory.py
@@ -56,15 +56,18 @@ class Factory(gym.Env):
self.level_filepath = Path(custom_level_path)
else:
self.level_filepath = Path(__file__).parent.parent / h.LEVELS_DIR / f'{self.conf.level_name}.txt'
- self._renderer = None # expensive - don't use; unless required !
parsed_entities = self.conf.load_entities()
self.map = LevelParser(self.level_filepath, parsed_entities, self.conf.pomdp_r)
# Init for later usage:
- self.state: Gamestate
- self.map: LevelParser
- self.obs_builder: OBSBuilder
+ # noinspection PyTypeChecker
+ self.state: Gamestate = None
+ # noinspection PyTypeChecker
+ self.obs_builder: OBSBuilder = None
+
+ # expensive - don't use; unless required !
+ self._renderer = None
# reset env to initial state, preparing env for new episode.
# returns tuple where the first dict contains initial observation for each agent in the env
@@ -74,7 +77,7 @@ class Factory(gym.Env):
return self.state.entities[item]
def reset(self) -> (dict, dict):
- if hasattr(self, 'state'):
+ if self.state is not None:
for entity_group in self.state.entities:
try:
entity_group[0].reset_uid()
@@ -87,12 +90,16 @@ class Factory(gym.Env):
entities = self.map.do_init()
# Init rules
- rules = self.conf.load_env_rules()
+ env_rules = self.conf.load_env_rules()
+ entity_rules = self.conf.load_entity_spawn_rules(entities)
+ env_rules.extend(entity_rules)
+
env_tests = self.conf.load_env_tests() if self.conf.tests else []
# Parse the agent conf
parsed_agents_conf = self.conf.parse_agents_conf()
- self.state = Gamestate(entities, parsed_agents_conf, rules, env_tests, self.conf.env_seed, self.conf.verbose)
+ self.state = Gamestate(entities, parsed_agents_conf, env_rules, env_tests, self.map.level_shape,
+ self.conf.env_seed, self.conf.verbose)
# All is set up, trigger entity init with variable pos
# All is set up, trigger additional init (after agent entity spawn etc)
@@ -160,7 +167,7 @@ class Factory(gym.Env):
# Finalize
reward, reward_info, done = self.summarize_step_results(tick_result, done_results)
- info = reward_info
+ info = dict(reward_info)
info.update(step_reward=sum(reward), step=self.state.curr_step)
diff --git a/marl_factory_grid/environment/groups/agents.py b/marl_factory_grid/environment/groups/agents.py
index f4a6ac6..d549384 100644
--- a/marl_factory_grid/environment/groups/agents.py
+++ b/marl_factory_grid/environment/groups/agents.py
@@ -1,10 +1,15 @@
from marl_factory_grid.environment.entity.agent import Agent
from marl_factory_grid.environment.groups.collection import Collection
+from marl_factory_grid.environment.rules import SpawnAgents
class Agents(Collection):
_entity = Agent
+ @property
+ def spawn_rule(self):
+ return {SpawnAgents.__name__: {}}
+
@property
def var_is_blocking_light(self):
return False
diff --git a/marl_factory_grid/environment/groups/collection.py b/marl_factory_grid/environment/groups/collection.py
index 640c3b4..c0f0f6b 100644
--- a/marl_factory_grid/environment/groups/collection.py
+++ b/marl_factory_grid/environment/groups/collection.py
@@ -1,18 +1,25 @@
-from typing import List, Tuple, Union
+from typing import List, Tuple, Union, Dict
from marl_factory_grid.environment.entity.entity import Entity
-from marl_factory_grid.environment.groups.objects import _Objects
-from marl_factory_grid.environment.entity.object import _Object
+from marl_factory_grid.environment.groups.objects import Objects
+# noinspection PyProtectedMember
+from marl_factory_grid.environment.entity.object import Object
import marl_factory_grid.environment.constants as c
+from marl_factory_grid.utils.results import Result
-class Collection(_Objects):
- _entity = _Object # entity?
+class Collection(Objects):
+ _entity = Object # entity?
+ symbol = None
@property
def var_is_blocking_light(self):
return False
+ @property
+ def var_is_blocking_pos(self):
+ return False
+
@property
def var_can_collide(self):
return False
@@ -23,33 +30,65 @@ class Collection(_Objects):
@property
def var_has_position(self):
- return False
-
- # @property
- # def var_has_bound(self):
- # return False # batteries, globalpos, inventories true
-
- @property
- def var_can_be_bound(self):
- return False
+ return True
@property
def encodings(self):
return [x.encoding for x in self]
- def __init__(self, size, *args, **kwargs):
- super(Collection, self).__init__(*args, **kwargs)
- self.size = size
-
- def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args): # woihn mit den args
- if isinstance(coords_or_quantity, int):
- self.add_items([self._entity() for _ in range(coords_or_quantity)])
+ @property
+ def spawn_rule(self):
+ """Prevent SpawnRule creation if Objects are spawned by map, Doors e.g."""
+ if self.symbol:
+ return None
+ elif self._spawnrule:
+ return self._spawnrule
else:
- self.add_items([self._entity(pos) for pos in coords_or_quantity])
+ return {c.SPAWN_ENTITY_RULE: dict(collection=self, coords_or_quantity=self._coords_or_quantity)}
+
+ def __init__(self, size, *args, coords_or_quantity: int = None, ignore_blocking=False,
+ spawnrule: Union[None, Dict[str, dict]] = None,
+ **kwargs):
+ super(Collection, self).__init__(*args, **kwargs)
+ self._coords_or_quantity = coords_or_quantity
+ self.size = size
+ self._spawnrule = spawnrule
+ self._ignore_blocking = ignore_blocking
+
+ def trigger_spawn(self, state, *entity_args, coords_or_quantity=None, ignore_blocking=False, **entity_kwargs):
+ coords_or_quantity = coords_or_quantity if coords_or_quantity else self._coords_or_quantity
+ if self.var_has_position:
+ if self.var_has_position and isinstance(coords_or_quantity, int):
+ if ignore_blocking or self._ignore_blocking:
+ coords_or_quantity = state.entities.floorlist[:coords_or_quantity]
+ else:
+ coords_or_quantity = state.get_n_random_free_positions(coords_or_quantity)
+ self.spawn(coords_or_quantity, *entity_args, **entity_kwargs)
+ state.print(f'{len(coords_or_quantity)} new {self.name} have been spawned at {coords_or_quantity}')
+ return Result(identifier=f'{self.name}_spawn', validity=c.VALID, value=len(coords_or_quantity))
+ else:
+ if isinstance(coords_or_quantity, int):
+ self.spawn(coords_or_quantity, *entity_args, **entity_kwargs)
+ state.print(f'{coords_or_quantity} new {self.name} have been spawned randomly.')
+ return Result(identifier=f'{self.name}_spawn', validity=c.VALID, value=coords_or_quantity)
+ else:
+ raise ValueError(f'{self._entity.__name__} has no position!')
+
+ def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args, **entity_kwargs):
+ if self.var_has_position:
+ if isinstance(coords_or_quantity, int):
+ raise ValueError(f'{self._entity.__name__} should have a position!')
+ else:
+ self.add_items([self._entity(pos, *entity_args, **entity_kwargs) for pos in coords_or_quantity])
+ else:
+ if isinstance(coords_or_quantity, int):
+ self.add_items([self._entity(*entity_args, **entity_kwargs) for _ in range(coords_or_quantity)])
+ else:
+ raise ValueError(f'{self._entity.__name__} has no position!')
return c.VALID
- def despawn(self, items: List[_Object]):
- items = [items] if isinstance(items, _Object) else items
+ def despawn(self, items: List[Object]):
+ items = [items] if isinstance(items, Object) else items
for item in items:
del self[item]
@@ -115,7 +154,7 @@ class Collection(_Objects):
except StopIteration:
pass
except ValueError:
- print()
+ pass
@property
def positions(self):
diff --git a/marl_factory_grid/environment/groups/global_entities.py b/marl_factory_grid/environment/groups/global_entities.py
index 8bfc9fe..37779f9 100644
--- a/marl_factory_grid/environment/groups/global_entities.py
+++ b/marl_factory_grid/environment/groups/global_entities.py
@@ -1,21 +1,21 @@
from collections import defaultdict
from operator import itemgetter
-from random import shuffle, random
+from random import shuffle
from typing import Dict
-from marl_factory_grid.environment.groups.objects import _Objects
+from marl_factory_grid.environment.groups.objects import Objects
from marl_factory_grid.utils.helpers import POS_MASK
-class Entities(_Objects):
- _entity = _Objects
+class Entities(Objects):
+ _entity = Objects
@staticmethod
def neighboring_positions(pos):
- return (POS_MASK + pos).reshape(-1, 2)
+ return [tuple(x) for x in (POS_MASK + pos).reshape(-1, 2)]
def get_entities_near_pos(self, pos):
- return [y for x in itemgetter(*(tuple(x) for x in self.neighboring_positions(pos)))(self.pos_dict) for y in x]
+ return [y for x in itemgetter(*self.neighboring_positions(pos))(self.pos_dict) for y in x]
def render(self):
return [y for x in self for y in x.render() if x is not None]
@@ -35,8 +35,9 @@ class Entities(_Objects):
super().__init__()
def guests_that_can_collide(self, pos):
- return[x for val in self.pos_dict[pos] for x in val if x.var_can_collide]
+ return [x for val in self.pos_dict[pos] for x in val if x.var_can_collide]
+ @property
def empty_positions(self):
empty_positions = [key for key in self.floorlist if not self.pos_dict[key]]
shuffle(empty_positions)
@@ -48,11 +49,23 @@ class Entities(_Objects):
shuffle(empty_positions)
return empty_positions
- def is_blocked(self):
- return[key for key, val in self.pos_dict.items() if any([x.var_is_blocking_pos for x in val])]
+ @property
+ def blocked_positions(self):
+ blocked_positions = [key for key, val in self.pos_dict.items() if any([x.var_is_blocking_pos for x in val])]
+ shuffle(blocked_positions)
+ return blocked_positions
- def is_not_blocked(self):
- return[key for key, val in self.pos_dict.items() if not all([x.var_is_blocking_pos for x in val])]
+ @property
+ def free_positions_generator(self):
+ generator = (
+ key for key in self.floorlist if all(not x.var_can_collide and not x.var_is_blocking_pos
+ for x in self.pos_dict[key])
+ )
+ return generator
+
+ @property
+ def free_positions_list(self):
+ return [x for x in self.free_positions_generator]
def iter_entities(self):
return iter((x for sublist in self.values() for x in sublist))
@@ -74,7 +87,7 @@ class Entities(_Objects):
def __delitem__(self, name):
assert_str = 'This group of entity does not exist in this collection!'
assert any([key for key in name.keys() if key in self.keys()]), assert_str
- self[name]._observers.delete(self)
+ self[name].del_observer(self)
for entity in self[name]:
entity.del_observer(self)
return super(Entities, self).__delitem__(name)
@@ -92,3 +105,6 @@ class Entities(_Objects):
@property
def positions(self):
return [k for k, v in self.pos_dict.items() for _ in v]
+
+ def is_occupied(self, pos):
+ return len([x for x in self.pos_dict[pos] if x.var_can_collide or x.var_is_blocking_pos]) >= 1
diff --git a/marl_factory_grid/environment/groups/mixins.py b/marl_factory_grid/environment/groups/mixins.py
index 48333ca..acfac7e 100644
--- a/marl_factory_grid/environment/groups/mixins.py
+++ b/marl_factory_grid/environment/groups/mixins.py
@@ -4,10 +4,6 @@ from marl_factory_grid.environment import constants as c
# noinspection PyUnresolvedReferences,PyTypeChecker
class IsBoundMixin:
- @property
- def name(self):
- return f'{self.__class__.__name__}({self._bound_entity.name})'
-
def __repr__(self):
return f'{self.__class__.__name__}#{self._bound_entity.name}({self._data})'
diff --git a/marl_factory_grid/environment/groups/objects.py b/marl_factory_grid/environment/groups/objects.py
index d3f32af..9229787 100644
--- a/marl_factory_grid/environment/groups/objects.py
+++ b/marl_factory_grid/environment/groups/objects.py
@@ -1,14 +1,19 @@
from collections import defaultdict
-from typing import List
+from typing import List, Iterator, Union
import numpy as np
-from marl_factory_grid.environment.entity.object import _Object
+from marl_factory_grid.environment.entity.object import Object
import marl_factory_grid.environment.constants as c
+from marl_factory_grid.utils import helpers as h
-class _Objects:
- _entity = _Object
+class Objects:
+ _entity = Object
+
+ @property
+ def var_can_be_bound(self):
+ return False
@property
def observers(self):
@@ -45,7 +50,7 @@ class _Objects:
def __len__(self):
return len(self._data)
- def __iter__(self):
+ def __iter__(self) -> Iterator[Union[Object, None]]:
return iter(self.values())
def add_item(self, item: _entity):
@@ -125,13 +130,14 @@ class _Objects:
repr_dict = {key: val for key, val in self._data.items() if key not in [c.WALLS]}
return f'{self.__class__.__name__}[{repr_dict}]'
- def notify_del_entity(self, entity: _Object):
+ def notify_del_entity(self, entity: Object):
try:
+ # noinspection PyUnresolvedReferences
self.pos_dict[entity.pos].remove(entity)
except (AttributeError, ValueError, IndexError):
pass
- def notify_add_entity(self, entity: _Object):
+ def notify_add_entity(self, entity: Object):
try:
if self not in entity.observers:
entity.add_observer(self)
@@ -148,12 +154,12 @@ class _Objects:
def by_entity(self, entity):
try:
- return next((x for x in self if x.belongs_to_entity(entity)))
+ return h.get_first(self, filter_by=lambda x: x.belongs_to_entity(entity))
except (StopIteration, AttributeError):
return None
def idx_by_entity(self, entity):
try:
- return next((idx for idx, x in enumerate(self) if x.belongs_to_entity(entity)))
+ return h.get_first_index(self, filter_by=lambda x: x.belongs_to_entity(entity))
except (StopIteration, AttributeError):
return None
diff --git a/marl_factory_grid/environment/groups/utils.py b/marl_factory_grid/environment/groups/utils.py
index 5619041..d272152 100644
--- a/marl_factory_grid/environment/groups/utils.py
+++ b/marl_factory_grid/environment/groups/utils.py
@@ -1,7 +1,10 @@
from typing import List, Union
+from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.entity.util import GlobalPosition
from marl_factory_grid.environment.groups.collection import Collection
+from marl_factory_grid.utils.results import Result
+from marl_factory_grid.utils.states import Gamestate
class Combined(Collection):
@@ -36,17 +39,17 @@ class GlobalPositions(Collection):
_entity = GlobalPosition
- @property
- def var_is_blocking_light(self):
- return False
-
- @property
- def var_can_collide(self):
- return False
-
- @property
- def var_can_be_bound(self):
- return True
+ var_is_blocking_light = False
+ var_can_be_bound = True
+ var_can_collide = False
+ var_has_position = False
def __init__(self, *args, **kwargs):
super(GlobalPositions, self).__init__(*args, **kwargs)
+
+ def spawn(self, agents, level_shape, *args, **kwargs):
+ self.add_items([self._entity(agent, level_shape, *args, **kwargs) for agent in agents])
+ return [Result(identifier=f'{self.name}_spawn', validity=c.VALID, value=len(self))]
+
+ def trigger_spawn(self, state: Gamestate, *args, **kwargs) -> [Result]:
+ return self.spawn(state[c.AGENT], state.lvl_shape, *args, **kwargs)
diff --git a/marl_factory_grid/environment/groups/walls.py b/marl_factory_grid/environment/groups/walls.py
index 2d85362..776bbca 100644
--- a/marl_factory_grid/environment/groups/walls.py
+++ b/marl_factory_grid/environment/groups/walls.py
@@ -7,9 +7,12 @@ class Walls(Collection):
_entity = Wall
symbol = c.SYMBOL_WALL
- @property
- def var_has_position(self):
- return True
+ var_can_collide = True
+ var_is_blocking_light = True
+ var_can_move = False
+ var_has_position = True
+ var_can_be_bound = False
+ var_is_blocking_pos = True
def __init__(self, *args, **kwargs):
super(Walls, self).__init__(*args, **kwargs)
diff --git a/marl_factory_grid/environment/rewards.py b/marl_factory_grid/environment/rewards.py
index b3ebe8c..aa0acbd 100644
--- a/marl_factory_grid/environment/rewards.py
+++ b/marl_factory_grid/environment/rewards.py
@@ -2,3 +2,4 @@ MOVEMENTS_VALID: float = -0.001
MOVEMENTS_FAIL: float = -0.05
NOOP: float = -0.01
COLLISION: float = -0.5
+COLLISION_DONE: float = -1
diff --git a/marl_factory_grid/environment/rules.py b/marl_factory_grid/environment/rules.py
index f9678b0..f5b6836 100644
--- a/marl_factory_grid/environment/rules.py
+++ b/marl_factory_grid/environment/rules.py
@@ -1,11 +1,11 @@
import abc
from random import shuffle
-from typing import List
+from typing import List, Collection
+from marl_factory_grid.environment import rewards as r, constants as c
from marl_factory_grid.environment.entity.agent import Agent
from marl_factory_grid.utils import helpers as h
from marl_factory_grid.utils.results import TickResult, DoneResult
-from marl_factory_grid.environment import rewards as r, constants as c
class Rule(abc.ABC):
@@ -39,6 +39,29 @@ class Rule(abc.ABC):
return []
+class SpawnEntity(Rule):
+
+ @property
+ def _collection(self) -> Collection:
+ return Collection()
+
+ @property
+ def name(self):
+ return f'{self.__class__.__name__}({self.collection.name})'
+
+ def __init__(self, collection, coords_or_quantity, ignore_blocking=False):
+ super().__init__()
+ self.coords_or_quantity = coords_or_quantity
+ self.collection = collection
+ self.ignore_blocking = ignore_blocking
+
+ def on_init(self, state, lvl_map) -> [TickResult]:
+ results = self.collection.trigger_spawn(state, ignore_blocking=self.ignore_blocking)
+ pos_str = f' on: {[x.pos for x in self.collection]}' if self.collection.var_has_position else ''
+ state.print(f'Initial {self.collection.__class__.__name__} were spawned{pos_str}')
+ return results
+
+
class SpawnAgents(Rule):
def __init__(self):
@@ -46,14 +69,14 @@ class SpawnAgents(Rule):
pass
def on_init(self, state, lvl_map):
- agent_conf = state.agents_conf
# agents = Agents(lvl_map.size)
agents = state[c.AGENT]
- empty_positions = state.entities.empty_positions()[:len(agent_conf)]
- for agent_name in agent_conf:
- actions = agent_conf[agent_name]['actions'].copy()
- observations = agent_conf[agent_name]['observations'].copy()
- positions = agent_conf[agent_name]['positions'].copy()
+ empty_positions = state.entities.empty_positions[:len(state.agents_conf)]
+ for agent_name, agent_conf in state.agents_conf.items():
+ actions = agent_conf['actions'].copy()
+ observations = agent_conf['observations'].copy()
+ positions = agent_conf['positions'].copy()
+ other = agent_conf['other'].copy()
if positions:
shuffle(positions)
while True:
@@ -61,18 +84,18 @@ class SpawnAgents(Rule):
pos = positions.pop()
except IndexError:
raise ValueError(f'It was not possible to spawn an Agent on the available position: '
- f'\n{agent_name[agent_name]["positions"].copy()}')
- if agents.by_pos(pos) and state.check_pos_validity(pos):
+ f'\n{agent_conf["positions"].copy()}')
+ if bool(agents.by_pos(pos)) or not state.check_pos_validity(pos):
continue
else:
- agents.add_item(Agent(actions, observations, pos, str_ident=agent_name))
+ agents.add_item(Agent(actions, observations, pos, str_ident=agent_name, **other))
break
else:
- agents.add_item(Agent(actions, observations, empty_positions.pop(), str_ident=agent_name))
+ agents.add_item(Agent(actions, observations, empty_positions.pop(), str_ident=agent_name, **other))
pass
-class MaxStepsReached(Rule):
+class DoneAtMaxStepsReached(Rule):
def __init__(self, max_steps: int = 500):
super().__init__()
@@ -83,8 +106,8 @@ class MaxStepsReached(Rule):
def on_check_done(self, state):
if self.max_steps <= state.curr_step:
- return [DoneResult(validity=c.VALID, identifier=self.name, reward=0)]
- return [DoneResult(validity=c.NOT_VALID, identifier=self.name, reward=0)]
+ return [DoneResult(validity=c.VALID, identifier=self.name)]
+ return [DoneResult(validity=c.NOT_VALID, identifier=self.name)]
class AssignGlobalPositions(Rule):
@@ -95,16 +118,17 @@ class AssignGlobalPositions(Rule):
def on_init(self, state, lvl_map):
from marl_factory_grid.environment.entity.util import GlobalPosition
for agent in state[c.AGENT]:
- gp = GlobalPosition(lvl_map.level_shape)
- gp.bind_to(agent)
+ gp = GlobalPosition(agent, lvl_map.level_shape)
state[c.GLOBALPOSITIONS].add_item(gp)
return []
-class Collision(Rule):
+class WatchCollisions(Rule):
- def __init__(self, done_at_collisions: bool = False):
+ def __init__(self, reward=r.COLLISION, done_at_collisions: bool = False, reward_at_done=r.COLLISION_DONE):
super().__init__()
+ self.reward_at_done = reward_at_done
+ self.reward = reward
self.done_at_collisions = done_at_collisions
self.curr_done = False
@@ -117,12 +141,12 @@ class Collision(Rule):
if len(guests) >= 2:
for i, guest in enumerate(guests):
try:
- guest.set_state(TickResult(identifier=c.COLLISION, reward=r.COLLISION,
+ guest.set_state(TickResult(identifier=c.COLLISION, reward=self.reward,
validity=c.NOT_VALID, entity=self))
except AttributeError:
pass
results.append(TickResult(entity=guest, identifier=c.COLLISION,
- reward=r.COLLISION, validity=c.VALID))
+ reward=self.reward, validity=c.VALID))
self.curr_done = True if self.done_at_collisions else False
return results
@@ -131,5 +155,5 @@ class Collision(Rule):
inter_entity_collision_detected = self.curr_done
move_failed = any(h.is_move(x.state.identifier) and not x.state.validity for x in state[c.AGENT])
if inter_entity_collision_detected or move_failed:
- return [DoneResult(validity=c.VALID, identifier=c.COLLISION, reward=r.COLLISION)]
- return [DoneResult(validity=c.NOT_VALID, identifier=self.name, reward=0)]
+ return [DoneResult(validity=c.VALID, identifier=c.COLLISION, reward=self.reward_at_done)]
+ return []
diff --git a/marl_factory_grid/modules/_template/rules.py b/marl_factory_grid/modules/_template/rules.py
index 6ed2f2d..7696616 100644
--- a/marl_factory_grid/modules/_template/rules.py
+++ b/marl_factory_grid/modules/_template/rules.py
@@ -6,7 +6,9 @@ from marl_factory_grid.utils.results import TickResult, DoneResult
class TemplateRule(Rule):
def __init__(self, *args, **kwargs):
- super(TemplateRule, self).__init__(*args, **kwargs)
+ super(TemplateRule, self).__init__()
+ self.args = args
+ self.kwargs = kwargs
def on_init(self, state, lvl_map):
pass
diff --git a/marl_factory_grid/modules/batteries/__init__.py b/marl_factory_grid/modules/batteries/__init__.py
index 0218021..80671fd 100644
--- a/marl_factory_grid/modules/batteries/__init__.py
+++ b/marl_factory_grid/modules/batteries/__init__.py
@@ -1,4 +1,4 @@
from .actions import BtryCharge
-from .entitites import Pod, Battery
+from .entitites import ChargePod, Battery
from .groups import ChargePods, Batteries
from .rules import DoneAtBatteryDischarge, BatteryDecharge
diff --git a/marl_factory_grid/modules/batteries/actions.py b/marl_factory_grid/modules/batteries/actions.py
index 343bbcc..7d1c4a2 100644
--- a/marl_factory_grid/modules/batteries/actions.py
+++ b/marl_factory_grid/modules/batteries/actions.py
@@ -1,11 +1,11 @@
from typing import Union
-import marl_factory_grid.modules.batteries.constants
from marl_factory_grid.environment.actions import Action
from marl_factory_grid.utils.results import ActionResult
from marl_factory_grid.modules.batteries import constants as b
from marl_factory_grid.environment import constants as c
+from marl_factory_grid.utils import helpers as h
class BtryCharge(Action):
@@ -14,8 +14,8 @@ class BtryCharge(Action):
super().__init__(b.ACTION_CHARGE)
def do(self, entity, state) -> Union[None, ActionResult]:
- if charge_pod := state[b.CHARGE_PODS].by_pos(entity.pos):
- valid = charge_pod.charge_battery(state[b.BATTERIES].by_entity(entity))
+ if charge_pod := h.get_first(state[b.CHARGE_PODS].by_pos(entity.pos)):
+ valid = h.get_first(charge_pod.charge_battery(state[b.BATTERIES].by_entity(entity)))
if valid:
state.print(f'{entity.name} just charged batteries at {charge_pod.name}.')
else:
@@ -23,5 +23,6 @@ class BtryCharge(Action):
else:
valid = c.NOT_VALID
state.print(f'{entity.name} failed to charged batteries at {entity.pos}.')
+
return ActionResult(entity=entity, identifier=self._identifier, validity=valid,
- reward=marl_factory_grid.modules.batteries.constants.REWARD_CHARGE_VALID if valid else marl_factory_grid.modules.batteries.constants.Reward_CHARGE_FAIL)
+ reward=b.REWARD_CHARGE_VALID if valid else b.Reward_CHARGE_FAIL)
diff --git a/marl_factory_grid/modules/batteries/chargepods.png b/marl_factory_grid/modules/batteries/chargepods.png
new file mode 100644
index 0000000..7221daa
Binary files /dev/null and b/marl_factory_grid/modules/batteries/chargepods.png differ
diff --git a/marl_factory_grid/modules/batteries/entitites.py b/marl_factory_grid/modules/batteries/entitites.py
index b51f2dd..7675fe9 100644
--- a/marl_factory_grid/modules/batteries/entitites.py
+++ b/marl_factory_grid/modules/batteries/entitites.py
@@ -1,11 +1,11 @@
from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.entity.entity import Entity
-from marl_factory_grid.environment.entity.object import _Object
+from marl_factory_grid.environment.entity.object import Object
from marl_factory_grid.modules.batteries import constants as b
from marl_factory_grid.utils.utility_classes import RenderEntity
-class Battery(_Object):
+class Battery(Object):
@property
def var_can_be_bound(self):
@@ -50,7 +50,7 @@ class Battery(_Object):
return summary
-class Pod(Entity):
+class ChargePod(Entity):
@property
def encoding(self):
@@ -58,7 +58,7 @@ class Pod(Entity):
def __init__(self, *args, charge_rate: float = 0.4,
multi_charge: bool = False, **kwargs):
- super(Pod, self).__init__(*args, **kwargs)
+ super(ChargePod, self).__init__(*args, **kwargs)
self.charge_rate = charge_rate
self.multi_charge = multi_charge
diff --git a/marl_factory_grid/modules/batteries/groups.py b/marl_factory_grid/modules/batteries/groups.py
index 8d9e060..7db43bd 100644
--- a/marl_factory_grid/modules/batteries/groups.py
+++ b/marl_factory_grid/modules/batteries/groups.py
@@ -1,52 +1,36 @@
from typing import Union, List, Tuple
+from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.groups.collection import Collection
-from marl_factory_grid.modules.batteries.entitites import Pod, Battery
+from marl_factory_grid.modules.batteries.entitites import ChargePod, Battery
+from marl_factory_grid.utils.results import Result
class Batteries(Collection):
_entity = Battery
- @property
- def var_is_blocking_light(self):
- return False
-
- @property
- def var_can_collide(self):
- return False
-
- @property
- def var_can_move(self):
- return False
-
- @property
- def var_has_position(self):
- return False
-
- @property
- def var_can_be_bound(self):
- return True
+ var_has_position = False
+ var_can_be_bound = True
@property
def obs_tag(self):
return self.__class__.__name__
- def __init__(self, *args, **kwargs):
- super(Batteries, self).__init__(*args, **kwargs)
+ def __init__(self, size, initial_charge_level: float=1.0, *args, **kwargs):
+ super(Batteries, self).__init__(size, *args, **kwargs)
+ self.initial_charge_level = initial_charge_level
- def spawn(self, agents, initial_charge_level):
- batteries = [self._entity(initial_charge_level, agent) for _, agent in enumerate(agents)]
+ def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], agents, *entity_args, **entity_kwargs):
+ batteries = [self._entity(self.initial_charge_level, agent) for _, agent in enumerate(agents)]
self.add_items(batteries)
- # def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args): hat keine pos
- # agents = entity_args[0]
- # initial_charge_level = entity_args[1]
- # batteries = [self._entity(initial_charge_level, agent) for _, agent in enumerate(agents)]
- # self.add_items(batteries)
+ def trigger_spawn(self, state, *entity_args, coords_or_quantity=None, **entity_kwargs):
+ self.spawn(0, state[c.AGENT])
+ return Result(identifier=f'{self.name}_spawn', validity=c.VALID, value=len(self))
class ChargePods(Collection):
- _entity = Pod
+ _entity = ChargePod
def __init__(self, *args, **kwargs):
super(ChargePods, self).__init__(*args, **kwargs)
diff --git a/marl_factory_grid/modules/batteries/rules.py b/marl_factory_grid/modules/batteries/rules.py
index e060629..8a4725b 100644
--- a/marl_factory_grid/modules/batteries/rules.py
+++ b/marl_factory_grid/modules/batteries/rules.py
@@ -1,11 +1,9 @@
from typing import List, Union
-import marl_factory_grid.modules.batteries.constants
-from marl_factory_grid.environment.rules import Rule
-from marl_factory_grid.utils.results import TickResult, DoneResult
-
from marl_factory_grid.environment import constants as c
+from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.modules.batteries import constants as b
+from marl_factory_grid.utils.results import TickResult, DoneResult
class BatteryDecharge(Rule):
@@ -49,10 +47,6 @@ class BatteryDecharge(Rule):
self.per_action_costs = per_action_costs
self.initial_charge = initial_charge
- def on_init(self, state, lvl_map): # on reset?
- assert len(state[c.AGENT]), "There are no agents, did you already spawn them?"
- state[b.BATTERIES].spawn(state[c.AGENT], self.initial_charge)
-
def tick_step(self, state) -> List[TickResult]:
# Decharge
batteries = state[b.BATTERIES]
@@ -66,7 +60,7 @@ class BatteryDecharge(Rule):
batteries.by_entity(agent).decharge(energy_consumption)
- results.append(TickResult(self.name, reward=0, entity=agent, validity=c.VALID))
+ results.append(TickResult(self.name, entity=agent, validity=c.VALID))
return results
@@ -82,13 +76,13 @@ class BatteryDecharge(Rule):
if self.paralyze_agents_on_discharge:
btry.bound_entity.paralyze(self.name)
results.append(
- TickResult("Paralyzed", entity=btry.bound_entity, reward=0, validity=c.VALID)
+ TickResult("Paralyzed", entity=btry.bound_entity, validity=c.VALID)
)
state.print(f'{btry.bound_entity.name} has just been paralyzed!')
if btry.bound_entity.var_is_paralyzed and not btry.is_discharged:
btry.bound_entity.de_paralyze(self.name)
results.append(
- TickResult("De-Paralyzed", entity=btry.bound_entity, reward=0, validity=c.VALID)
+ TickResult("De-Paralyzed", entity=btry.bound_entity, validity=c.VALID)
)
state.print(f'{btry.bound_entity.name} has just been de-paralyzed!')
return results
@@ -132,7 +126,7 @@ class DoneAtBatteryDischarge(BatteryDecharge):
if any_discharged or all_discharged:
return [DoneResult(self.name, validity=c.VALID, reward=self.reward_discharge_done)]
else:
- return [DoneResult(self.name, validity=c.NOT_VALID, reward=0)]
+ return [DoneResult(self.name, validity=c.NOT_VALID)]
class SpawnChargePods(Rule):
@@ -155,7 +149,7 @@ class SpawnChargePods(Rule):
def on_init(self, state, lvl_map):
pod_collection = state[b.CHARGE_PODS]
- empty_positions = state.entities.empty_positions()
+ empty_positions = state.entities.empty_positions
pods = pod_collection.from_coordinates(empty_positions, entity_kwargs=dict(
multi_charge=self.multi_charge, charge_rate=self.charge_rate)
)
diff --git a/marl_factory_grid/modules/clean_up/__init__.py b/marl_factory_grid/modules/clean_up/__init__.py
index 31cb841..ec4d1e7 100644
--- a/marl_factory_grid/modules/clean_up/__init__.py
+++ b/marl_factory_grid/modules/clean_up/__init__.py
@@ -1,4 +1,4 @@
from .actions import CleanUp
from .entitites import DirtPile
from .groups import DirtPiles
-from .rules import SpawnDirt, EntitiesSmearDirtOnMove, DoneOnAllDirtCleaned
+from .rules import EntitiesSmearDirtOnMove, DoneOnAllDirtCleaned
diff --git a/marl_factory_grid/modules/clean_up/entitites.py b/marl_factory_grid/modules/clean_up/entitites.py
index 8ac8a0c..25c6eb1 100644
--- a/marl_factory_grid/modules/clean_up/entitites.py
+++ b/marl_factory_grid/modules/clean_up/entitites.py
@@ -1,5 +1,3 @@
-from numpy import random
-
from marl_factory_grid.environment.entity.entity import Entity
from marl_factory_grid.utils.utility_classes import RenderEntity
from marl_factory_grid.modules.clean_up import constants as d
@@ -7,22 +5,6 @@ from marl_factory_grid.modules.clean_up import constants as d
class DirtPile(Entity):
- @property
- def var_can_collide(self):
- return False
-
- @property
- def var_can_move(self):
- return False
-
- @property
- def var_is_blocking_light(self):
- return False
-
- @property
- def var_has_position(self):
- return True
-
@property
def amount(self):
return self._amount
diff --git a/marl_factory_grid/modules/clean_up/groups.py b/marl_factory_grid/modules/clean_up/groups.py
index 63e5898..7ae3247 100644
--- a/marl_factory_grid/modules/clean_up/groups.py
+++ b/marl_factory_grid/modules/clean_up/groups.py
@@ -1,76 +1,61 @@
-from typing import Union, List, Tuple
-
from marl_factory_grid.environment import constants as c
-from marl_factory_grid.utils.results import Result
from marl_factory_grid.environment.groups.collection import Collection
from marl_factory_grid.modules.clean_up.entitites import DirtPile
+from marl_factory_grid.utils.results import Result
class DirtPiles(Collection):
_entity = DirtPile
- @property
- def var_is_blocking_light(self):
- return False
+ var_is_blocking_light = False
+ var_can_collide = False
+ var_can_move = False
+ var_has_position = True
@property
- def var_can_collide(self):
- return False
-
- @property
- def var_can_move(self):
- return False
-
- @property
- def var_has_position(self):
- return True
-
- @property
- def amount(self):
+ def global_amount(self):
return sum([dirt.amount for dirt in self])
def __init__(self, *args,
max_local_amount=5,
clean_amount=1,
- max_global_amount: int = 20, **kwargs):
+ max_global_amount: int = 20,
+ coords_or_quantity=10,
+ initial_amount=2,
+ amount_var=0.2,
+ n_var=0.2,
+ **kwargs):
super(DirtPiles, self).__init__(*args, **kwargs)
+ self.amount_var = amount_var
+ self.n_var = n_var
self.clean_amount = clean_amount
self.max_global_amount = max_global_amount
self.max_local_amount = max_local_amount
+ self.coords_or_quantity = coords_or_quantity
+ self.initial_amount = initial_amount
- def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args):
- amount_s = entity_args[0]
+ def trigger_spawn(self, state, coords_or_quantity=0, amount=0) -> [Result]:
+ coords_or_quantity = coords_or_quantity if coords_or_quantity else self.coords_or_quantity
+ n_new = int(abs(coords_or_quantity + (state.rng.uniform(-self.n_var, self.n_var))))
+ n_new = state.get_n_random_free_positions(n_new)
+
+ amounts = [amount if amount else (self.initial_amount + state.rng.uniform(-self.amount_var, self.amount_var))
+ for _ in range(coords_or_quantity)]
spawn_counter = 0
- for idx, pos in enumerate(coords_or_quantity):
- if not self.amount > self.max_global_amount:
- amount = amount_s[idx] if isinstance(amount_s, list) else amount_s
+ for idx, (pos, a) in enumerate(zip(n_new, amounts)):
+ if not self.global_amount > self.max_global_amount:
if dirt := self.by_pos(pos):
dirt = next(dirt.iter())
- new_value = dirt.amount + amount
+ new_value = dirt.amount + a
dirt.set_new_amount(new_value)
else:
- dirt = DirtPile(pos, amount=amount)
- self.add_item(dirt)
+ super().spawn([pos], amount=a)
spawn_counter += 1
else:
- return Result(identifier=f'{self.name}_spawn', validity=c.NOT_VALID, reward=0,
- value=spawn_counter)
- return Result(identifier=f'{self.name}_spawn', validity=c.VALID, reward=0, value=spawn_counter)
+ return Result(identifier=f'{self.name}_spawn', validity=c.NOT_VALID, value=spawn_counter)
- def trigger_dirt_spawn(self, n, amount, state, n_var=0.2, amount_var=0.2) -> Result:
- free_for_dirt = [x for x in state.entities.floorlist if len(state.entities.pos_dict[x]) == 0 or (
- len(state.entities.pos_dict[x]) >= 1 and isinstance(next(y for y in x), DirtPile))]
- # free_for_dirt = [x for x in state[c.FLOOR]
- # if len(x.guests) == 0 or (
- # len(x.guests) == 1 and
- # isinstance(next(y for y in x.guests), DirtPile))]
- state.rng.shuffle(free_for_dirt)
-
- new_spawn = int(abs(n + (state.rng.uniform(-n_var, n_var))))
- new_amount_s = [abs(amount + (amount*state.rng.uniform(-amount_var, amount_var))) for _ in range(new_spawn)]
- n_dirty_positions = free_for_dirt[:new_spawn]
- return self.spawn(n_dirty_positions, new_amount_s)
+ return Result(identifier=f'{self.name}_spawn', validity=c.VALID, value=spawn_counter)
def __repr__(self):
s = super(DirtPiles, self).__repr__()
- return f'{s[:-1]}, {self.amount})'
+ return f'{s[:-1]}, {self.global_amount}]'
diff --git a/marl_factory_grid/modules/clean_up/rules.py b/marl_factory_grid/modules/clean_up/rules.py
index 3f58cdb..b81ee41 100644
--- a/marl_factory_grid/modules/clean_up/rules.py
+++ b/marl_factory_grid/modules/clean_up/rules.py
@@ -22,58 +22,37 @@ class DoneOnAllDirtCleaned(Rule):
def on_check_done(self, state) -> [DoneResult]:
if len(state[d.DIRT]) == 0 and state.curr_step:
return [DoneResult(validity=c.VALID, identifier=self.name, reward=self.reward)]
- return [DoneResult(validity=c.NOT_VALID, identifier=self.name, reward=0)]
+ return [DoneResult(validity=c.NOT_VALID, identifier=self.name)]
-class SpawnDirt(Rule):
+class RespawnDirt(Rule):
- def __init__(self, initial_n: int = 5, initial_amount: float = 1.3,
- respawn_n: int = 3, respawn_amount: float = 0.8,
- n_var: float = 0.2, amount_var: float = 0.2, spawn_freq: int = 15):
+ def __init__(self, respawn_freq: int = 15, respawn_n: int = 5, respawn_amount: float = 1.0):
"""
Defines the spawn pattern of intial and additional 'Dirt'-entitites.
First chooses positions, then trys to spawn dirt until 'respawn_n' or the maximal global amount is reached.
If there is allready some, it is topped up to min(max_local_amount, amount).
- :type spawn_freq: int
- :parameter spawn_freq: In which frequency should this Rule try to spawn new 'Dirt'?
+ :type respawn_freq: int
+ :parameter respawn_freq: In which frequency should this Rule try to spawn new 'Dirt'?
:type respawn_n: int
:parameter respawn_n: How many respawn positions are considered.
- :type initial_n: int
- :parameter initial_n: How much initial positions are considered.
- :type amount_var: float
- :parameter amount_var: Variance of amount to spawn.
- :type n_var: float
- :parameter n_var: Variance of n to spawn.
:type respawn_amount: float
:parameter respawn_amount: Defines how much dirt 'amount' is placed every 'spawn_freq' ticks.
- :type initial_amount: float
- :parameter initial_amount: Defines how much dirt 'amount' is initially placed.
-
"""
super().__init__()
- self.amount_var = amount_var
- self.n_var = n_var
- self.respawn_amount = respawn_amount
self.respawn_n = respawn_n
- self.initial_amount = initial_amount
- self.initial_n = initial_n
- self.spawn_freq = spawn_freq
- self._next_dirt_spawn = spawn_freq
-
- def on_init(self, state, lvl_map) -> str:
- result = state[d.DIRT].trigger_dirt_spawn(self.initial_n, self.initial_amount, state,
- n_var=self.n_var, amount_var=self.amount_var)
- state.print(f'Initial Dirt was spawned on: {[x.pos for x in state[d.DIRT]]}')
- return result
+ self.respawn_amount = respawn_amount
+ self.respawn_freq = respawn_freq
+ self._next_dirt_spawn = respawn_freq
def tick_step(self, state):
+ collection = state[d.DIRT]
if self._next_dirt_spawn < 0:
- pass # No DirtPile Spawn
+ result = [] # No DirtPile Spawn
elif not self._next_dirt_spawn:
- result = [state[d.DIRT].trigger_dirt_spawn(self.respawn_n, self.respawn_amount, state,
- n_var=self.n_var, amount_var=self.amount_var)]
- self._next_dirt_spawn = self.spawn_freq
+ result = [collection.trigger_spawn(state, coords_or_quantity=self.respawn_n, amount=self.respawn_amount)]
+ self._next_dirt_spawn = self.respawn_freq
else:
self._next_dirt_spawn -= 1
result = []
@@ -99,8 +78,8 @@ class EntitiesSmearDirtOnMove(Rule):
for entity in state.moving_entites:
if is_move(entity.state.identifier) and entity.state.validity == c.VALID:
if old_pos_dirt := state[d.DIRT].by_pos(entity.last_pos):
+ old_pos_dirt = next(iter(old_pos_dirt))
if smeared_dirt := round(old_pos_dirt.amount * self.smear_ratio, 2):
if state[d.DIRT].spawn(entity.pos, amount=smeared_dirt):
- results.append(TickResult(identifier=self.name, entity=entity,
- reward=0, validity=c.VALID))
+ results.append(TickResult(identifier=self.name, entity=entity, validity=c.VALID))
return results
diff --git a/marl_factory_grid/modules/destinations/__init__.py b/marl_factory_grid/modules/destinations/__init__.py
index 83e5988..4614dd7 100644
--- a/marl_factory_grid/modules/destinations/__init__.py
+++ b/marl_factory_grid/modules/destinations/__init__.py
@@ -1,4 +1,7 @@
from .actions import DestAction
from .entitites import Destination
from .groups import Destinations
-from .rules import DoneAtDestinationReachAll, SpawnDestinations
+from .rules import (DoneAtDestinationReachAll,
+ DoneAtDestinationReachAny,
+ SpawnDestinationsPerAgent,
+ DestinationReachReward)
diff --git a/marl_factory_grid/modules/destinations/actions.py b/marl_factory_grid/modules/destinations/actions.py
index 13f7fe3..6367acd 100644
--- a/marl_factory_grid/modules/destinations/actions.py
+++ b/marl_factory_grid/modules/destinations/actions.py
@@ -21,4 +21,4 @@ class DestAction(Action):
valid = c.NOT_VALID
state.print(f'{entity.name} just tried to do_wait_action do_wait_action at {entity.pos} but failed')
return ActionResult(entity=entity, identifier=self._identifier, validity=valid,
- reward=marl_factory_grid.modules.destinations.constants.REWARD_WAIT_VALID if valid else marl_factory_grid.modules.destinations.constants.REWARD_WAIT_FAIL)
+ reward=d.REWARD_WAIT_VALID if valid else d.REWARD_WAIT_FAIL)
diff --git a/marl_factory_grid/modules/destinations/entitites.py b/marl_factory_grid/modules/destinations/entitites.py
index 7b866b7..d75f9e0 100644
--- a/marl_factory_grid/modules/destinations/entitites.py
+++ b/marl_factory_grid/modules/destinations/entitites.py
@@ -9,30 +9,6 @@ from marl_factory_grid.utils.utility_classes import RenderEntity
class Destination(Entity):
- @property
- def var_can_move(self):
- return False
-
- @property
- def var_can_collide(self):
- return False
-
- @property
- def var_has_position(self):
- return True
-
- @property
- def var_is_blocking_pos(self):
- return False
-
- @property
- def var_is_blocking_light(self):
- return False
-
- @property
- def var_can_be_bound(self):
- return True
-
def was_reached(self):
return self._was_reached
diff --git a/marl_factory_grid/modules/destinations/groups.py b/marl_factory_grid/modules/destinations/groups.py
index 5f91bb4..f0b7f9e 100644
--- a/marl_factory_grid/modules/destinations/groups.py
+++ b/marl_factory_grid/modules/destinations/groups.py
@@ -1,43 +1,18 @@
from marl_factory_grid.environment.groups.collection import Collection
from marl_factory_grid.modules.destinations.entitites import Destination
-from marl_factory_grid.environment import constants as c
-from marl_factory_grid.modules.destinations import constants as d
class Destinations(Collection):
_entity = Destination
- @property
- def var_is_blocking_light(self):
- return False
-
- @property
- def var_can_collide(self):
- return False
-
- @property
- def var_can_move(self):
- return False
-
- @property
- def var_has_position(self):
- return True
+ var_is_blocking_light = False
+ var_can_collide = False
+ var_can_move = False
+ var_has_position = True
+ var_can_be_bound = True
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __repr__(self):
return super(Destinations, self).__repr__()
-
- @staticmethod
- def trigger_destination_spawn(n_dests, state):
- coordinates = state.entities.floorlist[:n_dests]
- if destinations := [Destination(pos) for pos in coordinates]:
- state[d.DESTINATION].add_items(destinations)
- state.print(f'{n_dests} new destinations have been spawned')
- return c.VALID
- else:
- state.print('No Destiantions are spawning, limit is reached.')
- return c.NOT_VALID
-
-
diff --git a/marl_factory_grid/modules/destinations/rules.py b/marl_factory_grid/modules/destinations/rules.py
index afb8575..8e72141 100644
--- a/marl_factory_grid/modules/destinations/rules.py
+++ b/marl_factory_grid/modules/destinations/rules.py
@@ -2,8 +2,8 @@ import ast
from random import shuffle
from typing import List, Dict, Tuple
-import marl_factory_grid.modules.destinations.constants
from marl_factory_grid.environment.rules import Rule
+from marl_factory_grid.utils import helpers as h
from marl_factory_grid.utils.results import TickResult, DoneResult
from marl_factory_grid.environment import constants as c
@@ -54,7 +54,7 @@ class DoneAtDestinationReachAll(DestinationReachReward):
"""
This rule triggers and sets the done flag if ALL Destinations have been reached.
- :type reward_at_done: object
+ :type reward_at_done: float
:param reward_at_done: Specifies the reward, agent get, whenn all destinations are reached.
:type dest_reach_reward: float
:param dest_reach_reward: Specify the reward, agents get when reaching a single destination.
@@ -65,7 +65,7 @@ class DoneAtDestinationReachAll(DestinationReachReward):
def on_check_done(self, state) -> List[DoneResult]:
if all(x.was_reached() for x in state[d.DESTINATION]):
return [DoneResult(self.name, validity=c.VALID, reward=self.reward)]
- return [DoneResult(self.name, validity=c.NOT_VALID, reward=0)]
+ return [DoneResult(self.name, validity=c.NOT_VALID)]
class DoneAtDestinationReachAny(DestinationReachReward):
@@ -75,7 +75,7 @@ class DoneAtDestinationReachAny(DestinationReachReward):
This rule triggers and sets the done flag if ANY Destinations has been reached.
!!! IMPORTANT: 'reward_at_done' is shared between the agents; 'dest_reach_reward' is bound to a specific one.
- :type reward_at_done: object
+ :type reward_at_done: float
:param reward_at_done: Specifies the reward, all agent get, when any destinations has been reached.
Default {d.REWARD_DEST_DONE}
:type dest_reach_reward: float
@@ -87,67 +87,29 @@ class DoneAtDestinationReachAny(DestinationReachReward):
def on_check_done(self, state) -> List[DoneResult]:
if any(x.was_reached() for x in state[d.DESTINATION]):
- return [DoneResult(self.name, validity=c.VALID, reward=marl_factory_grid.modules.destinations.constants.REWARD_DEST_REACHED)]
+ return [DoneResult(self.name, validity=c.VALID, reward=d.REWARD_DEST_REACHED)]
return []
-class SpawnDestinations(Rule):
-
- def __init__(self, n_dests: int = 1, spawn_mode: str = d.MODE_GROUPED):
- f"""
- Defines how destinations are initially spawned and respawned in addition.
- !!! This rule introduces no kind of reward or Env.-Done condition!
-
- :type n_dests: int
- :param n_dests: How many destiantions should be maintained (and initally spawnewd) on the map?
- :type spawn_mode: str
- :param spawn_mode: One of {d.SPAWN_MODES}. {d.MODE_GROUPED}: Always wait for all Dstiantions do be gone,
- then respawn after the given time. {d.MODE_SINGLE}: Just spawn every destination,
- that has been reached, after the given time
-
- """
- super(SpawnDestinations, self).__init__()
- self.n_dests = n_dests
- self.spawn_mode = spawn_mode
-
- def on_init(self, state, lvl_map):
- # noinspection PyAttributeOutsideInit
- state[d.DESTINATION].trigger_destination_spawn(self.n_dests, state)
- pass
-
- def tick_pre_step(self, state) -> List[TickResult]:
- pass
-
- def tick_step(self, state) -> List[TickResult]:
- if n_dest_spawn := max(0, self.n_dests - len(state[d.DESTINATION])):
- if self.spawn_mode == d.MODE_GROUPED and n_dest_spawn == self.n_dests:
- validity = state[d.DESTINATION].trigger_destination_spawn(n_dest_spawn, state)
- return [TickResult(self.name, validity=validity, entity=None, value=n_dest_spawn)]
- elif self.spawn_mode == d.MODE_SINGLE and n_dest_spawn:
- validity = state[d.DESTINATION].trigger_destination_spawn(n_dest_spawn, state)
- return [TickResult(self.name, validity=validity, entity=None, value=n_dest_spawn)]
- else:
- pass
-
-
class SpawnDestinationsPerAgent(Rule):
- def __init__(self, per_agent_positions: Dict[str, List[Tuple[int, int]]]):
+ def __init__(self, coords_or_quantity: Dict[str, List[Tuple[int, int]]]):
"""
Special rule, that spawn distinations, that are bound to a single agent a fixed set of positions.
Usefull for introducing specialists, etc. ..
!!! This rule does not introduce any reward or done condition.
- :type per_agent_positions: Dict[str, List[Tuple[int, int]]
- :param per_agent_positions: Please provide a dictionary with agent names as keys; and a list of possible
+ :type coords_or_quantity: Dict[str, List[Tuple[int, int]]
+ :param coords_or_quantity: Please provide a dictionary with agent names as keys; and a list of possible
destiantion coords as value. Example: {Wolfgang: [(0, 0), (1, 1), ...]}
"""
super(Rule, self).__init__()
- self.per_agent_positions = {key: [ast.literal_eval(x) for x in val] for key, val in per_agent_positions.items()}
+ self.per_agent_positions = {key: [ast.literal_eval(x) for x in val] for key, val in coords_or_quantity.items()}
def on_init(self, state, lvl_map):
for (agent_name, position_list) in self.per_agent_positions.items():
- agent = next(x for x in state[c.AGENT] if agent_name in x.name) # Fixme: Ugly AF
+ agent = h.get_first(state[c.AGENT], lambda x: agent_name in x.name)
+ assert agent
position_list = position_list.copy()
shuffle(position_list)
while True:
@@ -155,7 +117,7 @@ class SpawnDestinationsPerAgent(Rule):
pos = position_list.pop()
except IndexError:
print(f"Could not spawn Destinations at: {self.per_agent_positions[agent_name]}")
- print(f'Check your agent palcement: {state[c.AGENT]} ... Exit ...')
+ print(f'Check your agent placement: {state[c.AGENT]} ... Exit ...')
exit(9999)
if (not pos == agent.pos) and (not state[d.DESTINATION].by_pos(pos)):
destination = Destination(pos, bind_to=agent)
diff --git a/marl_factory_grid/modules/doors/entitites.py b/marl_factory_grid/modules/doors/entitites.py
index 669f74e..1c33d7b 100644
--- a/marl_factory_grid/modules/doors/entitites.py
+++ b/marl_factory_grid/modules/doors/entitites.py
@@ -1,4 +1,5 @@
from marl_factory_grid.environment.entity.entity import Entity
+from marl_factory_grid.utils import Result
from marl_factory_grid.utils.utility_classes import RenderEntity
from marl_factory_grid.environment import constants as c
@@ -41,21 +42,6 @@ class Door(Entity):
def str_state(self):
return 'open' if self.is_open else 'closed'
- def __init__(self, *args, closed_on_init=True, auto_close_interval=10, **kwargs):
- self._status = d.STATE_CLOSED
- super(Door, self).__init__(*args, **kwargs)
- self.auto_close_interval = auto_close_interval
- self.time_to_close = 0
- if not closed_on_init:
- self._open()
- else:
- self._close()
-
- def summarize_state(self):
- state_dict = super().summarize_state()
- state_dict.update(state=str(self.str_state), time_to_close=int(self.time_to_close))
- return state_dict
-
@property
def is_closed(self):
return self._status == d.STATE_CLOSED
@@ -68,6 +54,25 @@ class Door(Entity):
def status(self):
return self._status
+ @property
+ def time_to_close(self):
+ return self._time_to_close
+
+ def __init__(self, *args, closed_on_init=True, auto_close_interval=10, **kwargs):
+ self._status = d.STATE_CLOSED
+ super(Door, self).__init__(*args, **kwargs)
+ self._auto_close_interval = auto_close_interval
+ self._time_to_close = 0
+ if not closed_on_init:
+ self._open()
+ else:
+ self._close()
+
+ def summarize_state(self):
+ state_dict = super().summarize_state()
+ state_dict.update(state=str(self.str_state), time_to_close=self.time_to_close)
+ return state_dict
+
def render(self):
name, state = 'door_open' if self.is_open else 'door_closed', 'blank'
return RenderEntity(name, self.pos, 1, 'none', state, self.u_int + 1)
@@ -80,18 +85,35 @@ class Door(Entity):
return c.VALID
def tick(self, state):
- if self.is_open and len(state.entities.pos_dict[self.pos]) == 2 and self.time_to_close:
- self.time_to_close -= 1
- return c.NOT_VALID
- elif self.is_open and not self.time_to_close and len(state.entities.pos_dict[self.pos]) == 2:
- self.use()
- return c.VALID
+ # Check if no entity is standing in the door
+ if len(state.entities.pos_dict[self.pos]) <= 2:
+ if self.is_open and self.time_to_close:
+ self._decrement_timer()
+ return Result(f"{d.DOOR}_tick", c.VALID, entity=self)
+ elif self.is_open and not self.time_to_close:
+ self.use()
+ return Result(f"{d.DOOR}_closed", c.VALID, entity=self)
+ else:
+ # No one is in door, but it is closed... Nothing to do....
+ return None
else:
- return c.NOT_VALID
+ # Entity is standing in the door, reset timer
+ self._reset_timer()
+ return Result(f"{d.DOOR}_reset", c.VALID, entity=self)
def _open(self):
self._status = d.STATE_OPEN
- self.time_to_close = self.auto_close_interval
+ self._reset_timer()
+ return True
def _close(self):
self._status = d.STATE_CLOSED
+ return True
+
+ def _decrement_timer(self):
+ self._time_to_close -= 1
+ return True
+
+ def _reset_timer(self):
+ self._time_to_close = self._auto_close_interval
+ return True
diff --git a/marl_factory_grid/modules/doors/groups.py b/marl_factory_grid/modules/doors/groups.py
index 687846e..973d1ab 100644
--- a/marl_factory_grid/modules/doors/groups.py
+++ b/marl_factory_grid/modules/doors/groups.py
@@ -1,5 +1,3 @@
-from typing import Union
-
from marl_factory_grid.environment.groups.collection import Collection
from marl_factory_grid.modules.doors import constants as d
from marl_factory_grid.modules.doors.entitites import Door
@@ -18,8 +16,10 @@ class Doors(Collection):
super(Doors, self).__init__(*args, can_collide=True, **kwargs)
def tick_doors(self, state):
- result_dict = dict()
+ results = list()
for door in self:
- did_tick = door.tick(state)
- result_dict.update({door.name: did_tick})
- return result_dict
+ tick_result = door.tick(state)
+ if tick_result is not None:
+ results.append(tick_result)
+ # TODO: Should return a Result object, not a random dict.
+ return results
diff --git a/marl_factory_grid/modules/doors/rewards.py b/marl_factory_grid/modules/doors/rewards.py
index c87d123..b38c7c5 100644
--- a/marl_factory_grid/modules/doors/rewards.py
+++ b/marl_factory_grid/modules/doors/rewards.py
@@ -1,2 +1,2 @@
USE_DOOR_VALID: float = -0.00
-USE_DOOR_FAIL: float = -0.01
\ No newline at end of file
+USE_DOOR_FAIL: float = -0.01
diff --git a/marl_factory_grid/modules/doors/rules.py b/marl_factory_grid/modules/doors/rules.py
index da312cd..599d975 100644
--- a/marl_factory_grid/modules/doors/rules.py
+++ b/marl_factory_grid/modules/doors/rules.py
@@ -19,10 +19,10 @@ class DoorAutoClose(Rule):
def tick_step(self, state):
if doors := state[d.DOORS]:
- doors_tick_result = doors.tick_doors(state)
- doors_that_ticked = [key for key, val in doors_tick_result.items() if val]
- state.print(f'{doors_that_ticked} were auto-closed'
- if doors_that_ticked else 'No Doors were auto-closed')
+ doors_tick_results = doors.tick_doors(state)
+ doors_that_closed = [x.entity.name for x in doors_tick_results if 'closed' in x.identifier]
+ door_str = doors_that_closed if doors_that_closed else "No Doors"
+ state.print(f'{door_str} were auto-closed')
return [TickResult(self.name, validity=c.VALID, value=1)]
state.print('There are no doors, but you loaded the corresponding Module')
return []
diff --git a/marl_factory_grid/modules/factory/rules.py b/marl_factory_grid/modules/factory/rules.py
index d736f7a..e056135 100644
--- a/marl_factory_grid/modules/factory/rules.py
+++ b/marl_factory_grid/modules/factory/rules.py
@@ -1,8 +1,8 @@
import random
-from typing import List, Union
+from typing import List
-from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.environment import constants as c
+from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.utils.results import TickResult
@@ -14,8 +14,6 @@ class AgentSingleZonePlacementBeta(Rule):
super().__init__()
def on_init(self, state, lvl_map):
- zones = state[c.ZONES]
- n_zones = state[c.ZONES]
agents = state[c.AGENT]
if len(self.coordinates) == len(agents):
coordinates = self.coordinates
@@ -31,4 +29,4 @@ class AgentSingleZonePlacementBeta(Rule):
return []
def tick_post_step(self, state) -> List[TickResult]:
- return []
\ No newline at end of file
+ return []
diff --git a/marl_factory_grid/modules/items/__init__.py b/marl_factory_grid/modules/items/__init__.py
index 157c385..cb9b69b 100644
--- a/marl_factory_grid/modules/items/__init__.py
+++ b/marl_factory_grid/modules/items/__init__.py
@@ -1,4 +1,3 @@
from .actions import ItemAction
from .entitites import Item, DropOffLocation
from .groups import DropOffLocations, Items, Inventory, Inventories
-from .rules import ItemRules
diff --git a/marl_factory_grid/modules/items/actions.py b/marl_factory_grid/modules/items/actions.py
index f9e4f6f..ef6aa99 100644
--- a/marl_factory_grid/modules/items/actions.py
+++ b/marl_factory_grid/modules/items/actions.py
@@ -29,7 +29,7 @@ class ItemAction(Action):
elif items := state[i.ITEM].by_pos(entity.pos):
item = items[0]
item.change_parent_collection(inventory)
- item.set_pos_to(c.VALUE_NO_POS)
+ item.set_pos(c.VALUE_NO_POS)
state.print(f'{entity.name} just picked up an item at {entity.pos}')
return ActionResult(entity=entity, identifier=self._identifier, validity=c.VALID, reward=r.PICK_UP_VALID)
diff --git a/marl_factory_grid/modules/items/constants.py b/marl_factory_grid/modules/items/constants.py
index 86b8b0c..5cb82c3 100644
--- a/marl_factory_grid/modules/items/constants.py
+++ b/marl_factory_grid/modules/items/constants.py
@@ -1,6 +1,3 @@
-from typing import NamedTuple
-
-
SYMBOL_NO_ITEM = 0
SYMBOL_DROP_OFF = 1
# Item Env
diff --git a/marl_factory_grid/modules/items/entitites.py b/marl_factory_grid/modules/items/entitites.py
index b710282..8549134 100644
--- a/marl_factory_grid/modules/items/entitites.py
+++ b/marl_factory_grid/modules/items/entitites.py
@@ -8,56 +8,20 @@ from marl_factory_grid.modules.items import constants as i
class Item(Entity):
- @property
- def var_can_collide(self):
- return False
-
def render(self):
return RenderEntity(i.ITEM, self.pos) if self.pos != c.VALUE_NO_POS else None
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
- self._auto_despawn = -1
-
- @property
- def auto_despawn(self):
- return self._auto_despawn
@property
def encoding(self):
# Edit this if you want items to be drawn in the ops differently
return 1
- def set_auto_despawn(self, auto_despawn):
- self._auto_despawn = auto_despawn
-
- def set_pos_to(self, no_pos):
- self._pos = no_pos
-
- def summarize_state(self) -> dict:
- super_summarization = super(Item, self).summarize_state()
- super_summarization.update(dict(auto_despawn=self.auto_despawn))
- return super_summarization
-
class DropOffLocation(Entity):
- @property
- def var_can_collide(self):
- return False
-
- @property
- def var_can_move(self):
- return False
-
- @property
- def var_is_blocking_light(self):
- return False
-
- @property
- def var_has_position(self):
- return True
-
def render(self):
return RenderEntity(i.DROP_OFF, self.pos)
@@ -65,18 +29,16 @@ class DropOffLocation(Entity):
def encoding(self):
return i.SYMBOL_DROP_OFF
- def __init__(self, *args, storage_size_until_full: int = 5, auto_item_despawn_interval: int = 5, **kwargs):
+ def __init__(self, *args, storage_size_until_full: int = 5, **kwargs):
super(DropOffLocation, self).__init__(*args, **kwargs)
- self.auto_item_despawn_interval = auto_item_despawn_interval
self.storage = deque(maxlen=storage_size_until_full or None)
def place_item(self, item: Item):
if self.is_full:
raise RuntimeWarning("There is currently no way to clear the storage or make it unfull.")
- return bc.NOT_VALID # in Zeile 81 verschieben?
+ return bc.NOT_VALID
else:
self.storage.append(item)
- item.set_auto_despawn(self.auto_item_despawn_interval)
return c.VALID
@property
diff --git a/marl_factory_grid/modules/items/groups.py b/marl_factory_grid/modules/items/groups.py
index 707f743..be5ca49 100644
--- a/marl_factory_grid/modules/items/groups.py
+++ b/marl_factory_grid/modules/items/groups.py
@@ -1,13 +1,11 @@
-from random import shuffle
-
-from marl_factory_grid.modules.items import constants as i
from marl_factory_grid.environment import constants as c
-
-from marl_factory_grid.environment.groups.collection import Collection
-from marl_factory_grid.environment.groups.objects import _Objects
-from marl_factory_grid.environment.groups.mixins import IsBoundMixin
from marl_factory_grid.environment.entity.agent import Agent
+from marl_factory_grid.environment.groups.collection import Collection
+from marl_factory_grid.environment.groups.mixins import IsBoundMixin
+from marl_factory_grid.environment.groups.objects import Objects
+from marl_factory_grid.modules.items import constants as i
from marl_factory_grid.modules.items.entitites import Item, DropOffLocation
+from marl_factory_grid.utils.results import Result
class Items(Collection):
@@ -15,7 +13,7 @@ class Items(Collection):
@property
def var_has_position(self):
- return False
+ return True
@property
def is_blocking_light(self):
@@ -28,18 +26,18 @@ class Items(Collection):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
- @staticmethod
- def trigger_item_spawn(state, n_items, spawn_frequency):
- if item_to_spawns := max(0, (n_items - len(state[i.ITEM]))):
- position_list = [x for x in state.entities.floorlist]
- shuffle(position_list)
- position_list = state.entities.floorlist[:item_to_spawns]
- state[i.ITEM].spawn(position_list)
- state.print(f'{item_to_spawns} new items have been spawned; next spawn in {spawn_frequency}')
- return len(position_list)
+ def trigger_spawn(self, state, *entity_args, coords_or_quantity=None, **entity_kwargs) -> [Result]:
+ coords_or_quantity = coords_or_quantity if coords_or_quantity else self._coords_or_quantity
+ assert coords_or_quantity
+
+ if item_to_spawns := max(0, (coords_or_quantity - len(self))):
+ return super().trigger_spawn(state,
+ *entity_args,
+ coords_or_quantity=item_to_spawns,
+ **entity_kwargs)
else:
state.print('No Items are spawning, limit is reached.')
- return 0
+ return Result(identifier=f'{self.name}_spawn', validity=c.NOT_VALID, value=coords_or_quantity)
class Inventory(IsBoundMixin, Collection):
@@ -73,12 +71,17 @@ class Inventory(IsBoundMixin, Collection):
self._collection = collection
-class Inventories(_Objects):
+class Inventories(Objects):
_entity = Inventory
+ var_can_move = False
+ var_has_position = False
+
+ symbol = None
+
@property
- def var_can_move(self):
- return False
+ def spawn_rule(self):
+ return {c.SPAWN_ENTITY_RULE: dict(collection=self, coords_or_quantity=None)}
def __init__(self, size: int, *args, **kwargs):
super(Inventories, self).__init__(*args, **kwargs)
@@ -86,10 +89,12 @@ class Inventories(_Objects):
self._obs = None
self._lazy_eval_transforms = []
- def spawn(self, agents):
- inventories = [self._entity(agent, self.size, )
- for _, agent in enumerate(agents)]
- self.add_items(inventories)
+ def spawn(self, agents, *args, **kwargs):
+ self.add_items([self._entity(agent, self.size, *args, **kwargs) for _, agent in enumerate(agents)])
+ return [Result(identifier=f'{self.name}_spawn', validity=c.VALID, value=len(self))]
+
+ def trigger_spawn(self, state, *args, **kwargs) -> [Result]:
+ return self.spawn(state[c.AGENT], *args, **kwargs)
def idx_by_entity(self, entity):
try:
@@ -106,10 +111,6 @@ class Inventories(_Objects):
def summarize_states(self, **kwargs):
return [val.summarize_states(**kwargs) for key, val in self.items()]
- @staticmethod
- def trigger_inventory_spawn(state):
- state[i.INVENTORY].spawn(state[c.AGENT])
-
class DropOffLocations(Collection):
_entity = DropOffLocation
@@ -135,7 +136,7 @@ class DropOffLocations(Collection):
@staticmethod
def trigger_drop_off_location_spawn(state, n_locations):
- empty_positions = state.entities.empty_positions()[:n_locations]
+ empty_positions = state.entities.empty_positions[:n_locations]
do_entites = state[i.DROP_OFF]
drop_offs = [DropOffLocation(pos) for pos in empty_positions]
do_entites.add_items(drop_offs)
diff --git a/marl_factory_grid/modules/items/rewards.py b/marl_factory_grid/modules/items/rewards.py
index 40adf46..bcd2918 100644
--- a/marl_factory_grid/modules/items/rewards.py
+++ b/marl_factory_grid/modules/items/rewards.py
@@ -1,4 +1,4 @@
DROP_OFF_VALID: float = 0.1
DROP_OFF_FAIL: float = -0.1
PICK_UP_FAIL: float = -0.1
-PICK_UP_VALID: float = 0.1
\ No newline at end of file
+PICK_UP_VALID: float = 0.1
diff --git a/marl_factory_grid/modules/items/rules.py b/marl_factory_grid/modules/items/rules.py
index 9f8a0cc..a655956 100644
--- a/marl_factory_grid/modules/items/rules.py
+++ b/marl_factory_grid/modules/items/rules.py
@@ -6,52 +6,28 @@ from marl_factory_grid.utils.results import TickResult
from marl_factory_grid.modules.items import constants as i
-class ItemRules(Rule):
+class RespawnItems(Rule):
- def __init__(self, n_items: int = 5, spawn_frequency: int = 15,
- n_locations: int = 5, max_dropoff_storage_size: int = 0):
+ def __init__(self, n_items: int = 5, respawn_freq: int = 15, n_locations: int = 5):
super().__init__()
- self.spawn_frequency = spawn_frequency
- self._next_item_spawn = spawn_frequency
+ self.spawn_frequency = respawn_freq
+ self._next_item_spawn = respawn_freq
self.n_items = n_items
- self.max_dropoff_storage_size = max_dropoff_storage_size
self.n_locations = n_locations
- def on_init(self, state, lvl_map):
- state[i.DROP_OFF].trigger_drop_off_location_spawn(state, self.n_locations)
- self._next_item_spawn = self.spawn_frequency
- state[i.INVENTORY].trigger_inventory_spawn(state)
- state[i.ITEM].trigger_item_spawn(state, self.n_items, self.spawn_frequency)
-
def tick_step(self, state):
- for item in list(state[i.ITEM].values()):
- if item.auto_despawn >= 1:
- item.set_auto_despawn(item.auto_despawn - 1)
- elif not item.auto_despawn:
- state[i.ITEM].delete_env_object(item)
- else:
- pass
-
if not self._next_item_spawn:
- state[i.ITEM].trigger_item_spawn(state, self.n_items, self.spawn_frequency)
+ state[i.ITEM].trigger_spawn(state, self.n_items, self.spawn_frequency)
else:
self._next_item_spawn = max(0, self._next_item_spawn - 1)
return []
def tick_post_step(self, state) -> List[TickResult]:
- for item in list(state[i.ITEM].values()):
- if item.auto_despawn >= 1:
- item.set_auto_despawn(item.auto_despawn-1)
- elif not item.auto_despawn:
- state[i.ITEM].delete_env_object(item)
- else:
- pass
-
if not self._next_item_spawn:
- if spawned_items := state[i.ITEM].trigger_item_spawn(state, self.n_items, self.spawn_frequency):
- return [TickResult(self.name, validity=c.VALID, value=spawned_items, entity=None)]
+ if spawned_items := state[i.ITEM].trigger_spawn(state, self.n_items, self.spawn_frequency):
+ return [TickResult(self.name, validity=c.VALID, value=spawned_items.value)]
else:
- return [TickResult(self.name, validity=c.NOT_VALID, value=0, entity=None)]
+ return [TickResult(self.name, validity=c.NOT_VALID, value=0)]
else:
self._next_item_spawn = max(0, self._next_item_spawn-1)
return []
diff --git a/marl_factory_grid/modules/machines/__init__.py b/marl_factory_grid/modules/machines/__init__.py
index 36ba51d..233efbb 100644
--- a/marl_factory_grid/modules/machines/__init__.py
+++ b/marl_factory_grid/modules/machines/__init__.py
@@ -1,3 +1,2 @@
from .entitites import Machine
from .groups import Machines
-from .rules import MachineRule
diff --git a/marl_factory_grid/modules/machines/actions.py b/marl_factory_grid/modules/machines/actions.py
index 8f4eaaa..dbb303f 100644
--- a/marl_factory_grid/modules/machines/actions.py
+++ b/marl_factory_grid/modules/machines/actions.py
@@ -1,10 +1,12 @@
from typing import Union
+import marl_factory_grid.modules.machines.constants
from marl_factory_grid.environment.actions import Action
from marl_factory_grid.utils.results import ActionResult
-from marl_factory_grid.modules.machines import constants as m, rewards as r
+from marl_factory_grid.modules.machines import constants as m
from marl_factory_grid.environment import constants as c
+from marl_factory_grid.utils import helpers as h
class MachineAction(Action):
@@ -13,13 +15,12 @@ class MachineAction(Action):
super().__init__(m.MACHINE_ACTION)
def do(self, entity, state) -> Union[None, ActionResult]:
- if machine := state[m.MACHINES].by_pos(entity.pos):
+ if machine := h.get_first(state[m.MACHINES].by_pos(entity.pos)):
if valid := machine.maintain():
- return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=r.MAINTAIN_VALID)
+ return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=marl_factory_grid.modules.machines.constants.MAINTAIN_VALID)
else:
- return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=r.MAINTAIN_FAIL)
+ return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=marl_factory_grid.modules.machines.constants.MAINTAIN_FAIL)
else:
- return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=r.MAINTAIN_FAIL)
-
-
-
+ return ActionResult(entity=entity, identifier=self._identifier,
+ validity=c.NOT_VALID, reward=marl_factory_grid.modules.machines.constants.MAINTAIN_FAIL
+ )
diff --git a/marl_factory_grid/modules/machines/constants.py b/marl_factory_grid/modules/machines/constants.py
index 29ce3bc..3771cbb 100644
--- a/marl_factory_grid/modules/machines/constants.py
+++ b/marl_factory_grid/modules/machines/constants.py
@@ -11,3 +11,7 @@ STATE_MAINTAIN = 'maintenance'
SYMBOL_WORK = 1
SYMBOL_IDLE = 0.6
SYMBOL_MAINTAIN = 0.3
+MAINTAIN_VALID: float = 0.5
+MAINTAIN_FAIL: float = -0.1
+FAIL_MISSING_MAINTENANCE: float = -0.5
+NONE: float = 0
diff --git a/marl_factory_grid/modules/machines/entitites.py b/marl_factory_grid/modules/machines/entitites.py
index 36a87cc..581adf6 100644
--- a/marl_factory_grid/modules/machines/entitites.py
+++ b/marl_factory_grid/modules/machines/entitites.py
@@ -8,22 +8,6 @@ from . import constants as m
class Machine(Entity):
- @property
- def var_can_collide(self):
- return False
-
- @property
- def var_can_move(self):
- return False
-
- @property
- def var_is_blocking_light(self):
- return False
-
- @property
- def var_has_position(self):
- return True
-
@property
def encoding(self):
return self._encodings[self.status]
@@ -46,12 +30,11 @@ class Machine(Entity):
else:
return c.NOT_VALID
- def tick(self):
- # if self.status == m.STATE_MAINTAIN and any([c.AGENT in x.name for x in self.tile.guests]):
- if self.status == m.STATE_MAINTAIN and any([c.AGENT in x.name for x in self.state.entities.pos_dict[self.pos]]):
- return TickResult(identifier=self.name, validity=c.VALID, reward=0, entity=self)
- # elif self.status == m.STATE_MAINTAIN and not any([c.AGENT in x.name for x in self.tile.guests]):
- elif self.status == m.STATE_MAINTAIN and not any([c.AGENT in x.name for x in self.state.entities.pos_dict[self.pos]]):
+ def tick(self, state):
+ others = state.entities.pos_dict[self.pos]
+ if self.status == m.STATE_MAINTAIN and any([c.AGENT in x.name for x in others]):
+ return TickResult(identifier=self.name, validity=c.VALID, entity=self)
+ elif self.status == m.STATE_MAINTAIN and not any([c.AGENT in x.name for x in others]):
self.status = m.STATE_WORK
self.reset_counter()
return None
diff --git a/marl_factory_grid/modules/machines/groups.py b/marl_factory_grid/modules/machines/groups.py
index 5f2d970..9d89d6c 100644
--- a/marl_factory_grid/modules/machines/groups.py
+++ b/marl_factory_grid/modules/machines/groups.py
@@ -1,5 +1,3 @@
-from typing import Union, List, Tuple
-
from marl_factory_grid.environment.groups.collection import Collection
from .entitites import Machine
diff --git a/marl_factory_grid/modules/machines/rewards.py b/marl_factory_grid/modules/machines/rewards.py
deleted file mode 100644
index c868196..0000000
--- a/marl_factory_grid/modules/machines/rewards.py
+++ /dev/null
@@ -1,5 +0,0 @@
-MAINTAIN_VALID: float = 0.5
-MAINTAIN_FAIL: float = -0.1
-FAIL_MISSING_MAINTENANCE: float = -0.5
-
-NONE: float = 0
diff --git a/marl_factory_grid/modules/machines/rules.py b/marl_factory_grid/modules/machines/rules.py
index 84e3410..e69de29 100644
--- a/marl_factory_grid/modules/machines/rules.py
+++ b/marl_factory_grid/modules/machines/rules.py
@@ -1,28 +0,0 @@
-from typing import List
-from marl_factory_grid.environment.rules import Rule
-from marl_factory_grid.utils.results import TickResult, DoneResult
-from marl_factory_grid.environment import constants as c
-from marl_factory_grid.modules.machines import constants as m
-from marl_factory_grid.modules.machines.entitites import Machine
-
-
-class MachineRule(Rule):
-
- def __init__(self, n_machines: int = 2):
- super(MachineRule, self).__init__()
- self.n_machines = n_machines
-
- def on_init(self, state, lvl_map):
- state[m.MACHINES].spawn(state.entities.empty_positions())
-
- def tick_pre_step(self, state) -> List[TickResult]:
- pass
-
- def tick_step(self, state) -> List[TickResult]:
- pass
-
- def tick_post_step(self, state) -> List[TickResult]:
- pass
-
- def on_check_done(self, state) -> List[DoneResult]:
- pass
diff --git a/marl_factory_grid/modules/maintenance/constants.py b/marl_factory_grid/modules/maintenance/constants.py
index e0ab12c..3aed36c 100644
--- a/marl_factory_grid/modules/maintenance/constants.py
+++ b/marl_factory_grid/modules/maintenance/constants.py
@@ -1,3 +1,4 @@
MAINTAINER = 'Maintainer' # TEMPLATE _identifier. Define your own!
MAINTAINERS = 'Maintainers' # TEMPLATE _identifier. Define your own!
+MAINTAINER_COLLISION_REWARD = -5
diff --git a/marl_factory_grid/modules/maintenance/entities.py b/marl_factory_grid/modules/maintenance/entities.py
index e084b0c..1a043c8 100644
--- a/marl_factory_grid/modules/maintenance/entities.py
+++ b/marl_factory_grid/modules/maintenance/entities.py
@@ -1,48 +1,35 @@
+from random import shuffle
+
import networkx as nx
import numpy as np
+
from ...algorithms.static.utils import points_to_graph
from ...environment import constants as c
from ...environment.actions import Action, ALL_BASEACTIONS
from ...environment.entity.entity import Entity
from ..doors import constants as do
from ..maintenance import constants as mi
-from ...utils.helpers import MOVEMAP
-from ...utils.utility_classes import RenderEntity
-from ...utils.states import Gamestate
+from ...utils import helpers as h
+from ...utils.utility_classes import RenderEntity, Floor
+from ..doors import DoorUse
class Maintainer(Entity):
- @property
- def var_can_collide(self):
- return True
-
- @property
- def var_can_move(self):
- return False
-
- @property
- def var_is_blocking_light(self):
- return False
-
- @property
- def var_has_position(self):
- return True
-
- def __init__(self, state: Gamestate, objective: str, action: Action, *args, **kwargs):
+ def __init__(self, objective: str, action: Action, *args, **kwargs):
super().__init__(*args, **kwargs)
self.action = action
- self.actions = [x() for x in ALL_BASEACTIONS]
+ self.actions = [x() for x in ALL_BASEACTIONS] + [DoorUse()]
self.objective = objective
self._path = None
self._next = []
self._last = []
self._last_serviced = 'None'
- self._floortile_graph = points_to_graph(state.entities.floorlist)
+ self._floortile_graph = None
def tick(self, state):
- if found_objective := state[self.objective].by_pos(self.pos):
+ if found_objective := h.get_first(state[self.objective].by_pos(self.pos)):
if found_objective.name != self._last_serviced:
self.action.do(self, state)
self._last_serviced = found_objective.name
@@ -54,24 +41,27 @@ class Maintainer(Entity):
return action.do(self, state)
def get_move_action(self, state) -> Action:
+ if not self._floortile_graph:
+ state.print("Generating Floorgraph....")
+ self._floortile_graph = points_to_graph(state.entities.floorlist)
if self._path is None or not self._path:
if not self._next:
- self._next = list(state[self.objective].values())
+ self._next = list(state[self.objective].values()) + [Floor(*state.random_free_position)]
+ shuffle(self._next)
self._last = []
self._last.append(self._next.pop())
+ state.print("Calculating shortest path....")
self._path = self.calculate_route(self._last[-1])
- if door := self._door_is_close(state):
- if door.is_closed:
- # Translate the action_object to an integer to have the same output as any other model
- action = do.ACTION_DOOR_USE
- else:
- action = self._predict_move(state)
+ if door := self._closed_door_in_path(state):
+ state.print(f"{self} found {door} that is closed. Attempt to open.")
+ # Translate the action_object to an integer to have the same output as any other model
+ action = do.ACTION_DOOR_USE
else:
action = self._predict_move(state)
# Translate the action_object to an integer to have the same output as any other model
try:
- action_obj = next(x for x in self.actions if x.name == action)
+ action_obj = h.get_first(self.actions, lambda x: x.name == action)
except (StopIteration, UnboundLocalError):
print('Will not happen')
raise EnvironmentError
@@ -81,11 +71,10 @@ class Maintainer(Entity):
route = nx.shortest_path(self._floortile_graph, self.pos, entity.pos)
return route[1:]
- def _door_is_close(self, state):
- state.print("Found a door that is close.")
- try:
- return next(y for x in state.entities.neighboring_positions(self.state.pos) for y in state.entities.pos_dict[x] if do.DOOR in y.name)
- except StopIteration:
+ def _closed_door_in_path(self, state):
+ if self._path:
+ return h.get_first(state[do.DOORS].by_pos(self._path[0]), lambda x: x.is_closed)
+ else:
return None
def _predict_move(self, state):
@@ -96,7 +85,7 @@ class Maintainer(Entity):
next_pos = self._path.pop(0)
diff = np.subtract(next_pos, self.pos)
# Retrieve action based on the pos dif (like in: What do I have to do to get there?)
- action = next(action for action, pos_diff in MOVEMAP.items() if np.all(diff == pos_diff))
+ action = next(action for action, pos_diff in h.MOVEMAP.items() if np.all(diff == pos_diff))
return action
def render(self):
diff --git a/marl_factory_grid/modules/maintenance/groups.py b/marl_factory_grid/modules/maintenance/groups.py
index 2df70cb..5b09c9c 100644
--- a/marl_factory_grid/modules/maintenance/groups.py
+++ b/marl_factory_grid/modules/maintenance/groups.py
@@ -1,34 +1,27 @@
-from typing import Union, List, Tuple
+from typing import Union, List, Tuple, Dict
from marl_factory_grid.environment.groups.collection import Collection
from .entities import Maintainer
from ..machines import constants as mc
from ..machines.actions import MachineAction
-from ...utils.states import Gamestate
class Maintainers(Collection):
_entity = Maintainer
- @property
- def var_can_collide(self):
- return True
+ var_can_collide = True
+ var_can_move = True
+ var_is_blocking_light = False
+ var_has_position = True
- @property
- def var_can_move(self):
- return True
+ def __init__(self, size, *args, coords_or_quantity: int = None,
+ spawnrule: Union[None, Dict[str, dict]] = None,
+ **kwargs):
+ super(Collection, self).__init__(*args, **kwargs)
+ self._coords_or_quantity = coords_or_quantity
+ self.size = size
+ self._spawnrule = spawnrule
- @property
- def var_is_blocking_light(self):
- return False
-
- @property
- def var_has_position(self):
- return True
-
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args):
- state = entity_args[0]
- self.add_items([self._entity(state, mc.MACHINES, MachineAction(), pos) for pos in coords_or_quantity])
+ self.add_items([self._entity(mc.MACHINES, MachineAction(), pos) for pos in coords_or_quantity])
diff --git a/marl_factory_grid/modules/maintenance/rewards.py b/marl_factory_grid/modules/maintenance/rewards.py
deleted file mode 100644
index 425ac3b..0000000
--- a/marl_factory_grid/modules/maintenance/rewards.py
+++ /dev/null
@@ -1 +0,0 @@
-MAINTAINER_COLLISION_REWARD = -5
\ No newline at end of file
diff --git a/marl_factory_grid/modules/maintenance/rules.py b/marl_factory_grid/modules/maintenance/rules.py
index 820183e..92e6e75 100644
--- a/marl_factory_grid/modules/maintenance/rules.py
+++ b/marl_factory_grid/modules/maintenance/rules.py
@@ -1,32 +1,28 @@
from typing import List
+
+import marl_factory_grid.modules.maintenance.constants
from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.utils.results import TickResult, DoneResult
from marl_factory_grid.environment import constants as c
-from . import rewards as r
from . import constants as M
-from marl_factory_grid.utils.states import Gamestate
-class MaintenanceRule(Rule):
+class MoveMaintainers(Rule):
- def __init__(self, n_maintainer: int = 1, *args, **kwargs):
- super(MaintenanceRule, self).__init__(*args, **kwargs)
- self.n_maintainer = n_maintainer
-
- def on_init(self, state: Gamestate, lvl_map):
- state[M.MAINTAINERS].spawn(state.entities.empty_positions[:self.n_maintainer], state)
- pass
-
- def tick_pre_step(self, state) -> List[TickResult]:
- pass
+ def __init__(self):
+ super().__init__()
def tick_step(self, state) -> List[TickResult]:
for maintainer in state[M.MAINTAINERS]:
maintainer.tick(state)
+ # Todo: Return a Result Object.
return []
- def tick_post_step(self, state) -> List[TickResult]:
- pass
+
+class DoneAtMaintainerCollision(Rule):
+
+ def __init__(self):
+ super().__init__()
def on_check_done(self, state) -> List[DoneResult]:
agents = list(state[c.AGENT].values())
@@ -35,5 +31,5 @@ class MaintenanceRule(Rule):
for agent in agents:
if agent.pos in m_pos:
done_results.append(DoneResult(entity=agent, validity=c.VALID, identifier=self.name,
- reward=r.MAINTAINER_COLLISION_REWARD))
+ reward=marl_factory_grid.modules.maintenance.constants.MAINTAINER_COLLISION_REWARD))
return done_results
diff --git a/marl_factory_grid/modules/zones/entitites.py b/marl_factory_grid/modules/zones/entitites.py
index cfd313f..4aa0f70 100644
--- a/marl_factory_grid/modules/zones/entitites.py
+++ b/marl_factory_grid/modules/zones/entitites.py
@@ -1,10 +1,10 @@
import random
from typing import List, Tuple
-from marl_factory_grid.environment.entity.object import _Object
+from marl_factory_grid.environment.entity.object import Object
-class Zone(_Object):
+class Zone(Object):
@property
def positions(self):
diff --git a/marl_factory_grid/modules/zones/groups.py b/marl_factory_grid/modules/zones/groups.py
index 71eb329..f5494cd 100644
--- a/marl_factory_grid/modules/zones/groups.py
+++ b/marl_factory_grid/modules/zones/groups.py
@@ -1,8 +1,8 @@
-from marl_factory_grid.environment.groups.objects import _Objects
+from marl_factory_grid.environment.groups.objects import Objects
from marl_factory_grid.modules.zones import Zone
-class Zones(_Objects):
+class Zones(Objects):
symbol = None
_entity = Zone
diff --git a/marl_factory_grid/modules/zones/rules.py b/marl_factory_grid/modules/zones/rules.py
index 2969186..f9b5c11 100644
--- a/marl_factory_grid/modules/zones/rules.py
+++ b/marl_factory_grid/modules/zones/rules.py
@@ -1,8 +1,8 @@
from random import choices, choice
from . import constants as z, Zone
+from .. import Destination
from ..destinations import constants as d
-from ... import Destination
from ...environment.rules import Rule
from ...environment import constants as c
diff --git a/marl_factory_grid/utils/__init__.py b/marl_factory_grid/utils/__init__.py
index e69de29..23848e0 100644
--- a/marl_factory_grid/utils/__init__.py
+++ b/marl_factory_grid/utils/__init__.py
@@ -0,0 +1,3 @@
+from . import helpers as h
+from . import helpers
+from .results import Result, DoneResult, ActionResult, TickResult
diff --git a/marl_factory_grid/utils/config_parser.py b/marl_factory_grid/utils/config_parser.py
index c9223f8..5cad113 100644
--- a/marl_factory_grid/utils/config_parser.py
+++ b/marl_factory_grid/utils/config_parser.py
@@ -1,4 +1,5 @@
import ast
+
from os import PathLike
from pathlib import Path
from typing import Union, List
@@ -9,18 +10,17 @@ from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.environment.tests import Test
from marl_factory_grid.utils.helpers import locate_and_import_class
-
-DEFAULT_PATH = 'environment'
-MODULE_PATH = 'modules'
+from marl_factory_grid.environment.constants import DEFAULT_PATH, MODULE_PATH
+from marl_factory_grid.environment import constants as c
class FactoryConfigParser(object):
default_entites = []
- default_rules = ['MaxStepsReached', 'Collision']
+ default_rules = ['DoneAtMaxStepsReached', 'WatchCollision']
default_actions = [c.MOVE8, c.NOOP]
default_observations = [c.WALLS, c.AGENT]
- def __init__(self, config_path, custom_modules_path: Union[None, PathLike] = None):
+ def __init__(self, config_path, custom_modules_path: Union[PathLike] = None):
self.config_path = Path(config_path)
self.custom_modules_path = Path(custom_modules_path) if custom_modules_path is not None else custom_modules_path
self.config = yaml.safe_load(self.config_path.open())
@@ -44,6 +44,10 @@ class FactoryConfigParser(object):
def rules(self):
return self.config['Rules']
+ @property
+ def tests(self):
+ return self.config.get('Tests', [])
+
@property
def agents(self):
return self.config['Agents']
@@ -56,10 +60,12 @@ class FactoryConfigParser(object):
return str(self.config)
def __getitem__(self, item):
- return self.config[item]
+ try:
+ return self.config[item]
+ except KeyError:
+ print(f'The mandatory {item} section could not be found in your .config gile. Check Spelling!')
def load_entities(self):
- # entites = Entities()
entity_classes = dict()
entities = []
if c.DEFAULTS in self.entities:
@@ -67,28 +73,40 @@ class FactoryConfigParser(object):
entities.extend(x for x in self.entities if x != c.DEFAULTS)
for entity in entities:
+ e1 = e2 = e3 = None
try:
folder_path = Path(__file__).parent.parent / DEFAULT_PATH
entity_class = locate_and_import_class(entity, folder_path)
- except AttributeError as e1:
+ except AttributeError as e:
+ e1 = e
try:
- folder_path = Path(__file__).parent.parent / MODULE_PATH
- entity_class = locate_and_import_class(entity, folder_path)
- except AttributeError as e2:
- try:
- folder_path = self.custom_modules_path
- entity_class = locate_and_import_class(entity, folder_path)
- except AttributeError as e3:
- ents = [y for x in [e1.argss[1], e2.argss[1], e3.argss[1]] for y in x]
- print('### Error ### Error ### Error ### Error ### Error ###')
- print()
- print(f'Class "{entity}" was not found in "{folder_path.name}"')
- print('Possible Entitys are:', str(ents))
- print()
- print('Goodbye')
- print()
- exit()
- # raise AttributeError(e1.argss[0], e2.argss[0], e3.argss[0], 'Possible Entitys are:', str(ents))
+ module_path = Path(__file__).parent.parent / MODULE_PATH
+ entity_class = locate_and_import_class(entity, module_path)
+ except AttributeError as e:
+ e2 = e
+ if self.custom_modules_path:
+ try:
+ entity_class = locate_and_import_class(entity, self.custom_modules_path)
+ except AttributeError as e:
+ e3 = e
+ pass
+ if (e1 and e2) or e3:
+ ents = [y for x in [e1, e2, e3] if x is not None for y in x.args[1]]
+ print('##############################################################')
+ print('### Error ### Error ### Error ### Error ### Error ###')
+ print('##############################################################')
+ print(f'Class "{entity}" was not found in "{module_path.name}"')
+ print(f'Class "{entity}" was not found in "{folder_path.name}"')
+ print('##############################################################')
+ if self.custom_modules_path:
+ print(f'Class "{entity}" was not found in "{self.custom_modules_path}"')
+ print('Possible Entitys are:', str(ents))
+ print('##############################################################')
+ print('Goodbye')
+ print('##############################################################')
+ print('### Error ### Error ### Error ### Error ### Error ###')
+ print('##############################################################')
+ exit(-99999)
entity_kwargs = self.entities.get(entity, {})
entity_symbol = entity_class.symbol if hasattr(entity_class, 'symbol') else None
@@ -126,7 +144,12 @@ class FactoryConfigParser(object):
observations.extend(self.default_observations)
observations.extend(x for x in self.agents[name]['Observations'] if x != c.DEFAULTS)
positions = [ast.literal_eval(x) for x in self.agents[name].get('Positions', [])]
- parsed_agents_conf[name] = dict(actions=parsed_actions, observations=observations, positions=positions)
+ other_kwargs = {k: v for k, v in self.agents[name].items() if k not in
+ ['Actions', 'Observations', 'Positions']}
+ parsed_agents_conf[name] = dict(
+ actions=parsed_actions, observations=observations, positions=positions, other=other_kwargs
+ )
+
return parsed_agents_conf
def load_env_rules(self) -> List[Rule]:
@@ -137,28 +160,69 @@ class FactoryConfigParser(object):
rules.append({rule: {}})
return self._load_smth(rules, Rule)
- pass
- def load_env_tests(self) -> List[Test]:
+ def load_env_tests(self) -> List[Rule]:
return self._load_smth(self.tests, None) # Test
- pass
def _load_smth(self, config, class_obj):
rules = list()
- rules_names = list()
-
- for rule in rules_names:
+ for rule in config:
+ e1 = e2 = e3 = None
try:
folder_path = (Path(__file__).parent.parent / DEFAULT_PATH)
rule_class = locate_and_import_class(rule, folder_path)
- except AttributeError:
+ except AttributeError as e:
+ e1 = e
try:
- folder_path = (Path(__file__).parent.parent / MODULE_PATH)
- rule_class = locate_and_import_class(rule, folder_path)
- except AttributeError:
- rule_class = locate_and_import_class(rule, self.custom_modules_path)
- # Fixme This check does not work!
- # assert isinstance(rule_class, class_obj), f'{rule_class.__name__} is no valid "class_obj.__name__".'
- rule_kwargs = config.get(rule, {})
- rules.append(rule_class(**rule_kwargs))
+ module_path = (Path(__file__).parent.parent / MODULE_PATH)
+ rule_class = locate_and_import_class(rule, module_path)
+ except AttributeError as e:
+ e2 = e
+ if self.custom_modules_path:
+ try:
+ rule_class = locate_and_import_class(rule, self.custom_modules_path)
+ except AttributeError as e:
+ e3 = e
+ pass
+ if (e1 and e2) or e3:
+ ents = [y for x in [e1, e2, e3] if x is not None for y in x.args[1]]
+ print('### Error ### Error ### Error ### Error ### Error ###')
+ print('')
+ print(f'Class "{rule}" was not found in "{module_path.name}"')
+ print(f'Class "{rule}" was not found in "{folder_path.name}"')
+ if self.custom_modules_path:
+ print(f'Class "{rule}" was not found in "{self.custom_modules_path}"')
+ print('Possible Entitys are:', str(ents))
+ print('')
+ print('Goodbye')
+ print('')
+ exit(-99999)
+
+ if issubclass(rule_class, class_obj):
+ rule_kwargs = config.get(rule, {})
+ rules.append(rule_class(**(rule_kwargs or {})))
+ return rules
+
+ def load_entity_spawn_rules(self, entities) -> List[Rule]:
+ rules = list()
+ rules_dicts = list()
+ for e in entities:
+ try:
+ if spawn_rule := e.spawn_rule:
+ rules_dicts.append(spawn_rule)
+ except AttributeError:
+ pass
+
+ for rule_dict in rules_dicts:
+ for rule_name, rule_kwargs in rule_dict.items():
+ try:
+ folder_path = (Path(__file__).parent.parent / DEFAULT_PATH)
+ rule_class = locate_and_import_class(rule_name, folder_path)
+ except AttributeError:
+ try:
+ folder_path = (Path(__file__).parent.parent / MODULE_PATH)
+ rule_class = locate_and_import_class(rule_name, folder_path)
+ except AttributeError:
+ rule_class = locate_and_import_class(rule_name, self.custom_modules_path)
+ rules.append(rule_class(**rule_kwargs))
return rules
diff --git a/marl_factory_grid/utils/helpers.py b/marl_factory_grid/utils/helpers.py
index e2f3c9a..f5f6d00 100644
--- a/marl_factory_grid/utils/helpers.py
+++ b/marl_factory_grid/utils/helpers.py
@@ -2,7 +2,7 @@ import importlib
from collections import defaultdict
from pathlib import PurePath, Path
-from typing import Union, Dict, List
+from typing import Union, Dict, List, Iterable, Callable
import numpy as np
from numpy.typing import ArrayLike
@@ -61,8 +61,8 @@ class ObservationTranslator:
:param per_agent_named_obs_spaces: `Named observation space` one for each agent. Overloaded.
type per_agent_named_obs_spaces: Dict[str, dict]
- :param placeholder_fill_value: Currently not fully implemented!!!
- :type placeholder_fill_value: Union[int, str] = 'N')
+ :param placeholder_fill_value: Currently, not fully implemented!!!
+ :type placeholder_fill_value: Union[int, str] = 'N'
"""
if isinstance(placeholder_fill_value, str):
@@ -222,7 +222,7 @@ def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''):
mod = importlib.import_module('.'.join(module_parts))
all_found_modules.extend([x for x in dir(mod) if (not(x.startswith('__') or len(x) <= 2) and x.istitle())
and x not in ['Entity', 'NamedTuple', 'List', 'Rule', 'Union',
- 'TickResult', 'ActionResult', 'Action', 'Agent', 'BoundEntityMixin',
+ 'TickResult', 'ActionResult', 'Action', 'Agent',
'RenderEntity', 'TemplateRule', 'Objects', 'PositionMixin',
'IsBoundMixin', 'EnvObject', 'EnvObjects', 'Dict', 'Any'
]])
@@ -240,7 +240,13 @@ def add_bound_name(name_str, bound_e):
def add_pos_name(name_str, bound_e):
if bound_e.var_has_position:
- return f'{name_str}({bound_e.pos})'
+ return f'{name_str}@{bound_e.pos}'
return name_str
+def get_first(iterable: Iterable, filter_by: Callable[[any], bool] = lambda _: True):
+ return next((x for x in iterable if filter_by(x)), None)
+
+
+def get_first_index(iterable: Iterable, filter_by: Callable[[any], bool] = lambda _: True):
+ return next((idx for idx, x in enumerate(iterable) if filter_by(x)), None)
diff --git a/marl_factory_grid/utils/level_parser.py b/marl_factory_grid/utils/level_parser.py
index fc8b948..24a05df 100644
--- a/marl_factory_grid/utils/level_parser.py
+++ b/marl_factory_grid/utils/level_parser.py
@@ -47,6 +47,7 @@ class LevelParser(object):
# All other
for es_name in self.e_p_dict:
e_class, e_kwargs = self.e_p_dict[es_name]['class'], self.e_p_dict[es_name]['kwargs']
+ e_kwargs = e_kwargs if e_kwargs else {}
if hasattr(e_class, 'symbol') and e_class.symbol is not None:
symbols = e_class.symbol
diff --git a/marl_factory_grid/utils/logging/envmonitor.py b/marl_factory_grid/utils/logging/envmonitor.py
index 67eac73..e2551c8 100644
--- a/marl_factory_grid/utils/logging/envmonitor.py
+++ b/marl_factory_grid/utils/logging/envmonitor.py
@@ -9,7 +9,7 @@ from marl_factory_grid.utils.helpers import IGNORED_DF_COLUMNS
import pandas as pd
-from marl_factory_grid.utils.plotting.compare_runs import plot_single_run
+from marl_factory_grid.utils.plotting.plot_single_runs import plot_single_run
class EnvMonitor(Wrapper):
@@ -22,7 +22,6 @@ class EnvMonitor(Wrapper):
self._monitor_df = pd.DataFrame()
self._monitor_dict = dict()
-
def step(self, action):
obs_type, obs, reward, done, info = self.env.step(action)
self._read_info(info)
diff --git a/marl_factory_grid/utils/logging/recorder.py b/marl_factory_grid/utils/logging/recorder.py
index fac2e16..797866e 100644
--- a/marl_factory_grid/utils/logging/recorder.py
+++ b/marl_factory_grid/utils/logging/recorder.py
@@ -2,11 +2,9 @@ from os import PathLike
from pathlib import Path
from typing import Union, List
-import yaml
-from gymnasium import Wrapper
-
import numpy as np
import pandas as pd
+from gymnasium import Wrapper
class EnvRecorder(Wrapper):
@@ -106,7 +104,7 @@ class EnvRecorder(Wrapper):
out_dict = {'episodes': self._recorder_out_list}
out_dict.update(
{'n_episodes': self._curr_episode,
- 'metadata':dict(
+ 'metadata': dict(
level_name=self.env.params['General']['level_name'],
verbose=False,
n_agents=len(self.env.params['Agents']),
diff --git a/marl_factory_grid/utils/observation_builder.py b/marl_factory_grid/utils/observation_builder.py
index 9fd1d26..55d6ec0 100644
--- a/marl_factory_grid/utils/observation_builder.py
+++ b/marl_factory_grid/utils/observation_builder.py
@@ -1,17 +1,16 @@
-import math
import re
from collections import defaultdict
-from itertools import product
from typing import Dict, List
import numpy as np
-from numba import njit
from marl_factory_grid.environment import constants as c
+from marl_factory_grid.environment.entity.object import Object
from marl_factory_grid.environment.groups.utils import Combined
-import marl_factory_grid.utils.helpers as h
-from marl_factory_grid.utils.states import Gamestate
from marl_factory_grid.utils.utility_classes import Floor
+from marl_factory_grid.utils.ray_caster import RayCaster
+from marl_factory_grid.utils.states import Gamestate
+from marl_factory_grid.utils import helpers as h
class OBSBuilder(object):
@@ -77,11 +76,13 @@ class OBSBuilder(object):
def place_entity_in_observation(self, obs_array, agent, e):
x, y = (e.x - agent.x) + self.pomdp_r, (e.y - agent.y) + self.pomdp_r
- try:
- obs_array[x, y] += e.encoding
- except IndexError:
- # Seemded to be visible but is out of range
- pass
+ if not min([y, x]) < 0:
+ try:
+ obs_array[x, y] += e.encoding
+ except IndexError:
+ # Seemded to be visible but is out of range
+ pass
+ pass
def build_for_agent(self, agent, state) -> (List[str], np.ndarray):
assert self._curr_env_step == state.curr_step, (
@@ -121,18 +122,24 @@ class OBSBuilder(object):
e = self.all_obs[l_name]
except KeyError:
try:
- # Look for bound entity names!
- pattern = re.compile(f'{re.escape(l_name)}(.*){re.escape(agent.name)}')
- name = next((x for x in self.all_obs if pattern.search(x)), None)
+ # Look for bound entity REPRs!
+ pattern = re.compile(f'{re.escape(l_name)}'
+ f'{re.escape("[")}(.*){re.escape("]")}'
+ f'{re.escape("(")}{re.escape(agent.name)}{re.escape(")")}')
+ name = next((key for key, val in self.all_obs.items()
+ if pattern.search(str(val)) and isinstance(val, Object)), None)
e = self.all_obs[name]
except KeyError:
try:
e = next(v for k, v in self.all_obs.items() if l_name in k and agent.name in k)
except StopIteration:
- raise KeyError(
- f'Check for spelling errors! \n '
- f'No combination of "{l_name} and {agent.name}" could not be found in:\n '
- f'{list(dict(self.all_obs).keys())}')
+ print(f'# Check for spelling errors!')
+ print(f'# No combination of "{l_name}" and "{agent.name}" could not be found in:')
+ print(f'# {list(dict(self.all_obs).keys())}')
+ print('#')
+ print('# exiting...')
+ print('#')
+ exit(-99999)
try:
positional = e.var_has_position
@@ -161,31 +168,30 @@ class OBSBuilder(object):
try:
light_map = np.zeros(self.obs_shape)
visible_floor = self.ray_caster[agent.name].visible_entities(self._floortiles, reset_cache=False)
- if self.pomdp_r:
- for f in set(visible_floor):
- self.place_entity_in_observation(light_map, agent, f)
- else:
- for f in set(visible_floor):
- light_map[f.x, f.y] += f.encoding
+
+ for f in set(visible_floor):
+ self.place_entity_in_observation(light_map, agent, f)
+ # else:
+ # for f in set(visible_floor):
+ # light_map[f.x, f.y] += f.encoding
self.curr_lightmaps[agent.name] = light_map
except (KeyError, ValueError):
- print()
pass
return obs, self.obs_layers[agent.name]
def _sort_and_name_observation_conf(self, agent):
- '''
+ """
Builds the useable observation scheme per agent from conf.yaml.
:param agent:
:return:
- '''
+ """
# Fixme: no asymetric shapes possible.
self.ray_caster[agent.name] = RayCaster(agent, min(self.obs_shape))
obs_layers = []
for obs_str in agent.observations:
if isinstance(obs_str, dict):
- obs_str, vals = next(obs_str.items().__iter__())
+ obs_str, vals = h.get_first(obs_str.items())
else:
vals = None
if obs_str == c.SELF:
@@ -214,129 +220,3 @@ class OBSBuilder(object):
obs_layers.append(obs_str)
self.obs_layers[agent.name] = obs_layers
self.curr_lightmaps[agent.name] = np.zeros(self.obs_shape)
-
-
-class RayCaster:
- def __init__(self, agent, pomdp_r, degs=360):
- self.agent = agent
- self.pomdp_r = pomdp_r
- self.n_rays = (self.pomdp_r + 1) * 8
- self.degs = degs
- self.ray_targets = self.build_ray_targets()
- self.obs_shape_cube = np.array([self.pomdp_r, self.pomdp_r])
- self._cache_dict = {}
-
- def __repr__(self):
- return f'{self.__class__.__name__}({self.agent.name})'
-
- def build_ray_targets(self):
- north = np.array([0, -1]) * self.pomdp_r
- thetas = [np.deg2rad(deg) for deg in np.linspace(-self.degs // 2, self.degs // 2, self.n_rays)[::-1]]
- rot_M = [
- [[math.cos(theta), -math.sin(theta)],
- [math.sin(theta), math.cos(theta)]] for theta in thetas
- ]
- rot_M = np.stack(rot_M, 0)
- rot_M = np.unique(np.round(rot_M @ north), axis=0)
- return rot_M.astype(int)
-
- def ray_block_cache(self, key, callback):
- if key not in self._cache_dict:
- self._cache_dict[key] = callback()
- return self._cache_dict[key]
-
- def visible_entities(self, pos_dict, reset_cache=True):
- visible = list()
- if reset_cache:
- self._cache_dict = {}
-
- for ray in self.get_rays():
- rx, ry = ray[0]
- for x, y in ray:
- cx, cy = x - rx, y - ry
-
- entities_hit = pos_dict[(x, y)]
- hits = self.ray_block_cache((x, y),
- lambda: any(True for e in entities_hit if e.var_is_blocking_light)
- )
-
- diag_hits = all([
- self.ray_block_cache(
- key,
- lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light) and bool(
- pos_dict[key]))
- for key in ((x, y - cy), (x - cx, y))
- ]) if (cx != 0 and cy != 0) else False
-
- visible += entities_hit if not diag_hits else []
- if hits or diag_hits:
- break
- rx, ry = x, y
- return visible
-
- def get_rays(self):
- a_pos = self.agent.pos
- outline = self.ray_targets + a_pos
- return self.bresenham_loop(a_pos, outline)
-
- # todo do this once and cache the points!
- def get_fov_outline(self) -> np.ndarray:
- return self.ray_targets + self.agent.pos
-
- def get_square_outline(self):
- agent = self.agent
- x_coords = range(agent.x - self.pomdp_r, agent.x + self.pomdp_r + 1)
- y_coords = range(agent.y - self.pomdp_r, agent.y + self.pomdp_r + 1)
- outline = list(product(x_coords, [agent.y - self.pomdp_r, agent.y + self.pomdp_r])) \
- + list(product([agent.x - self.pomdp_r, agent.x + self.pomdp_r], y_coords))
- return outline
-
- @staticmethod
- @njit
- def bresenham_loop(a_pos, points):
- results = []
- for end in points:
- x1, y1 = a_pos
- x2, y2 = end
- dx = x2 - x1
- dy = y2 - y1
-
- # Determine how steep the line is
- is_steep = abs(dy) > abs(dx)
-
- # Rotate line
- if is_steep:
- x1, y1 = y1, x1
- x2, y2 = y2, x2
-
- # Swap start and end points if necessary and store swap state
- swapped = False
- if x1 > x2:
- x1, x2 = x2, x1
- y1, y2 = y2, y1
- swapped = True
-
- # Recalculate differentials
- dx = x2 - x1
- dy = y2 - y1
-
- # Calculate error
- error = int(dx / 2.0)
- ystep = 1 if y1 < y2 else -1
-
- # Iterate over bounding box generating points between start and end
- y = y1
- points = []
- for x in range(int(x1), int(x2) + 1):
- coord = [y, x] if is_steep else [x, y]
- points.append(coord)
- error -= abs(dy)
- if error < 0:
- y += ystep
- error += dx
-
- # Reverse the list if the coordinates were swapped
- if swapped:
- points.reverse()
- results.append(points)
- return results
diff --git a/marl_factory_grid/utils/plotting/compare_runs.py b/marl_factory_grid/utils/plotting/plot_compare_runs.py
similarity index 83%
rename from marl_factory_grid/utils/plotting/compare_runs.py
rename to marl_factory_grid/utils/plotting/plot_compare_runs.py
index cb5c853..5115478 100644
--- a/marl_factory_grid/utils/plotting/compare_runs.py
+++ b/marl_factory_grid/utils/plotting/plot_compare_runs.py
@@ -7,50 +7,11 @@ from typing import Union, List
import pandas as pd
from marl_factory_grid.utils.helpers import IGNORED_DF_COLUMNS
-from marl_factory_grid.utils.plotting.plotting import prepare_plot
+from marl_factory_grid.utils.plotting.plotting_utils import prepare_plot
MODEL_MAP = None
-def plot_single_run(run_path: Union[str, PathLike], use_tex: bool = False, column_keys=None):
- run_path = Path(run_path)
- df_list = list()
- if run_path.is_dir():
- monitor_file = next(run_path.glob('*monitor*.pick'))
- elif run_path.exists() and run_path.is_file():
- monitor_file = run_path
- else:
- raise ValueError
-
- with monitor_file.open('rb') as f:
- monitor_df = pickle.load(f)
-
- monitor_df = monitor_df.fillna(0)
- df_list.append(monitor_df)
-
- df = pd.concat(df_list, ignore_index=True)
- df = df.fillna(0).rename(columns={'episode': 'Episode'}).sort_values(['Episode'])
- if column_keys is not None:
- columns = [col for col in column_keys if col in df.columns]
- else:
- columns = [col for col in df.columns if col not in IGNORED_DF_COLUMNS]
-
- roll_n = 50
-
- non_overlapp_window = df.groupby(['Episode']).rolling(roll_n, min_periods=1).mean()
-
- df_melted = df[columns + ['Episode']].reset_index().melt(
- id_vars=['Episode'], value_vars=columns, var_name="Measurement", value_name="Score"
- )
-
- if df_melted['Episode'].max() > 800:
- skip_n = round(df_melted['Episode'].max() * 0.02)
- df_melted = df_melted[df_melted['Episode'] % skip_n == 0]
-
- prepare_plot(run_path.parent / f'{run_path.parent.name}_monitor_lineplot.png', df_melted, use_tex=use_tex)
- print('Plotting done.')
-
-
def compare_seed_runs(run_path: Union[str, PathLike], use_tex: bool = False):
run_path = Path(run_path)
df_list = list()
diff --git a/marl_factory_grid/utils/plotting/plot_single_runs.py b/marl_factory_grid/utils/plotting/plot_single_runs.py
new file mode 100644
index 0000000..7316d6a
--- /dev/null
+++ b/marl_factory_grid/utils/plotting/plot_single_runs.py
@@ -0,0 +1,48 @@
+import pickle
+from os import PathLike
+from pathlib import Path
+from typing import Union
+
+import pandas as pd
+
+from marl_factory_grid.utils.helpers import IGNORED_DF_COLUMNS
+from marl_factory_grid.utils.plotting.plotting_utils import prepare_plot
+
+
+def plot_single_run(run_path: Union[str, PathLike], use_tex: bool = False, column_keys=None,
+ file_key: str ='monitor', file_ext: str ='pkl'):
+ run_path = Path(run_path)
+ df_list = list()
+ if run_path.is_dir():
+ monitor_file = next(run_path.glob(f'*{file_key}*.{file_ext}'))
+ elif run_path.exists() and run_path.is_file():
+ monitor_file = run_path
+ else:
+ raise ValueError
+
+ with monitor_file.open('rb') as f:
+ monitor_df = pickle.load(f)
+
+ monitor_df = monitor_df.fillna(0)
+ df_list.append(monitor_df)
+
+ df = pd.concat(df_list, ignore_index=True)
+ df = df.fillna(0).rename(columns={'episode': 'Episode'}).sort_values(['Episode'])
+ if column_keys is not None:
+ columns = [col for col in column_keys if col in df.columns]
+ else:
+ columns = [col for col in df.columns if col not in IGNORED_DF_COLUMNS]
+
+ # roll_n = 50
+ # non_overlapp_window = df.groupby(['Episode']).rolling(roll_n, min_periods=1).mean()
+
+ df_melted = df[columns + ['Episode']].reset_index().melt(
+ id_vars=['Episode'], value_vars=columns, var_name="Measurement", value_name="Score"
+ )
+
+ if df_melted['Episode'].max() > 800:
+ skip_n = round(df_melted['Episode'].max() * 0.02)
+ df_melted = df_melted[df_melted['Episode'] % skip_n == 0]
+
+ prepare_plot(run_path.parent / f'{run_path.parent.name}_monitor_lineplot.png', df_melted, use_tex=use_tex)
+ print('Plotting done.')
diff --git a/marl_factory_grid/utils/plotting/plotting.py b/marl_factory_grid/utils/plotting/plotting_utils.py
similarity index 98%
rename from marl_factory_grid/utils/plotting/plotting.py
rename to marl_factory_grid/utils/plotting/plotting_utils.py
index 455f81a..17bb7ff 100644
--- a/marl_factory_grid/utils/plotting/plotting.py
+++ b/marl_factory_grid/utils/plotting/plotting_utils.py
@@ -60,7 +60,7 @@ def prepare_center_double_column_legend(df, hue, style, hue_order):
print('Struggling to plot Figure using LaTeX - going back to normal.')
plt.close('all')
sns.set(rc={'text.usetex': False}, style='whitegrid')
- fig = plt.figure(figsize=(10, 11))
+ _ = plt.figure(figsize=(10, 11))
lineplot = sns.lineplot(data=df, x='Episode', y='Score', hue=hue, style=style,
ci=95, palette=PALETTE, hue_order=hue_order, legend=False)
# plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0)
diff --git a/marl_factory_grid/utils/ray_caster.py b/marl_factory_grid/utils/ray_caster.py
index cf17bd1..d89997e 100644
--- a/marl_factory_grid/utils/ray_caster.py
+++ b/marl_factory_grid/utils/ray_caster.py
@@ -19,7 +19,7 @@ class RayCaster:
return f'{self.__class__.__name__}({self.agent.name})'
def build_ray_targets(self):
- north = np.array([0, -1])*self.pomdp_r
+ north = np.array([0, -1]) * self.pomdp_r
thetas = [np.deg2rad(deg) for deg in np.linspace(-self.degs // 2, self.degs // 2, self.n_rays)[::-1]]
rot_M = [
[[math.cos(theta), -math.sin(theta)],
@@ -39,8 +39,9 @@ class RayCaster:
if reset_cache:
self._cache_dict = dict()
- for ray in self.get_rays():
+ for ray in self.get_rays(): # Do not check, just trust.
rx, ry = ray[0]
+ # self.ray_block_cache(ray[0], lambda: False) We do not do that, because of doors etc...
for x, y in ray:
cx, cy = x - rx, y - ry
@@ -52,8 +53,9 @@ class RayCaster:
diag_hits = all([
self.ray_block_cache(
key,
- lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light))
- for key in ((x, y-cy), (x-cx, y))
+ # lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light)
+ lambda: any(True for e in pos_dict[key] if e.var_is_blocking_light))
+ for key in ((x, y - cy), (x - cx, y))
]) if (cx != 0 and cy != 0) else False
visible += entities_hit if not diag_hits else []
@@ -75,8 +77,8 @@ class RayCaster:
agent = self.agent
x_coords = range(agent.x - self.pomdp_r, agent.x + self.pomdp_r + 1)
y_coords = range(agent.y - self.pomdp_r, agent.y + self.pomdp_r + 1)
- outline = list(product(x_coords, [agent.y - self.pomdp_r, agent.y + self.pomdp_r])) \
- + list(product([agent.x - self.pomdp_r, agent.x + self.pomdp_r], y_coords))
+ outline = list(product(x_coords, [agent.y - self.pomdp_r, agent.y + self.pomdp_r]))
+ outline += list(product([agent.x - self.pomdp_r, agent.x + self.pomdp_r], y_coords))
return outline
@staticmethod
diff --git a/marl_factory_grid/utils/renderer.py b/marl_factory_grid/utils/renderer.py
index db6a93f..1976974 100644
--- a/marl_factory_grid/utils/renderer.py
+++ b/marl_factory_grid/utils/renderer.py
@@ -31,7 +31,7 @@ class Renderer:
def __init__(self, lvl_shape: Tuple[int, int] = (16, 16),
lvl_padded_shape: Union[Tuple[int, int], None] = None,
- cell_size: int = 40, fps: int = 7,
+ cell_size: int = 40, fps: int = 7, factor: float = 0.9,
grid_lines: bool = True, view_radius: int = 2):
# TODO: Customn_assets paths
self.grid_h, self.grid_w = lvl_shape
@@ -45,7 +45,7 @@ class Renderer:
self.screen = pygame.display.set_mode(self.screen_size)
self.clock = pygame.time.Clock()
assets = list(self.ASSETS.rglob('*.png'))
- self.assets = {path.stem: self.load_asset(str(path), 1) for path in assets}
+ self.assets = {path.stem: self.load_asset(str(path), factor) for path in assets}
self.fill_bg()
now = time.time()
@@ -110,22 +110,22 @@ class Renderer:
pygame.quit()
sys.exit()
self.fill_bg()
- blits = deque()
- for entity in [x for x in entities]:
- bp = self.blit_params(entity)
- blits.append(bp)
- if entity.name.lower() == AGENT:
- if self.view_radius > 0:
- vis_rects = self.visibility_rects(bp, entity.aux)
- blits.extendleft(vis_rects)
- if entity.state != BLANK:
- agent_state_blits = self.blit_params(
- RenderEntity(entity.state, (entity.pos[0] + 0.12, entity.pos[1]), 0.48, SCALE)
- )
- textsurface = self.font.render(str(entity.id), False, (0, 0, 0))
- text_blit = dict(source=textsurface, dest=(bp['dest'].center[0]-.07*self.cell_size,
- bp['dest'].center[1]))
- blits += [agent_state_blits, text_blit]
+ # First all others
+ blits = deque(self.blit_params(x) for x in entities if not x.name.lower() == AGENT)
+ # Then Agents, so that agents are rendered on top.
+ for agent in (x for x in entities if x.name.lower() == AGENT):
+ agent_blit = self.blit_params(agent)
+ if self.view_radius > 0:
+ vis_rects = self.visibility_rects(agent_blit, agent.aux)
+ blits.extendleft(vis_rects)
+ if agent.state != BLANK:
+ state_blit = self.blit_params(
+ RenderEntity(agent.state, (agent.pos[0] + 0.12, agent.pos[1]), 0.48, SCALE)
+ )
+ textsurface = self.font.render(str(agent.id), False, (0, 0, 0))
+ text_blit = dict(source=textsurface, dest=(agent_blit['dest'].center[0]-.07*self.cell_size,
+ agent_blit['dest'].center[1]))
+ blits += [agent_blit, state_blit, text_blit]
for blit in blits:
self.screen.blit(**blit)
diff --git a/marl_factory_grid/utils/results.py b/marl_factory_grid/utils/results.py
index 9f0fa38..b4b07fc 100644
--- a/marl_factory_grid/utils/results.py
+++ b/marl_factory_grid/utils/results.py
@@ -1,9 +1,12 @@
from typing import Union
from dataclasses import dataclass
+from marl_factory_grid.environment.entity.object import Object
+
TYPE_VALUE = 'value'
TYPE_REWARD = 'reward'
-types = [TYPE_VALUE, TYPE_REWARD]
+TYPES = [TYPE_VALUE, TYPE_REWARD]
+
@dataclass
class InfoObject:
@@ -18,17 +21,21 @@ class Result:
validity: bool
reward: Union[float, None] = None
value: Union[float, None] = None
- entity: None = None
+ entity: Object = None
def get_infos(self):
n = self.entity.name if self.entity is not None else "Global"
- return [InfoObject(identifier=f'{n}_{self.identifier}_{t}',
- val_type=t, value=self.__getattribute__(t)) for t in types
+ # Return multiple Info Dicts
+ return [InfoObject(identifier=f'{n}_{self.identifier}',
+ val_type=t, value=self.__getattribute__(t)) for t in TYPES
if self.__getattribute__(t) is not None]
def __repr__(self):
valid = "not " if not self.validity else ""
- return f'{self.__class__.__name__}({self.identifier.capitalize()} {valid}valid: {self.reward})'
+ reward = f" | Reward: {self.reward}" if self.reward is not None else ""
+ value = f" | Value: {self.value}" if self.value is not None else ""
+ entity = f" | by: {self.entity.name}" if self.entity is not None else ""
+ return f'{self.__class__.__name__}({self.identifier.capitalize()} {valid}valid{reward}{value}{entity})'
@dataclass
diff --git a/marl_factory_grid/utils/states.py b/marl_factory_grid/utils/states.py
index 1461826..d54db6a 100644
--- a/marl_factory_grid/utils/states.py
+++ b/marl_factory_grid/utils/states.py
@@ -1,9 +1,12 @@
-from typing import List, Dict, Tuple
+from itertools import islice
+from typing import List, Tuple
import numpy as np
from marl_factory_grid.environment import constants as c
+from marl_factory_grid.environment.entity.entity import Entity
from marl_factory_grid.environment.rules import Rule
+from marl_factory_grid.utils.results import Result, DoneResult
from marl_factory_grid.environment.tests import Test
from marl_factory_grid.utils.results import Result
@@ -60,7 +63,8 @@ class Gamestate(object):
def moving_entites(self):
return [y for x in self.entities for y in x if x.var_can_move]
- def __init__(self, entities, agents_conf, rules: [Rule], tests: [Test], env_seed=69, verbose=False):
+ def __init__(self, entities, agents_conf, rules: List[Rule], tests: [Test], lvl_shape, env_seed=69, verbose=False):
+ self.lvl_shape = lvl_shape
self.entities = entities
self.curr_step = 0
self.curr_actions = None
@@ -82,7 +86,52 @@ class Gamestate(object):
def __repr__(self):
return f'{self.__class__.__name__}({len(self.entities)} Entitites @ Step {self.curr_step})'
- def tick(self, actions) -> List[Result]:
+ @property
+ def random_free_position(self) -> (int, int):
+ """
+ Returns a single **free** position (x, y), which is **free** for spawning or walking.
+ No Entity at this position posses *var_is_blocking_pos* or *var_can_collide*.
+
+ :return: Single **free** position.
+ """
+ return self.get_n_random_free_positions(1)[0]
+
+ def get_n_random_free_positions(self, n) -> list[tuple[int, int]]:
+ """
+ Returns a list of *n* **free** positions [(x, y), ... ], which are **free** for spawning or walking.
+ No Entity at this position posses *var_is_blocking_pos* or *var_can_collide*.
+
+ :return: List of n **free** position.
+ """
+ return list(islice(self.entities.free_positions_generator, n))
+
+ @property
+ def random_position(self) -> (int, int):
+ """
+ Returns a single available position (x, y), ignores all entity attributes.
+
+ :return: Single random position.
+ """
+ return self.get_n_random_positions(1)[0]
+
+ def get_n_random_positions(self, n) -> list[tuple[int, int]]:
+ """
+ Returns a list of *n* available positions [(x, y), ... ], ignores all entity attributes.
+
+ :return: List of n random positions.
+ """
+ return list(islice(self.entities.floorlist, n))
+
+ def tick(self, actions) -> list[Result]:
+ """
+ Performs a single **Gamestate Tick**by calling the inner rule hooks in sequential order.
+ - tick_pre_step_all: Things to do before the agents do their actions. Statechange, Moving, Spawning etc...
+ - agent tick: Agents do their actions.
+ - tick_step_all: Things to do after the agents did their actions. Statechange, Moving, Spawning etc...
+ - tick_post_step_all: Things to do at the very end of each step. Counting, Reward calculations etc...
+
+ :return: List of *Result*-objects.
+ """
results = list()
test_results = list()
self.curr_step += 1
@@ -112,11 +161,23 @@ class Gamestate(object):
return results
- def print(self, string):
+ def print(self, string) -> None:
+ """
+ When *verbose* is active, print stuff.
+
+ :param string: *String* to print.
+ :type string: str
+ :return: Nothing
+ """
if self.verbose:
print(string)
- def check_done(self):
+ def check_done(self) -> List[DoneResult]:
+ """
+ Iterate all **Rules** that override tehe *on_ckeck_done* hook.
+
+ :return: List of Results
+ """
results = list()
for rule in self.rules:
if on_check_done_result := rule.on_check_done(self):
@@ -124,24 +185,47 @@ class Gamestate(object):
return results
def get_all_pos_with_collisions(self) -> List[Tuple[(int, int)]]:
- positions = [pos for pos, entity_list_for_position in self.entities.pos_dict.items()
- if any([e.var_can_collide for e in entity_list_for_position])]
+ """
+ Returns a list positions [(x, y), ... ] on which collisions occur. This does not include agents,
+ that were unable to move because their target direction was blocked, also a form of collision.
+
+ :return: List of positions.
+ """
+ positions = [pos for pos, entities in self.entities.pos_dict.items() if
+ len(entities) >= 2 and (len([e for e in entities if e.var_can_collide]) >= 2)
+ ]
return positions
- def check_move_validity(self, moving_entity, position):
- if moving_entity.pos != position and not any(
- entity.var_is_blocking_pos for entity in self.entities.pos_dict[position]) and not (
- moving_entity.var_is_blocking_pos and self.entities.is_occupied(position)):
- return True
- else:
- return False
+ def check_move_validity(self, moving_entity: Entity, target_position: (int, int)) -> bool:
+ """
+ Whether it is safe to move to the target positions and moving entity does not introduce a blocking attribute,
+ when position is allready occupied.
- def check_pos_validity(self, position):
- if not any(entity.var_is_blocking_pos for entity in self.entities.pos_dict[position]):
- return True
- else:
- return False
+ :param moving_entity: Entity
+ :param target_position: pos
+ :return: Safe to move to
+ """
+ is_not_blocked = self.check_pos_validity(target_position)
+ will_not_block_others = moving_entity.var_is_blocking_pos and self.entities.is_occupied(target_position)
+
+ if moving_entity.pos != target_position and is_not_blocked and not will_not_block_others:
+ return c.VALID
+ else:
+ return c.NOT_VALID
+
+ def check_pos_validity(self, pos: (int, int)) -> bool:
+ """
+ Check if *pos* is a valid position to move or spawn to.
+
+ :param pos: position to check
+ :return: Wheter pos is a valid target.
+ """
+
+ if not any(e.var_is_blocking_pos for e in self.entities.pos_dict[pos]) and pos in self.entities.floorlist:
+ return c.VALID
+ else:
+ return c.NOT_VALID
class StepTests:
def __init__(self, *args):
diff --git a/marl_factory_grid/utils/tools.py b/marl_factory_grid/utils/tools.py
index d2f9bd1..73fa50d 100644
--- a/marl_factory_grid/utils/tools.py
+++ b/marl_factory_grid/utils/tools.py
@@ -28,7 +28,9 @@ class ConfigExplainer:
def explain_module(self, class_to_explain):
parameters = inspect.signature(class_to_explain).parameters
- explained = {class_to_explain.__name__: {key: val.default for key, val in parameters.items() if key not in EXCLUDED}}
+ explained = {class_to_explain.__name__:
+ {key: val.default for key, val in parameters.items() if key not in EXCLUDED}
+ }
return explained
def _load_and_compare(self, compare_class, paths):
@@ -135,4 +137,3 @@ if __name__ == '__main__':
ce.get_observations()
ce.get_assets()
all_conf = ce.get_all()
- print()
diff --git a/marl_factory_grid/utils/utility_classes.py b/marl_factory_grid/utils/utility_classes.py
index 4844133..4d1cfe1 100644
--- a/marl_factory_grid/utils/utility_classes.py
+++ b/marl_factory_grid/utils/utility_classes.py
@@ -52,3 +52,6 @@ class Floor:
def __hash__(self):
return hash(self.name)
+
+ def __repr__(self):
+ return f"Floor{self.pos}"
diff --git a/random_testrun.py b/random_testrun.py
index 9bebf17..ef8df08 100644
--- a/random_testrun.py
+++ b/random_testrun.py
@@ -6,18 +6,21 @@ from marl_factory_grid.environment.factory import Factory
from marl_factory_grid.utils.logging.envmonitor import EnvMonitor
from marl_factory_grid.utils.logging.recorder import EnvRecorder
+from marl_factory_grid.utils.plotting.plot_single_runs import plot_single_run
from marl_factory_grid.utils.tools import ConfigExplainer
if __name__ == '__main__':
# Render at each step?
- render = True
+ render = False
# Reveal all possible Modules (Entities, Rules, Agents[Actions, Observations], etc.)
explain_config = False
# Collect statistics?
- monitor = False
+ monitor = True
# Record as Protobuf?
record = False
+ # Plot Results?
+ plotting = True
run_path = Path('study_out')
@@ -38,7 +41,7 @@ if __name__ == '__main__':
factory = EnvRecorder(factory)
# RL learn Loop
- for episode in trange(500):
+ for episode in trange(10):
_ = factory.reset()
done = False
if render:
@@ -54,7 +57,10 @@ if __name__ == '__main__':
break
if monitor:
- factory.save_run(run_path / 'test.pkl')
+ factory.save_run(run_path / 'test_monitor.pkl')
if record:
factory.save_records(run_path / 'test.pb')
+ if plotting:
+ plot_single_run(run_path)
+
print('Done!!! Goodbye....')
diff --git a/reload_agent.py b/reload_agent.py
index 8c16069..0fb8066 100644
--- a/reload_agent.py
+++ b/reload_agent.py
@@ -6,6 +6,7 @@ import yaml
from marl_factory_grid.environment.factory import Factory
from marl_factory_grid.utils.logging.envmonitor import EnvMonitor
from marl_factory_grid.utils.logging.recorder import EnvRecorder
+from marl_factory_grid.utils import helpers as h
from marl_factory_grid.modules.doors import constants as d
@@ -55,13 +56,14 @@ if __name__ == '__main__':
for model_idx, model in enumerate(models)]
else:
actions = models[0].predict(env_state, deterministic=determin)[0]
+ # noinspection PyTupleAssignmentBalance
env_state, step_r, done_bool, info_obj = env.step(actions)
rew += step_r
if render:
env.render()
try:
- door = next(x for x in env.unwrapped.unwrapped[d.DOORS] if x.is_open)
+ door = h.get_first([x for x in env.unwrapped.unwrapped[d.DOORS] if x.is_open])
print('openDoor found')
except StopIteration:
pass
diff --git a/studies/normalization_study.py b/studies/normalization_study.py
index 37e10c4..7c72982 100644
--- a/studies/normalization_study.py
+++ b/studies/normalization_study.py
@@ -1,8 +1,8 @@
from algorithms.utils import Checkpointer
from pathlib import Path
from algorithms.utils import load_yaml_file, add_env_props, instantiate_class, load_class
-#from algorithms.marl import LoopSNAC, LoopIAC, LoopSEAC
+# from algorithms.marl import LoopSNAC, LoopIAC, LoopSEAC
for i in range(0, 5):
diff --git a/transform_wg_to_json_no_priv.py b/transform_wg_to_json_no_priv.py
new file mode 100644
index 0000000..1b7ef3e
--- /dev/null
+++ b/transform_wg_to_json_no_priv.py
@@ -0,0 +1,41 @@
+import configparser
+import json
+from datetime import datetime
+from pathlib import Path
+
+if __name__ == '__main__':
+
+ conf_path = Path('wg0')
+ wg0_conf = configparser.ConfigParser()
+ wg0_conf.read(conf_path/'wg0.conf')
+ interface = wg0_conf['Interface']
+ # Iterate all pears
+ for client_name in wg0_conf.sections():
+ if client_name == 'Interface':
+ continue
+ # Delete any old conf.json for the current peer
+ (conf_path / f'{client_name}.json').unlink(missing_ok=True)
+
+ peer = wg0_conf[client_name]
+
+ date_time = datetime.now().strftime('%Y-%m-%dT%H:%M:%S.%f000Z')
+
+ jdict = dict(
+ id=client_name,
+ private_key=peer['PublicKey'],
+ public_key=peer['PublicKey'],
+ # preshared_key=wg0_conf[client_name_wg0]['PresharedKey'],
+ name=client_name,
+ email=f"sysadmin@mobile.ifi.lmu.de",
+ allocated_ips=[interface['Address'].replace('/24', '')],
+ allowed_ips=['10.4.0.0/24', '10.153.199.0/24'],
+ extra_allowed_ips=[],
+ use_server_dns=True,
+ enabled=True,
+ created_at=date_time,
+ updated_at=date_time
+ )
+
+ with (conf_path / f'{client_name}.json').open('w+') as f:
+ json.dump(jdict, f, indent='\t', separators=(',', ': '))
+ print(client_name, ' written...')