mirror of
				https://github.com/illiumst/marl-factory-grid.git
				synced 2025-10-31 04:37:25 +01:00 
			
		
		
		
	Merge branch 'main' into unit_testing
# Conflicts: # marl_factory_grid/environment/factory.py # marl_factory_grid/utils/config_parser.py # marl_factory_grid/utils/states.py
This commit is contained in:
		
							
								
								
									
										5
									
								
								.idea/.gitignore
									
									
									
										generated
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								.idea/.gitignore
									
									
									
										generated
									
									
										vendored
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | ||||
| # Default ignored files | ||||
| /shelf/ | ||||
| /workspace.xml | ||||
| # Editor-based HTTP Client requests | ||||
| /httpRequests/ | ||||
| @@ -56,7 +56,7 @@ Just define what your environment needs in a *yaml*-configfile like: | ||||
|                     - Items | ||||
|     Rules: | ||||
|         Defaults: {} | ||||
|         Collision: | ||||
|         WatchCollisions: | ||||
|             done_at_collisions: !!bool True | ||||
|         ItemRespawn: | ||||
|             spawn_freq: 5 | ||||
| @@ -94,7 +94,7 @@ All [Entites](marl_factory_grid/environment/entity/global_entities.py) are avail | ||||
|  | ||||
|  | ||||
| #### Rules | ||||
| [Rules](marl_factory_grid/environment/entity/object.py) define how the environment behaves on micro-scale. | ||||
| [Rules](marl_factory_grid/environment/entity/object.py) define how the environment behaves on microscale. | ||||
| Each of the hookes (`on_init`, `pre_step`, `on_step`, '`post_step`', `on_done`)  | ||||
| provide env-access to implement customn logic, calculate rewards, or gather information. | ||||
|  | ||||
| @@ -107,6 +107,7 @@ Make sure to bring your own assets for each Entity living in the Gridworld as th | ||||
| PNG-files (transparent background) of square aspect-ratio should do the job, in general. | ||||
|  | ||||
| <img src="/marl_factory_grid/environment/assets/wall.png"  width="5%">  | ||||
| <!--suppress HtmlUnknownAttribute --> | ||||
| <html      html>  | ||||
| <img src="/marl_factory_grid/environment/assets/agent/agent.png"  width="5%"> | ||||
|  | ||||
|   | ||||
| @@ -1,6 +1 @@ | ||||
| from .environment import * | ||||
| from .modules import * | ||||
| from .utils import * | ||||
|  | ||||
| from .quickstart import init | ||||
|  | ||||
|   | ||||
| @@ -1 +1,4 @@ | ||||
| import os, sys; sys.path.append(os.path.dirname(os.path.realpath(__file__))) | ||||
| import os | ||||
| import sys | ||||
|  | ||||
| sys.path.append(os.path.dirname(os.path.realpath(__file__))) | ||||
|   | ||||
| @@ -28,6 +28,7 @@ class Names: | ||||
|     BATCH_SIZE      = 'bnatch_size' | ||||
|     N_ACTIONS       = 'n_actions' | ||||
|  | ||||
|  | ||||
| nms = Names | ||||
| ListOrTensor = Union[List, torch.Tensor] | ||||
|  | ||||
| @@ -112,10 +113,9 @@ class BaseActorCritic: | ||||
|                 next_obs, reward, done, info = env.step(action) | ||||
|                 done = [done] * self.n_agents if isinstance(done, bool) else done | ||||
|  | ||||
|                 last_hiddens = dict(hidden_actor =out[nms.HIDDEN_ACTOR], | ||||
|                 last_hiddens = dict(hidden_actor=out[nms.HIDDEN_ACTOR], | ||||
|                                     hidden_critic=out[nms.HIDDEN_CRITIC]) | ||||
|  | ||||
|  | ||||
|                 tm.add(observation=obs, action=action, reward=reward, done=done, | ||||
|                        logits=out.get(nms.LOGITS, None), values=out.get(nms.CRITIC, None), | ||||
|                        **last_hiddens) | ||||
| @@ -142,7 +142,9 @@ class BaseActorCritic: | ||||
|             print(f'reward at episode: {episode} = {rew_log}') | ||||
|             episode += 1 | ||||
|             df_results.append([episode, rew_log, *reward]) | ||||
|         df_results = pd.DataFrame(df_results, columns=['steps', 'reward', *[f'agent#{i}' for i in range(self.n_agents)]]) | ||||
|         df_results = pd.DataFrame(df_results, | ||||
|                                   columns=['steps', 'reward', *[f'agent#{i}' for i in range(self.n_agents)]] | ||||
|                                   ) | ||||
|         if checkpointer is not None: | ||||
|             df_results.to_csv(checkpointer.path / 'results.csv', index=False) | ||||
|         return df_results | ||||
| @@ -157,24 +159,27 @@ class BaseActorCritic: | ||||
|             last_action, reward    = [-1] * self.n_agents, [0.] * self.n_agents | ||||
|             done, rew_log, eps_rew = [False] * self.n_agents, 0, torch.zeros(self.n_agents) | ||||
|             while not all(done): | ||||
|                 if render: env.render() | ||||
|                 if render: | ||||
|                     env.render() | ||||
|  | ||||
|                 out    = self.forward(obs, last_action, **last_hiddens) | ||||
|                 action = self.get_actions(out) | ||||
|                 next_obs, reward, done, info = env.step(action) | ||||
|  | ||||
|                 if isinstance(done, bool): done = [done] * obs.shape[0] | ||||
|                 if isinstance(done, bool): | ||||
|                     done = [done] * obs.shape[0] | ||||
|                 obs = next_obs | ||||
|                 last_action = action | ||||
|                 last_hiddens = dict(hidden_actor=out.get(nms.HIDDEN_ACTOR,   None), | ||||
|                                     hidden_critic=out.get(nms.HIDDEN_CRITIC, None) | ||||
|                                     ) | ||||
|                 eps_rew += torch.tensor(reward) | ||||
|             results.append(eps_rew.tolist() + [sum(eps_rew).item()] + [episode]) | ||||
|             results.append(eps_rew.tolist() + [np.sum(eps_rew).item()] + [episode]) | ||||
|             episode += 1 | ||||
|         agent_columns = [f'agent#{i}' for i in range(self.cfg['environment']['n_agents'])] | ||||
|         results = pd.DataFrame(results, columns=agent_columns + ['sum', 'episode']) | ||||
|         results = pd.melt(results, id_vars=['episode'], value_vars=agent_columns + ['sum'], value_name='reward', var_name='agent') | ||||
|         results = pd.melt(results, id_vars=['episode'], value_vars=agent_columns + ['sum'], | ||||
|                           value_name='reward', var_name='agent') | ||||
|         return results | ||||
|  | ||||
|     @staticmethod | ||||
|   | ||||
| @@ -36,7 +36,7 @@ class LoopMAPPO(LoopSNAC): | ||||
|         rewards_ = torch.stack(rewards_, dim=1) | ||||
|         return rewards_ | ||||
|  | ||||
|     def mappo(self, batch, network, gamma, entropy_coef, vf_coef, clip_range, **kwargs): | ||||
|     def mappo(self, batch, network, gamma, entropy_coef, vf_coef, clip_range, **__): | ||||
|         out = network(batch[nms.OBSERVATION], batch[nms.ACTION], batch[nms.HIDDEN_ACTOR], batch[nms.HIDDEN_CRITIC]) | ||||
|         logits = out[nms.LOGITS][:, :-1]  # last one only needed for v_{t+1} | ||||
|  | ||||
| @@ -45,7 +45,7 @@ class LoopMAPPO(LoopSNAC): | ||||
|  | ||||
|         # monte carlo returns | ||||
|         mc_returns = self.monte_carlo_returns(batch[nms.REWARD], batch[nms.DONE], gamma) | ||||
|         mc_returns = (mc_returns - mc_returns.mean()) / (mc_returns.std() + 1e-8) #todo: norm across agent ok? | ||||
|         mc_returns = (mc_returns - mc_returns.mean()) / (mc_returns.std() + 1e-8)  # todo: norm across agent ok? | ||||
|         advantages =  mc_returns - out[nms.CRITIC][:, :-1] | ||||
|  | ||||
|         # policy loss | ||||
|   | ||||
| @@ -1,8 +1,7 @@ | ||||
| import numpy as np | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| import numpy as np | ||||
| import torch.nn.functional as F | ||||
| from torch.nn.utils import spectral_norm | ||||
|  | ||||
|  | ||||
| class RecurrentAC(nn.Module): | ||||
| @@ -88,8 +87,8 @@ class NormalizedLinear(nn.Linear): | ||||
|         self.trainable_magnitude = trainable_magnitude | ||||
|         self.scale = nn.Parameter(torch.tensor([1.]), requires_grad=trainable_magnitude) | ||||
|  | ||||
|     def forward(self, input): | ||||
|         normalized_input = F.normalize(input, dim=-1, p=2, eps=1e-5) | ||||
|     def forward(self, in_array): | ||||
|         normalized_input = F.normalize(in_array, dim=-1, p=2, eps=1e-5) | ||||
|         normalized_weight = F.normalize(self.weight, dim=-1, p=2, eps=1e-5) | ||||
|         return F.linear(normalized_input, normalized_weight) * self.d_sqrt * self.scale | ||||
|  | ||||
|   | ||||
| @@ -38,7 +38,6 @@ class LoopSEAC(LoopIAC): | ||||
|  | ||||
|             a2c_loss = (-iw*log_ap * advantages.detach()).mean(-1) | ||||
|  | ||||
|  | ||||
|             value_loss = (iw*advantages.pow(2)).mean(-1)  # n_agent | ||||
|  | ||||
|             # weighted loss | ||||
|   | ||||
| @@ -56,8 +56,8 @@ class TSPBaseAgent(ABC): | ||||
|  | ||||
|     def _door_is_close(self, state): | ||||
|         try: | ||||
|             # return next(y for x in self.state.tile.neighboring_floor for y in x.guests if do.DOOR in y.name) | ||||
|             return next(y for x in state.entities.neighboring_positions(self.state.pos) for y in state.entities.pos_dict[x] if do.DOOR in y.name) | ||||
|             return next(y for x in state.entities.neighboring_positions(self.state.pos) | ||||
|                         for y in state.entities.pos_dict[x] if do.DOOR in y.name) | ||||
|         except StopIteration: | ||||
|             return None | ||||
|  | ||||
|   | ||||
| @@ -14,8 +14,8 @@ class TSPTargetAgent(TSPBaseAgent): | ||||
|     def _handle_doors(self, state): | ||||
|  | ||||
|         try: | ||||
|             # return next(y for x in self.state.tile.neighboring_floor for y in x.guests if do.DOOR in y.name) | ||||
|             return next(y for x in state.entities.neighboring_positions(self.state.pos) for y in state.entities.pos_dict[x] if do.DOOR in y.name) | ||||
|             return next(y for x in state.entities.neighboring_positions(self.state.pos) | ||||
|                         for y in state.entities.pos_dict[x] if do.DOOR in y.name) | ||||
|         except StopIteration: | ||||
|             return None | ||||
|  | ||||
| @@ -30,4 +30,3 @@ class TSPTargetAgent(TSPBaseAgent): | ||||
|         except (StopIteration, UnboundLocalError): | ||||
|             print('Will not happen') | ||||
|         return action_obj | ||||
|  | ||||
|   | ||||
| @@ -26,12 +26,16 @@ def points_to_graph(coordiniates, allow_euclidean_connections=True, allow_manhat | ||||
|     assert allow_euclidean_connections or allow_manhattan_connections | ||||
|     possible_connections = itertools.combinations(coordiniates, 2) | ||||
|     graph = nx.Graph() | ||||
|     for a, b in possible_connections: | ||||
|         diff = np.linalg.norm(np.asarray(a)-np.asarray(b)) | ||||
|         if allow_manhattan_connections and allow_euclidean_connections and diff <= np.sqrt(2): | ||||
|             graph.add_edge(a, b) | ||||
|         elif not allow_manhattan_connections and allow_euclidean_connections and diff == np.sqrt(2): | ||||
|             graph.add_edge(a, b) | ||||
|         elif allow_manhattan_connections and not allow_euclidean_connections and diff == 1: | ||||
|             graph.add_edge(a, b) | ||||
|     if allow_manhattan_connections and allow_euclidean_connections: | ||||
|         graph.add_edges_from( | ||||
|             (a, b) for a, b in possible_connections if np.linalg.norm(np.asarray(a) - np.asarray(b)) <= np.sqrt(2) | ||||
|         ) | ||||
|     elif not allow_manhattan_connections and allow_euclidean_connections: | ||||
|         graph.add_edges_from( | ||||
|             (a, b) for a, b in possible_connections if np.linalg.norm(np.asarray(a) - np.asarray(b)) == np.sqrt(2) | ||||
|         ) | ||||
|     elif allow_manhattan_connections and not allow_euclidean_connections: | ||||
|         graph.add_edges_from( | ||||
|             (a, b) for a, b in possible_connections if np.linalg.norm(np.asarray(a) - np.asarray(b)) == 1 | ||||
|         ) | ||||
|     return graph | ||||
|   | ||||
| @@ -1,8 +1,9 @@ | ||||
| import torch | ||||
| import numpy as np | ||||
| import yaml | ||||
| from pathlib import Path | ||||
|  | ||||
| import numpy as np | ||||
| import torch | ||||
| import yaml | ||||
|  | ||||
|  | ||||
| def load_class(classname): | ||||
|     from importlib import import_module | ||||
| @@ -42,7 +43,6 @@ def get_class(arguments): | ||||
|  | ||||
|  | ||||
| def get_arguments(arguments): | ||||
|     from importlib import import_module | ||||
|     d = dict(arguments) | ||||
|     if "classname" in d: | ||||
|         del d["classname"] | ||||
|   | ||||
| @@ -22,26 +22,41 @@ Agents: | ||||
|     - Inventory | ||||
|     - DropOffLocations | ||||
|     - Maintainers | ||||
|     # This is special for agents, as each one is differten and can act as an adversary e.g. | ||||
|     Positions: | ||||
|       - (16, 7) | ||||
|       - (16, 6) | ||||
|       - (16, 3) | ||||
|       - (16, 4) | ||||
|       - (16, 5) | ||||
| Entities: | ||||
|   Batteries: | ||||
|     initial_charge: 0.8 | ||||
|     per_action_costs: 0.02 | ||||
|   ChargePods: {} | ||||
|   Destinations: {} | ||||
|   ChargePods: | ||||
|     coords_or_quantity: 2 | ||||
|   Destinations: | ||||
|     coords_or_quantity: 1 | ||||
|     spawn_mode: GROUPED | ||||
|   DirtPiles: | ||||
|     coords_or_quantity: 10 | ||||
|     initial_amount: 2 | ||||
|     clean_amount: 1 | ||||
|     dirt_spawn_r_var: 0.1 | ||||
|     initial_amount: 2 | ||||
|     initial_dirt_ratio: 0.05 | ||||
|     max_global_amount: 20 | ||||
|     max_local_amount: 5 | ||||
|   Doors: {} | ||||
|   DropOffLocations: {} | ||||
|   Doors: | ||||
|   DropOffLocations: | ||||
|     coords_or_quantity: 1 | ||||
|     max_dropoff_storage_size: 0 | ||||
|   GlobalPositions: {} | ||||
|   Inventories: {} | ||||
|   Items: {} | ||||
|   Machines: {} | ||||
|   Maintainers: {} | ||||
|   Items: | ||||
|     coords_or_quantity: 5 | ||||
|   Machines: | ||||
|     coords_or_quantity: 2 | ||||
|   Maintainers: | ||||
|     coords_or_quantity: 1 | ||||
|   Zones: {} | ||||
|  | ||||
| General: | ||||
| @@ -49,32 +64,31 @@ General: | ||||
|   individual_rewards: true | ||||
|   level_name: large | ||||
|   pomdp_r: 3 | ||||
|   verbose: false | ||||
|   verbose: True | ||||
|   tests: false | ||||
|  | ||||
| Rules: | ||||
|   SpawnAgents: {} | ||||
|   DoneAtBatteryDischarge: {} | ||||
|   Collision: | ||||
|     done_at_collisions: false | ||||
|   AssignGlobalPositions: {} | ||||
|   DoneAtDestinationReachAny: {} | ||||
|   DestinationReachReward: {} | ||||
|   SpawnDestinations: | ||||
|     n_dests: 1 | ||||
|     spawn_mode: GROUPED | ||||
|   DoneOnAllDirtCleaned: {} | ||||
|   SpawnDirt: | ||||
|     spawn_freq: 15 | ||||
|   # Environment Dynamics | ||||
|   EntitiesSmearDirtOnMove: | ||||
|     smear_ratio: 0.2 | ||||
|   DoorAutoClose: | ||||
|     close_frequency: 10 | ||||
|   ItemRules: | ||||
|     max_dropoff_storage_size: 0 | ||||
|     n_items: 5 | ||||
|     n_locations: 5 | ||||
|     spawn_frequency: 15 | ||||
|   MaxStepsReached: | ||||
|   MoveMaintainers: | ||||
|  | ||||
|   # Respawn Stuff | ||||
|   RespawnDirt: | ||||
|     respawn_freq: 15 | ||||
|   RespawnItems: | ||||
|     respawn_freq: 15 | ||||
|  | ||||
|   # Utilities | ||||
|   WatchCollisions: | ||||
|     done_at_collisions: false | ||||
|  | ||||
|   # Done Conditions | ||||
|   DoneAtDestinationReachAny: | ||||
|   DoneOnAllDirtCleaned: | ||||
|   DoneAtBatteryDischarge: | ||||
|   DoneAtMaintainerCollision: | ||||
|   DoneAtMaxStepsReached: | ||||
|     max_steps: 500 | ||||
| #  AgentSingleZonePlacement: | ||||
| #    n_zones: 4 | ||||
|   | ||||
| @@ -1,46 +1,89 @@ | ||||
| Agents: | ||||
|   Wolfgang: | ||||
|     Actions: | ||||
|     - Noop | ||||
|     - Move8 | ||||
|     Observations: | ||||
|     - Walls | ||||
|     - Other | ||||
|     - Destination | ||||
|     Positions: | ||||
|       - (2, 1) | ||||
|       - (2, 5) | ||||
|   Karl-Heinz: | ||||
|     Actions: | ||||
|       - Noop | ||||
|       - Move8 | ||||
|     Observations: | ||||
|       - Walls | ||||
|       - Other | ||||
|       - Destination | ||||
|     Positions: | ||||
|       - (2, 1) | ||||
|       - (2, 5) | ||||
| Entities: | ||||
|   Destinations: {} | ||||
|  | ||||
| General: | ||||
|   # Your Seed | ||||
|   env_seed: 69 | ||||
|   # Individual or global rewards? | ||||
|   individual_rewards: true | ||||
|   # The level.txt file to load | ||||
|   level_name: narrow_corridor | ||||
|   # View Radius; 0 = full observatbility | ||||
|   pomdp_r: 0 | ||||
|   # print all messages and events | ||||
|   verbose: true | ||||
|  | ||||
| Rules: | ||||
|   SpawnAgents: {} | ||||
|   Collision: | ||||
|     done_at_collisions: false | ||||
|   FixedDestinationSpawn: | ||||
|     per_agent_positions: | ||||
| Agents: | ||||
|   # Agents are identified by their name  | ||||
|   Wolfgang: | ||||
|     # The available actions for this particular agent | ||||
|     Actions: | ||||
|     # Able to do nothing | ||||
|     - Noop | ||||
|     # Able to move in all 8 directions | ||||
|     - Move8 | ||||
|     # Stuff the agent can observe (per 2d slice) | ||||
|     #   use "Combined" if you want to merge multiple slices into one | ||||
|     Observations: | ||||
|     # He sees walls | ||||
|     - Walls | ||||
|     # he sees other agent, "karl-Heinz" in this setting would be fine, too | ||||
|     - Other | ||||
|     # He can see Destinations, that are assigned to him (hence the singular)  | ||||
|     - Destination | ||||
|     # Avaiable Spawn Positions as list | ||||
|     Positions: | ||||
|       - (2, 1) | ||||
|       - (2, 5) | ||||
|     # It is okay to collide with other agents, so that  | ||||
|     #   they end up on the same position | ||||
|     is_blocking_pos: true | ||||
|   # See Above.... | ||||
|   Karl-Heinz: | ||||
|     Actions: | ||||
|       - Noop | ||||
|       - Move8 | ||||
|     Observations: | ||||
|       - Walls | ||||
|       - Other | ||||
|       - Destination | ||||
|     Positions: | ||||
|       - (2, 1) | ||||
|       - (2, 5) | ||||
|     is_blocking_pos: true | ||||
|  | ||||
| # Other noteworthy Entitites | ||||
| Entities: | ||||
|   # The destiantions or positional targets to reach | ||||
|   Destinations: | ||||
|     # Let them spawn on closed doors and agent positions | ||||
|     ignore_blocking: true | ||||
|     # We need a special spawn rule... | ||||
|     spawnrule: | ||||
|       # ...which assigns the destinations per agent | ||||
|       SpawnDestinationsPerAgent: | ||||
|         # we use this parameter | ||||
|         coords_or_quantity: | ||||
|           # to enable and assign special positions per agent | ||||
|           Wolfgang: | ||||
|               - (2, 1) | ||||
|               - (2, 5) | ||||
|           Karl-Heinz: | ||||
|               - (2, 1) | ||||
|               - (2, 5) | ||||
|   DestinationReachAll: {} | ||||
|     # Whether you want to provide a numeric Position observation. | ||||
|     # GlobalPositions: | ||||
|     #   normalized: false | ||||
|  | ||||
| # Define the env. dynamics | ||||
| Rules: | ||||
|   # Utilities | ||||
|   #  This rule Checks for Collision, also it assigns the (negative) reward | ||||
|   WatchCollisions: | ||||
|     reward: -0.1 | ||||
|     reward_at_done: -1 | ||||
|     done_at_collisions: false | ||||
|   # Done Conditions | ||||
|   #   Load any of the rules, to check for done conditions.  | ||||
|   # DoneAtDestinationReachAny: | ||||
|   DoneAtDestinationReachAll: | ||||
|   #  reward_at_done: 1 | ||||
|   DoneAtMaxStepsReached: | ||||
|     max_steps: 200 | ||||
|   | ||||
| @@ -48,9 +48,9 @@ class Move(Action, abc.ABC): | ||||
|             reward = r.MOVEMENTS_VALID if move_validity else r.MOVEMENTS_FAIL | ||||
|             return ActionResult(entity=entity, identifier=self._identifier, validity=move_validity, reward=reward) | ||||
|         else:  # There is no place to go, propably collision | ||||
|             # This is currently handeld by the Collision rule, so that it can be switched on and off by conf.yml | ||||
|             # This is currently handeld by the WatchCollisions rule, so that it can be switched on and off by conf.yml | ||||
|             # return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=r.COLLISION) | ||||
|             return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=0) | ||||
|             return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID) | ||||
|  | ||||
|     def _calc_new_pos(self, pos): | ||||
|         x_diff, y_diff = MOVEMAP[self._identifier] | ||||
|   | ||||
| @@ -10,6 +10,7 @@ AGENT                   = 'Agent'               # Identifier of Agent-objects an | ||||
| OTHERS                  = 'Other' | ||||
| COMBINED                = 'Combined' | ||||
| GLOBALPOSITIONS         = 'GlobalPositions'     # Identifier of the global position slice | ||||
| SPAWN_ENTITY_RULE       = 'SpawnEntity' | ||||
|  | ||||
| # Attributes | ||||
| IS_BLOCKING_LIGHT       = 'var_is_blocking_light' | ||||
| @@ -29,7 +30,7 @@ VALUE_NO_POS            = (-9999, -9999)  # Invalid Position value used in the e | ||||
|  | ||||
|  | ||||
| ACTION                  = 'action'  # Identifier of Action-objects and groups (groups). | ||||
| COLLISION               = 'Collision'  # Identifier to use in the context of collitions. | ||||
| COLLISION               = 'Collisions'  # Identifier to use in the context of collitions. | ||||
| # LAST_POS                = 'LAST_POS'  # Identifiert for retrieving an enitites last pos. | ||||
| VALIDITY                = 'VALIDITY'  # Identifiert for retrieving the Validity of Action, Tick, etc. ... | ||||
|  | ||||
| @@ -54,3 +55,5 @@ NOOP                    = 'Noop' | ||||
| # Result Identifier | ||||
| MOVEMENTS_VALID = 'motion_valid' | ||||
| MOVEMENTS_FAIL  = 'motion_not_valid' | ||||
| DEFAULT_PATH = 'environment' | ||||
| MODULE_PATH = 'modules' | ||||
|   | ||||
| @@ -12,14 +12,6 @@ from marl_factory_grid.environment import constants as c | ||||
|  | ||||
| class Agent(Entity): | ||||
|  | ||||
|     @property | ||||
|     def var_is_blocking_light(self): | ||||
|         return False | ||||
|  | ||||
|     @property | ||||
|     def var_can_move(self): | ||||
|         return True | ||||
|  | ||||
|     @property | ||||
|     def var_is_paralyzed(self): | ||||
|         return len(self._paralyzed) | ||||
| @@ -28,14 +20,6 @@ class Agent(Entity): | ||||
|     def paralyze_reasons(self): | ||||
|         return [x for x in self._paralyzed] | ||||
|  | ||||
|     @property | ||||
|     def var_is_blocking_pos(self): | ||||
|         return False | ||||
|  | ||||
|     @property | ||||
|     def var_has_position(self): | ||||
|         return True | ||||
|  | ||||
|     @property | ||||
|     def obs_tag(self): | ||||
|         return self.name | ||||
| @@ -48,10 +32,6 @@ class Agent(Entity): | ||||
|     def observations(self): | ||||
|         return self._observations | ||||
|  | ||||
|     @property | ||||
|     def var_can_collide(self): | ||||
|         return True | ||||
|  | ||||
|     def step_result(self): | ||||
|         pass | ||||
|  | ||||
| @@ -60,16 +40,21 @@ class Agent(Entity): | ||||
|         return self._collection | ||||
|  | ||||
|     @property | ||||
|     def state(self): | ||||
|         return self._state or ActionResult(entity=self, identifier=c.NOOP, validity=c.VALID, reward=0) | ||||
|     def var_is_blocking_pos(self): | ||||
|         return self._is_blocking_pos | ||||
|  | ||||
|     def __init__(self, actions: List[Action], observations: List[str], *args, **kwargs): | ||||
|     @property | ||||
|     def state(self): | ||||
|         return self._state or ActionResult(entity=self, identifier=c.NOOP, validity=c.VALID) | ||||
|  | ||||
|     def __init__(self, actions: List[Action], observations: List[str], *args, is_blocking_pos=False, **kwargs): | ||||
|         super(Agent, self).__init__(*args, **kwargs) | ||||
|         self._paralyzed = set() | ||||
|         self.step_result = dict() | ||||
|         self._actions = actions | ||||
|         self._observations = observations | ||||
|         self._state: Union[Result, None] = None | ||||
|         self._is_blocking_pos = is_blocking_pos | ||||
|  | ||||
|     # noinspection PyAttributeOutsideInit | ||||
|     def clear_temp_state(self): | ||||
|   | ||||
| @@ -1,20 +1,19 @@ | ||||
| import abc | ||||
| from collections import defaultdict | ||||
|  | ||||
| import numpy as np | ||||
|  | ||||
| from .object import _Object | ||||
| from .object import Object | ||||
| from .. import constants as c | ||||
| from ...utils.results import ActionResult | ||||
| from ...utils.utility_classes import RenderEntity | ||||
|  | ||||
|  | ||||
| class Entity(_Object, abc.ABC): | ||||
| class Entity(Object, abc.ABC): | ||||
|     """Full Env Entity that lives on the environment Grid. Doors, Items, DirtPile etc...""" | ||||
|  | ||||
|     @property | ||||
|     def state(self): | ||||
|         return self._status or ActionResult(entity=self, identifier=c.NOOP, validity=c.VALID, reward=0) | ||||
|         return self._status or ActionResult(entity=self, identifier=c.NOOP, validity=c.VALID) | ||||
|  | ||||
|     @property | ||||
|     def var_has_position(self): | ||||
| @@ -60,6 +59,10 @@ class Entity(_Object, abc.ABC): | ||||
|     def pos(self): | ||||
|         return self._pos | ||||
|  | ||||
|     def set_pos(self, pos): | ||||
|         assert isinstance(pos, tuple) and len(pos) == 2 | ||||
|         self._pos = pos | ||||
|  | ||||
|     @property | ||||
|     def last_pos(self): | ||||
|         try: | ||||
| @@ -84,7 +87,7 @@ class Entity(_Object, abc.ABC): | ||||
|                 for observer in self.observers: | ||||
|                     observer.notify_del_entity(self) | ||||
|                 self._view_directory = curr_pos[0] - next_pos[0], curr_pos[1] - next_pos[1] | ||||
|                 self._pos = next_pos | ||||
|                 self.set_pos(next_pos) | ||||
|                 for observer in self.observers: | ||||
|                     observer.notify_add_entity(self) | ||||
|             return valid | ||||
| @@ -92,6 +95,7 @@ class Entity(_Object, abc.ABC): | ||||
|  | ||||
|     def __init__(self, pos, bind_to=None, **kwargs): | ||||
|         super().__init__(**kwargs) | ||||
|         self._view_directory = c.VALUE_NO_POS | ||||
|         self._status = None | ||||
|         self._pos = pos | ||||
|         self._last_pos = pos | ||||
| @@ -109,9 +113,6 @@ class Entity(_Object, abc.ABC): | ||||
|     def render(self): | ||||
|         return RenderEntity(self.__class__.__name__.lower(), self.pos) | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return super(Entity, self).__repr__() + f'(@{self.pos})' | ||||
|  | ||||
|     @property | ||||
|     def obs_tag(self): | ||||
|         try: | ||||
| @@ -128,25 +129,3 @@ class Entity(_Object, abc.ABC): | ||||
|         self._collection.delete_env_object(self) | ||||
|         self._collection = other_collection | ||||
|         return self._collection == other_collection | ||||
|  | ||||
|     @classmethod | ||||
|     def from_coordinates(cls, positions: [(int, int)], *args, entity_kwargs=None, **kwargs, ): | ||||
|         collection = cls(*args, **kwargs) | ||||
|         collection.add_items( | ||||
|             [cls._entity(tuple(pos), **entity_kwargs if entity_kwargs is not None else {}) for pos in positions]) | ||||
|         return collection | ||||
|  | ||||
|     def notify_del_entity(self, entity): | ||||
|         try: | ||||
|             self.pos_dict[entity.pos].remove(entity) | ||||
|         except (ValueError, AttributeError): | ||||
|             pass | ||||
|  | ||||
|     def by_pos(self, pos: (int, int)): | ||||
|         pos = tuple(pos) | ||||
|         try: | ||||
|             return self.state.entities.pos_dict[pos] | ||||
|         except StopIteration: | ||||
|             pass | ||||
|         except ValueError: | ||||
|             print() | ||||
|   | ||||
| @@ -1,24 +0,0 @@ | ||||
|  | ||||
|  | ||||
| # noinspection PyAttributeOutsideInit | ||||
| class BoundEntityMixin: | ||||
|  | ||||
|     @property | ||||
|     def bound_entity(self): | ||||
|         return self._bound_entity | ||||
|  | ||||
|     @property | ||||
|     def name(self): | ||||
|         if self.bound_entity: | ||||
|             return f'{self.__class__.__name__}({self.bound_entity.name})' | ||||
|         else: | ||||
|             pass | ||||
|  | ||||
|     def belongs_to_entity(self, entity): | ||||
|         return entity == self.bound_entity | ||||
|  | ||||
|     def bind_to(self, entity): | ||||
|         self._bound_entity = entity | ||||
|  | ||||
|     def unbind(self): | ||||
|         self._bound_entity = None | ||||
| @@ -5,7 +5,7 @@ from marl_factory_grid.environment import constants as c | ||||
| import marl_factory_grid.utils.helpers as h | ||||
|  | ||||
|  | ||||
| class _Object: | ||||
| class Object: | ||||
|     """Generell Objects for Organisation and Maintanance such as Actions etc...""" | ||||
|  | ||||
|     _u_idx = defaultdict(lambda: 0) | ||||
| @@ -13,10 +13,6 @@ class _Object: | ||||
|     def __bool__(self): | ||||
|         return True | ||||
|  | ||||
|     @property | ||||
|     def var_has_position(self): | ||||
|         return False | ||||
|  | ||||
|     @property | ||||
|     def var_can_be_bound(self): | ||||
|         try: | ||||
| @@ -30,22 +26,14 @@ class _Object: | ||||
|  | ||||
|     @property | ||||
|     def name(self): | ||||
|         if self._str_ident is not None: | ||||
|             name = f'{self.__class__.__name__}[{self._str_ident}]' | ||||
|         else: | ||||
|             name = f'{self.__class__.__name__}#{self.u_int}' | ||||
|         if self.bound_entity: | ||||
|             name = h.add_bound_name(name, self.bound_entity) | ||||
|         if self.var_has_position: | ||||
|             name = h.add_pos_name(name, self) | ||||
|         return name | ||||
|         return f'{self.__class__.__name__}[{self.identifier}]' | ||||
|  | ||||
|     @property | ||||
|     def identifier(self): | ||||
|         if self._str_ident is not None: | ||||
|             return self._str_ident | ||||
|         else: | ||||
|             return self.name | ||||
|             return self.u_int | ||||
|  | ||||
|     def reset_uid(self): | ||||
|         self._u_idx = defaultdict(lambda: 0) | ||||
| @@ -62,7 +50,15 @@ class _Object: | ||||
|             print(f'Following kwargs were passed, but ignored: {kwargs}') | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return f'{self.name}' | ||||
|         name = self.name | ||||
|         if self.bound_entity: | ||||
|             name = h.add_bound_name(name, self.bound_entity) | ||||
|         try: | ||||
|             if self.var_has_position: | ||||
|                 name = h.add_pos_name(name, self) | ||||
|         except AttributeError: | ||||
|             pass | ||||
|         return name | ||||
|  | ||||
|     def __eq__(self, other) -> bool: | ||||
|         return other == self.identifier | ||||
| @@ -71,8 +67,8 @@ class _Object: | ||||
|         return hash(self.identifier) | ||||
|  | ||||
|     def _identify_and_count_up(self): | ||||
|         idx = _Object._u_idx[self.__class__.__name__] | ||||
|         _Object._u_idx[self.__class__.__name__] += 1 | ||||
|         idx = Object._u_idx[self.__class__.__name__] | ||||
|         Object._u_idx[self.__class__.__name__] += 1 | ||||
|         return idx | ||||
|  | ||||
|     def set_collection(self, collection): | ||||
| @@ -88,7 +84,7 @@ class _Object: | ||||
|     def summarize_state(self): | ||||
|         return dict() | ||||
|  | ||||
|     def bind(self, entity): | ||||
|     def bind_to(self, entity): | ||||
|         # noinspection PyAttributeOutsideInit | ||||
|         self._bound_entity = entity | ||||
|         return c.VALID | ||||
| @@ -100,84 +96,5 @@ class _Object: | ||||
|     def bound_entity(self): | ||||
|         return self._bound_entity | ||||
|  | ||||
|     def bind_to(self, entity): | ||||
|         self._bound_entity = entity | ||||
|  | ||||
|     def unbind(self): | ||||
|         self._bound_entity = None | ||||
|  | ||||
|  | ||||
| # class EnvObject(_Object): | ||||
| #     """Objects that hold Information that are observable, but have no position on the environment grid. Inventories etc...""" | ||||
| # | ||||
|     # _u_idx = defaultdict(lambda: 0) | ||||
| # | ||||
| #     @property | ||||
| #     def obs_tag(self): | ||||
| #         try: | ||||
| #             return self._collection.name or self.name | ||||
| #         except AttributeError: | ||||
| #             return self.name | ||||
| # | ||||
| #     @property | ||||
| #     def var_is_blocking_light(self): | ||||
| #         try: | ||||
| #             return self._collection.var_is_blocking_light or False | ||||
| #         except AttributeError: | ||||
| #             return False | ||||
| # | ||||
| #     @property | ||||
| #     def var_can_be_bound(self): | ||||
| #         try: | ||||
| #             return self._collection.var_can_be_bound or False | ||||
| #         except AttributeError: | ||||
| #             return False | ||||
| # | ||||
| #     @property | ||||
| #     def var_can_move(self): | ||||
| #         try: | ||||
| #             return self._collection.var_can_move or False | ||||
| #         except AttributeError: | ||||
| #             return False | ||||
| # | ||||
| #     @property | ||||
| #     def var_is_blocking_pos(self): | ||||
| #         try: | ||||
| #             return self._collection.var_is_blocking_pos or False | ||||
| #         except AttributeError: | ||||
| #             return False | ||||
| # | ||||
| #     @property | ||||
| #     def var_has_position(self): | ||||
| #         try: | ||||
| #             return self._collection.var_has_position or False | ||||
| #         except AttributeError: | ||||
| #             return False | ||||
| # | ||||
| # @property | ||||
| # def var_can_collide(self): | ||||
| #     try: | ||||
| #         return self._collection.var_can_collide or False | ||||
| #     except AttributeError: | ||||
| #         return False | ||||
| # | ||||
| # | ||||
| # @property | ||||
| # def encoding(self): | ||||
| #     return c.VALUE_OCCUPIED_CELL | ||||
| # | ||||
| # | ||||
| # def __init__(self, **kwargs): | ||||
| #     self._bound_entity = None | ||||
| #     super(EnvObject, self).__init__(**kwargs) | ||||
| # | ||||
| # | ||||
| # def change_parent_collection(self, other_collection): | ||||
| #     other_collection.add_item(self) | ||||
| #     self._collection.delete_env_object(self) | ||||
| #     self._collection = other_collection | ||||
| #     return self._collection == other_collection | ||||
| # | ||||
| # | ||||
| # def summarize_state(self): | ||||
| #     return dict(name=str(self.name)) | ||||
|   | ||||
| @@ -1,6 +1,6 @@ | ||||
| import numpy as np | ||||
|  | ||||
| from marl_factory_grid.environment.entity.object import _Object | ||||
| from marl_factory_grid.environment.entity.object import Object | ||||
|  | ||||
|  | ||||
| ########################################################################## | ||||
| @@ -8,7 +8,7 @@ from marl_factory_grid.environment.entity.object import _Object | ||||
| ########################################################################## | ||||
|  | ||||
|  | ||||
| class PlaceHolder(_Object): | ||||
| class PlaceHolder(Object): | ||||
|  | ||||
|     def __init__(self, *args, fill_value=0, **kwargs): | ||||
|         super().__init__(*args, **kwargs) | ||||
| @@ -24,10 +24,10 @@ class PlaceHolder(_Object): | ||||
|  | ||||
|     @property | ||||
|     def name(self): | ||||
|         return "PlaceHolder" | ||||
|         return self.__class__.__name__ | ||||
|  | ||||
|  | ||||
| class GlobalPosition(_Object): | ||||
| class GlobalPosition(Object): | ||||
|  | ||||
|     @property | ||||
|     def encoding(self): | ||||
| @@ -36,7 +36,8 @@ class GlobalPosition(_Object): | ||||
|         else: | ||||
|             return self.bound_entity.pos | ||||
|  | ||||
|     def __init__(self, level_shape, *args, normalized: bool = True, **kwargs): | ||||
|     def __init__(self, agent, level_shape, *args, normalized: bool = True, **kwargs): | ||||
|         super(GlobalPosition, self).__init__(*args, **kwargs) | ||||
|         self.bind_to(agent) | ||||
|         self._normalized = normalized | ||||
|         self._shape = level_shape | ||||
|   | ||||
| @@ -5,13 +5,8 @@ from marl_factory_grid.utils.utility_classes import RenderEntity | ||||
|  | ||||
| class Wall(Entity): | ||||
|  | ||||
|     @property | ||||
|     def var_has_position(self): | ||||
|         return True | ||||
|  | ||||
|     @property | ||||
|     def var_can_collide(self): | ||||
|         return True | ||||
|     def __init__(self, *args, **kwargs): | ||||
|         super().__init__(*args, **kwargs) | ||||
|  | ||||
|     @property | ||||
|     def encoding(self): | ||||
| @@ -19,11 +14,3 @@ class Wall(Entity): | ||||
|  | ||||
|     def render(self): | ||||
|         return RenderEntity(c.WALL, self.pos) | ||||
|  | ||||
|     @property | ||||
|     def var_is_blocking_pos(self): | ||||
|         return True | ||||
|  | ||||
|     @property | ||||
|     def var_is_blocking_light(self): | ||||
|         return True | ||||
|   | ||||
| @@ -56,15 +56,18 @@ class Factory(gym.Env): | ||||
|             self.level_filepath = Path(custom_level_path) | ||||
|         else: | ||||
|             self.level_filepath = Path(__file__).parent.parent / h.LEVELS_DIR / f'{self.conf.level_name}.txt' | ||||
|         self._renderer = None  # expensive - don't use; unless required ! | ||||
|  | ||||
|         parsed_entities = self.conf.load_entities() | ||||
|         self.map = LevelParser(self.level_filepath, parsed_entities, self.conf.pomdp_r) | ||||
|  | ||||
|         # Init for later usage: | ||||
|         self.state: Gamestate | ||||
|         self.map: LevelParser | ||||
|         self.obs_builder: OBSBuilder | ||||
|         # noinspection PyTypeChecker | ||||
|         self.state: Gamestate = None | ||||
|         # noinspection PyTypeChecker | ||||
|         self.obs_builder: OBSBuilder = None | ||||
|  | ||||
|         # expensive - don't use; unless required ! | ||||
|         self._renderer = None | ||||
|  | ||||
|         # reset env to initial state, preparing env for new episode. | ||||
|         # returns tuple where the first dict contains initial observation for each agent in the env | ||||
| @@ -74,7 +77,7 @@ class Factory(gym.Env): | ||||
|         return self.state.entities[item] | ||||
|  | ||||
|     def reset(self) -> (dict, dict): | ||||
|         if hasattr(self, 'state'): | ||||
|         if self.state is not None: | ||||
|             for entity_group in self.state.entities: | ||||
|                 try: | ||||
|                     entity_group[0].reset_uid() | ||||
| @@ -87,12 +90,16 @@ class Factory(gym.Env): | ||||
|         entities = self.map.do_init() | ||||
|  | ||||
|         # Init rules | ||||
|         rules = self.conf.load_env_rules() | ||||
|         env_rules = self.conf.load_env_rules() | ||||
|         entity_rules = self.conf.load_entity_spawn_rules(entities) | ||||
|         env_rules.extend(entity_rules) | ||||
|  | ||||
|         env_tests = self.conf.load_env_tests() if self.conf.tests else [] | ||||
|  | ||||
|         # Parse the agent conf | ||||
|         parsed_agents_conf = self.conf.parse_agents_conf() | ||||
|         self.state = Gamestate(entities, parsed_agents_conf, rules, env_tests, self.conf.env_seed, self.conf.verbose) | ||||
|         self.state = Gamestate(entities, parsed_agents_conf, env_rules, env_tests, self.map.level_shape, | ||||
|                                self.conf.env_seed, self.conf.verbose) | ||||
|  | ||||
|         # All is set up, trigger entity init with variable pos | ||||
|         # All is set up, trigger additional init (after agent entity spawn etc) | ||||
| @@ -160,7 +167,7 @@ class Factory(gym.Env): | ||||
|         # Finalize | ||||
|         reward, reward_info, done = self.summarize_step_results(tick_result, done_results) | ||||
|  | ||||
|         info = reward_info | ||||
|         info = dict(reward_info) | ||||
|  | ||||
|         info.update(step_reward=sum(reward), step=self.state.curr_step) | ||||
|  | ||||
|   | ||||
| @@ -1,10 +1,15 @@ | ||||
| from marl_factory_grid.environment.entity.agent import Agent | ||||
| from marl_factory_grid.environment.groups.collection import Collection | ||||
| from marl_factory_grid.environment.rules import SpawnAgents | ||||
|  | ||||
|  | ||||
| class Agents(Collection): | ||||
|     _entity = Agent | ||||
|  | ||||
|     @property | ||||
|     def spawn_rule(self): | ||||
|         return {SpawnAgents.__name__: {}} | ||||
|  | ||||
|     @property | ||||
|     def var_is_blocking_light(self): | ||||
|         return False | ||||
|   | ||||
| @@ -1,18 +1,25 @@ | ||||
| from typing import List, Tuple, Union | ||||
| from typing import List, Tuple, Union, Dict | ||||
|  | ||||
| from marl_factory_grid.environment.entity.entity import Entity | ||||
| from marl_factory_grid.environment.groups.objects import _Objects | ||||
| from marl_factory_grid.environment.entity.object import _Object | ||||
| from marl_factory_grid.environment.groups.objects import Objects | ||||
| # noinspection PyProtectedMember | ||||
| from marl_factory_grid.environment.entity.object import Object | ||||
| import marl_factory_grid.environment.constants as c | ||||
| from marl_factory_grid.utils.results import Result | ||||
|  | ||||
|  | ||||
| class Collection(_Objects): | ||||
|     _entity = _Object  # entity? | ||||
| class Collection(Objects): | ||||
|     _entity = Object  # entity? | ||||
|     symbol = None | ||||
|  | ||||
|     @property | ||||
|     def var_is_blocking_light(self): | ||||
|         return False | ||||
|  | ||||
|     @property | ||||
|     def var_is_blocking_pos(self): | ||||
|         return False | ||||
|  | ||||
|     @property | ||||
|     def var_can_collide(self): | ||||
|         return False | ||||
| @@ -23,33 +30,65 @@ class Collection(_Objects): | ||||
|  | ||||
|     @property | ||||
|     def var_has_position(self): | ||||
|         return False | ||||
|  | ||||
|     # @property | ||||
|     # def var_has_bound(self): | ||||
|     #     return False  # batteries, globalpos, inventories true | ||||
|  | ||||
|     @property | ||||
|     def var_can_be_bound(self): | ||||
|         return False | ||||
|         return True | ||||
|  | ||||
|     @property | ||||
|     def encodings(self): | ||||
|         return [x.encoding for x in self] | ||||
|  | ||||
|     def __init__(self, size, *args, **kwargs): | ||||
|         super(Collection, self).__init__(*args, **kwargs) | ||||
|         self.size = size | ||||
|  | ||||
|     def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args):  # woihn mit den args | ||||
|         if isinstance(coords_or_quantity, int): | ||||
|             self.add_items([self._entity() for _ in range(coords_or_quantity)]) | ||||
|     @property | ||||
|     def spawn_rule(self): | ||||
|         """Prevent SpawnRule creation if Objects are spawned by map, Doors e.g.""" | ||||
|         if self.symbol: | ||||
|             return None | ||||
|         elif self._spawnrule: | ||||
|             return self._spawnrule | ||||
|         else: | ||||
|             self.add_items([self._entity(pos) for pos in coords_or_quantity]) | ||||
|             return {c.SPAWN_ENTITY_RULE: dict(collection=self, coords_or_quantity=self._coords_or_quantity)} | ||||
|  | ||||
|     def __init__(self, size, *args, coords_or_quantity: int = None, ignore_blocking=False, | ||||
|                  spawnrule: Union[None, Dict[str, dict]] = None, | ||||
|                  **kwargs): | ||||
|         super(Collection, self).__init__(*args, **kwargs) | ||||
|         self._coords_or_quantity = coords_or_quantity | ||||
|         self.size = size | ||||
|         self._spawnrule = spawnrule | ||||
|         self._ignore_blocking = ignore_blocking | ||||
|  | ||||
|     def trigger_spawn(self, state, *entity_args, coords_or_quantity=None, ignore_blocking=False,  **entity_kwargs): | ||||
|         coords_or_quantity = coords_or_quantity if coords_or_quantity else self._coords_or_quantity | ||||
|         if self.var_has_position: | ||||
|             if self.var_has_position and isinstance(coords_or_quantity, int): | ||||
|                 if ignore_blocking or self._ignore_blocking: | ||||
|                     coords_or_quantity = state.entities.floorlist[:coords_or_quantity] | ||||
|                 else: | ||||
|                     coords_or_quantity = state.get_n_random_free_positions(coords_or_quantity) | ||||
|             self.spawn(coords_or_quantity, *entity_args,  **entity_kwargs) | ||||
|             state.print(f'{len(coords_or_quantity)} new {self.name} have been spawned at {coords_or_quantity}') | ||||
|             return Result(identifier=f'{self.name}_spawn', validity=c.VALID, value=len(coords_or_quantity)) | ||||
|         else: | ||||
|             if isinstance(coords_or_quantity, int): | ||||
|                 self.spawn(coords_or_quantity, *entity_args,  **entity_kwargs) | ||||
|                 state.print(f'{coords_or_quantity} new {self.name} have been spawned randomly.') | ||||
|                 return Result(identifier=f'{self.name}_spawn', validity=c.VALID, value=coords_or_quantity) | ||||
|             else: | ||||
|                 raise ValueError(f'{self._entity.__name__} has no position!') | ||||
|  | ||||
|     def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args, **entity_kwargs): | ||||
|         if self.var_has_position: | ||||
|             if isinstance(coords_or_quantity, int): | ||||
|                 raise ValueError(f'{self._entity.__name__} should have a position!') | ||||
|             else: | ||||
|                 self.add_items([self._entity(pos, *entity_args, **entity_kwargs) for pos in coords_or_quantity]) | ||||
|         else: | ||||
|             if isinstance(coords_or_quantity, int): | ||||
|                 self.add_items([self._entity(*entity_args, **entity_kwargs) for _ in range(coords_or_quantity)]) | ||||
|             else: | ||||
|                 raise ValueError(f'{self._entity.__name__} has no  position!') | ||||
|         return c.VALID | ||||
|  | ||||
|     def despawn(self, items: List[_Object]): | ||||
|         items = [items] if isinstance(items, _Object) else items | ||||
|     def despawn(self, items: List[Object]): | ||||
|         items = [items] if isinstance(items, Object) else items | ||||
|         for item in items: | ||||
|             del self[item] | ||||
|  | ||||
| @@ -115,7 +154,7 @@ class Collection(_Objects): | ||||
|         except StopIteration: | ||||
|             pass | ||||
|         except ValueError: | ||||
|             print() | ||||
|             pass | ||||
|  | ||||
|     @property | ||||
|     def positions(self): | ||||
|   | ||||
| @@ -1,21 +1,21 @@ | ||||
| from collections import defaultdict | ||||
| from operator import itemgetter | ||||
| from random import shuffle, random | ||||
| from random import shuffle | ||||
| from typing import Dict | ||||
|  | ||||
| from marl_factory_grid.environment.groups.objects import _Objects | ||||
| from marl_factory_grid.environment.groups.objects import Objects | ||||
| from marl_factory_grid.utils.helpers import POS_MASK | ||||
|  | ||||
|  | ||||
| class Entities(_Objects): | ||||
|     _entity = _Objects | ||||
| class Entities(Objects): | ||||
|     _entity = Objects | ||||
|  | ||||
|     @staticmethod | ||||
|     def neighboring_positions(pos): | ||||
|         return (POS_MASK + pos).reshape(-1, 2) | ||||
|         return [tuple(x) for x in (POS_MASK + pos).reshape(-1, 2)] | ||||
|  | ||||
|     def get_entities_near_pos(self, pos): | ||||
|         return [y for x in itemgetter(*(tuple(x) for x in self.neighboring_positions(pos)))(self.pos_dict) for y in x] | ||||
|         return [y for x in itemgetter(*self.neighboring_positions(pos))(self.pos_dict) for y in x] | ||||
|  | ||||
|     def render(self): | ||||
|         return [y for x in self for y in x.render() if x is not None] | ||||
| @@ -35,8 +35,9 @@ class Entities(_Objects): | ||||
|         super().__init__() | ||||
|  | ||||
|     def guests_that_can_collide(self, pos): | ||||
|         return[x for val in self.pos_dict[pos] for x in val if x.var_can_collide] | ||||
|         return [x for val in self.pos_dict[pos] for x in val if x.var_can_collide] | ||||
|  | ||||
|     @property | ||||
|     def empty_positions(self): | ||||
|         empty_positions = [key for key in self.floorlist if not self.pos_dict[key]] | ||||
|         shuffle(empty_positions) | ||||
| @@ -48,11 +49,23 @@ class Entities(_Objects): | ||||
|         shuffle(empty_positions) | ||||
|         return empty_positions | ||||
|  | ||||
|     def is_blocked(self): | ||||
|         return[key for key, val in self.pos_dict.items() if any([x.var_is_blocking_pos for x in val])] | ||||
|     @property | ||||
|     def blocked_positions(self): | ||||
|         blocked_positions = [key for key, val in self.pos_dict.items() if any([x.var_is_blocking_pos for x in val])] | ||||
|         shuffle(blocked_positions) | ||||
|         return blocked_positions | ||||
|  | ||||
|     def is_not_blocked(self): | ||||
|         return[key for key, val in self.pos_dict.items() if not all([x.var_is_blocking_pos for x in val])] | ||||
|     @property | ||||
|     def free_positions_generator(self): | ||||
|         generator = ( | ||||
|             key for key in self.floorlist if all(not x.var_can_collide and not x.var_is_blocking_pos | ||||
|                                                  for x in self.pos_dict[key]) | ||||
|                      ) | ||||
|         return generator | ||||
|  | ||||
|     @property | ||||
|     def free_positions_list(self): | ||||
|         return [x for x in self.free_positions_generator] | ||||
|  | ||||
|     def iter_entities(self): | ||||
|         return iter((x for sublist in self.values() for x in sublist)) | ||||
| @@ -74,7 +87,7 @@ class Entities(_Objects): | ||||
|     def __delitem__(self, name): | ||||
|         assert_str = 'This group of entity does not exist in this collection!' | ||||
|         assert any([key for key in name.keys() if key in self.keys()]), assert_str | ||||
|         self[name]._observers.delete(self) | ||||
|         self[name].del_observer(self) | ||||
|         for entity in self[name]: | ||||
|             entity.del_observer(self) | ||||
|         return super(Entities, self).__delitem__(name) | ||||
| @@ -92,3 +105,6 @@ class Entities(_Objects): | ||||
|     @property | ||||
|     def positions(self): | ||||
|         return [k for k, v in self.pos_dict.items() for _ in v] | ||||
|  | ||||
|     def is_occupied(self, pos): | ||||
|         return len([x for x in self.pos_dict[pos] if x.var_can_collide or x.var_is_blocking_pos]) >= 1 | ||||
|   | ||||
| @@ -4,10 +4,6 @@ from marl_factory_grid.environment import constants as c | ||||
| # noinspection PyUnresolvedReferences,PyTypeChecker | ||||
| class IsBoundMixin: | ||||
|  | ||||
|     @property | ||||
|     def name(self): | ||||
|         return f'{self.__class__.__name__}({self._bound_entity.name})' | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return f'{self.__class__.__name__}#{self._bound_entity.name}({self._data})' | ||||
|  | ||||
|   | ||||
| @@ -1,14 +1,19 @@ | ||||
| from collections import defaultdict | ||||
| from typing import List | ||||
| from typing import List, Iterator, Union | ||||
|  | ||||
| import numpy as np | ||||
|  | ||||
| from marl_factory_grid.environment.entity.object import _Object | ||||
| from marl_factory_grid.environment.entity.object import Object | ||||
| import marl_factory_grid.environment.constants as c | ||||
| from marl_factory_grid.utils import helpers as h | ||||
|  | ||||
|  | ||||
| class _Objects: | ||||
|     _entity = _Object | ||||
| class Objects: | ||||
|     _entity = Object | ||||
|  | ||||
|     @property | ||||
|     def var_can_be_bound(self): | ||||
|         return False | ||||
|  | ||||
|     @property | ||||
|     def observers(self): | ||||
| @@ -45,7 +50,7 @@ class _Objects: | ||||
|     def __len__(self): | ||||
|         return len(self._data) | ||||
|  | ||||
|     def __iter__(self): | ||||
|     def __iter__(self) -> Iterator[Union[Object, None]]: | ||||
|         return iter(self.values()) | ||||
|  | ||||
|     def add_item(self, item: _entity): | ||||
| @@ -125,13 +130,14 @@ class _Objects: | ||||
|         repr_dict = {key: val for key, val in self._data.items() if key not in [c.WALLS]} | ||||
|         return f'{self.__class__.__name__}[{repr_dict}]' | ||||
|  | ||||
|     def notify_del_entity(self, entity: _Object): | ||||
|     def notify_del_entity(self, entity: Object): | ||||
|         try: | ||||
|             # noinspection PyUnresolvedReferences | ||||
|             self.pos_dict[entity.pos].remove(entity) | ||||
|         except (AttributeError, ValueError, IndexError): | ||||
|             pass | ||||
|  | ||||
|     def notify_add_entity(self, entity: _Object): | ||||
|     def notify_add_entity(self, entity: Object): | ||||
|         try: | ||||
|             if self not in entity.observers: | ||||
|                 entity.add_observer(self) | ||||
| @@ -148,12 +154,12 @@ class _Objects: | ||||
|  | ||||
|     def by_entity(self, entity): | ||||
|         try: | ||||
|             return next((x for x in self if x.belongs_to_entity(entity))) | ||||
|             return h.get_first(self, filter_by=lambda x: x.belongs_to_entity(entity)) | ||||
|         except (StopIteration, AttributeError): | ||||
|             return None | ||||
|  | ||||
|     def idx_by_entity(self, entity): | ||||
|         try: | ||||
|             return next((idx for idx, x in enumerate(self) if x.belongs_to_entity(entity))) | ||||
|             return h.get_first_index(self, filter_by=lambda x: x.belongs_to_entity(entity)) | ||||
|         except (StopIteration, AttributeError): | ||||
|             return None | ||||
|   | ||||
| @@ -1,7 +1,10 @@ | ||||
| from typing import List, Union | ||||
|  | ||||
| from marl_factory_grid.environment import constants as c | ||||
| from marl_factory_grid.environment.entity.util import GlobalPosition | ||||
| from marl_factory_grid.environment.groups.collection import Collection | ||||
| from marl_factory_grid.utils.results import Result | ||||
| from marl_factory_grid.utils.states import Gamestate | ||||
|  | ||||
|  | ||||
| class Combined(Collection): | ||||
| @@ -36,17 +39,17 @@ class GlobalPositions(Collection): | ||||
|  | ||||
|     _entity = GlobalPosition | ||||
|  | ||||
|     @property | ||||
|     def var_is_blocking_light(self): | ||||
|         return False | ||||
|  | ||||
|     @property | ||||
|     def var_can_collide(self): | ||||
|         return False | ||||
|  | ||||
|     @property | ||||
|     def var_can_be_bound(self): | ||||
|         return True | ||||
|     var_is_blocking_light = False | ||||
|     var_can_be_bound = True | ||||
|     var_can_collide = False | ||||
|     var_has_position = False | ||||
|  | ||||
|     def __init__(self, *args, **kwargs): | ||||
|         super(GlobalPositions, self).__init__(*args, **kwargs) | ||||
|  | ||||
|     def spawn(self, agents, level_shape, *args, **kwargs): | ||||
|         self.add_items([self._entity(agent, level_shape, *args, **kwargs) for agent in agents]) | ||||
|         return [Result(identifier=f'{self.name}_spawn', validity=c.VALID, value=len(self))] | ||||
|  | ||||
|     def trigger_spawn(self, state: Gamestate, *args, **kwargs) -> [Result]: | ||||
|         return self.spawn(state[c.AGENT], state.lvl_shape, *args, **kwargs) | ||||
|   | ||||
| @@ -7,9 +7,12 @@ class Walls(Collection): | ||||
|     _entity = Wall | ||||
|     symbol = c.SYMBOL_WALL | ||||
|  | ||||
|     @property | ||||
|     def var_has_position(self): | ||||
|         return True | ||||
|     var_can_collide = True | ||||
|     var_is_blocking_light = True | ||||
|     var_can_move = False | ||||
|     var_has_position = True | ||||
|     var_can_be_bound = False | ||||
|     var_is_blocking_pos = True | ||||
|  | ||||
|     def __init__(self, *args, **kwargs): | ||||
|         super(Walls, self).__init__(*args, **kwargs) | ||||
|   | ||||
| @@ -2,3 +2,4 @@ MOVEMENTS_VALID: float = -0.001 | ||||
| MOVEMENTS_FAIL: float  = -0.05 | ||||
| NOOP: float            = -0.01 | ||||
| COLLISION: float       = -0.5 | ||||
| COLLISION_DONE: float  = -1 | ||||
|   | ||||
| @@ -1,11 +1,11 @@ | ||||
| import abc | ||||
| from random import shuffle | ||||
| from typing import List | ||||
| from typing import List, Collection | ||||
|  | ||||
| from marl_factory_grid.environment import rewards as r, constants as c | ||||
| from marl_factory_grid.environment.entity.agent import Agent | ||||
| from marl_factory_grid.utils import helpers as h | ||||
| from marl_factory_grid.utils.results import TickResult, DoneResult | ||||
| from marl_factory_grid.environment import rewards as r, constants as c | ||||
|  | ||||
|  | ||||
| class Rule(abc.ABC): | ||||
| @@ -39,6 +39,29 @@ class Rule(abc.ABC): | ||||
|         return [] | ||||
|  | ||||
|  | ||||
| class SpawnEntity(Rule): | ||||
|  | ||||
|     @property | ||||
|     def _collection(self) -> Collection: | ||||
|         return Collection() | ||||
|  | ||||
|     @property | ||||
|     def name(self): | ||||
|         return f'{self.__class__.__name__}({self.collection.name})' | ||||
|  | ||||
|     def __init__(self, collection, coords_or_quantity, ignore_blocking=False): | ||||
|         super().__init__() | ||||
|         self.coords_or_quantity = coords_or_quantity | ||||
|         self.collection = collection | ||||
|         self.ignore_blocking = ignore_blocking | ||||
|  | ||||
|     def on_init(self, state, lvl_map) -> [TickResult]: | ||||
|         results = self.collection.trigger_spawn(state, ignore_blocking=self.ignore_blocking) | ||||
|         pos_str = f' on: {[x.pos for x in self.collection]}' if self.collection.var_has_position else '' | ||||
|         state.print(f'Initial {self.collection.__class__.__name__} were spawned{pos_str}') | ||||
|         return results | ||||
|  | ||||
|  | ||||
| class SpawnAgents(Rule): | ||||
|  | ||||
|     def __init__(self): | ||||
| @@ -46,14 +69,14 @@ class SpawnAgents(Rule): | ||||
|         pass | ||||
|  | ||||
|     def on_init(self, state, lvl_map): | ||||
|         agent_conf = state.agents_conf | ||||
|         # agents = Agents(lvl_map.size) | ||||
|         agents = state[c.AGENT] | ||||
|         empty_positions = state.entities.empty_positions()[:len(agent_conf)] | ||||
|         for agent_name in agent_conf: | ||||
|             actions = agent_conf[agent_name]['actions'].copy() | ||||
|             observations = agent_conf[agent_name]['observations'].copy() | ||||
|             positions = agent_conf[agent_name]['positions'].copy() | ||||
|         empty_positions = state.entities.empty_positions[:len(state.agents_conf)] | ||||
|         for agent_name, agent_conf in state.agents_conf.items(): | ||||
|             actions = agent_conf['actions'].copy() | ||||
|             observations = agent_conf['observations'].copy() | ||||
|             positions = agent_conf['positions'].copy() | ||||
|             other = agent_conf['other'].copy() | ||||
|             if positions: | ||||
|                 shuffle(positions) | ||||
|                 while True: | ||||
| @@ -61,18 +84,18 @@ class SpawnAgents(Rule): | ||||
|                         pos = positions.pop() | ||||
|                     except IndexError: | ||||
|                         raise ValueError(f'It was not possible to spawn an Agent on the available position: ' | ||||
|                                          f'\n{agent_name[agent_name]["positions"].copy()}') | ||||
|                     if agents.by_pos(pos) and state.check_pos_validity(pos): | ||||
|                                          f'\n{agent_conf["positions"].copy()}') | ||||
|                     if bool(agents.by_pos(pos)) or not state.check_pos_validity(pos): | ||||
|                         continue | ||||
|                     else: | ||||
|                         agents.add_item(Agent(actions, observations, pos, str_ident=agent_name)) | ||||
|                         agents.add_item(Agent(actions, observations, pos, str_ident=agent_name, **other)) | ||||
|                     break | ||||
|             else: | ||||
|                 agents.add_item(Agent(actions, observations, empty_positions.pop(), str_ident=agent_name)) | ||||
|                 agents.add_item(Agent(actions, observations, empty_positions.pop(), str_ident=agent_name, **other)) | ||||
|         pass | ||||
|  | ||||
|  | ||||
| class MaxStepsReached(Rule): | ||||
| class DoneAtMaxStepsReached(Rule): | ||||
|  | ||||
|     def __init__(self, max_steps: int = 500): | ||||
|         super().__init__() | ||||
| @@ -83,8 +106,8 @@ class MaxStepsReached(Rule): | ||||
|  | ||||
|     def on_check_done(self, state): | ||||
|         if self.max_steps <= state.curr_step: | ||||
|             return [DoneResult(validity=c.VALID, identifier=self.name, reward=0)] | ||||
|         return [DoneResult(validity=c.NOT_VALID, identifier=self.name, reward=0)] | ||||
|             return [DoneResult(validity=c.VALID, identifier=self.name)] | ||||
|         return [DoneResult(validity=c.NOT_VALID, identifier=self.name)] | ||||
|  | ||||
|  | ||||
| class AssignGlobalPositions(Rule): | ||||
| @@ -95,16 +118,17 @@ class AssignGlobalPositions(Rule): | ||||
|     def on_init(self, state, lvl_map): | ||||
|         from marl_factory_grid.environment.entity.util import GlobalPosition | ||||
|         for agent in state[c.AGENT]: | ||||
|             gp = GlobalPosition(lvl_map.level_shape) | ||||
|             gp.bind_to(agent) | ||||
|             gp = GlobalPosition(agent, lvl_map.level_shape) | ||||
|             state[c.GLOBALPOSITIONS].add_item(gp) | ||||
|         return [] | ||||
|  | ||||
|  | ||||
| class Collision(Rule): | ||||
| class WatchCollisions(Rule): | ||||
|  | ||||
|     def __init__(self, done_at_collisions: bool = False): | ||||
|     def __init__(self, reward=r.COLLISION, done_at_collisions: bool = False, reward_at_done=r.COLLISION_DONE): | ||||
|         super().__init__() | ||||
|         self.reward_at_done = reward_at_done | ||||
|         self.reward = reward | ||||
|         self.done_at_collisions = done_at_collisions | ||||
|         self.curr_done = False | ||||
|  | ||||
| @@ -117,12 +141,12 @@ class Collision(Rule): | ||||
|             if len(guests) >= 2: | ||||
|                 for i, guest in enumerate(guests): | ||||
|                     try: | ||||
|                         guest.set_state(TickResult(identifier=c.COLLISION, reward=r.COLLISION, | ||||
|                         guest.set_state(TickResult(identifier=c.COLLISION, reward=self.reward, | ||||
|                                                    validity=c.NOT_VALID, entity=self)) | ||||
|                     except AttributeError: | ||||
|                         pass | ||||
|                     results.append(TickResult(entity=guest, identifier=c.COLLISION, | ||||
|                                               reward=r.COLLISION, validity=c.VALID)) | ||||
|                                               reward=self.reward, validity=c.VALID)) | ||||
|                 self.curr_done = True if self.done_at_collisions else False | ||||
|         return results | ||||
|  | ||||
| @@ -131,5 +155,5 @@ class Collision(Rule): | ||||
|             inter_entity_collision_detected = self.curr_done | ||||
|             move_failed = any(h.is_move(x.state.identifier) and not x.state.validity for x in state[c.AGENT]) | ||||
|             if inter_entity_collision_detected or move_failed: | ||||
|                 return [DoneResult(validity=c.VALID, identifier=c.COLLISION, reward=r.COLLISION)] | ||||
|         return [DoneResult(validity=c.NOT_VALID, identifier=self.name, reward=0)] | ||||
|                 return [DoneResult(validity=c.VALID, identifier=c.COLLISION, reward=self.reward_at_done)] | ||||
|         return [] | ||||
|   | ||||
| @@ -6,7 +6,9 @@ from marl_factory_grid.utils.results import TickResult, DoneResult | ||||
| class TemplateRule(Rule): | ||||
|  | ||||
|     def __init__(self, *args, **kwargs): | ||||
|         super(TemplateRule, self).__init__(*args, **kwargs) | ||||
|         super(TemplateRule, self).__init__() | ||||
|         self.args = args | ||||
|         self.kwargs = kwargs | ||||
|  | ||||
|     def on_init(self, state, lvl_map): | ||||
|         pass | ||||
|   | ||||
| @@ -1,4 +1,4 @@ | ||||
| from .actions import BtryCharge | ||||
| from .entitites import Pod, Battery | ||||
| from .entitites import ChargePod, Battery | ||||
| from .groups import ChargePods, Batteries | ||||
| from .rules import DoneAtBatteryDischarge, BatteryDecharge | ||||
|   | ||||
| @@ -1,11 +1,11 @@ | ||||
| from typing import Union | ||||
|  | ||||
| import marl_factory_grid.modules.batteries.constants | ||||
| from marl_factory_grid.environment.actions import Action | ||||
| from marl_factory_grid.utils.results import ActionResult | ||||
|  | ||||
| from marl_factory_grid.modules.batteries import constants as b | ||||
| from marl_factory_grid.environment import constants as c | ||||
| from marl_factory_grid.utils import helpers as h | ||||
|  | ||||
|  | ||||
| class BtryCharge(Action): | ||||
| @@ -14,8 +14,8 @@ class BtryCharge(Action): | ||||
|         super().__init__(b.ACTION_CHARGE) | ||||
|  | ||||
|     def do(self, entity, state) -> Union[None, ActionResult]: | ||||
|         if charge_pod := state[b.CHARGE_PODS].by_pos(entity.pos): | ||||
|             valid = charge_pod.charge_battery(state[b.BATTERIES].by_entity(entity)) | ||||
|         if charge_pod := h.get_first(state[b.CHARGE_PODS].by_pos(entity.pos)): | ||||
|             valid = h.get_first(charge_pod.charge_battery(state[b.BATTERIES].by_entity(entity))) | ||||
|             if valid: | ||||
|                 state.print(f'{entity.name} just charged batteries at {charge_pod.name}.') | ||||
|             else: | ||||
| @@ -23,5 +23,6 @@ class BtryCharge(Action): | ||||
|         else: | ||||
|             valid = c.NOT_VALID | ||||
|             state.print(f'{entity.name} failed to charged batteries at {entity.pos}.') | ||||
|  | ||||
|         return ActionResult(entity=entity, identifier=self._identifier, validity=valid, | ||||
|                             reward=marl_factory_grid.modules.batteries.constants.REWARD_CHARGE_VALID if valid else marl_factory_grid.modules.batteries.constants.Reward_CHARGE_FAIL) | ||||
|                             reward=b.REWARD_CHARGE_VALID if valid else b.Reward_CHARGE_FAIL) | ||||
|   | ||||
							
								
								
									
										
											BIN
										
									
								
								marl_factory_grid/modules/batteries/chargepods.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								marl_factory_grid/modules/batteries/chargepods.png
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| After Width: | Height: | Size: 7.9 KiB | 
| @@ -1,11 +1,11 @@ | ||||
| from marl_factory_grid.environment import constants as c | ||||
| from marl_factory_grid.environment.entity.entity import Entity | ||||
| from marl_factory_grid.environment.entity.object import _Object | ||||
| from marl_factory_grid.environment.entity.object import Object | ||||
| from marl_factory_grid.modules.batteries import constants as b | ||||
| from marl_factory_grid.utils.utility_classes import RenderEntity | ||||
|  | ||||
|  | ||||
| class Battery(_Object): | ||||
| class Battery(Object): | ||||
|  | ||||
|     @property | ||||
|     def var_can_be_bound(self): | ||||
| @@ -50,7 +50,7 @@ class Battery(_Object): | ||||
|         return summary | ||||
|  | ||||
|  | ||||
| class Pod(Entity): | ||||
| class ChargePod(Entity): | ||||
|  | ||||
|     @property | ||||
|     def encoding(self): | ||||
| @@ -58,7 +58,7 @@ class Pod(Entity): | ||||
|  | ||||
|     def __init__(self, *args, charge_rate: float = 0.4, | ||||
|                  multi_charge: bool = False, **kwargs): | ||||
|         super(Pod, self).__init__(*args, **kwargs) | ||||
|         super(ChargePod, self).__init__(*args, **kwargs) | ||||
|         self.charge_rate = charge_rate | ||||
|         self.multi_charge = multi_charge | ||||
|  | ||||
|   | ||||
| @@ -1,52 +1,36 @@ | ||||
| from typing import Union, List, Tuple | ||||
|  | ||||
| from marl_factory_grid.environment import constants as c | ||||
| from marl_factory_grid.environment.groups.collection import Collection | ||||
| from marl_factory_grid.modules.batteries.entitites import Pod, Battery | ||||
| from marl_factory_grid.modules.batteries.entitites import ChargePod, Battery | ||||
| from marl_factory_grid.utils.results import Result | ||||
|  | ||||
|  | ||||
| class Batteries(Collection): | ||||
|     _entity = Battery | ||||
|  | ||||
|     @property | ||||
|     def var_is_blocking_light(self): | ||||
|         return False | ||||
|  | ||||
|     @property | ||||
|     def var_can_collide(self): | ||||
|         return False | ||||
|  | ||||
|     @property | ||||
|     def var_can_move(self): | ||||
|         return False | ||||
|  | ||||
|     @property | ||||
|     def var_has_position(self): | ||||
|         return False | ||||
|  | ||||
|     @property | ||||
|     def var_can_be_bound(self): | ||||
|         return True | ||||
|     var_has_position = False | ||||
|     var_can_be_bound = True | ||||
|  | ||||
|     @property | ||||
|     def obs_tag(self): | ||||
|         return self.__class__.__name__ | ||||
|  | ||||
|     def __init__(self, *args, **kwargs): | ||||
|         super(Batteries, self).__init__(*args, **kwargs) | ||||
|     def __init__(self, size, initial_charge_level: float=1.0, *args, **kwargs): | ||||
|         super(Batteries, self).__init__(size, *args, **kwargs) | ||||
|         self.initial_charge_level = initial_charge_level | ||||
|  | ||||
|     def spawn(self, agents, initial_charge_level): | ||||
|         batteries = [self._entity(initial_charge_level, agent) for _, agent in enumerate(agents)] | ||||
|     def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], agents, *entity_args, **entity_kwargs): | ||||
|         batteries = [self._entity(self.initial_charge_level, agent) for _, agent in enumerate(agents)] | ||||
|         self.add_items(batteries) | ||||
|  | ||||
|     # def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args):           hat keine pos | ||||
|     #     agents = entity_args[0] | ||||
|     #     initial_charge_level = entity_args[1] | ||||
|     #     batteries = [self._entity(initial_charge_level, agent) for _, agent in enumerate(agents)] | ||||
|     #     self.add_items(batteries) | ||||
|     def trigger_spawn(self, state, *entity_args, coords_or_quantity=None,  **entity_kwargs): | ||||
|         self.spawn(0, state[c.AGENT]) | ||||
|         return Result(identifier=f'{self.name}_spawn', validity=c.VALID, value=len(self)) | ||||
|  | ||||
|  | ||||
| class ChargePods(Collection): | ||||
|     _entity = Pod | ||||
|     _entity = ChargePod | ||||
|  | ||||
|     def __init__(self, *args, **kwargs): | ||||
|         super(ChargePods, self).__init__(*args, **kwargs) | ||||
|   | ||||
| @@ -1,11 +1,9 @@ | ||||
| from typing import List, Union | ||||
|  | ||||
| import marl_factory_grid.modules.batteries.constants | ||||
| from marl_factory_grid.environment.rules import Rule | ||||
| from marl_factory_grid.utils.results import TickResult, DoneResult | ||||
|  | ||||
| from marl_factory_grid.environment import constants as c | ||||
| from marl_factory_grid.environment.rules import Rule | ||||
| from marl_factory_grid.modules.batteries import constants as b | ||||
| from marl_factory_grid.utils.results import TickResult, DoneResult | ||||
|  | ||||
|  | ||||
| class BatteryDecharge(Rule): | ||||
| @@ -49,10 +47,6 @@ class BatteryDecharge(Rule): | ||||
|         self.per_action_costs = per_action_costs | ||||
|         self.initial_charge = initial_charge | ||||
|  | ||||
|     def on_init(self, state, lvl_map):  # on reset? | ||||
|         assert len(state[c.AGENT]), "There are no agents, did you already spawn them?" | ||||
|         state[b.BATTERIES].spawn(state[c.AGENT], self.initial_charge) | ||||
|  | ||||
|     def tick_step(self, state) -> List[TickResult]: | ||||
|         # Decharge | ||||
|         batteries = state[b.BATTERIES] | ||||
| @@ -66,7 +60,7 @@ class BatteryDecharge(Rule): | ||||
|  | ||||
|             batteries.by_entity(agent).decharge(energy_consumption) | ||||
|  | ||||
|             results.append(TickResult(self.name, reward=0, entity=agent, validity=c.VALID)) | ||||
|             results.append(TickResult(self.name, entity=agent, validity=c.VALID)) | ||||
|  | ||||
|         return results | ||||
|  | ||||
| @@ -82,13 +76,13 @@ class BatteryDecharge(Rule): | ||||
|                 if self.paralyze_agents_on_discharge: | ||||
|                     btry.bound_entity.paralyze(self.name) | ||||
|                     results.append( | ||||
|                         TickResult("Paralyzed", entity=btry.bound_entity, reward=0, validity=c.VALID) | ||||
|                         TickResult("Paralyzed", entity=btry.bound_entity, validity=c.VALID) | ||||
|                     ) | ||||
|                     state.print(f'{btry.bound_entity.name} has just been paralyzed!') | ||||
|             if btry.bound_entity.var_is_paralyzed and not btry.is_discharged: | ||||
|                 btry.bound_entity.de_paralyze(self.name) | ||||
|                 results.append( | ||||
|                     TickResult("De-Paralyzed", entity=btry.bound_entity, reward=0, validity=c.VALID) | ||||
|                     TickResult("De-Paralyzed", entity=btry.bound_entity, validity=c.VALID) | ||||
|                 ) | ||||
|                 state.print(f'{btry.bound_entity.name} has just been de-paralyzed!') | ||||
|         return results | ||||
| @@ -132,7 +126,7 @@ class DoneAtBatteryDischarge(BatteryDecharge): | ||||
|         if any_discharged or all_discharged: | ||||
|             return [DoneResult(self.name, validity=c.VALID, reward=self.reward_discharge_done)] | ||||
|         else: | ||||
|             return [DoneResult(self.name, validity=c.NOT_VALID, reward=0)] | ||||
|             return [DoneResult(self.name, validity=c.NOT_VALID)] | ||||
|  | ||||
|  | ||||
| class SpawnChargePods(Rule): | ||||
| @@ -155,7 +149,7 @@ class SpawnChargePods(Rule): | ||||
|  | ||||
|     def on_init(self, state, lvl_map): | ||||
|         pod_collection = state[b.CHARGE_PODS] | ||||
|         empty_positions = state.entities.empty_positions() | ||||
|         empty_positions = state.entities.empty_positions | ||||
|         pods = pod_collection.from_coordinates(empty_positions, entity_kwargs=dict( | ||||
|             multi_charge=self.multi_charge, charge_rate=self.charge_rate) | ||||
|                                                ) | ||||
|   | ||||
| @@ -1,4 +1,4 @@ | ||||
| from .actions import CleanUp | ||||
| from .entitites import DirtPile | ||||
| from .groups import DirtPiles | ||||
| from .rules import SpawnDirt, EntitiesSmearDirtOnMove, DoneOnAllDirtCleaned | ||||
| from .rules import EntitiesSmearDirtOnMove, DoneOnAllDirtCleaned | ||||
|   | ||||
| @@ -1,5 +1,3 @@ | ||||
| from numpy import random | ||||
|  | ||||
| from marl_factory_grid.environment.entity.entity import Entity | ||||
| from marl_factory_grid.utils.utility_classes import RenderEntity | ||||
| from marl_factory_grid.modules.clean_up import constants as d | ||||
| @@ -7,22 +5,6 @@ from marl_factory_grid.modules.clean_up import constants as d | ||||
|  | ||||
| class DirtPile(Entity): | ||||
|  | ||||
|     @property | ||||
|     def var_can_collide(self): | ||||
|         return False | ||||
|  | ||||
|     @property | ||||
|     def var_can_move(self): | ||||
|         return False | ||||
|  | ||||
|     @property | ||||
|     def var_is_blocking_light(self): | ||||
|         return False | ||||
|  | ||||
|     @property | ||||
|     def var_has_position(self): | ||||
|         return True | ||||
|  | ||||
|     @property | ||||
|     def amount(self): | ||||
|         return self._amount | ||||
|   | ||||
| @@ -1,76 +1,61 @@ | ||||
| from typing import Union, List, Tuple | ||||
|  | ||||
| from marl_factory_grid.environment import constants as c | ||||
| from marl_factory_grid.utils.results import Result | ||||
| from marl_factory_grid.environment.groups.collection import Collection | ||||
| from marl_factory_grid.modules.clean_up.entitites import DirtPile | ||||
| from marl_factory_grid.utils.results import Result | ||||
|  | ||||
|  | ||||
| class DirtPiles(Collection): | ||||
|     _entity = DirtPile | ||||
|  | ||||
|     @property | ||||
|     def var_is_blocking_light(self): | ||||
|         return False | ||||
|     var_is_blocking_light = False | ||||
|     var_can_collide = False | ||||
|     var_can_move = False | ||||
|     var_has_position = True | ||||
|  | ||||
|     @property | ||||
|     def var_can_collide(self): | ||||
|         return False | ||||
|  | ||||
|     @property | ||||
|     def var_can_move(self): | ||||
|         return False | ||||
|  | ||||
|     @property | ||||
|     def var_has_position(self): | ||||
|         return True | ||||
|  | ||||
|     @property | ||||
|     def amount(self): | ||||
|     def global_amount(self): | ||||
|         return sum([dirt.amount for dirt in self]) | ||||
|  | ||||
|     def __init__(self, *args, | ||||
|                  max_local_amount=5, | ||||
|                  clean_amount=1, | ||||
|                  max_global_amount: int = 20, **kwargs): | ||||
|                  max_global_amount: int = 20, | ||||
|                  coords_or_quantity=10, | ||||
|                  initial_amount=2, | ||||
|                  amount_var=0.2, | ||||
|                  n_var=0.2, | ||||
|                  **kwargs): | ||||
|         super(DirtPiles, self).__init__(*args, **kwargs) | ||||
|         self.amount_var = amount_var | ||||
|         self.n_var = n_var | ||||
|         self.clean_amount = clean_amount | ||||
|         self.max_global_amount = max_global_amount | ||||
|         self.max_local_amount = max_local_amount | ||||
|         self.coords_or_quantity = coords_or_quantity | ||||
|         self.initial_amount = initial_amount | ||||
|  | ||||
|     def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args): | ||||
|         amount_s = entity_args[0] | ||||
|     def trigger_spawn(self, state, coords_or_quantity=0, amount=0) -> [Result]: | ||||
|         coords_or_quantity = coords_or_quantity if coords_or_quantity else self.coords_or_quantity | ||||
|         n_new = int(abs(coords_or_quantity + (state.rng.uniform(-self.n_var, self.n_var)))) | ||||
|         n_new = state.get_n_random_free_positions(n_new) | ||||
|  | ||||
|         amounts = [amount if amount else (self.initial_amount + state.rng.uniform(-self.amount_var, self.amount_var)) | ||||
|                    for _ in range(coords_or_quantity)] | ||||
|         spawn_counter = 0 | ||||
|         for idx, pos in enumerate(coords_or_quantity): | ||||
|             if not self.amount > self.max_global_amount: | ||||
|                 amount = amount_s[idx] if isinstance(amount_s, list) else amount_s | ||||
|         for idx, (pos, a) in enumerate(zip(n_new, amounts)): | ||||
|             if not self.global_amount > self.max_global_amount: | ||||
|                 if dirt := self.by_pos(pos): | ||||
|                     dirt = next(dirt.iter()) | ||||
|                     new_value = dirt.amount + amount | ||||
|                     new_value = dirt.amount + a | ||||
|                     dirt.set_new_amount(new_value) | ||||
|                 else: | ||||
|                     dirt = DirtPile(pos, amount=amount) | ||||
|                     self.add_item(dirt) | ||||
|                     super().spawn([pos], amount=a) | ||||
|                     spawn_counter += 1 | ||||
|             else: | ||||
|                 return Result(identifier=f'{self.name}_spawn', validity=c.NOT_VALID, reward=0, | ||||
|                               value=spawn_counter) | ||||
|         return Result(identifier=f'{self.name}_spawn', validity=c.VALID, reward=0, value=spawn_counter) | ||||
|                 return Result(identifier=f'{self.name}_spawn', validity=c.NOT_VALID, value=spawn_counter) | ||||
|  | ||||
|     def trigger_dirt_spawn(self, n, amount, state, n_var=0.2, amount_var=0.2) -> Result: | ||||
|         free_for_dirt = [x for x in state.entities.floorlist if len(state.entities.pos_dict[x]) == 0 or ( | ||||
|                 len(state.entities.pos_dict[x]) >= 1 and isinstance(next(y for y in x), DirtPile))] | ||||
|         # free_for_dirt = [x for x in state[c.FLOOR] | ||||
|         #                  if len(x.guests) == 0 or ( | ||||
|         #                          len(x.guests) == 1 and | ||||
|         #                          isinstance(next(y for y in x.guests), DirtPile))] | ||||
|         state.rng.shuffle(free_for_dirt) | ||||
|  | ||||
|         new_spawn = int(abs(n + (state.rng.uniform(-n_var, n_var)))) | ||||
|         new_amount_s = [abs(amount + (amount*state.rng.uniform(-amount_var, amount_var))) for _ in range(new_spawn)] | ||||
|         n_dirty_positions = free_for_dirt[:new_spawn] | ||||
|         return self.spawn(n_dirty_positions, new_amount_s) | ||||
|         return Result(identifier=f'{self.name}_spawn', validity=c.VALID, value=spawn_counter) | ||||
|  | ||||
|     def __repr__(self): | ||||
|         s = super(DirtPiles, self).__repr__() | ||||
|         return f'{s[:-1]}, {self.amount})' | ||||
|         return f'{s[:-1]}, {self.global_amount}]' | ||||
|   | ||||
| @@ -22,58 +22,37 @@ class DoneOnAllDirtCleaned(Rule): | ||||
|     def on_check_done(self, state) -> [DoneResult]: | ||||
|         if len(state[d.DIRT]) == 0 and state.curr_step: | ||||
|             return [DoneResult(validity=c.VALID, identifier=self.name, reward=self.reward)] | ||||
|         return [DoneResult(validity=c.NOT_VALID, identifier=self.name, reward=0)] | ||||
|         return [DoneResult(validity=c.NOT_VALID, identifier=self.name)] | ||||
|  | ||||
|  | ||||
| class SpawnDirt(Rule): | ||||
| class RespawnDirt(Rule): | ||||
|  | ||||
|     def __init__(self, initial_n: int = 5, initial_amount: float = 1.3, | ||||
|                  respawn_n: int = 3, respawn_amount: float = 0.8, | ||||
|                  n_var: float = 0.2, amount_var: float = 0.2, spawn_freq: int = 15): | ||||
|     def __init__(self, respawn_freq: int = 15, respawn_n: int = 5, respawn_amount: float = 1.0): | ||||
|         """ | ||||
|         Defines the spawn pattern of intial and additional 'Dirt'-entitites. | ||||
|         First chooses positions, then trys to spawn dirt until 'respawn_n' or the maximal global amount is reached. | ||||
|         If there is allready some, it is topped up to min(max_local_amount, amount). | ||||
|  | ||||
|         :type spawn_freq: int | ||||
|         :parameter spawn_freq: In which frequency should this Rule try to spawn new 'Dirt'? | ||||
|         :type respawn_freq: int | ||||
|         :parameter respawn_freq: In which frequency should this Rule try to spawn new 'Dirt'? | ||||
|         :type respawn_n: int | ||||
|         :parameter respawn_n: How many respawn positions are considered. | ||||
|         :type initial_n: int | ||||
|         :parameter initial_n: How much initial positions are considered. | ||||
|         :type amount_var: float | ||||
|         :parameter amount_var: Variance of amount to spawn. | ||||
|         :type n_var: float | ||||
|         :parameter n_var: Variance of n to spawn. | ||||
|         :type respawn_amount: float | ||||
|         :parameter respawn_amount: Defines how much dirt 'amount' is placed every 'spawn_freq' ticks. | ||||
|         :type initial_amount: float | ||||
|         :parameter initial_amount: Defines how much dirt 'amount' is initially placed. | ||||
|  | ||||
|         """ | ||||
|         super().__init__() | ||||
|         self.amount_var = amount_var | ||||
|         self.n_var = n_var | ||||
|         self.respawn_amount = respawn_amount | ||||
|         self.respawn_n = respawn_n | ||||
|         self.initial_amount = initial_amount | ||||
|         self.initial_n = initial_n | ||||
|         self.spawn_freq = spawn_freq | ||||
|         self._next_dirt_spawn = spawn_freq | ||||
|  | ||||
|     def on_init(self, state, lvl_map) -> str: | ||||
|         result = state[d.DIRT].trigger_dirt_spawn(self.initial_n, self.initial_amount, state, | ||||
|                                                   n_var=self.n_var, amount_var=self.amount_var) | ||||
|         state.print(f'Initial Dirt was spawned on: {[x.pos for x in state[d.DIRT]]}') | ||||
|         return result | ||||
|         self.respawn_amount = respawn_amount | ||||
|         self.respawn_freq = respawn_freq | ||||
|         self._next_dirt_spawn = respawn_freq | ||||
|  | ||||
|     def tick_step(self, state): | ||||
|         collection = state[d.DIRT] | ||||
|         if self._next_dirt_spawn < 0: | ||||
|             pass  # No DirtPile Spawn | ||||
|             result = []  # No DirtPile Spawn | ||||
|         elif not self._next_dirt_spawn: | ||||
|             result = [state[d.DIRT].trigger_dirt_spawn(self.respawn_n, self.respawn_amount, state, | ||||
|                                                        n_var=self.n_var, amount_var=self.amount_var)] | ||||
|             self._next_dirt_spawn = self.spawn_freq | ||||
|             result = [collection.trigger_spawn(state, coords_or_quantity=self.respawn_n, amount=self.respawn_amount)] | ||||
|             self._next_dirt_spawn = self.respawn_freq | ||||
|         else: | ||||
|             self._next_dirt_spawn -= 1 | ||||
|             result = [] | ||||
| @@ -99,8 +78,8 @@ class EntitiesSmearDirtOnMove(Rule): | ||||
|         for entity in state.moving_entites: | ||||
|             if is_move(entity.state.identifier) and entity.state.validity == c.VALID: | ||||
|                 if old_pos_dirt := state[d.DIRT].by_pos(entity.last_pos): | ||||
|                     old_pos_dirt = next(iter(old_pos_dirt)) | ||||
|                     if smeared_dirt := round(old_pos_dirt.amount * self.smear_ratio, 2): | ||||
|                         if state[d.DIRT].spawn(entity.pos, amount=smeared_dirt): | ||||
|                             results.append(TickResult(identifier=self.name, entity=entity, | ||||
|                                                       reward=0, validity=c.VALID)) | ||||
|                             results.append(TickResult(identifier=self.name, entity=entity, validity=c.VALID)) | ||||
|         return results | ||||
|   | ||||
| @@ -1,4 +1,7 @@ | ||||
| from .actions import DestAction | ||||
| from .entitites import Destination | ||||
| from .groups import Destinations | ||||
| from .rules import DoneAtDestinationReachAll, SpawnDestinations | ||||
| from .rules import (DoneAtDestinationReachAll, | ||||
|                     DoneAtDestinationReachAny, | ||||
|                     SpawnDestinationsPerAgent, | ||||
|                     DestinationReachReward) | ||||
|   | ||||
| @@ -21,4 +21,4 @@ class DestAction(Action): | ||||
|             valid = c.NOT_VALID | ||||
|             state.print(f'{entity.name} just tried to do_wait_action do_wait_action at {entity.pos} but failed') | ||||
|         return ActionResult(entity=entity, identifier=self._identifier, validity=valid, | ||||
|                             reward=marl_factory_grid.modules.destinations.constants.REWARD_WAIT_VALID if valid else marl_factory_grid.modules.destinations.constants.REWARD_WAIT_FAIL) | ||||
|                             reward=d.REWARD_WAIT_VALID if valid else d.REWARD_WAIT_FAIL) | ||||
|   | ||||
| @@ -9,30 +9,6 @@ from marl_factory_grid.utils.utility_classes import RenderEntity | ||||
|  | ||||
| class Destination(Entity): | ||||
|  | ||||
|     @property | ||||
|     def var_can_move(self): | ||||
|         return False | ||||
|  | ||||
|     @property | ||||
|     def var_can_collide(self): | ||||
|         return False | ||||
|  | ||||
|     @property | ||||
|     def var_has_position(self): | ||||
|         return True | ||||
|  | ||||
|     @property | ||||
|     def var_is_blocking_pos(self): | ||||
|         return False | ||||
|  | ||||
|     @property | ||||
|     def var_is_blocking_light(self): | ||||
|         return False | ||||
|  | ||||
|     @property | ||||
|     def var_can_be_bound(self): | ||||
|         return True | ||||
|  | ||||
|     def was_reached(self): | ||||
|         return self._was_reached | ||||
|  | ||||
|   | ||||
| @@ -1,43 +1,18 @@ | ||||
| from marl_factory_grid.environment.groups.collection import Collection | ||||
| from marl_factory_grid.modules.destinations.entitites import Destination | ||||
| from marl_factory_grid.environment import constants as c | ||||
| from marl_factory_grid.modules.destinations import constants as d | ||||
|  | ||||
|  | ||||
| class Destinations(Collection): | ||||
|     _entity = Destination | ||||
|  | ||||
|     @property | ||||
|     def var_is_blocking_light(self): | ||||
|         return False | ||||
|  | ||||
|     @property | ||||
|     def var_can_collide(self): | ||||
|         return False | ||||
|  | ||||
|     @property | ||||
|     def var_can_move(self): | ||||
|         return False | ||||
|  | ||||
|     @property | ||||
|     def var_has_position(self): | ||||
|         return True | ||||
|     var_is_blocking_light = False | ||||
|     var_can_collide = False | ||||
|     var_can_move = False | ||||
|     var_has_position = True | ||||
|     var_can_be_bound = True | ||||
|  | ||||
|     def __init__(self, *args, **kwargs): | ||||
|         super().__init__(*args, **kwargs) | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return super(Destinations, self).__repr__() | ||||
|  | ||||
|     @staticmethod | ||||
|     def trigger_destination_spawn(n_dests, state): | ||||
|         coordinates = state.entities.floorlist[:n_dests] | ||||
|         if destinations := [Destination(pos) for pos in coordinates]: | ||||
|             state[d.DESTINATION].add_items(destinations) | ||||
|             state.print(f'{n_dests} new destinations have been spawned') | ||||
|             return c.VALID | ||||
|         else: | ||||
|             state.print('No Destiantions are spawning, limit is reached.') | ||||
|             return c.NOT_VALID | ||||
|  | ||||
|  | ||||
|   | ||||
| @@ -2,8 +2,8 @@ import ast | ||||
| from random import shuffle | ||||
| from typing import List, Dict, Tuple | ||||
|  | ||||
| import marl_factory_grid.modules.destinations.constants | ||||
| from marl_factory_grid.environment.rules import Rule | ||||
| from marl_factory_grid.utils import helpers as h | ||||
| from marl_factory_grid.utils.results import TickResult, DoneResult | ||||
| from marl_factory_grid.environment import constants as c | ||||
|  | ||||
| @@ -54,7 +54,7 @@ class DoneAtDestinationReachAll(DestinationReachReward): | ||||
|         """ | ||||
|         This rule triggers and sets the done flag if ALL Destinations have been reached. | ||||
|  | ||||
|         :type reward_at_done: object | ||||
|         :type reward_at_done: float | ||||
|         :param reward_at_done: Specifies the reward, agent get, whenn all destinations are reached. | ||||
|         :type dest_reach_reward: float | ||||
|         :param dest_reach_reward: Specify the reward, agents get when reaching a single destination. | ||||
| @@ -65,7 +65,7 @@ class DoneAtDestinationReachAll(DestinationReachReward): | ||||
|     def on_check_done(self, state) -> List[DoneResult]: | ||||
|         if all(x.was_reached() for x in state[d.DESTINATION]): | ||||
|             return [DoneResult(self.name, validity=c.VALID, reward=self.reward)] | ||||
|         return [DoneResult(self.name, validity=c.NOT_VALID, reward=0)] | ||||
|         return [DoneResult(self.name, validity=c.NOT_VALID)] | ||||
|  | ||||
|  | ||||
| class DoneAtDestinationReachAny(DestinationReachReward): | ||||
| @@ -75,7 +75,7 @@ class DoneAtDestinationReachAny(DestinationReachReward): | ||||
|         This rule triggers and sets the done flag if ANY Destinations has been reached. | ||||
|         !!! IMPORTANT: 'reward_at_done' is shared between the agents; 'dest_reach_reward' is bound to a specific one. | ||||
|                  | ||||
|         :type reward_at_done: object | ||||
|         :type reward_at_done: float | ||||
|         :param reward_at_done: Specifies the reward, all agent get, when any destinations has been reached.  | ||||
|                                 Default {d.REWARD_DEST_DONE} | ||||
|         :type dest_reach_reward: float | ||||
| @@ -87,67 +87,29 @@ class DoneAtDestinationReachAny(DestinationReachReward): | ||||
|  | ||||
|     def on_check_done(self, state) -> List[DoneResult]: | ||||
|         if any(x.was_reached() for x in state[d.DESTINATION]): | ||||
|             return [DoneResult(self.name, validity=c.VALID, reward=marl_factory_grid.modules.destinations.constants.REWARD_DEST_REACHED)] | ||||
|             return [DoneResult(self.name, validity=c.VALID, reward=d.REWARD_DEST_REACHED)] | ||||
|         return [] | ||||
|  | ||||
|  | ||||
| class SpawnDestinations(Rule): | ||||
|  | ||||
|     def __init__(self, n_dests: int = 1, spawn_mode: str = d.MODE_GROUPED): | ||||
|         f""" | ||||
|         Defines how destinations are initially spawned and respawned in addition. | ||||
|         !!! This rule introduces no kind of reward or Env.-Done condition! | ||||
|                  | ||||
|         :type n_dests: int | ||||
|         :param n_dests: How many destiantions should be maintained (and initally spawnewd) on the map? | ||||
|         :type spawn_mode: str  | ||||
|         :param spawn_mode: One of {d.SPAWN_MODES}. {d.MODE_GROUPED}: Always wait for all Dstiantions do be gone,  | ||||
|                            then respawn after the given time. {d.MODE_SINGLE}: Just spawn every destination,  | ||||
|                            that has been reached, after the given time | ||||
|                              | ||||
|         """ | ||||
|         super(SpawnDestinations, self).__init__() | ||||
|         self.n_dests = n_dests | ||||
|         self.spawn_mode = spawn_mode | ||||
|  | ||||
|     def on_init(self, state, lvl_map): | ||||
|         # noinspection PyAttributeOutsideInit | ||||
|         state[d.DESTINATION].trigger_destination_spawn(self.n_dests, state) | ||||
|         pass | ||||
|  | ||||
|     def tick_pre_step(self, state) -> List[TickResult]: | ||||
|         pass | ||||
|  | ||||
|     def tick_step(self, state) -> List[TickResult]: | ||||
|         if n_dest_spawn := max(0, self.n_dests - len(state[d.DESTINATION])): | ||||
|             if self.spawn_mode == d.MODE_GROUPED and n_dest_spawn == self.n_dests: | ||||
|                 validity = state[d.DESTINATION].trigger_destination_spawn(n_dest_spawn, state) | ||||
|                 return [TickResult(self.name, validity=validity, entity=None, value=n_dest_spawn)] | ||||
|             elif self.spawn_mode == d.MODE_SINGLE and n_dest_spawn: | ||||
|                 validity = state[d.DESTINATION].trigger_destination_spawn(n_dest_spawn, state) | ||||
|                 return [TickResult(self.name, validity=validity, entity=None, value=n_dest_spawn)] | ||||
|             else: | ||||
|                 pass | ||||
|  | ||||
|  | ||||
| class SpawnDestinationsPerAgent(Rule): | ||||
|     def __init__(self, per_agent_positions: Dict[str, List[Tuple[int, int]]]): | ||||
|     def __init__(self, coords_or_quantity: Dict[str, List[Tuple[int, int]]]): | ||||
|         """ | ||||
|         Special rule, that spawn distinations, that are bound to a single agent a fixed set of positions. | ||||
|         Usefull for introducing specialists, etc. .. | ||||
|  | ||||
|         !!! This rule does not introduce any reward or done condition. | ||||
|  | ||||
|         :type per_agent_positions:  Dict[str, List[Tuple[int, int]] | ||||
|         :param per_agent_positions: Please provide a dictionary with agent names as keys; and a list of possible | ||||
|         :type coords_or_quantity:  Dict[str, List[Tuple[int, int]] | ||||
|         :param coords_or_quantity: Please provide a dictionary with agent names as keys; and a list of possible | ||||
|                                      destiantion coords as value. Example: {Wolfgang: [(0, 0), (1, 1), ...]} | ||||
|         """ | ||||
|         super(Rule, self).__init__() | ||||
|         self.per_agent_positions = {key: [ast.literal_eval(x) for x in val] for key, val in per_agent_positions.items()} | ||||
|         self.per_agent_positions = {key: [ast.literal_eval(x) for x in val] for key, val in coords_or_quantity.items()} | ||||
|  | ||||
|     def on_init(self, state, lvl_map): | ||||
|         for (agent_name, position_list) in self.per_agent_positions.items(): | ||||
|             agent = next(x for x in state[c.AGENT] if agent_name in x.name)  # Fixme: Ugly AF | ||||
|             agent = h.get_first(state[c.AGENT], lambda x: agent_name in x.name) | ||||
|             assert agent | ||||
|             position_list = position_list.copy() | ||||
|             shuffle(position_list) | ||||
|             while True: | ||||
| @@ -155,7 +117,7 @@ class SpawnDestinationsPerAgent(Rule): | ||||
|                     pos = position_list.pop() | ||||
|                 except IndexError: | ||||
|                     print(f"Could not spawn Destinations at: {self.per_agent_positions[agent_name]}") | ||||
|                     print(f'Check your agent palcement: {state[c.AGENT]} ... Exit ...') | ||||
|                     print(f'Check your agent placement: {state[c.AGENT]} ... Exit ...') | ||||
|                     exit(9999) | ||||
|                 if (not pos == agent.pos) and (not state[d.DESTINATION].by_pos(pos)): | ||||
|                     destination = Destination(pos, bind_to=agent) | ||||
|   | ||||
| @@ -1,4 +1,5 @@ | ||||
| from marl_factory_grid.environment.entity.entity import Entity | ||||
| from marl_factory_grid.utils import Result | ||||
| from marl_factory_grid.utils.utility_classes import RenderEntity | ||||
| from marl_factory_grid.environment import constants as c | ||||
|  | ||||
| @@ -41,21 +42,6 @@ class Door(Entity): | ||||
|     def str_state(self): | ||||
|         return 'open' if self.is_open else 'closed' | ||||
|  | ||||
|     def __init__(self, *args, closed_on_init=True, auto_close_interval=10, **kwargs): | ||||
|         self._status = d.STATE_CLOSED | ||||
|         super(Door, self).__init__(*args, **kwargs) | ||||
|         self.auto_close_interval = auto_close_interval | ||||
|         self.time_to_close = 0 | ||||
|         if not closed_on_init: | ||||
|             self._open() | ||||
|         else: | ||||
|             self._close() | ||||
|  | ||||
|     def summarize_state(self): | ||||
|         state_dict = super().summarize_state() | ||||
|         state_dict.update(state=str(self.str_state), time_to_close=int(self.time_to_close)) | ||||
|         return state_dict | ||||
|  | ||||
|     @property | ||||
|     def is_closed(self): | ||||
|         return self._status == d.STATE_CLOSED | ||||
| @@ -68,6 +54,25 @@ class Door(Entity): | ||||
|     def status(self): | ||||
|         return self._status | ||||
|  | ||||
|     @property | ||||
|     def time_to_close(self): | ||||
|         return self._time_to_close | ||||
|  | ||||
|     def __init__(self, *args, closed_on_init=True, auto_close_interval=10, **kwargs): | ||||
|         self._status = d.STATE_CLOSED | ||||
|         super(Door, self).__init__(*args, **kwargs) | ||||
|         self._auto_close_interval = auto_close_interval | ||||
|         self._time_to_close = 0 | ||||
|         if not closed_on_init: | ||||
|             self._open() | ||||
|         else: | ||||
|             self._close() | ||||
|  | ||||
|     def summarize_state(self): | ||||
|         state_dict = super().summarize_state() | ||||
|         state_dict.update(state=str(self.str_state), time_to_close=self.time_to_close) | ||||
|         return state_dict | ||||
|  | ||||
|     def render(self): | ||||
|         name, state = 'door_open' if self.is_open else 'door_closed', 'blank' | ||||
|         return RenderEntity(name, self.pos, 1, 'none', state, self.u_int + 1) | ||||
| @@ -80,18 +85,35 @@ class Door(Entity): | ||||
|         return c.VALID | ||||
|  | ||||
|     def tick(self, state): | ||||
|         if self.is_open and len(state.entities.pos_dict[self.pos]) == 2 and self.time_to_close: | ||||
|             self.time_to_close -= 1 | ||||
|             return c.NOT_VALID | ||||
|         elif self.is_open and not self.time_to_close and len(state.entities.pos_dict[self.pos]) == 2: | ||||
|         # Check if no entity is standing in the door | ||||
|         if len(state.entities.pos_dict[self.pos]) <= 2: | ||||
|             if self.is_open and self.time_to_close: | ||||
|                 self._decrement_timer() | ||||
|                 return Result(f"{d.DOOR}_tick", c.VALID, entity=self) | ||||
|             elif self.is_open and not self.time_to_close: | ||||
|                 self.use() | ||||
|             return c.VALID | ||||
|                 return Result(f"{d.DOOR}_closed", c.VALID, entity=self) | ||||
|             else: | ||||
|             return c.NOT_VALID | ||||
|                 # No one is in door, but it is closed... Nothing to do.... | ||||
|                 return None | ||||
|         else: | ||||
|             # Entity is standing in the door, reset timer | ||||
|             self._reset_timer() | ||||
|             return Result(f"{d.DOOR}_reset", c.VALID, entity=self) | ||||
|  | ||||
|     def _open(self): | ||||
|         self._status = d.STATE_OPEN | ||||
|         self.time_to_close = self.auto_close_interval | ||||
|         self._reset_timer() | ||||
|         return True | ||||
|  | ||||
|     def _close(self): | ||||
|         self._status = d.STATE_CLOSED | ||||
|         return True | ||||
|  | ||||
|     def _decrement_timer(self): | ||||
|         self._time_to_close -= 1 | ||||
|         return True | ||||
|  | ||||
|     def _reset_timer(self): | ||||
|         self._time_to_close = self._auto_close_interval | ||||
|         return True | ||||
|   | ||||
| @@ -1,5 +1,3 @@ | ||||
| from typing import Union | ||||
|  | ||||
| from marl_factory_grid.environment.groups.collection import Collection | ||||
| from marl_factory_grid.modules.doors import constants as d | ||||
| from marl_factory_grid.modules.doors.entitites import Door | ||||
| @@ -18,8 +16,10 @@ class Doors(Collection): | ||||
|         super(Doors, self).__init__(*args, can_collide=True, **kwargs) | ||||
|  | ||||
|     def tick_doors(self, state): | ||||
|         result_dict = dict() | ||||
|         results = list() | ||||
|         for door in self: | ||||
|             did_tick = door.tick(state) | ||||
|             result_dict.update({door.name: did_tick}) | ||||
|         return result_dict | ||||
|             tick_result = door.tick(state) | ||||
|             if tick_result is not None: | ||||
|                 results.append(tick_result) | ||||
|         # TODO: Should return a Result object, not a random dict. | ||||
|         return results | ||||
|   | ||||
| @@ -19,10 +19,10 @@ class DoorAutoClose(Rule): | ||||
|  | ||||
|     def tick_step(self, state): | ||||
|         if doors := state[d.DOORS]: | ||||
|             doors_tick_result = doors.tick_doors(state) | ||||
|             doors_that_ticked = [key for key, val in doors_tick_result.items() if val] | ||||
|             state.print(f'{doors_that_ticked} were auto-closed' | ||||
|                         if doors_that_ticked else 'No Doors were auto-closed') | ||||
|             doors_tick_results = doors.tick_doors(state) | ||||
|             doors_that_closed = [x.entity.name for x in doors_tick_results if 'closed' in x.identifier] | ||||
|             door_str = doors_that_closed if doors_that_closed else "No Doors" | ||||
|             state.print(f'{door_str} were auto-closed') | ||||
|             return [TickResult(self.name, validity=c.VALID, value=1)] | ||||
|         state.print('There are no doors, but you loaded the corresponding Module') | ||||
|         return [] | ||||
|   | ||||
| @@ -1,8 +1,8 @@ | ||||
| import random | ||||
| from typing import List, Union | ||||
| from typing import List | ||||
|  | ||||
| from marl_factory_grid.environment.rules import Rule | ||||
| from marl_factory_grid.environment import constants as c | ||||
| from marl_factory_grid.environment.rules import Rule | ||||
| from marl_factory_grid.utils.results import TickResult | ||||
|  | ||||
|  | ||||
| @@ -14,8 +14,6 @@ class AgentSingleZonePlacementBeta(Rule): | ||||
|         super().__init__() | ||||
|  | ||||
|     def on_init(self, state, lvl_map): | ||||
|         zones = state[c.ZONES] | ||||
|         n_zones = state[c.ZONES] | ||||
|         agents = state[c.AGENT] | ||||
|         if len(self.coordinates) == len(agents): | ||||
|             coordinates = self.coordinates | ||||
|   | ||||
| @@ -1,4 +1,3 @@ | ||||
| from .actions import ItemAction | ||||
| from .entitites import Item, DropOffLocation | ||||
| from .groups import DropOffLocations, Items, Inventory, Inventories | ||||
| from .rules import ItemRules | ||||
|   | ||||
| @@ -29,7 +29,7 @@ class ItemAction(Action): | ||||
|         elif items := state[i.ITEM].by_pos(entity.pos): | ||||
|             item = items[0] | ||||
|             item.change_parent_collection(inventory) | ||||
|             item.set_pos_to(c.VALUE_NO_POS) | ||||
|             item.set_pos(c.VALUE_NO_POS) | ||||
|             state.print(f'{entity.name} just picked up an item at {entity.pos}') | ||||
|             return ActionResult(entity=entity, identifier=self._identifier, validity=c.VALID, reward=r.PICK_UP_VALID) | ||||
|  | ||||
|   | ||||
| @@ -1,6 +1,3 @@ | ||||
| from typing import NamedTuple | ||||
|  | ||||
|  | ||||
| SYMBOL_NO_ITEM      = 0 | ||||
| SYMBOL_DROP_OFF     = 1 | ||||
| # Item Env | ||||
|   | ||||
| @@ -8,56 +8,20 @@ from marl_factory_grid.modules.items import constants as i | ||||
|  | ||||
| class Item(Entity): | ||||
|  | ||||
|     @property | ||||
|     def var_can_collide(self): | ||||
|         return False | ||||
|  | ||||
|     def render(self): | ||||
|         return RenderEntity(i.ITEM, self.pos) if self.pos != c.VALUE_NO_POS else None | ||||
|  | ||||
|     def __init__(self, *args, **kwargs): | ||||
|         super().__init__(*args, **kwargs) | ||||
|         self._auto_despawn = -1 | ||||
|  | ||||
|     @property | ||||
|     def auto_despawn(self): | ||||
|         return self._auto_despawn | ||||
|  | ||||
|     @property | ||||
|     def encoding(self): | ||||
|         # Edit this if you want items to be drawn in the ops differently | ||||
|         return 1 | ||||
|  | ||||
|     def set_auto_despawn(self, auto_despawn): | ||||
|         self._auto_despawn = auto_despawn | ||||
|  | ||||
|     def set_pos_to(self, no_pos): | ||||
|         self._pos = no_pos | ||||
|  | ||||
|     def summarize_state(self) -> dict: | ||||
|         super_summarization = super(Item, self).summarize_state() | ||||
|         super_summarization.update(dict(auto_despawn=self.auto_despawn)) | ||||
|         return super_summarization | ||||
|  | ||||
|  | ||||
| class DropOffLocation(Entity): | ||||
|  | ||||
|     @property | ||||
|     def var_can_collide(self): | ||||
|         return False | ||||
|  | ||||
|     @property | ||||
|     def var_can_move(self): | ||||
|         return False | ||||
|  | ||||
|     @property | ||||
|     def var_is_blocking_light(self): | ||||
|         return False | ||||
|  | ||||
|     @property | ||||
|     def var_has_position(self): | ||||
|         return True | ||||
|  | ||||
|     def render(self): | ||||
|         return RenderEntity(i.DROP_OFF, self.pos) | ||||
|  | ||||
| @@ -65,18 +29,16 @@ class DropOffLocation(Entity): | ||||
|     def encoding(self): | ||||
|         return i.SYMBOL_DROP_OFF | ||||
|  | ||||
|     def __init__(self, *args, storage_size_until_full: int = 5, auto_item_despawn_interval: int = 5, **kwargs): | ||||
|     def __init__(self, *args, storage_size_until_full: int = 5, **kwargs): | ||||
|         super(DropOffLocation, self).__init__(*args, **kwargs) | ||||
|         self.auto_item_despawn_interval = auto_item_despawn_interval | ||||
|         self.storage = deque(maxlen=storage_size_until_full or None) | ||||
|  | ||||
|     def place_item(self, item: Item): | ||||
|         if self.is_full: | ||||
|             raise RuntimeWarning("There is currently no way to clear the storage or make it unfull.") | ||||
|             return bc.NOT_VALID  # in Zeile 81 verschieben? | ||||
|             return bc.NOT_VALID | ||||
|         else: | ||||
|             self.storage.append(item) | ||||
|             item.set_auto_despawn(self.auto_item_despawn_interval) | ||||
|             return c.VALID | ||||
|  | ||||
|     @property | ||||
|   | ||||
| @@ -1,13 +1,11 @@ | ||||
| from random import shuffle | ||||
|  | ||||
| from marl_factory_grid.modules.items import constants as i | ||||
| from marl_factory_grid.environment import constants as c | ||||
|  | ||||
| from marl_factory_grid.environment.groups.collection import Collection | ||||
| from marl_factory_grid.environment.groups.objects import _Objects | ||||
| from marl_factory_grid.environment.groups.mixins import IsBoundMixin | ||||
| from marl_factory_grid.environment.entity.agent import Agent | ||||
| from marl_factory_grid.environment.groups.collection import Collection | ||||
| from marl_factory_grid.environment.groups.mixins import IsBoundMixin | ||||
| from marl_factory_grid.environment.groups.objects import Objects | ||||
| from marl_factory_grid.modules.items import constants as i | ||||
| from marl_factory_grid.modules.items.entitites import Item, DropOffLocation | ||||
| from marl_factory_grid.utils.results import Result | ||||
|  | ||||
|  | ||||
| class Items(Collection): | ||||
| @@ -15,7 +13,7 @@ class Items(Collection): | ||||
|  | ||||
|     @property | ||||
|     def var_has_position(self): | ||||
|         return False | ||||
|         return True | ||||
|  | ||||
|     @property | ||||
|     def is_blocking_light(self): | ||||
| @@ -28,18 +26,18 @@ class Items(Collection): | ||||
|     def __init__(self, *args, **kwargs): | ||||
|         super().__init__(*args, **kwargs) | ||||
|  | ||||
|     @staticmethod | ||||
|     def trigger_item_spawn(state, n_items, spawn_frequency): | ||||
|         if item_to_spawns := max(0, (n_items - len(state[i.ITEM]))): | ||||
|             position_list = [x for x in state.entities.floorlist] | ||||
|             shuffle(position_list) | ||||
|             position_list = state.entities.floorlist[:item_to_spawns] | ||||
|             state[i.ITEM].spawn(position_list) | ||||
|             state.print(f'{item_to_spawns} new items have been spawned; next spawn in {spawn_frequency}') | ||||
|             return len(position_list) | ||||
|     def trigger_spawn(self, state, *entity_args, coords_or_quantity=None, **entity_kwargs) -> [Result]: | ||||
|         coords_or_quantity = coords_or_quantity if coords_or_quantity else self._coords_or_quantity | ||||
|         assert coords_or_quantity | ||||
|  | ||||
|         if item_to_spawns := max(0, (coords_or_quantity - len(self))): | ||||
|             return super().trigger_spawn(state, | ||||
|                                          *entity_args, | ||||
|                                          coords_or_quantity=item_to_spawns, | ||||
|                                          **entity_kwargs) | ||||
|         else: | ||||
|             state.print('No Items are spawning, limit is reached.') | ||||
|             return 0 | ||||
|             return Result(identifier=f'{self.name}_spawn', validity=c.NOT_VALID, value=coords_or_quantity) | ||||
|  | ||||
|  | ||||
| class Inventory(IsBoundMixin, Collection): | ||||
| @@ -73,12 +71,17 @@ class Inventory(IsBoundMixin, Collection): | ||||
|         self._collection = collection | ||||
|  | ||||
|  | ||||
| class Inventories(_Objects): | ||||
| class Inventories(Objects): | ||||
|     _entity = Inventory | ||||
|  | ||||
|     var_can_move = False | ||||
|     var_has_position = False | ||||
|  | ||||
|     symbol = None | ||||
|  | ||||
|     @property | ||||
|     def var_can_move(self): | ||||
|         return False | ||||
|     def spawn_rule(self): | ||||
|         return {c.SPAWN_ENTITY_RULE: dict(collection=self, coords_or_quantity=None)} | ||||
|  | ||||
|     def __init__(self, size: int, *args, **kwargs): | ||||
|         super(Inventories, self).__init__(*args, **kwargs) | ||||
| @@ -86,10 +89,12 @@ class Inventories(_Objects): | ||||
|         self._obs = None | ||||
|         self._lazy_eval_transforms = [] | ||||
|  | ||||
|     def spawn(self, agents): | ||||
|         inventories = [self._entity(agent, self.size, ) | ||||
|                        for _, agent in enumerate(agents)] | ||||
|         self.add_items(inventories) | ||||
|     def spawn(self, agents, *args, **kwargs): | ||||
|         self.add_items([self._entity(agent, self.size, *args, **kwargs) for _, agent in enumerate(agents)]) | ||||
|         return [Result(identifier=f'{self.name}_spawn', validity=c.VALID, value=len(self))] | ||||
|  | ||||
|     def trigger_spawn(self, state, *args, **kwargs) -> [Result]: | ||||
|         return self.spawn(state[c.AGENT], *args, **kwargs) | ||||
|  | ||||
|     def idx_by_entity(self, entity): | ||||
|         try: | ||||
| @@ -106,10 +111,6 @@ class Inventories(_Objects): | ||||
|     def summarize_states(self, **kwargs): | ||||
|         return [val.summarize_states(**kwargs) for key, val in self.items()] | ||||
|  | ||||
|     @staticmethod | ||||
|     def trigger_inventory_spawn(state): | ||||
|         state[i.INVENTORY].spawn(state[c.AGENT]) | ||||
|  | ||||
|  | ||||
| class DropOffLocations(Collection): | ||||
|     _entity = DropOffLocation | ||||
| @@ -135,7 +136,7 @@ class DropOffLocations(Collection): | ||||
|  | ||||
|     @staticmethod | ||||
|     def trigger_drop_off_location_spawn(state, n_locations): | ||||
|         empty_positions = state.entities.empty_positions()[:n_locations] | ||||
|         empty_positions = state.entities.empty_positions[:n_locations] | ||||
|         do_entites = state[i.DROP_OFF] | ||||
|         drop_offs = [DropOffLocation(pos) for pos in empty_positions] | ||||
|         do_entites.add_items(drop_offs) | ||||
|   | ||||
| @@ -6,52 +6,28 @@ from marl_factory_grid.utils.results import TickResult | ||||
| from marl_factory_grid.modules.items import constants as i | ||||
|  | ||||
|  | ||||
| class ItemRules(Rule): | ||||
| class RespawnItems(Rule): | ||||
|  | ||||
|     def __init__(self, n_items: int = 5, spawn_frequency: int = 15, | ||||
|                  n_locations: int = 5, max_dropoff_storage_size: int = 0): | ||||
|     def __init__(self, n_items: int = 5, respawn_freq: int = 15, n_locations: int = 5): | ||||
|         super().__init__() | ||||
|         self.spawn_frequency = spawn_frequency | ||||
|         self._next_item_spawn = spawn_frequency | ||||
|         self.spawn_frequency = respawn_freq | ||||
|         self._next_item_spawn = respawn_freq | ||||
|         self.n_items = n_items | ||||
|         self.max_dropoff_storage_size = max_dropoff_storage_size | ||||
|         self.n_locations = n_locations | ||||
|  | ||||
|     def on_init(self, state, lvl_map): | ||||
|         state[i.DROP_OFF].trigger_drop_off_location_spawn(state, self.n_locations) | ||||
|         self._next_item_spawn = self.spawn_frequency | ||||
|         state[i.INVENTORY].trigger_inventory_spawn(state) | ||||
|         state[i.ITEM].trigger_item_spawn(state, self.n_items, self.spawn_frequency) | ||||
|  | ||||
|     def tick_step(self, state): | ||||
|         for item in list(state[i.ITEM].values()): | ||||
|             if item.auto_despawn >= 1: | ||||
|                 item.set_auto_despawn(item.auto_despawn - 1) | ||||
|             elif not item.auto_despawn: | ||||
|                 state[i.ITEM].delete_env_object(item) | ||||
|             else: | ||||
|                 pass | ||||
|  | ||||
|         if not self._next_item_spawn: | ||||
|             state[i.ITEM].trigger_item_spawn(state, self.n_items, self.spawn_frequency) | ||||
|             state[i.ITEM].trigger_spawn(state, self.n_items, self.spawn_frequency) | ||||
|         else: | ||||
|             self._next_item_spawn = max(0, self._next_item_spawn - 1) | ||||
|         return [] | ||||
|  | ||||
|     def tick_post_step(self, state) -> List[TickResult]: | ||||
|         for item in list(state[i.ITEM].values()): | ||||
|             if item.auto_despawn >= 1: | ||||
|                 item.set_auto_despawn(item.auto_despawn-1) | ||||
|             elif not item.auto_despawn: | ||||
|                 state[i.ITEM].delete_env_object(item) | ||||
|             else: | ||||
|                 pass | ||||
|  | ||||
|         if not self._next_item_spawn: | ||||
|             if spawned_items := state[i.ITEM].trigger_item_spawn(state, self.n_items, self.spawn_frequency): | ||||
|                 return [TickResult(self.name, validity=c.VALID, value=spawned_items, entity=None)] | ||||
|             if spawned_items := state[i.ITEM].trigger_spawn(state, self.n_items, self.spawn_frequency): | ||||
|                 return [TickResult(self.name, validity=c.VALID, value=spawned_items.value)] | ||||
|             else: | ||||
|                 return [TickResult(self.name, validity=c.NOT_VALID, value=0, entity=None)] | ||||
|                 return [TickResult(self.name, validity=c.NOT_VALID, value=0)] | ||||
|         else: | ||||
|             self._next_item_spawn = max(0, self._next_item_spawn-1) | ||||
|             return [] | ||||
|   | ||||
| @@ -1,3 +1,2 @@ | ||||
| from .entitites import Machine | ||||
| from .groups import Machines | ||||
| from .rules import MachineRule | ||||
|   | ||||
| @@ -1,10 +1,12 @@ | ||||
| from typing import Union | ||||
|  | ||||
| import marl_factory_grid.modules.machines.constants | ||||
| from marl_factory_grid.environment.actions import Action | ||||
| from marl_factory_grid.utils.results import ActionResult | ||||
|  | ||||
| from marl_factory_grid.modules.machines import constants as m, rewards as r | ||||
| from marl_factory_grid.modules.machines import constants as m | ||||
| from marl_factory_grid.environment import constants as c | ||||
| from marl_factory_grid.utils import helpers as h | ||||
|  | ||||
|  | ||||
| class MachineAction(Action): | ||||
| @@ -13,13 +15,12 @@ class MachineAction(Action): | ||||
|         super().__init__(m.MACHINE_ACTION) | ||||
|  | ||||
|     def do(self, entity, state) -> Union[None, ActionResult]: | ||||
|         if machine := state[m.MACHINES].by_pos(entity.pos): | ||||
|         if machine := h.get_first(state[m.MACHINES].by_pos(entity.pos)): | ||||
|             if valid := machine.maintain(): | ||||
|                 return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=r.MAINTAIN_VALID) | ||||
|                 return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=marl_factory_grid.modules.machines.constants.MAINTAIN_VALID) | ||||
|             else: | ||||
|                 return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=r.MAINTAIN_FAIL) | ||||
|                 return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=marl_factory_grid.modules.machines.constants.MAINTAIN_FAIL) | ||||
|         else: | ||||
|             return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=r.MAINTAIN_FAIL) | ||||
|  | ||||
|  | ||||
|  | ||||
|             return ActionResult(entity=entity, identifier=self._identifier, | ||||
|                                 validity=c.NOT_VALID, reward=marl_factory_grid.modules.machines.constants.MAINTAIN_FAIL | ||||
|                                 ) | ||||
|   | ||||
| @@ -11,3 +11,7 @@ STATE_MAINTAIN = 'maintenance' | ||||
| SYMBOL_WORK = 1 | ||||
| SYMBOL_IDLE = 0.6 | ||||
| SYMBOL_MAINTAIN = 0.3 | ||||
| MAINTAIN_VALID: float = 0.5 | ||||
| MAINTAIN_FAIL: float = -0.1 | ||||
| FAIL_MISSING_MAINTENANCE: float = -0.5 | ||||
| NONE: float = 0 | ||||
|   | ||||
| @@ -8,22 +8,6 @@ from . import constants as m | ||||
|  | ||||
| class Machine(Entity): | ||||
|  | ||||
|     @property | ||||
|     def var_can_collide(self): | ||||
|         return False | ||||
|  | ||||
|     @property | ||||
|     def var_can_move(self): | ||||
|         return False | ||||
|  | ||||
|     @property | ||||
|     def var_is_blocking_light(self): | ||||
|         return False | ||||
|  | ||||
|     @property | ||||
|     def var_has_position(self): | ||||
|         return True | ||||
|  | ||||
|     @property | ||||
|     def encoding(self): | ||||
|         return self._encodings[self.status] | ||||
| @@ -46,12 +30,11 @@ class Machine(Entity): | ||||
|         else: | ||||
|             return c.NOT_VALID | ||||
|  | ||||
|     def tick(self): | ||||
|         # if self.status == m.STATE_MAINTAIN and any([c.AGENT in x.name for x in self.tile.guests]): | ||||
|         if self.status == m.STATE_MAINTAIN and any([c.AGENT in x.name for x in self.state.entities.pos_dict[self.pos]]): | ||||
|             return TickResult(identifier=self.name, validity=c.VALID, reward=0, entity=self) | ||||
|         # elif self.status == m.STATE_MAINTAIN and not any([c.AGENT in x.name for x in self.tile.guests]): | ||||
|         elif self.status == m.STATE_MAINTAIN and not any([c.AGENT in x.name for x in self.state.entities.pos_dict[self.pos]]): | ||||
|     def tick(self, state): | ||||
|         others = state.entities.pos_dict[self.pos] | ||||
|         if self.status == m.STATE_MAINTAIN and any([c.AGENT in x.name for x in others]): | ||||
|             return TickResult(identifier=self.name, validity=c.VALID, entity=self) | ||||
|         elif self.status == m.STATE_MAINTAIN and not any([c.AGENT in x.name for x in others]): | ||||
|             self.status = m.STATE_WORK | ||||
|             self.reset_counter() | ||||
|             return None | ||||
|   | ||||
| @@ -1,5 +1,3 @@ | ||||
| from typing import Union, List, Tuple | ||||
|  | ||||
| from marl_factory_grid.environment.groups.collection import Collection | ||||
|  | ||||
| from .entitites import Machine | ||||
|   | ||||
| @@ -1,5 +0,0 @@ | ||||
| MAINTAIN_VALID: float = 0.5 | ||||
| MAINTAIN_FAIL: float = -0.1 | ||||
| FAIL_MISSING_MAINTENANCE: float = -0.5 | ||||
|  | ||||
| NONE: float = 0 | ||||
| @@ -1,28 +0,0 @@ | ||||
| from typing import List | ||||
| from marl_factory_grid.environment.rules import Rule | ||||
| from marl_factory_grid.utils.results import TickResult, DoneResult | ||||
| from marl_factory_grid.environment import constants as c | ||||
| from marl_factory_grid.modules.machines import constants as m | ||||
| from marl_factory_grid.modules.machines.entitites import Machine | ||||
|  | ||||
|  | ||||
| class MachineRule(Rule): | ||||
|  | ||||
|     def __init__(self, n_machines: int = 2): | ||||
|         super(MachineRule, self).__init__() | ||||
|         self.n_machines = n_machines | ||||
|  | ||||
|     def on_init(self, state, lvl_map): | ||||
|         state[m.MACHINES].spawn(state.entities.empty_positions()) | ||||
|  | ||||
|     def tick_pre_step(self, state) -> List[TickResult]: | ||||
|         pass | ||||
|  | ||||
|     def tick_step(self, state) -> List[TickResult]: | ||||
|         pass | ||||
|  | ||||
|     def tick_post_step(self, state) -> List[TickResult]: | ||||
|         pass | ||||
|  | ||||
|     def on_check_done(self, state) -> List[DoneResult]: | ||||
|         pass | ||||
|   | ||||
| @@ -1,3 +1,4 @@ | ||||
| MAINTAINER         = 'Maintainer'                    # TEMPLATE _identifier. Define your own! | ||||
| MAINTAINERS        = 'Maintainers'                   # TEMPLATE _identifier. Define your own! | ||||
|  | ||||
| MAINTAINER_COLLISION_REWARD = -5 | ||||
|   | ||||
| @@ -1,48 +1,35 @@ | ||||
| from random import shuffle | ||||
|  | ||||
| import networkx as nx | ||||
| import numpy as np | ||||
|  | ||||
|  | ||||
| from ...algorithms.static.utils import points_to_graph | ||||
| from ...environment import constants as c | ||||
| from ...environment.actions import Action, ALL_BASEACTIONS | ||||
| from ...environment.entity.entity import Entity | ||||
| from ..doors import constants as do | ||||
| from ..maintenance import constants as mi | ||||
| from ...utils.helpers import MOVEMAP | ||||
| from ...utils.utility_classes import RenderEntity | ||||
| from ...utils.states import Gamestate | ||||
| from ...utils import helpers as h | ||||
| from ...utils.utility_classes import RenderEntity, Floor | ||||
| from ..doors import DoorUse | ||||
|  | ||||
|  | ||||
| class Maintainer(Entity): | ||||
|  | ||||
|     @property | ||||
|     def var_can_collide(self): | ||||
|         return True | ||||
|  | ||||
|     @property | ||||
|     def var_can_move(self): | ||||
|         return False | ||||
|  | ||||
|     @property | ||||
|     def var_is_blocking_light(self): | ||||
|         return False | ||||
|  | ||||
|     @property | ||||
|     def var_has_position(self): | ||||
|         return True | ||||
|  | ||||
|     def __init__(self, state: Gamestate, objective: str, action: Action, *args, **kwargs): | ||||
|     def __init__(self, objective: str, action: Action, *args, **kwargs): | ||||
|         super().__init__(*args, **kwargs) | ||||
|         self.action = action | ||||
|         self.actions = [x() for x in ALL_BASEACTIONS] | ||||
|         self.actions = [x() for x in ALL_BASEACTIONS] + [DoorUse()] | ||||
|         self.objective = objective | ||||
|         self._path = None | ||||
|         self._next = [] | ||||
|         self._last = [] | ||||
|         self._last_serviced = 'None' | ||||
|         self._floortile_graph = points_to_graph(state.entities.floorlist) | ||||
|         self._floortile_graph = None | ||||
|  | ||||
|     def tick(self, state): | ||||
|         if found_objective := state[self.objective].by_pos(self.pos): | ||||
|         if found_objective := h.get_first(state[self.objective].by_pos(self.pos)): | ||||
|             if found_objective.name != self._last_serviced: | ||||
|                 self.action.do(self, state) | ||||
|                 self._last_serviced = found_objective.name | ||||
| @@ -54,24 +41,27 @@ class Maintainer(Entity): | ||||
|             return action.do(self, state) | ||||
|  | ||||
|     def get_move_action(self, state) -> Action: | ||||
|         if not self._floortile_graph: | ||||
|             state.print("Generating Floorgraph....") | ||||
|             self._floortile_graph = points_to_graph(state.entities.floorlist) | ||||
|         if self._path is None or not self._path: | ||||
|             if not self._next: | ||||
|                 self._next = list(state[self.objective].values()) | ||||
|                 self._next = list(state[self.objective].values()) + [Floor(*state.random_free_position)] | ||||
|                 shuffle(self._next) | ||||
|                 self._last = [] | ||||
|             self._last.append(self._next.pop()) | ||||
|             state.print("Calculating shortest path....") | ||||
|             self._path = self.calculate_route(self._last[-1]) | ||||
|  | ||||
|         if door := self._door_is_close(state): | ||||
|             if door.is_closed: | ||||
|         if door := self._closed_door_in_path(state): | ||||
|             state.print(f"{self} found {door} that is closed. Attempt to open.") | ||||
|             # Translate the action_object to an integer to have the same output as any other model | ||||
|             action = do.ACTION_DOOR_USE | ||||
|         else: | ||||
|             action = self._predict_move(state) | ||||
|         else: | ||||
|             action = self._predict_move(state) | ||||
|         # Translate the action_object to an integer to have the same output as any other model | ||||
|         try: | ||||
|             action_obj = next(x for x in self.actions if x.name == action) | ||||
|             action_obj = h.get_first(self.actions, lambda x: x.name == action) | ||||
|         except (StopIteration, UnboundLocalError): | ||||
|             print('Will not happen') | ||||
|             raise EnvironmentError | ||||
| @@ -81,11 +71,10 @@ class Maintainer(Entity): | ||||
|         route = nx.shortest_path(self._floortile_graph, self.pos, entity.pos) | ||||
|         return route[1:] | ||||
|  | ||||
|     def _door_is_close(self, state): | ||||
|         state.print("Found a door that is close.") | ||||
|         try: | ||||
|             return next(y for x in state.entities.neighboring_positions(self.state.pos) for y in state.entities.pos_dict[x] if do.DOOR in y.name) | ||||
|         except StopIteration: | ||||
|     def _closed_door_in_path(self, state): | ||||
|         if self._path: | ||||
|             return h.get_first(state[do.DOORS].by_pos(self._path[0]), lambda x: x.is_closed) | ||||
|         else: | ||||
|             return None | ||||
|  | ||||
|     def _predict_move(self, state): | ||||
| @@ -96,7 +85,7 @@ class Maintainer(Entity): | ||||
|             next_pos = self._path.pop(0) | ||||
|             diff = np.subtract(next_pos, self.pos) | ||||
|             # Retrieve action based on the pos dif (like in: What do I have to do to get there?) | ||||
|             action = next(action for action, pos_diff in MOVEMAP.items() if np.all(diff == pos_diff)) | ||||
|             action = next(action for action, pos_diff in h.MOVEMAP.items() if np.all(diff == pos_diff)) | ||||
|         return action | ||||
|  | ||||
|     def render(self): | ||||
|   | ||||
| @@ -1,34 +1,27 @@ | ||||
| from typing import Union, List, Tuple | ||||
| from typing import Union, List, Tuple, Dict | ||||
|  | ||||
| from marl_factory_grid.environment.groups.collection import Collection | ||||
| from .entities import Maintainer | ||||
| from ..machines import constants as mc | ||||
| from ..machines.actions import MachineAction | ||||
| from ...utils.states import Gamestate | ||||
|  | ||||
|  | ||||
| class Maintainers(Collection): | ||||
|     _entity = Maintainer | ||||
|  | ||||
|     @property | ||||
|     def var_can_collide(self): | ||||
|         return True | ||||
|     var_can_collide = True | ||||
|     var_can_move = True | ||||
|     var_is_blocking_light = False | ||||
|     var_has_position = True | ||||
|  | ||||
|     @property | ||||
|     def var_can_move(self): | ||||
|         return True | ||||
|     def __init__(self, size, *args, coords_or_quantity: int = None, | ||||
|                  spawnrule: Union[None, Dict[str, dict]] = None, | ||||
|                  **kwargs): | ||||
|         super(Collection, self).__init__(*args, **kwargs) | ||||
|         self._coords_or_quantity = coords_or_quantity | ||||
|         self.size = size | ||||
|         self._spawnrule = spawnrule | ||||
|  | ||||
|     @property | ||||
|     def var_is_blocking_light(self): | ||||
|         return False | ||||
|  | ||||
|     @property | ||||
|     def var_has_position(self): | ||||
|         return True | ||||
|  | ||||
|     def __init__(self, *args, **kwargs): | ||||
|         super().__init__(*args, **kwargs) | ||||
|  | ||||
|     def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args): | ||||
|         state = entity_args[0] | ||||
|         self.add_items([self._entity(state, mc.MACHINES, MachineAction(), pos) for pos in coords_or_quantity]) | ||||
|         self.add_items([self._entity(mc.MACHINES, MachineAction(), pos) for pos in coords_or_quantity]) | ||||
|   | ||||
| @@ -1 +0,0 @@ | ||||
| MAINTAINER_COLLISION_REWARD = -5 | ||||
| @@ -1,32 +1,28 @@ | ||||
| from typing import List | ||||
|  | ||||
| import marl_factory_grid.modules.maintenance.constants | ||||
| from marl_factory_grid.environment.rules import Rule | ||||
| from marl_factory_grid.utils.results import TickResult, DoneResult | ||||
| from marl_factory_grid.environment import constants as c | ||||
| from . import rewards as r | ||||
| from . import constants as M | ||||
| from marl_factory_grid.utils.states import Gamestate | ||||
|  | ||||
|  | ||||
| class MaintenanceRule(Rule): | ||||
| class MoveMaintainers(Rule): | ||||
|  | ||||
|     def __init__(self, n_maintainer: int = 1, *args, **kwargs): | ||||
|         super(MaintenanceRule, self).__init__(*args, **kwargs) | ||||
|         self.n_maintainer = n_maintainer | ||||
|  | ||||
|     def on_init(self, state: Gamestate, lvl_map): | ||||
|         state[M.MAINTAINERS].spawn(state.entities.empty_positions[:self.n_maintainer], state) | ||||
|         pass | ||||
|  | ||||
|     def tick_pre_step(self, state) -> List[TickResult]: | ||||
|         pass | ||||
|     def __init__(self): | ||||
|         super().__init__() | ||||
|  | ||||
|     def tick_step(self, state) -> List[TickResult]: | ||||
|         for maintainer in state[M.MAINTAINERS]: | ||||
|             maintainer.tick(state) | ||||
|         # Todo: Return a Result Object. | ||||
|         return [] | ||||
|  | ||||
|     def tick_post_step(self, state) -> List[TickResult]: | ||||
|         pass | ||||
|  | ||||
| class DoneAtMaintainerCollision(Rule): | ||||
|  | ||||
|     def __init__(self): | ||||
|         super().__init__() | ||||
|  | ||||
|     def on_check_done(self, state) -> List[DoneResult]: | ||||
|         agents = list(state[c.AGENT].values()) | ||||
| @@ -35,5 +31,5 @@ class MaintenanceRule(Rule): | ||||
|         for agent in agents: | ||||
|             if agent.pos in m_pos: | ||||
|                 done_results.append(DoneResult(entity=agent, validity=c.VALID, identifier=self.name, | ||||
|                                                reward=r.MAINTAINER_COLLISION_REWARD)) | ||||
|                                                reward=marl_factory_grid.modules.maintenance.constants.MAINTAINER_COLLISION_REWARD)) | ||||
|         return done_results | ||||
|   | ||||
| @@ -1,10 +1,10 @@ | ||||
| import random | ||||
| from typing import List, Tuple | ||||
|  | ||||
| from marl_factory_grid.environment.entity.object import _Object | ||||
| from marl_factory_grid.environment.entity.object import Object | ||||
|  | ||||
|  | ||||
| class Zone(_Object): | ||||
| class Zone(Object): | ||||
|  | ||||
|     @property | ||||
|     def positions(self): | ||||
|   | ||||
| @@ -1,8 +1,8 @@ | ||||
| from marl_factory_grid.environment.groups.objects import _Objects | ||||
| from marl_factory_grid.environment.groups.objects import Objects | ||||
| from marl_factory_grid.modules.zones import Zone | ||||
|  | ||||
|  | ||||
| class Zones(_Objects): | ||||
| class Zones(Objects): | ||||
|     symbol = None | ||||
|     _entity = Zone | ||||
|  | ||||
|   | ||||
| @@ -1,8 +1,8 @@ | ||||
| from random import choices, choice | ||||
|  | ||||
| from . import constants as z, Zone | ||||
| from .. import Destination | ||||
| from ..destinations import constants as d | ||||
| from ... import Destination | ||||
| from ...environment.rules import Rule | ||||
| from ...environment import constants as c | ||||
|  | ||||
|   | ||||
| @@ -0,0 +1,3 @@ | ||||
| from . import helpers as h | ||||
| from . import helpers | ||||
| from .results import Result, DoneResult, ActionResult, TickResult | ||||
|   | ||||
| @@ -1,4 +1,5 @@ | ||||
| import ast | ||||
|  | ||||
| from os import PathLike | ||||
| from pathlib import Path | ||||
| from typing import Union, List | ||||
| @@ -9,18 +10,17 @@ from marl_factory_grid.environment import constants as c | ||||
| from marl_factory_grid.environment.rules import Rule | ||||
| from marl_factory_grid.environment.tests import Test | ||||
| from marl_factory_grid.utils.helpers import locate_and_import_class | ||||
|  | ||||
| DEFAULT_PATH = 'environment' | ||||
| MODULE_PATH = 'modules' | ||||
| from marl_factory_grid.environment.constants import DEFAULT_PATH, MODULE_PATH | ||||
| from marl_factory_grid.environment import constants as c | ||||
|  | ||||
|  | ||||
| class FactoryConfigParser(object): | ||||
|     default_entites = [] | ||||
|     default_rules = ['MaxStepsReached', 'Collision'] | ||||
|     default_rules = ['DoneAtMaxStepsReached', 'WatchCollision'] | ||||
|     default_actions = [c.MOVE8, c.NOOP] | ||||
|     default_observations = [c.WALLS, c.AGENT] | ||||
|  | ||||
|     def __init__(self, config_path, custom_modules_path: Union[None, PathLike] = None): | ||||
|     def __init__(self, config_path, custom_modules_path: Union[PathLike] = None): | ||||
|         self.config_path = Path(config_path) | ||||
|         self.custom_modules_path = Path(custom_modules_path) if custom_modules_path is not None else custom_modules_path | ||||
|         self.config = yaml.safe_load(self.config_path.open()) | ||||
| @@ -44,6 +44,10 @@ class FactoryConfigParser(object): | ||||
|     def rules(self): | ||||
|         return self.config['Rules'] | ||||
|  | ||||
|     @property | ||||
|     def tests(self): | ||||
|         return self.config.get('Tests', []) | ||||
|  | ||||
|     @property | ||||
|     def agents(self): | ||||
|         return self.config['Agents'] | ||||
| @@ -56,10 +60,12 @@ class FactoryConfigParser(object): | ||||
|         return str(self.config) | ||||
|  | ||||
|     def __getitem__(self, item): | ||||
|         try: | ||||
|             return self.config[item] | ||||
|         except KeyError: | ||||
|             print(f'The mandatory {item} section could not be found in your .config gile. Check Spelling!') | ||||
|  | ||||
|     def load_entities(self): | ||||
|         # entites = Entities() | ||||
|         entity_classes = dict() | ||||
|         entities = [] | ||||
|         if c.DEFAULTS in self.entities: | ||||
| @@ -67,28 +73,40 @@ class FactoryConfigParser(object): | ||||
|         entities.extend(x for x in self.entities if x != c.DEFAULTS) | ||||
|  | ||||
|         for entity in entities: | ||||
|             e1 = e2 = e3 = None | ||||
|             try: | ||||
|                 folder_path = Path(__file__).parent.parent / DEFAULT_PATH | ||||
|                 entity_class = locate_and_import_class(entity, folder_path) | ||||
|             except AttributeError as e1: | ||||
|             except AttributeError as e: | ||||
|                 e1 = e | ||||
|                 try: | ||||
|                     folder_path = Path(__file__).parent.parent / MODULE_PATH | ||||
|                     entity_class = locate_and_import_class(entity, folder_path) | ||||
|                 except AttributeError as e2: | ||||
|                     module_path = Path(__file__).parent.parent / MODULE_PATH | ||||
|                     entity_class = locate_and_import_class(entity, module_path) | ||||
|                 except AttributeError as e: | ||||
|                     e2 = e | ||||
|                     if self.custom_modules_path: | ||||
|                         try: | ||||
|                         folder_path = self.custom_modules_path | ||||
|                         entity_class = locate_and_import_class(entity, folder_path) | ||||
|                     except AttributeError as e3: | ||||
|                         ents = [y for x in [e1.argss[1], e2.argss[1], e3.argss[1]] for y in x] | ||||
|                             entity_class = locate_and_import_class(entity, self.custom_modules_path) | ||||
|                         except AttributeError as e: | ||||
|                             e3 = e | ||||
|                             pass | ||||
|             if (e1 and e2) or e3: | ||||
|                 ents = [y for x in [e1, e2, e3] if x is not None for y in x.args[1]] | ||||
|                 print('##############################################################') | ||||
|                 print('### Error  ###  Error  ###  Error  ###  Error  ###  Error  ###') | ||||
|                         print() | ||||
|                 print('##############################################################') | ||||
|                 print(f'Class "{entity}" was not found in "{module_path.name}"') | ||||
|                 print(f'Class "{entity}" was not found in "{folder_path.name}"') | ||||
|                 print('##############################################################') | ||||
|                 if self.custom_modules_path: | ||||
|                     print(f'Class "{entity}" was not found in "{self.custom_modules_path}"') | ||||
|                 print('Possible Entitys are:', str(ents)) | ||||
|                         print() | ||||
|                 print('##############################################################') | ||||
|                 print('Goodbye') | ||||
|                         print() | ||||
|                         exit() | ||||
|                         # raise AttributeError(e1.argss[0], e2.argss[0], e3.argss[0], 'Possible Entitys are:', str(ents)) | ||||
|                 print('##############################################################') | ||||
|                 print('### Error  ###  Error  ###  Error  ###  Error  ###  Error  ###') | ||||
|                 print('##############################################################') | ||||
|                 exit(-99999) | ||||
|  | ||||
|             entity_kwargs = self.entities.get(entity, {}) | ||||
|             entity_symbol = entity_class.symbol if hasattr(entity_class, 'symbol') else None | ||||
| @@ -126,7 +144,12 @@ class FactoryConfigParser(object): | ||||
|                 observations.extend(self.default_observations) | ||||
|             observations.extend(x for x in self.agents[name]['Observations'] if x != c.DEFAULTS) | ||||
|             positions = [ast.literal_eval(x) for x in self.agents[name].get('Positions', [])] | ||||
|             parsed_agents_conf[name] = dict(actions=parsed_actions, observations=observations, positions=positions) | ||||
|             other_kwargs = {k: v for k, v in self.agents[name].items() if k not in | ||||
|                             ['Actions', 'Observations', 'Positions']} | ||||
|             parsed_agents_conf[name] = dict( | ||||
|                 actions=parsed_actions, observations=observations, positions=positions, other=other_kwargs | ||||
|                                             ) | ||||
|  | ||||
|         return parsed_agents_conf | ||||
|  | ||||
|     def load_env_rules(self) -> List[Rule]: | ||||
| @@ -137,28 +160,69 @@ class FactoryConfigParser(object): | ||||
|                     rules.append({rule: {}}) | ||||
|  | ||||
|         return self._load_smth(rules, Rule) | ||||
|         pass | ||||
|  | ||||
|     def load_env_tests(self) -> List[Test]: | ||||
|     def load_env_tests(self) -> List[Rule]: | ||||
|         return self._load_smth(self.tests, None)  # Test | ||||
|         pass | ||||
|  | ||||
|     def _load_smth(self, config, class_obj): | ||||
|         rules = list() | ||||
|         rules_names = list() | ||||
|  | ||||
|         for rule in rules_names: | ||||
|         for rule in config: | ||||
|             e1 = e2 = e3 = None | ||||
|             try: | ||||
|                 folder_path = (Path(__file__).parent.parent / DEFAULT_PATH) | ||||
|                 rule_class = locate_and_import_class(rule, folder_path) | ||||
|             except AttributeError as e: | ||||
|                 e1 = e | ||||
|                 try: | ||||
|                     module_path = (Path(__file__).parent.parent / MODULE_PATH) | ||||
|                     rule_class = locate_and_import_class(rule, module_path) | ||||
|                 except AttributeError as e: | ||||
|                     e2 = e | ||||
|                     if self.custom_modules_path: | ||||
|                         try: | ||||
|                             rule_class = locate_and_import_class(rule, self.custom_modules_path) | ||||
|                         except AttributeError as e: | ||||
|                             e3 = e | ||||
|                             pass | ||||
|             if (e1 and e2) or e3: | ||||
|                 ents = [y for x in [e1, e2, e3] if x is not None for y in x.args[1]] | ||||
|                 print('### Error  ###  Error  ###  Error  ###  Error  ###  Error  ###') | ||||
|                 print('') | ||||
|                 print(f'Class "{rule}" was not found in "{module_path.name}"') | ||||
|                 print(f'Class "{rule}" was not found in "{folder_path.name}"') | ||||
|                 if self.custom_modules_path: | ||||
|                     print(f'Class "{rule}" was not found in "{self.custom_modules_path}"') | ||||
|                 print('Possible Entitys are:', str(ents)) | ||||
|                 print('') | ||||
|                 print('Goodbye') | ||||
|                 print('') | ||||
|                 exit(-99999) | ||||
|  | ||||
|             if issubclass(rule_class, class_obj): | ||||
|                 rule_kwargs = config.get(rule, {}) | ||||
|                 rules.append(rule_class(**(rule_kwargs or {}))) | ||||
|         return rules | ||||
|  | ||||
|     def load_entity_spawn_rules(self, entities) -> List[Rule]: | ||||
|         rules = list() | ||||
|         rules_dicts = list() | ||||
|         for e in entities: | ||||
|             try: | ||||
|                 if spawn_rule := e.spawn_rule: | ||||
|                     rules_dicts.append(spawn_rule) | ||||
|             except AttributeError: | ||||
|                 pass | ||||
|  | ||||
|         for rule_dict in rules_dicts: | ||||
|             for rule_name, rule_kwargs in rule_dict.items(): | ||||
|                 try: | ||||
|                     folder_path = (Path(__file__).parent.parent / DEFAULT_PATH) | ||||
|                     rule_class = locate_and_import_class(rule_name, folder_path) | ||||
|                 except AttributeError: | ||||
|                     try: | ||||
|                         folder_path = (Path(__file__).parent.parent / MODULE_PATH) | ||||
|                     rule_class = locate_and_import_class(rule, folder_path) | ||||
|                         rule_class = locate_and_import_class(rule_name, folder_path) | ||||
|                     except AttributeError: | ||||
|                     rule_class = locate_and_import_class(rule, self.custom_modules_path) | ||||
|             # Fixme This check does not work! | ||||
|             #  assert isinstance(rule_class, class_obj), f'{rule_class.__name__} is no valid "class_obj.__name__".' | ||||
|             rule_kwargs = config.get(rule, {}) | ||||
|                         rule_class = locate_and_import_class(rule_name, self.custom_modules_path) | ||||
|                 rules.append(rule_class(**rule_kwargs)) | ||||
|         return rules | ||||
|   | ||||
| @@ -2,7 +2,7 @@ import importlib | ||||
|  | ||||
| from collections import defaultdict | ||||
| from pathlib import PurePath, Path | ||||
| from typing import Union, Dict, List | ||||
| from typing import Union, Dict, List, Iterable, Callable | ||||
|  | ||||
| import numpy as np | ||||
| from numpy.typing import ArrayLike | ||||
| @@ -61,8 +61,8 @@ class ObservationTranslator: | ||||
|         :param per_agent_named_obs_spaces: `Named observation space` one for each agent. Overloaded. | ||||
|         type  per_agent_named_obs_spaces: Dict[str, dict] | ||||
|  | ||||
|         :param placeholder_fill_value: Currently not fully implemented!!! | ||||
|         :type  placeholder_fill_value: Union[int, str] = 'N') | ||||
|         :param placeholder_fill_value: Currently, not fully implemented!!! | ||||
|         :type  placeholder_fill_value: Union[int, str] = 'N' | ||||
|         """ | ||||
|  | ||||
|         if isinstance(placeholder_fill_value, str): | ||||
| @@ -222,7 +222,7 @@ def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''): | ||||
|         mod = importlib.import_module('.'.join(module_parts)) | ||||
|         all_found_modules.extend([x for x in dir(mod) if (not(x.startswith('__') or len(x) <= 2) and x.istitle()) | ||||
|                                   and x not in ['Entity',  'NamedTuple', 'List', 'Rule', 'Union', | ||||
|                                                 'TickResult', 'ActionResult', 'Action', 'Agent', 'BoundEntityMixin', | ||||
|                                                 'TickResult', 'ActionResult', 'Action', 'Agent', | ||||
|                                                 'RenderEntity', 'TemplateRule', 'Objects', 'PositionMixin', | ||||
|                                                 'IsBoundMixin', 'EnvObject', 'EnvObjects', 'Dict', 'Any' | ||||
|                                                 ]]) | ||||
| @@ -240,7 +240,13 @@ def add_bound_name(name_str, bound_e): | ||||
|  | ||||
| def add_pos_name(name_str, bound_e): | ||||
|     if bound_e.var_has_position: | ||||
|         return f'{name_str}({bound_e.pos})' | ||||
|         return f'{name_str}@{bound_e.pos}' | ||||
|     return name_str | ||||
|  | ||||
|  | ||||
| def get_first(iterable: Iterable, filter_by: Callable[[any], bool] = lambda _: True): | ||||
|     return next((x for x in iterable if filter_by(x)), None) | ||||
|  | ||||
|  | ||||
| def get_first_index(iterable: Iterable, filter_by: Callable[[any], bool] = lambda _: True): | ||||
|     return next((idx for idx, x in enumerate(iterable) if filter_by(x)), None) | ||||
|   | ||||
| @@ -47,6 +47,7 @@ class LevelParser(object): | ||||
|         # All other | ||||
|         for es_name in self.e_p_dict: | ||||
|             e_class, e_kwargs = self.e_p_dict[es_name]['class'], self.e_p_dict[es_name]['kwargs'] | ||||
|             e_kwargs = e_kwargs if e_kwargs else {} | ||||
|  | ||||
|             if hasattr(e_class, 'symbol') and e_class.symbol is not None: | ||||
|                 symbols = e_class.symbol | ||||
|   | ||||
| @@ -9,7 +9,7 @@ from marl_factory_grid.utils.helpers import IGNORED_DF_COLUMNS | ||||
|  | ||||
| import pandas as pd | ||||
|  | ||||
| from marl_factory_grid.utils.plotting.compare_runs import plot_single_run | ||||
| from marl_factory_grid.utils.plotting.plot_single_runs import plot_single_run | ||||
|  | ||||
|  | ||||
| class EnvMonitor(Wrapper): | ||||
| @@ -22,7 +22,6 @@ class EnvMonitor(Wrapper): | ||||
|         self._monitor_df = pd.DataFrame() | ||||
|         self._monitor_dict = dict() | ||||
|  | ||||
|  | ||||
|     def step(self, action): | ||||
|         obs_type, obs, reward, done, info = self.env.step(action) | ||||
|         self._read_info(info) | ||||
|   | ||||
| @@ -2,11 +2,9 @@ from os import PathLike | ||||
| from pathlib import Path | ||||
| from typing import Union, List | ||||
|  | ||||
| import yaml | ||||
| from gymnasium import Wrapper | ||||
|  | ||||
| import numpy as np | ||||
| import pandas as pd | ||||
| from gymnasium import Wrapper | ||||
|  | ||||
|  | ||||
| class EnvRecorder(Wrapper): | ||||
| @@ -106,7 +104,7 @@ class EnvRecorder(Wrapper): | ||||
|                 out_dict = {'episodes': self._recorder_out_list} | ||||
|             out_dict.update( | ||||
|                 {'n_episodes': self._curr_episode, | ||||
|                  'metadata':dict( | ||||
|                  'metadata': dict( | ||||
|                      level_name=self.env.params['General']['level_name'], | ||||
|                      verbose=False, | ||||
|                      n_agents=len(self.env.params['Agents']), | ||||
|   | ||||
| @@ -1,17 +1,16 @@ | ||||
| import math | ||||
| import re | ||||
| from collections import defaultdict | ||||
| from itertools import product | ||||
| from typing import Dict, List | ||||
|  | ||||
| import numpy as np | ||||
| from numba import njit | ||||
|  | ||||
| from marl_factory_grid.environment import constants as c | ||||
| from marl_factory_grid.environment.entity.object import Object | ||||
| from marl_factory_grid.environment.groups.utils import Combined | ||||
| import marl_factory_grid.utils.helpers as h | ||||
| from marl_factory_grid.utils.states import Gamestate | ||||
| from marl_factory_grid.utils.utility_classes import Floor | ||||
| from marl_factory_grid.utils.ray_caster import RayCaster | ||||
| from marl_factory_grid.utils.states import Gamestate | ||||
| from marl_factory_grid.utils import helpers as h | ||||
|  | ||||
|  | ||||
| class OBSBuilder(object): | ||||
| @@ -77,11 +76,13 @@ class OBSBuilder(object): | ||||
|  | ||||
|     def place_entity_in_observation(self, obs_array, agent, e): | ||||
|         x, y = (e.x - agent.x) + self.pomdp_r, (e.y - agent.y) + self.pomdp_r | ||||
|         if not min([y, x]) < 0: | ||||
|             try: | ||||
|                 obs_array[x, y] += e.encoding | ||||
|             except IndexError: | ||||
|                 # Seemded to be visible but is out of range | ||||
|                 pass | ||||
|         pass | ||||
|  | ||||
|     def build_for_agent(self, agent, state) -> (List[str], np.ndarray): | ||||
|         assert self._curr_env_step == state.curr_step, ( | ||||
| @@ -121,18 +122,24 @@ class OBSBuilder(object): | ||||
|                         e = self.all_obs[l_name] | ||||
|                     except KeyError: | ||||
|                         try: | ||||
|                             # Look for bound entity names! | ||||
|                             pattern = re.compile(f'{re.escape(l_name)}(.*){re.escape(agent.name)}') | ||||
|                             name = next((x for x in self.all_obs if pattern.search(x)), None) | ||||
|                             # Look for bound entity REPRs! | ||||
|                             pattern = re.compile(f'{re.escape(l_name)}' | ||||
|                                                  f'{re.escape("[")}(.*){re.escape("]")}' | ||||
|                                                  f'{re.escape("(")}{re.escape(agent.name)}{re.escape(")")}') | ||||
|                             name = next((key for key, val in self.all_obs.items() | ||||
|                                          if pattern.search(str(val)) and isinstance(val, Object)), None) | ||||
|                             e = self.all_obs[name] | ||||
|                         except KeyError: | ||||
|                             try: | ||||
|                                 e = next(v for k, v in self.all_obs.items() if l_name in k and agent.name in k) | ||||
|                             except StopIteration: | ||||
|                                 raise KeyError( | ||||
|                                     f'Check for spelling errors! \n ' | ||||
|                                     f'No combination of "{l_name} and {agent.name}" could not be found in:\n ' | ||||
|                                     f'{list(dict(self.all_obs).keys())}') | ||||
|                                 print(f'# Check for spelling errors!') | ||||
|                                 print(f'# No combination of "{l_name}" and "{agent.name}" could not be found in:') | ||||
|                                 print(f'# {list(dict(self.all_obs).keys())}') | ||||
|                                 print('#') | ||||
|                                 print('# exiting...') | ||||
|                                 print('#') | ||||
|                                 exit(-99999) | ||||
|  | ||||
|                     try: | ||||
|                         positional = e.var_has_position | ||||
| @@ -161,31 +168,30 @@ class OBSBuilder(object): | ||||
|             try: | ||||
|                 light_map = np.zeros(self.obs_shape) | ||||
|                 visible_floor = self.ray_caster[agent.name].visible_entities(self._floortiles, reset_cache=False) | ||||
|                 if self.pomdp_r: | ||||
|  | ||||
|                 for f in set(visible_floor): | ||||
|                     self.place_entity_in_observation(light_map, agent, f) | ||||
|                 else: | ||||
|                     for f in set(visible_floor): | ||||
|                         light_map[f.x, f.y] += f.encoding | ||||
|                 # else: | ||||
|                 #     for f in set(visible_floor): | ||||
|                 #         light_map[f.x, f.y] += f.encoding | ||||
|                 self.curr_lightmaps[agent.name] = light_map | ||||
|             except (KeyError, ValueError): | ||||
|                 print() | ||||
|                 pass | ||||
|         return obs, self.obs_layers[agent.name] | ||||
|  | ||||
|     def _sort_and_name_observation_conf(self, agent): | ||||
|         ''' | ||||
|         """ | ||||
|         Builds the useable observation scheme per agent from conf.yaml. | ||||
|         :param agent: | ||||
|         :return: | ||||
|         ''' | ||||
|         """ | ||||
|         # Fixme: no asymetric shapes possible. | ||||
|         self.ray_caster[agent.name] = RayCaster(agent, min(self.obs_shape)) | ||||
|         obs_layers = [] | ||||
|  | ||||
|         for obs_str in agent.observations: | ||||
|             if isinstance(obs_str, dict): | ||||
|                 obs_str, vals = next(obs_str.items().__iter__()) | ||||
|                 obs_str, vals = h.get_first(obs_str.items()) | ||||
|             else: | ||||
|                 vals = None | ||||
|             if obs_str == c.SELF: | ||||
| @@ -214,129 +220,3 @@ class OBSBuilder(object): | ||||
|                 obs_layers.append(obs_str) | ||||
|         self.obs_layers[agent.name] = obs_layers | ||||
|         self.curr_lightmaps[agent.name] = np.zeros(self.obs_shape) | ||||
|  | ||||
|  | ||||
| class RayCaster: | ||||
|     def __init__(self, agent, pomdp_r, degs=360): | ||||
|         self.agent = agent | ||||
|         self.pomdp_r = pomdp_r | ||||
|         self.n_rays = (self.pomdp_r + 1) * 8 | ||||
|         self.degs = degs | ||||
|         self.ray_targets = self.build_ray_targets() | ||||
|         self.obs_shape_cube = np.array([self.pomdp_r, self.pomdp_r]) | ||||
|         self._cache_dict = {} | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return f'{self.__class__.__name__}({self.agent.name})' | ||||
|  | ||||
|     def build_ray_targets(self): | ||||
|         north = np.array([0, -1]) * self.pomdp_r | ||||
|         thetas = [np.deg2rad(deg) for deg in np.linspace(-self.degs // 2, self.degs // 2, self.n_rays)[::-1]] | ||||
|         rot_M = [ | ||||
|             [[math.cos(theta), -math.sin(theta)], | ||||
|              [math.sin(theta), math.cos(theta)]] for theta in thetas | ||||
|         ] | ||||
|         rot_M = np.stack(rot_M, 0) | ||||
|         rot_M = np.unique(np.round(rot_M @ north), axis=0) | ||||
|         return rot_M.astype(int) | ||||
|  | ||||
|     def ray_block_cache(self, key, callback): | ||||
|         if key not in self._cache_dict: | ||||
|             self._cache_dict[key] = callback() | ||||
|         return self._cache_dict[key] | ||||
|  | ||||
|     def visible_entities(self, pos_dict, reset_cache=True): | ||||
|         visible = list() | ||||
|         if reset_cache: | ||||
|             self._cache_dict = {} | ||||
|  | ||||
|         for ray in self.get_rays(): | ||||
|             rx, ry = ray[0] | ||||
|             for x, y in ray: | ||||
|                 cx, cy = x - rx, y - ry | ||||
|  | ||||
|                 entities_hit = pos_dict[(x, y)] | ||||
|                 hits = self.ray_block_cache((x, y), | ||||
|                                             lambda: any(True for e in entities_hit if e.var_is_blocking_light) | ||||
|                                             ) | ||||
|  | ||||
|                 diag_hits = all([ | ||||
|                     self.ray_block_cache( | ||||
|                         key, | ||||
|                         lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light) and bool( | ||||
|                             pos_dict[key])) | ||||
|                     for key in ((x, y - cy), (x - cx, y)) | ||||
|                 ]) if (cx != 0 and cy != 0) else False | ||||
|  | ||||
|                 visible += entities_hit if not diag_hits else [] | ||||
|                 if hits or diag_hits: | ||||
|                     break | ||||
|                 rx, ry = x, y | ||||
|         return visible | ||||
|  | ||||
|     def get_rays(self): | ||||
|         a_pos = self.agent.pos | ||||
|         outline = self.ray_targets + a_pos | ||||
|         return self.bresenham_loop(a_pos, outline) | ||||
|  | ||||
|     # todo do this once and cache the points! | ||||
|     def get_fov_outline(self) -> np.ndarray: | ||||
|         return self.ray_targets + self.agent.pos | ||||
|  | ||||
|     def get_square_outline(self): | ||||
|         agent = self.agent | ||||
|         x_coords = range(agent.x - self.pomdp_r, agent.x + self.pomdp_r + 1) | ||||
|         y_coords = range(agent.y - self.pomdp_r, agent.y + self.pomdp_r + 1) | ||||
|         outline = list(product(x_coords, [agent.y - self.pomdp_r, agent.y + self.pomdp_r])) \ | ||||
|                   + list(product([agent.x - self.pomdp_r, agent.x + self.pomdp_r], y_coords)) | ||||
|         return outline | ||||
|  | ||||
|     @staticmethod | ||||
|     @njit | ||||
|     def bresenham_loop(a_pos, points): | ||||
|         results = [] | ||||
|         for end in points: | ||||
|             x1, y1 = a_pos | ||||
|             x2, y2 = end | ||||
|             dx = x2 - x1 | ||||
|             dy = y2 - y1 | ||||
|  | ||||
|             # Determine how steep the line is | ||||
|             is_steep = abs(dy) > abs(dx) | ||||
|  | ||||
|             # Rotate line | ||||
|             if is_steep: | ||||
|                 x1, y1 = y1, x1 | ||||
|                 x2, y2 = y2, x2 | ||||
|  | ||||
|             # Swap start and end points if necessary and store swap state | ||||
|             swapped = False | ||||
|             if x1 > x2: | ||||
|                 x1, x2 = x2, x1 | ||||
|                 y1, y2 = y2, y1 | ||||
|                 swapped = True | ||||
|  | ||||
|             # Recalculate differentials | ||||
|             dx = x2 - x1 | ||||
|             dy = y2 - y1 | ||||
|  | ||||
|             # Calculate error | ||||
|             error = int(dx / 2.0) | ||||
|             ystep = 1 if y1 < y2 else -1 | ||||
|  | ||||
|             # Iterate over bounding box generating points between start and end | ||||
|             y = y1 | ||||
|             points = [] | ||||
|             for x in range(int(x1), int(x2) + 1): | ||||
|                 coord = [y, x] if is_steep else [x, y] | ||||
|                 points.append(coord) | ||||
|                 error -= abs(dy) | ||||
|                 if error < 0: | ||||
|                     y += ystep | ||||
|                     error += dx | ||||
|  | ||||
|             # Reverse the list if the coordinates were swapped | ||||
|             if swapped: | ||||
|                 points.reverse() | ||||
|             results.append(points) | ||||
|         return results | ||||
|   | ||||
| @@ -7,50 +7,11 @@ from typing import Union, List | ||||
| import pandas as pd | ||||
| 
 | ||||
| from marl_factory_grid.utils.helpers import IGNORED_DF_COLUMNS | ||||
| from marl_factory_grid.utils.plotting.plotting import prepare_plot | ||||
| from marl_factory_grid.utils.plotting.plotting_utils import prepare_plot | ||||
| 
 | ||||
| MODEL_MAP = None | ||||
| 
 | ||||
| 
 | ||||
| def plot_single_run(run_path: Union[str, PathLike], use_tex: bool = False, column_keys=None): | ||||
|     run_path = Path(run_path) | ||||
|     df_list = list() | ||||
|     if run_path.is_dir(): | ||||
|         monitor_file = next(run_path.glob('*monitor*.pick')) | ||||
|     elif run_path.exists() and run_path.is_file(): | ||||
|         monitor_file = run_path | ||||
|     else: | ||||
|         raise ValueError | ||||
| 
 | ||||
|     with monitor_file.open('rb') as f: | ||||
|         monitor_df = pickle.load(f) | ||||
| 
 | ||||
|         monitor_df = monitor_df.fillna(0) | ||||
|         df_list.append(monitor_df) | ||||
| 
 | ||||
|     df = pd.concat(df_list,  ignore_index=True) | ||||
|     df = df.fillna(0).rename(columns={'episode': 'Episode'}).sort_values(['Episode']) | ||||
|     if column_keys is not None: | ||||
|         columns = [col for col in column_keys if col in df.columns] | ||||
|     else: | ||||
|         columns = [col for col in df.columns if col not in IGNORED_DF_COLUMNS] | ||||
| 
 | ||||
|     roll_n = 50 | ||||
| 
 | ||||
|     non_overlapp_window = df.groupby(['Episode']).rolling(roll_n, min_periods=1).mean() | ||||
| 
 | ||||
|     df_melted = df[columns + ['Episode']].reset_index().melt( | ||||
|         id_vars=['Episode'], value_vars=columns, var_name="Measurement", value_name="Score" | ||||
|     ) | ||||
| 
 | ||||
|     if df_melted['Episode'].max() > 800: | ||||
|         skip_n = round(df_melted['Episode'].max() * 0.02) | ||||
|         df_melted = df_melted[df_melted['Episode'] % skip_n == 0] | ||||
| 
 | ||||
|     prepare_plot(run_path.parent / f'{run_path.parent.name}_monitor_lineplot.png', df_melted, use_tex=use_tex) | ||||
|     print('Plotting done.') | ||||
| 
 | ||||
| 
 | ||||
| def compare_seed_runs(run_path: Union[str, PathLike], use_tex: bool = False): | ||||
|     run_path = Path(run_path) | ||||
|     df_list = list() | ||||
							
								
								
									
										48
									
								
								marl_factory_grid/utils/plotting/plot_single_runs.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										48
									
								
								marl_factory_grid/utils/plotting/plot_single_runs.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,48 @@ | ||||
| import pickle | ||||
| from os import PathLike | ||||
| from pathlib import Path | ||||
| from typing import Union | ||||
|  | ||||
| import pandas as pd | ||||
|  | ||||
| from marl_factory_grid.utils.helpers import IGNORED_DF_COLUMNS | ||||
| from marl_factory_grid.utils.plotting.plotting_utils import prepare_plot | ||||
|  | ||||
|  | ||||
| def plot_single_run(run_path: Union[str, PathLike], use_tex: bool = False, column_keys=None, | ||||
|                     file_key: str ='monitor', file_ext: str ='pkl'): | ||||
|     run_path = Path(run_path) | ||||
|     df_list = list() | ||||
|     if run_path.is_dir(): | ||||
|         monitor_file = next(run_path.glob(f'*{file_key}*.{file_ext}')) | ||||
|     elif run_path.exists() and run_path.is_file(): | ||||
|         monitor_file = run_path | ||||
|     else: | ||||
|         raise ValueError | ||||
|  | ||||
|     with monitor_file.open('rb') as f: | ||||
|         monitor_df = pickle.load(f) | ||||
|  | ||||
|         monitor_df = monitor_df.fillna(0) | ||||
|         df_list.append(monitor_df) | ||||
|  | ||||
|     df = pd.concat(df_list,  ignore_index=True) | ||||
|     df = df.fillna(0).rename(columns={'episode': 'Episode'}).sort_values(['Episode']) | ||||
|     if column_keys is not None: | ||||
|         columns = [col for col in column_keys if col in df.columns] | ||||
|     else: | ||||
|         columns = [col for col in df.columns if col not in IGNORED_DF_COLUMNS] | ||||
|  | ||||
|     # roll_n = 50 | ||||
|     # non_overlapp_window = df.groupby(['Episode']).rolling(roll_n, min_periods=1).mean() | ||||
|  | ||||
|     df_melted = df[columns + ['Episode']].reset_index().melt( | ||||
|         id_vars=['Episode'], value_vars=columns, var_name="Measurement", value_name="Score" | ||||
|     ) | ||||
|  | ||||
|     if df_melted['Episode'].max() > 800: | ||||
|         skip_n = round(df_melted['Episode'].max() * 0.02) | ||||
|         df_melted = df_melted[df_melted['Episode'] % skip_n == 0] | ||||
|  | ||||
|     prepare_plot(run_path.parent / f'{run_path.parent.name}_monitor_lineplot.png', df_melted, use_tex=use_tex) | ||||
|     print('Plotting done.') | ||||
| @@ -60,7 +60,7 @@ def prepare_center_double_column_legend(df, hue, style, hue_order): | ||||
|     print('Struggling to plot Figure using LaTeX - going back to normal.') | ||||
|     plt.close('all') | ||||
|     sns.set(rc={'text.usetex': False}, style='whitegrid') | ||||
|     fig = plt.figure(figsize=(10, 11)) | ||||
|     _ = plt.figure(figsize=(10, 11)) | ||||
|     lineplot = sns.lineplot(data=df, x='Episode', y='Score', hue=hue, style=style, | ||||
|                             ci=95, palette=PALETTE, hue_order=hue_order, legend=False) | ||||
|     # plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0) | ||||
| @@ -19,7 +19,7 @@ class RayCaster: | ||||
|         return f'{self.__class__.__name__}({self.agent.name})' | ||||
|  | ||||
|     def build_ray_targets(self): | ||||
|         north = np.array([0, -1])*self.pomdp_r | ||||
|         north = np.array([0, -1]) * self.pomdp_r | ||||
|         thetas = [np.deg2rad(deg) for deg in np.linspace(-self.degs // 2, self.degs // 2, self.n_rays)[::-1]] | ||||
|         rot_M = [ | ||||
|             [[math.cos(theta), -math.sin(theta)], | ||||
| @@ -39,8 +39,9 @@ class RayCaster: | ||||
|         if reset_cache: | ||||
|             self._cache_dict = dict() | ||||
|  | ||||
|         for ray in self.get_rays(): | ||||
|         for ray in self.get_rays():  # Do not check, just trust. | ||||
|             rx, ry = ray[0] | ||||
|             # self.ray_block_cache(ray[0], lambda: False) We do not do that, because of doors etc... | ||||
|             for x, y in ray: | ||||
|                 cx, cy = x - rx, y - ry | ||||
|  | ||||
| @@ -52,8 +53,9 @@ class RayCaster: | ||||
|                 diag_hits = all([ | ||||
|                     self.ray_block_cache( | ||||
|                         key, | ||||
|                         lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light)) | ||||
|                     for key in ((x, y-cy), (x-cx, y)) | ||||
|                         # lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light) | ||||
|                         lambda: any(True for e in pos_dict[key] if e.var_is_blocking_light)) | ||||
|                     for key in ((x, y - cy), (x - cx, y)) | ||||
|                 ]) if (cx != 0 and cy != 0) else False | ||||
|  | ||||
|                 visible += entities_hit if not diag_hits else [] | ||||
| @@ -75,8 +77,8 @@ class RayCaster: | ||||
|         agent = self.agent | ||||
|         x_coords = range(agent.x - self.pomdp_r, agent.x + self.pomdp_r + 1) | ||||
|         y_coords = range(agent.y - self.pomdp_r, agent.y + self.pomdp_r + 1) | ||||
|         outline = list(product(x_coords, [agent.y - self.pomdp_r, agent.y + self.pomdp_r])) \ | ||||
|                   + list(product([agent.x - self.pomdp_r, agent.x + self.pomdp_r], y_coords)) | ||||
|         outline = list(product(x_coords, [agent.y - self.pomdp_r, agent.y + self.pomdp_r])) | ||||
|         outline += list(product([agent.x - self.pomdp_r, agent.x + self.pomdp_r], y_coords)) | ||||
|         return outline | ||||
|  | ||||
|     @staticmethod | ||||
|   | ||||
| @@ -31,7 +31,7 @@ class Renderer: | ||||
|  | ||||
|     def __init__(self, lvl_shape: Tuple[int, int] = (16, 16), | ||||
|                  lvl_padded_shape: Union[Tuple[int, int], None] = None, | ||||
|                  cell_size: int = 40, fps: int = 7, | ||||
|                  cell_size: int = 40, fps: int = 7, factor: float = 0.9, | ||||
|                  grid_lines: bool = True, view_radius: int = 2): | ||||
|         # TODO: Customn_assets paths | ||||
|         self.grid_h, self.grid_w = lvl_shape | ||||
| @@ -45,7 +45,7 @@ class Renderer: | ||||
|         self.screen = pygame.display.set_mode(self.screen_size) | ||||
|         self.clock = pygame.time.Clock() | ||||
|         assets = list(self.ASSETS.rglob('*.png')) | ||||
|         self.assets = {path.stem: self.load_asset(str(path), 1) for path in assets} | ||||
|         self.assets = {path.stem: self.load_asset(str(path), factor) for path in assets} | ||||
|         self.fill_bg() | ||||
|  | ||||
|         now = time.time() | ||||
| @@ -110,22 +110,22 @@ class Renderer: | ||||
|                 pygame.quit() | ||||
|                 sys.exit() | ||||
|         self.fill_bg() | ||||
|         blits = deque() | ||||
|         for entity in [x for x in entities]: | ||||
|             bp = self.blit_params(entity) | ||||
|             blits.append(bp) | ||||
|             if entity.name.lower() == AGENT: | ||||
|         # First all others | ||||
|         blits = deque(self.blit_params(x) for x in entities if not x.name.lower() == AGENT) | ||||
|         # Then Agents, so that agents are rendered on top. | ||||
|         for agent in (x for x in entities if x.name.lower() == AGENT): | ||||
|             agent_blit = self.blit_params(agent) | ||||
|             if self.view_radius > 0: | ||||
|                     vis_rects = self.visibility_rects(bp, entity.aux) | ||||
|                 vis_rects = self.visibility_rects(agent_blit, agent.aux) | ||||
|                 blits.extendleft(vis_rects) | ||||
|                 if entity.state != BLANK: | ||||
|                     agent_state_blits = self.blit_params( | ||||
|                         RenderEntity(entity.state, (entity.pos[0] + 0.12, entity.pos[1]), 0.48, SCALE) | ||||
|             if agent.state != BLANK: | ||||
|                 state_blit = self.blit_params( | ||||
|                     RenderEntity(agent.state, (agent.pos[0] + 0.12, agent.pos[1]), 0.48, SCALE) | ||||
|                 ) | ||||
|                     textsurface = self.font.render(str(entity.id), False, (0, 0, 0)) | ||||
|                     text_blit = dict(source=textsurface, dest=(bp['dest'].center[0]-.07*self.cell_size, | ||||
|                                                                bp['dest'].center[1])) | ||||
|                     blits += [agent_state_blits, text_blit] | ||||
|                 textsurface = self.font.render(str(agent.id), False, (0, 0, 0)) | ||||
|                 text_blit = dict(source=textsurface, dest=(agent_blit['dest'].center[0]-.07*self.cell_size, | ||||
|                                                            agent_blit['dest'].center[1])) | ||||
|                 blits += [agent_blit, state_blit, text_blit] | ||||
|  | ||||
|         for blit in blits: | ||||
|             self.screen.blit(**blit) | ||||
|   | ||||
| @@ -1,9 +1,12 @@ | ||||
| from typing import Union | ||||
| from dataclasses import dataclass | ||||
|  | ||||
| from marl_factory_grid.environment.entity.object import Object | ||||
|  | ||||
| TYPE_VALUE  = 'value' | ||||
| TYPE_REWARD = 'reward' | ||||
| types = [TYPE_VALUE, TYPE_REWARD] | ||||
| TYPES = [TYPE_VALUE, TYPE_REWARD] | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class InfoObject: | ||||
| @@ -18,17 +21,21 @@ class Result: | ||||
|     validity: bool | ||||
|     reward: Union[float, None] = None | ||||
|     value: Union[float, None] = None | ||||
|     entity: None = None | ||||
|     entity: Object = None | ||||
|  | ||||
|     def get_infos(self): | ||||
|         n = self.entity.name if self.entity is not None else "Global" | ||||
|         return [InfoObject(identifier=f'{n}_{self.identifier}_{t}', | ||||
|                            val_type=t, value=self.__getattribute__(t)) for t in types | ||||
|         # Return multiple Info Dicts | ||||
|         return [InfoObject(identifier=f'{n}_{self.identifier}', | ||||
|                            val_type=t, value=self.__getattribute__(t)) for t in TYPES | ||||
|                 if self.__getattribute__(t) is not None] | ||||
|  | ||||
|     def __repr__(self): | ||||
|         valid = "not " if not self.validity else "" | ||||
|         return f'{self.__class__.__name__}({self.identifier.capitalize()} {valid}valid: {self.reward})' | ||||
|         reward = f" | Reward: {self.reward}" if self.reward is not None else "" | ||||
|         value = f" | Value: {self.value}" if self.value is not None else "" | ||||
|         entity = f" | by: {self.entity.name}" if self.entity is not None else "" | ||||
|         return f'{self.__class__.__name__}({self.identifier.capitalize()} {valid}valid{reward}{value}{entity})' | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
|   | ||||
| @@ -1,9 +1,12 @@ | ||||
| from typing import List, Dict, Tuple | ||||
| from itertools import islice | ||||
| from typing import List, Tuple | ||||
|  | ||||
| import numpy as np | ||||
|  | ||||
| from marl_factory_grid.environment import constants as c | ||||
| from marl_factory_grid.environment.entity.entity import Entity | ||||
| from marl_factory_grid.environment.rules import Rule | ||||
| from marl_factory_grid.utils.results import Result, DoneResult | ||||
| from marl_factory_grid.environment.tests import Test | ||||
| from marl_factory_grid.utils.results import Result | ||||
|  | ||||
| @@ -60,7 +63,8 @@ class Gamestate(object): | ||||
|     def moving_entites(self): | ||||
|         return [y for x in self.entities for y in x if x.var_can_move] | ||||
|  | ||||
|     def __init__(self, entities, agents_conf, rules: [Rule], tests: [Test], env_seed=69, verbose=False): | ||||
|     def __init__(self, entities, agents_conf, rules: List[Rule], tests: [Test], lvl_shape, env_seed=69, verbose=False): | ||||
|         self.lvl_shape = lvl_shape | ||||
|         self.entities = entities | ||||
|         self.curr_step = 0 | ||||
|         self.curr_actions = None | ||||
| @@ -82,7 +86,52 @@ class Gamestate(object): | ||||
|     def __repr__(self): | ||||
|         return f'{self.__class__.__name__}({len(self.entities)} Entitites @ Step {self.curr_step})' | ||||
|  | ||||
|     def tick(self, actions) -> List[Result]: | ||||
|     @property | ||||
|     def random_free_position(self) -> (int, int): | ||||
|         """ | ||||
|         Returns a single **free** position (x, y), which is **free** for spawning or walking. | ||||
|         No Entity at this position posses *var_is_blocking_pos* or *var_can_collide*. | ||||
|  | ||||
|         :return:    Single **free** position. | ||||
|         """ | ||||
|         return self.get_n_random_free_positions(1)[0] | ||||
|  | ||||
|     def get_n_random_free_positions(self, n) -> list[tuple[int, int]]: | ||||
|         """ | ||||
|         Returns a list of *n* **free** positions [(x, y), ... ], which are **free** for spawning or walking. | ||||
|         No Entity at this position posses *var_is_blocking_pos* or *var_can_collide*. | ||||
|  | ||||
|         :return:    List of n **free** position. | ||||
|         """ | ||||
|         return list(islice(self.entities.free_positions_generator, n)) | ||||
|  | ||||
|     @property | ||||
|     def random_position(self) -> (int, int): | ||||
|         """ | ||||
|         Returns a single available position (x, y), ignores all entity attributes. | ||||
|  | ||||
|         :return:    Single random position. | ||||
|         """ | ||||
|         return self.get_n_random_positions(1)[0] | ||||
|  | ||||
|     def get_n_random_positions(self, n) -> list[tuple[int, int]]: | ||||
|         """ | ||||
|         Returns a list of *n* available positions [(x, y), ... ], ignores all entity attributes. | ||||
|  | ||||
|         :return:    List of n random positions. | ||||
|         """ | ||||
|         return list(islice(self.entities.floorlist, n)) | ||||
|  | ||||
|     def tick(self, actions) -> list[Result]: | ||||
|         """ | ||||
|         Performs a single **Gamestate Tick**by calling the inner rule hooks in sequential order. | ||||
|         - tick_pre_step_all:    Things to do before the agents do their actions. Statechange, Moving, Spawning etc... | ||||
|         - agent tick:           Agents do their actions. | ||||
|         - tick_step_all:        Things to do after the agents did their actions. Statechange, Moving, Spawning etc... | ||||
|         - tick_post_step_all:   Things to do at the very end of each step. Counting, Reward calculations etc... | ||||
|  | ||||
|         :return:    List of *Result*-objects. | ||||
|         """ | ||||
|         results = list() | ||||
|         test_results = list() | ||||
|         self.curr_step += 1 | ||||
| @@ -112,11 +161,23 @@ class Gamestate(object): | ||||
|  | ||||
|         return results | ||||
|  | ||||
|     def print(self, string): | ||||
|     def print(self, string) -> None: | ||||
|         """ | ||||
|         When *verbose* is active, print stuff. | ||||
|  | ||||
|         :param string:      *String* to print. | ||||
|         :type string:       str | ||||
|         :return: Nothing | ||||
|         """ | ||||
|         if self.verbose: | ||||
|             print(string) | ||||
|  | ||||
|     def check_done(self): | ||||
|     def check_done(self) -> List[DoneResult]: | ||||
|         """ | ||||
|         Iterate all **Rules** that override tehe *on_ckeck_done* hook. | ||||
|  | ||||
|         :return:    List of Results | ||||
|         """ | ||||
|         results = list() | ||||
|         for rule in self.rules: | ||||
|             if on_check_done_result := rule.on_check_done(self): | ||||
| @@ -124,24 +185,47 @@ class Gamestate(object): | ||||
|         return results | ||||
|  | ||||
|     def get_all_pos_with_collisions(self) -> List[Tuple[(int, int)]]: | ||||
|         positions = [pos for pos, entity_list_for_position in self.entities.pos_dict.items() | ||||
|                      if any([e.var_can_collide for e in entity_list_for_position])] | ||||
|         """ | ||||
|         Returns a list positions [(x, y), ... ] on which collisions occur. This does not include agents, | ||||
|         that were unable to move because their target direction was blocked, also a form of collision. | ||||
|  | ||||
|         :return:    List of positions. | ||||
|         """ | ||||
|         positions = [pos for pos, entities in self.entities.pos_dict.items() if | ||||
|                      len(entities) >= 2 and (len([e for e in entities if e.var_can_collide]) >= 2) | ||||
|                      ] | ||||
|         return positions | ||||
|  | ||||
|     def check_move_validity(self, moving_entity, position): | ||||
|         if moving_entity.pos != position and not any( | ||||
|                 entity.var_is_blocking_pos for entity in self.entities.pos_dict[position]) and not ( | ||||
|                 moving_entity.var_is_blocking_pos and self.entities.is_occupied(position)): | ||||
|             return True | ||||
|         else: | ||||
|             return False | ||||
|     def check_move_validity(self, moving_entity: Entity, target_position: (int, int)) -> bool: | ||||
|         """ | ||||
|         Whether it is safe to move to the target positions and moving entity does not introduce a blocking attribute, | ||||
|         when position is allready occupied. | ||||
|  | ||||
|     def check_pos_validity(self, position): | ||||
|         if not any(entity.var_is_blocking_pos for entity in self.entities.pos_dict[position]): | ||||
|             return True | ||||
|         else: | ||||
|             return False | ||||
|         :param moving_entity: Entity | ||||
|         :param target_position: pos | ||||
|         :return:    Safe to move to | ||||
|         """ | ||||
|  | ||||
|         is_not_blocked = self.check_pos_validity(target_position) | ||||
|         will_not_block_others = moving_entity.var_is_blocking_pos and self.entities.is_occupied(target_position) | ||||
|  | ||||
|         if moving_entity.pos != target_position and is_not_blocked and not will_not_block_others: | ||||
|             return c.VALID | ||||
|         else: | ||||
|             return c.NOT_VALID | ||||
|  | ||||
|     def check_pos_validity(self, pos: (int, int)) -> bool: | ||||
|         """ | ||||
|         Check if *pos* is a valid position to move or spawn to. | ||||
|  | ||||
|         :param pos: position to check | ||||
|         :return: Wheter pos is a valid target. | ||||
|         """ | ||||
|  | ||||
|         if not any(e.var_is_blocking_pos for e in self.entities.pos_dict[pos]) and pos in self.entities.floorlist: | ||||
|             return c.VALID | ||||
|         else: | ||||
|             return c.NOT_VALID | ||||
|  | ||||
| class StepTests: | ||||
|     def __init__(self, *args): | ||||
|   | ||||
| @@ -28,7 +28,9 @@ class ConfigExplainer: | ||||
|  | ||||
|     def explain_module(self, class_to_explain): | ||||
|         parameters = inspect.signature(class_to_explain).parameters | ||||
|         explained = {class_to_explain.__name__: {key: val.default for key, val in parameters.items() if key not in EXCLUDED}} | ||||
|         explained = {class_to_explain.__name__: | ||||
|                          {key: val.default for key, val in parameters.items() if key not in EXCLUDED} | ||||
|                      } | ||||
|         return explained | ||||
|  | ||||
|     def _load_and_compare(self, compare_class, paths): | ||||
| @@ -135,4 +137,3 @@ if __name__ == '__main__': | ||||
|     ce.get_observations() | ||||
|     ce.get_assets() | ||||
|     all_conf = ce.get_all() | ||||
|     print() | ||||
|   | ||||
| @@ -52,3 +52,6 @@ class Floor: | ||||
|  | ||||
|     def __hash__(self): | ||||
|         return hash(self.name) | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return f"Floor{self.pos}" | ||||
|   | ||||
| @@ -6,18 +6,21 @@ from marl_factory_grid.environment.factory import Factory | ||||
|  | ||||
| from marl_factory_grid.utils.logging.envmonitor import EnvMonitor | ||||
| from marl_factory_grid.utils.logging.recorder import EnvRecorder | ||||
| from marl_factory_grid.utils.plotting.plot_single_runs import plot_single_run | ||||
| from marl_factory_grid.utils.tools import ConfigExplainer | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     # Render at each step? | ||||
|     render = True | ||||
|     render = False | ||||
|     # Reveal all possible Modules (Entities, Rules, Agents[Actions, Observations], etc.) | ||||
|     explain_config = False | ||||
|     # Collect statistics? | ||||
|     monitor = False | ||||
|     monitor = True | ||||
|     # Record as Protobuf? | ||||
|     record = False | ||||
|     # Plot Results? | ||||
|     plotting = True | ||||
|  | ||||
|     run_path = Path('study_out') | ||||
|  | ||||
| @@ -38,7 +41,7 @@ if __name__ == '__main__': | ||||
|         factory = EnvRecorder(factory) | ||||
|  | ||||
|     # RL learn Loop | ||||
|     for episode in trange(500): | ||||
|     for episode in trange(10): | ||||
|         _ = factory.reset() | ||||
|         done = False | ||||
|         if render: | ||||
| @@ -54,7 +57,10 @@ if __name__ == '__main__': | ||||
|                 break | ||||
|  | ||||
|     if monitor: | ||||
|         factory.save_run(run_path / 'test.pkl') | ||||
|         factory.save_run(run_path / 'test_monitor.pkl') | ||||
|     if record: | ||||
|         factory.save_records(run_path / 'test.pb') | ||||
|     if plotting: | ||||
|         plot_single_run(run_path) | ||||
|  | ||||
|     print('Done!!! Goodbye....') | ||||
|   | ||||
| @@ -6,6 +6,7 @@ import yaml | ||||
| from marl_factory_grid.environment.factory import Factory | ||||
| from marl_factory_grid.utils.logging.envmonitor import EnvMonitor | ||||
| from marl_factory_grid.utils.logging.recorder import EnvRecorder | ||||
| from marl_factory_grid.utils import helpers as h | ||||
|  | ||||
| from marl_factory_grid.modules.doors import constants as d | ||||
|  | ||||
| @@ -55,13 +56,14 @@ if __name__ == '__main__': | ||||
|                                for model_idx, model in enumerate(models)] | ||||
|                 else: | ||||
|                     actions = models[0].predict(env_state, deterministic=determin)[0] | ||||
|                 # noinspection PyTupleAssignmentBalance | ||||
|                 env_state, step_r, done_bool, info_obj = env.step(actions) | ||||
|  | ||||
|                 rew += step_r | ||||
|                 if render: | ||||
|                     env.render() | ||||
|                 try: | ||||
|                     door = next(x for x in env.unwrapped.unwrapped[d.DOORS] if x.is_open) | ||||
|                     door = h.get_first([x for x in env.unwrapped.unwrapped[d.DOORS] if x.is_open]) | ||||
|                     print('openDoor found') | ||||
|                 except StopIteration: | ||||
|                     pass | ||||
|   | ||||
| @@ -1,8 +1,8 @@ | ||||
| from algorithms.utils import Checkpointer | ||||
| from pathlib import Path | ||||
| from algorithms.utils import load_yaml_file, add_env_props, instantiate_class, load_class | ||||
| #from algorithms.marl import LoopSNAC, LoopIAC, LoopSEAC | ||||
|  | ||||
| # from algorithms.marl import LoopSNAC, LoopIAC, LoopSEAC | ||||
|  | ||||
|  | ||||
| for i in range(0, 5): | ||||
|   | ||||
							
								
								
									
										41
									
								
								transform_wg_to_json_no_priv.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										41
									
								
								transform_wg_to_json_no_priv.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,41 @@ | ||||
| import configparser | ||||
| import json | ||||
| from datetime import datetime | ||||
| from pathlib import Path | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|  | ||||
|     conf_path = Path('wg0') | ||||
|     wg0_conf = configparser.ConfigParser() | ||||
|     wg0_conf.read(conf_path/'wg0.conf') | ||||
|     interface = wg0_conf['Interface'] | ||||
|     # Iterate all pears | ||||
|     for client_name in wg0_conf.sections(): | ||||
|         if client_name == 'Interface': | ||||
|             continue | ||||
|         # Delete any old conf.json for the current peer | ||||
|         (conf_path / f'{client_name}.json').unlink(missing_ok=True) | ||||
|  | ||||
|         peer = wg0_conf[client_name] | ||||
|  | ||||
|         date_time = datetime.now().strftime('%Y-%m-%dT%H:%M:%S.%f000Z') | ||||
|  | ||||
|         jdict = dict( | ||||
|             id=client_name, | ||||
|             private_key=peer['PublicKey'], | ||||
|             public_key=peer['PublicKey'], | ||||
|             # preshared_key=wg0_conf[client_name_wg0]['PresharedKey'], | ||||
|             name=client_name, | ||||
|             email=f"sysadmin@mobile.ifi.lmu.de", | ||||
|             allocated_ips=[interface['Address'].replace('/24', '')], | ||||
|             allowed_ips=['10.4.0.0/24', '10.153.199.0/24'], | ||||
|             extra_allowed_ips=[], | ||||
|             use_server_dns=True, | ||||
|             enabled=True, | ||||
|             created_at=date_time, | ||||
|             updated_at=date_time | ||||
|         ) | ||||
|  | ||||
|         with (conf_path / f'{client_name}.json').open('w+') as f: | ||||
|             json.dump(jdict, f, indent='\t', separators=(',', ': ')) | ||||
|         print(client_name, ' written...') | ||||
		Reference in New Issue
	
	Block a user
	 Chanumask
					Chanumask