mirror of
				https://github.com/illiumst/marl-factory-grid.git
				synced 2025-10-31 12:37:27 +01:00 
			
		
		
		
	Merge branch 'main' into unit_testing
# Conflicts: # marl_factory_grid/environment/factory.py # marl_factory_grid/utils/config_parser.py # marl_factory_grid/utils/states.py
This commit is contained in:
		
							
								
								
									
										5
									
								
								.idea/.gitignore
									
									
									
										generated
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								.idea/.gitignore
									
									
									
										generated
									
									
										vendored
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | |||||||
|  | # Default ignored files | ||||||
|  | /shelf/ | ||||||
|  | /workspace.xml | ||||||
|  | # Editor-based HTTP Client requests | ||||||
|  | /httpRequests/ | ||||||
| @@ -56,7 +56,7 @@ Just define what your environment needs in a *yaml*-configfile like: | |||||||
|                     - Items |                     - Items | ||||||
|     Rules: |     Rules: | ||||||
|         Defaults: {} |         Defaults: {} | ||||||
|         Collision: |         WatchCollisions: | ||||||
|             done_at_collisions: !!bool True |             done_at_collisions: !!bool True | ||||||
|         ItemRespawn: |         ItemRespawn: | ||||||
|             spawn_freq: 5 |             spawn_freq: 5 | ||||||
| @@ -94,7 +94,7 @@ All [Entites](marl_factory_grid/environment/entity/global_entities.py) are avail | |||||||
|  |  | ||||||
|  |  | ||||||
| #### Rules | #### Rules | ||||||
| [Rules](marl_factory_grid/environment/entity/object.py) define how the environment behaves on micro-scale. | [Rules](marl_factory_grid/environment/entity/object.py) define how the environment behaves on microscale. | ||||||
| Each of the hookes (`on_init`, `pre_step`, `on_step`, '`post_step`', `on_done`)  | Each of the hookes (`on_init`, `pre_step`, `on_step`, '`post_step`', `on_done`)  | ||||||
| provide env-access to implement customn logic, calculate rewards, or gather information. | provide env-access to implement customn logic, calculate rewards, or gather information. | ||||||
|  |  | ||||||
| @@ -107,6 +107,7 @@ Make sure to bring your own assets for each Entity living in the Gridworld as th | |||||||
| PNG-files (transparent background) of square aspect-ratio should do the job, in general. | PNG-files (transparent background) of square aspect-ratio should do the job, in general. | ||||||
|  |  | ||||||
| <img src="/marl_factory_grid/environment/assets/wall.png"  width="5%">  | <img src="/marl_factory_grid/environment/assets/wall.png"  width="5%">  | ||||||
|  | <!--suppress HtmlUnknownAttribute --> | ||||||
| <html      html>  | <html      html>  | ||||||
| <img src="/marl_factory_grid/environment/assets/agent/agent.png"  width="5%"> | <img src="/marl_factory_grid/environment/assets/agent/agent.png"  width="5%"> | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,6 +1 @@ | |||||||
| from .environment import * |  | ||||||
| from .modules import * |  | ||||||
| from .utils import * |  | ||||||
|  |  | ||||||
| from .quickstart import init | from .quickstart import init | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1 +1,4 @@ | |||||||
| import os, sys; sys.path.append(os.path.dirname(os.path.realpath(__file__))) | import os | ||||||
|  | import sys | ||||||
|  |  | ||||||
|  | sys.path.append(os.path.dirname(os.path.realpath(__file__))) | ||||||
|   | |||||||
| @@ -28,6 +28,7 @@ class Names: | |||||||
|     BATCH_SIZE      = 'bnatch_size' |     BATCH_SIZE      = 'bnatch_size' | ||||||
|     N_ACTIONS       = 'n_actions' |     N_ACTIONS       = 'n_actions' | ||||||
|  |  | ||||||
|  |  | ||||||
| nms = Names | nms = Names | ||||||
| ListOrTensor = Union[List, torch.Tensor] | ListOrTensor = Union[List, torch.Tensor] | ||||||
|  |  | ||||||
| @@ -112,10 +113,9 @@ class BaseActorCritic: | |||||||
|                 next_obs, reward, done, info = env.step(action) |                 next_obs, reward, done, info = env.step(action) | ||||||
|                 done = [done] * self.n_agents if isinstance(done, bool) else done |                 done = [done] * self.n_agents if isinstance(done, bool) else done | ||||||
|  |  | ||||||
|                 last_hiddens = dict(hidden_actor =out[nms.HIDDEN_ACTOR], |                 last_hiddens = dict(hidden_actor=out[nms.HIDDEN_ACTOR], | ||||||
|                                     hidden_critic=out[nms.HIDDEN_CRITIC]) |                                     hidden_critic=out[nms.HIDDEN_CRITIC]) | ||||||
|  |  | ||||||
|  |  | ||||||
|                 tm.add(observation=obs, action=action, reward=reward, done=done, |                 tm.add(observation=obs, action=action, reward=reward, done=done, | ||||||
|                        logits=out.get(nms.LOGITS, None), values=out.get(nms.CRITIC, None), |                        logits=out.get(nms.LOGITS, None), values=out.get(nms.CRITIC, None), | ||||||
|                        **last_hiddens) |                        **last_hiddens) | ||||||
| @@ -142,7 +142,9 @@ class BaseActorCritic: | |||||||
|             print(f'reward at episode: {episode} = {rew_log}') |             print(f'reward at episode: {episode} = {rew_log}') | ||||||
|             episode += 1 |             episode += 1 | ||||||
|             df_results.append([episode, rew_log, *reward]) |             df_results.append([episode, rew_log, *reward]) | ||||||
|         df_results = pd.DataFrame(df_results, columns=['steps', 'reward', *[f'agent#{i}' for i in range(self.n_agents)]]) |         df_results = pd.DataFrame(df_results, | ||||||
|  |                                   columns=['steps', 'reward', *[f'agent#{i}' for i in range(self.n_agents)]] | ||||||
|  |                                   ) | ||||||
|         if checkpointer is not None: |         if checkpointer is not None: | ||||||
|             df_results.to_csv(checkpointer.path / 'results.csv', index=False) |             df_results.to_csv(checkpointer.path / 'results.csv', index=False) | ||||||
|         return df_results |         return df_results | ||||||
| @@ -157,24 +159,27 @@ class BaseActorCritic: | |||||||
|             last_action, reward    = [-1] * self.n_agents, [0.] * self.n_agents |             last_action, reward    = [-1] * self.n_agents, [0.] * self.n_agents | ||||||
|             done, rew_log, eps_rew = [False] * self.n_agents, 0, torch.zeros(self.n_agents) |             done, rew_log, eps_rew = [False] * self.n_agents, 0, torch.zeros(self.n_agents) | ||||||
|             while not all(done): |             while not all(done): | ||||||
|                 if render: env.render() |                 if render: | ||||||
|  |                     env.render() | ||||||
|  |  | ||||||
|                 out    = self.forward(obs, last_action, **last_hiddens) |                 out    = self.forward(obs, last_action, **last_hiddens) | ||||||
|                 action = self.get_actions(out) |                 action = self.get_actions(out) | ||||||
|                 next_obs, reward, done, info = env.step(action) |                 next_obs, reward, done, info = env.step(action) | ||||||
|  |  | ||||||
|                 if isinstance(done, bool): done = [done] * obs.shape[0] |                 if isinstance(done, bool): | ||||||
|  |                     done = [done] * obs.shape[0] | ||||||
|                 obs = next_obs |                 obs = next_obs | ||||||
|                 last_action = action |                 last_action = action | ||||||
|                 last_hiddens = dict(hidden_actor=out.get(nms.HIDDEN_ACTOR,   None), |                 last_hiddens = dict(hidden_actor=out.get(nms.HIDDEN_ACTOR,   None), | ||||||
|                                     hidden_critic=out.get(nms.HIDDEN_CRITIC, None) |                                     hidden_critic=out.get(nms.HIDDEN_CRITIC, None) | ||||||
|                                     ) |                                     ) | ||||||
|                 eps_rew += torch.tensor(reward) |                 eps_rew += torch.tensor(reward) | ||||||
|             results.append(eps_rew.tolist() + [sum(eps_rew).item()] + [episode]) |             results.append(eps_rew.tolist() + [np.sum(eps_rew).item()] + [episode]) | ||||||
|             episode += 1 |             episode += 1 | ||||||
|         agent_columns = [f'agent#{i}' for i in range(self.cfg['environment']['n_agents'])] |         agent_columns = [f'agent#{i}' for i in range(self.cfg['environment']['n_agents'])] | ||||||
|         results = pd.DataFrame(results, columns=agent_columns + ['sum', 'episode']) |         results = pd.DataFrame(results, columns=agent_columns + ['sum', 'episode']) | ||||||
|         results = pd.melt(results, id_vars=['episode'], value_vars=agent_columns + ['sum'], value_name='reward', var_name='agent') |         results = pd.melt(results, id_vars=['episode'], value_vars=agent_columns + ['sum'], | ||||||
|  |                           value_name='reward', var_name='agent') | ||||||
|         return results |         return results | ||||||
|  |  | ||||||
|     @staticmethod |     @staticmethod | ||||||
|   | |||||||
| @@ -36,7 +36,7 @@ class LoopMAPPO(LoopSNAC): | |||||||
|         rewards_ = torch.stack(rewards_, dim=1) |         rewards_ = torch.stack(rewards_, dim=1) | ||||||
|         return rewards_ |         return rewards_ | ||||||
|  |  | ||||||
|     def mappo(self, batch, network, gamma, entropy_coef, vf_coef, clip_range, **kwargs): |     def mappo(self, batch, network, gamma, entropy_coef, vf_coef, clip_range, **__): | ||||||
|         out = network(batch[nms.OBSERVATION], batch[nms.ACTION], batch[nms.HIDDEN_ACTOR], batch[nms.HIDDEN_CRITIC]) |         out = network(batch[nms.OBSERVATION], batch[nms.ACTION], batch[nms.HIDDEN_ACTOR], batch[nms.HIDDEN_CRITIC]) | ||||||
|         logits = out[nms.LOGITS][:, :-1]  # last one only needed for v_{t+1} |         logits = out[nms.LOGITS][:, :-1]  # last one only needed for v_{t+1} | ||||||
|  |  | ||||||
| @@ -45,7 +45,7 @@ class LoopMAPPO(LoopSNAC): | |||||||
|  |  | ||||||
|         # monte carlo returns |         # monte carlo returns | ||||||
|         mc_returns = self.monte_carlo_returns(batch[nms.REWARD], batch[nms.DONE], gamma) |         mc_returns = self.monte_carlo_returns(batch[nms.REWARD], batch[nms.DONE], gamma) | ||||||
|         mc_returns = (mc_returns - mc_returns.mean()) / (mc_returns.std() + 1e-8) #todo: norm across agent ok? |         mc_returns = (mc_returns - mc_returns.mean()) / (mc_returns.std() + 1e-8)  # todo: norm across agent ok? | ||||||
|         advantages =  mc_returns - out[nms.CRITIC][:, :-1] |         advantages =  mc_returns - out[nms.CRITIC][:, :-1] | ||||||
|  |  | ||||||
|         # policy loss |         # policy loss | ||||||
|   | |||||||
| @@ -1,8 +1,7 @@ | |||||||
|  | import numpy as np | ||||||
| import torch | import torch | ||||||
| import torch.nn as nn | import torch.nn as nn | ||||||
| import numpy as np |  | ||||||
| import torch.nn.functional as F | import torch.nn.functional as F | ||||||
| from torch.nn.utils import spectral_norm |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class RecurrentAC(nn.Module): | class RecurrentAC(nn.Module): | ||||||
| @@ -88,8 +87,8 @@ class NormalizedLinear(nn.Linear): | |||||||
|         self.trainable_magnitude = trainable_magnitude |         self.trainable_magnitude = trainable_magnitude | ||||||
|         self.scale = nn.Parameter(torch.tensor([1.]), requires_grad=trainable_magnitude) |         self.scale = nn.Parameter(torch.tensor([1.]), requires_grad=trainable_magnitude) | ||||||
|  |  | ||||||
|     def forward(self, input): |     def forward(self, in_array): | ||||||
|         normalized_input = F.normalize(input, dim=-1, p=2, eps=1e-5) |         normalized_input = F.normalize(in_array, dim=-1, p=2, eps=1e-5) | ||||||
|         normalized_weight = F.normalize(self.weight, dim=-1, p=2, eps=1e-5) |         normalized_weight = F.normalize(self.weight, dim=-1, p=2, eps=1e-5) | ||||||
|         return F.linear(normalized_input, normalized_weight) * self.d_sqrt * self.scale |         return F.linear(normalized_input, normalized_weight) * self.d_sqrt * self.scale | ||||||
|  |  | ||||||
|   | |||||||
| @@ -38,7 +38,6 @@ class LoopSEAC(LoopIAC): | |||||||
|  |  | ||||||
|             a2c_loss = (-iw*log_ap * advantages.detach()).mean(-1) |             a2c_loss = (-iw*log_ap * advantages.detach()).mean(-1) | ||||||
|  |  | ||||||
|  |  | ||||||
|             value_loss = (iw*advantages.pow(2)).mean(-1)  # n_agent |             value_loss = (iw*advantages.pow(2)).mean(-1)  # n_agent | ||||||
|  |  | ||||||
|             # weighted loss |             # weighted loss | ||||||
|   | |||||||
| @@ -56,8 +56,8 @@ class TSPBaseAgent(ABC): | |||||||
|  |  | ||||||
|     def _door_is_close(self, state): |     def _door_is_close(self, state): | ||||||
|         try: |         try: | ||||||
|             # return next(y for x in self.state.tile.neighboring_floor for y in x.guests if do.DOOR in y.name) |             return next(y for x in state.entities.neighboring_positions(self.state.pos) | ||||||
|             return next(y for x in state.entities.neighboring_positions(self.state.pos) for y in state.entities.pos_dict[x] if do.DOOR in y.name) |                         for y in state.entities.pos_dict[x] if do.DOOR in y.name) | ||||||
|         except StopIteration: |         except StopIteration: | ||||||
|             return None |             return None | ||||||
|  |  | ||||||
|   | |||||||
| @@ -14,8 +14,8 @@ class TSPTargetAgent(TSPBaseAgent): | |||||||
|     def _handle_doors(self, state): |     def _handle_doors(self, state): | ||||||
|  |  | ||||||
|         try: |         try: | ||||||
|             # return next(y for x in self.state.tile.neighboring_floor for y in x.guests if do.DOOR in y.name) |             return next(y for x in state.entities.neighboring_positions(self.state.pos) | ||||||
|             return next(y for x in state.entities.neighboring_positions(self.state.pos) for y in state.entities.pos_dict[x] if do.DOOR in y.name) |                         for y in state.entities.pos_dict[x] if do.DOOR in y.name) | ||||||
|         except StopIteration: |         except StopIteration: | ||||||
|             return None |             return None | ||||||
|  |  | ||||||
| @@ -30,4 +30,3 @@ class TSPTargetAgent(TSPBaseAgent): | |||||||
|         except (StopIteration, UnboundLocalError): |         except (StopIteration, UnboundLocalError): | ||||||
|             print('Will not happen') |             print('Will not happen') | ||||||
|         return action_obj |         return action_obj | ||||||
|  |  | ||||||
|   | |||||||
| @@ -26,12 +26,16 @@ def points_to_graph(coordiniates, allow_euclidean_connections=True, allow_manhat | |||||||
|     assert allow_euclidean_connections or allow_manhattan_connections |     assert allow_euclidean_connections or allow_manhattan_connections | ||||||
|     possible_connections = itertools.combinations(coordiniates, 2) |     possible_connections = itertools.combinations(coordiniates, 2) | ||||||
|     graph = nx.Graph() |     graph = nx.Graph() | ||||||
|     for a, b in possible_connections: |     if allow_manhattan_connections and allow_euclidean_connections: | ||||||
|         diff = np.linalg.norm(np.asarray(a)-np.asarray(b)) |         graph.add_edges_from( | ||||||
|         if allow_manhattan_connections and allow_euclidean_connections and diff <= np.sqrt(2): |             (a, b) for a, b in possible_connections if np.linalg.norm(np.asarray(a) - np.asarray(b)) <= np.sqrt(2) | ||||||
|             graph.add_edge(a, b) |         ) | ||||||
|         elif not allow_manhattan_connections and allow_euclidean_connections and diff == np.sqrt(2): |     elif not allow_manhattan_connections and allow_euclidean_connections: | ||||||
|             graph.add_edge(a, b) |         graph.add_edges_from( | ||||||
|         elif allow_manhattan_connections and not allow_euclidean_connections and diff == 1: |             (a, b) for a, b in possible_connections if np.linalg.norm(np.asarray(a) - np.asarray(b)) == np.sqrt(2) | ||||||
|             graph.add_edge(a, b) |         ) | ||||||
|  |     elif allow_manhattan_connections and not allow_euclidean_connections: | ||||||
|  |         graph.add_edges_from( | ||||||
|  |             (a, b) for a, b in possible_connections if np.linalg.norm(np.asarray(a) - np.asarray(b)) == 1 | ||||||
|  |         ) | ||||||
|     return graph |     return graph | ||||||
|   | |||||||
| @@ -1,8 +1,9 @@ | |||||||
| import torch |  | ||||||
| import numpy as np |  | ||||||
| import yaml |  | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
|  |  | ||||||
|  | import numpy as np | ||||||
|  | import torch | ||||||
|  | import yaml | ||||||
|  |  | ||||||
|  |  | ||||||
| def load_class(classname): | def load_class(classname): | ||||||
|     from importlib import import_module |     from importlib import import_module | ||||||
| @@ -42,7 +43,6 @@ def get_class(arguments): | |||||||
|  |  | ||||||
|  |  | ||||||
| def get_arguments(arguments): | def get_arguments(arguments): | ||||||
|     from importlib import import_module |  | ||||||
|     d = dict(arguments) |     d = dict(arguments) | ||||||
|     if "classname" in d: |     if "classname" in d: | ||||||
|         del d["classname"] |         del d["classname"] | ||||||
|   | |||||||
| @@ -22,26 +22,41 @@ Agents: | |||||||
|     - Inventory |     - Inventory | ||||||
|     - DropOffLocations |     - DropOffLocations | ||||||
|     - Maintainers |     - Maintainers | ||||||
|  |     # This is special for agents, as each one is differten and can act as an adversary e.g. | ||||||
|  |     Positions: | ||||||
|  |       - (16, 7) | ||||||
|  |       - (16, 6) | ||||||
|  |       - (16, 3) | ||||||
|  |       - (16, 4) | ||||||
|  |       - (16, 5) | ||||||
| Entities: | Entities: | ||||||
|   Batteries: |   Batteries: | ||||||
|     initial_charge: 0.8 |     initial_charge: 0.8 | ||||||
|     per_action_costs: 0.02 |     per_action_costs: 0.02 | ||||||
|   ChargePods: {} |   ChargePods: | ||||||
|   Destinations: {} |     coords_or_quantity: 2 | ||||||
|  |   Destinations: | ||||||
|  |     coords_or_quantity: 1 | ||||||
|  |     spawn_mode: GROUPED | ||||||
|   DirtPiles: |   DirtPiles: | ||||||
|  |     coords_or_quantity: 10 | ||||||
|  |     initial_amount: 2 | ||||||
|     clean_amount: 1 |     clean_amount: 1 | ||||||
|     dirt_spawn_r_var: 0.1 |     dirt_spawn_r_var: 0.1 | ||||||
|     initial_amount: 2 |  | ||||||
|     initial_dirt_ratio: 0.05 |  | ||||||
|     max_global_amount: 20 |     max_global_amount: 20 | ||||||
|     max_local_amount: 5 |     max_local_amount: 5 | ||||||
|   Doors: {} |   Doors: | ||||||
|   DropOffLocations: {} |   DropOffLocations: | ||||||
|  |     coords_or_quantity: 1 | ||||||
|  |     max_dropoff_storage_size: 0 | ||||||
|   GlobalPositions: {} |   GlobalPositions: {} | ||||||
|   Inventories: {} |   Inventories: {} | ||||||
|   Items: {} |   Items: | ||||||
|   Machines: {} |     coords_or_quantity: 5 | ||||||
|   Maintainers: {} |   Machines: | ||||||
|  |     coords_or_quantity: 2 | ||||||
|  |   Maintainers: | ||||||
|  |     coords_or_quantity: 1 | ||||||
|   Zones: {} |   Zones: {} | ||||||
|  |  | ||||||
| General: | General: | ||||||
| @@ -49,32 +64,31 @@ General: | |||||||
|   individual_rewards: true |   individual_rewards: true | ||||||
|   level_name: large |   level_name: large | ||||||
|   pomdp_r: 3 |   pomdp_r: 3 | ||||||
|   verbose: false |   verbose: True | ||||||
|  |   tests: false | ||||||
|  |  | ||||||
| Rules: | Rules: | ||||||
|   SpawnAgents: {} |   # Environment Dynamics | ||||||
|   DoneAtBatteryDischarge: {} |  | ||||||
|   Collision: |  | ||||||
|     done_at_collisions: false |  | ||||||
|   AssignGlobalPositions: {} |  | ||||||
|   DoneAtDestinationReachAny: {} |  | ||||||
|   DestinationReachReward: {} |  | ||||||
|   SpawnDestinations: |  | ||||||
|     n_dests: 1 |  | ||||||
|     spawn_mode: GROUPED |  | ||||||
|   DoneOnAllDirtCleaned: {} |  | ||||||
|   SpawnDirt: |  | ||||||
|     spawn_freq: 15 |  | ||||||
|   EntitiesSmearDirtOnMove: |   EntitiesSmearDirtOnMove: | ||||||
|     smear_ratio: 0.2 |     smear_ratio: 0.2 | ||||||
|   DoorAutoClose: |   DoorAutoClose: | ||||||
|     close_frequency: 10 |     close_frequency: 10 | ||||||
|   ItemRules: |   MoveMaintainers: | ||||||
|     max_dropoff_storage_size: 0 |  | ||||||
|     n_items: 5 |   # Respawn Stuff | ||||||
|     n_locations: 5 |   RespawnDirt: | ||||||
|     spawn_frequency: 15 |     respawn_freq: 15 | ||||||
|   MaxStepsReached: |   RespawnItems: | ||||||
|  |     respawn_freq: 15 | ||||||
|  |  | ||||||
|  |   # Utilities | ||||||
|  |   WatchCollisions: | ||||||
|  |     done_at_collisions: false | ||||||
|  |  | ||||||
|  |   # Done Conditions | ||||||
|  |   DoneAtDestinationReachAny: | ||||||
|  |   DoneOnAllDirtCleaned: | ||||||
|  |   DoneAtBatteryDischarge: | ||||||
|  |   DoneAtMaintainerCollision: | ||||||
|  |   DoneAtMaxStepsReached: | ||||||
|     max_steps: 500 |     max_steps: 500 | ||||||
| #  AgentSingleZonePlacement: |  | ||||||
| #    n_zones: 4 |  | ||||||
|   | |||||||
| @@ -1,46 +1,89 @@ | |||||||
| Agents: |  | ||||||
|   Wolfgang: |  | ||||||
|     Actions: |  | ||||||
|     - Noop |  | ||||||
|     - Move8 |  | ||||||
|     Observations: |  | ||||||
|     - Walls |  | ||||||
|     - Other |  | ||||||
|     - Destination |  | ||||||
|     Positions: |  | ||||||
|       - (2, 1) |  | ||||||
|       - (2, 5) |  | ||||||
|   Karl-Heinz: |  | ||||||
|     Actions: |  | ||||||
|       - Noop |  | ||||||
|       - Move8 |  | ||||||
|     Observations: |  | ||||||
|       - Walls |  | ||||||
|       - Other |  | ||||||
|       - Destination |  | ||||||
|     Positions: |  | ||||||
|       - (2, 1) |  | ||||||
|       - (2, 5) |  | ||||||
| Entities: |  | ||||||
|   Destinations: {} |  | ||||||
|  |  | ||||||
| General: | General: | ||||||
|  |   # Your Seed | ||||||
|   env_seed: 69 |   env_seed: 69 | ||||||
|  |   # Individual or global rewards? | ||||||
|   individual_rewards: true |   individual_rewards: true | ||||||
|  |   # The level.txt file to load | ||||||
|   level_name: narrow_corridor |   level_name: narrow_corridor | ||||||
|  |   # View Radius; 0 = full observatbility | ||||||
|   pomdp_r: 0 |   pomdp_r: 0 | ||||||
|  |   # print all messages and events | ||||||
|   verbose: true |   verbose: true | ||||||
|  |  | ||||||
| Rules: | Agents: | ||||||
|   SpawnAgents: {} |   # Agents are identified by their name  | ||||||
|   Collision: |   Wolfgang: | ||||||
|     done_at_collisions: false |     # The available actions for this particular agent | ||||||
|   FixedDestinationSpawn: |     Actions: | ||||||
|     per_agent_positions: |     # Able to do nothing | ||||||
|  |     - Noop | ||||||
|  |     # Able to move in all 8 directions | ||||||
|  |     - Move8 | ||||||
|  |     # Stuff the agent can observe (per 2d slice) | ||||||
|  |     #   use "Combined" if you want to merge multiple slices into one | ||||||
|  |     Observations: | ||||||
|  |     # He sees walls | ||||||
|  |     - Walls | ||||||
|  |     # he sees other agent, "karl-Heinz" in this setting would be fine, too | ||||||
|  |     - Other | ||||||
|  |     # He can see Destinations, that are assigned to him (hence the singular)  | ||||||
|  |     - Destination | ||||||
|  |     # Avaiable Spawn Positions as list | ||||||
|  |     Positions: | ||||||
|  |       - (2, 1) | ||||||
|  |       - (2, 5) | ||||||
|  |     # It is okay to collide with other agents, so that  | ||||||
|  |     #   they end up on the same position | ||||||
|  |     is_blocking_pos: true | ||||||
|  |   # See Above.... | ||||||
|  |   Karl-Heinz: | ||||||
|  |     Actions: | ||||||
|  |       - Noop | ||||||
|  |       - Move8 | ||||||
|  |     Observations: | ||||||
|  |       - Walls | ||||||
|  |       - Other | ||||||
|  |       - Destination | ||||||
|  |     Positions: | ||||||
|  |       - (2, 1) | ||||||
|  |       - (2, 5) | ||||||
|  |     is_blocking_pos: true | ||||||
|  |  | ||||||
|  | # Other noteworthy Entitites | ||||||
|  | Entities: | ||||||
|  |   # The destiantions or positional targets to reach | ||||||
|  |   Destinations: | ||||||
|  |     # Let them spawn on closed doors and agent positions | ||||||
|  |     ignore_blocking: true | ||||||
|  |     # We need a special spawn rule... | ||||||
|  |     spawnrule: | ||||||
|  |       # ...which assigns the destinations per agent | ||||||
|  |       SpawnDestinationsPerAgent: | ||||||
|  |         # we use this parameter | ||||||
|  |         coords_or_quantity: | ||||||
|  |           # to enable and assign special positions per agent | ||||||
|           Wolfgang: |           Wolfgang: | ||||||
|               - (2, 1) |               - (2, 1) | ||||||
|               - (2, 5) |               - (2, 5) | ||||||
|           Karl-Heinz: |           Karl-Heinz: | ||||||
|               - (2, 1) |               - (2, 1) | ||||||
|               - (2, 5) |               - (2, 5) | ||||||
|   DestinationReachAll: {} |     # Whether you want to provide a numeric Position observation. | ||||||
|  |     # GlobalPositions: | ||||||
|  |     #   normalized: false | ||||||
|  |  | ||||||
|  | # Define the env. dynamics | ||||||
|  | Rules: | ||||||
|  |   # Utilities | ||||||
|  |   #  This rule Checks for Collision, also it assigns the (negative) reward | ||||||
|  |   WatchCollisions: | ||||||
|  |     reward: -0.1 | ||||||
|  |     reward_at_done: -1 | ||||||
|  |     done_at_collisions: false | ||||||
|  |   # Done Conditions | ||||||
|  |   #   Load any of the rules, to check for done conditions.  | ||||||
|  |   # DoneAtDestinationReachAny: | ||||||
|  |   DoneAtDestinationReachAll: | ||||||
|  |   #  reward_at_done: 1 | ||||||
|  |   DoneAtMaxStepsReached: | ||||||
|  |     max_steps: 200 | ||||||
|   | |||||||
| @@ -48,9 +48,9 @@ class Move(Action, abc.ABC): | |||||||
|             reward = r.MOVEMENTS_VALID if move_validity else r.MOVEMENTS_FAIL |             reward = r.MOVEMENTS_VALID if move_validity else r.MOVEMENTS_FAIL | ||||||
|             return ActionResult(entity=entity, identifier=self._identifier, validity=move_validity, reward=reward) |             return ActionResult(entity=entity, identifier=self._identifier, validity=move_validity, reward=reward) | ||||||
|         else:  # There is no place to go, propably collision |         else:  # There is no place to go, propably collision | ||||||
|             # This is currently handeld by the Collision rule, so that it can be switched on and off by conf.yml |             # This is currently handeld by the WatchCollisions rule, so that it can be switched on and off by conf.yml | ||||||
|             # return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=r.COLLISION) |             # return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=r.COLLISION) | ||||||
|             return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=0) |             return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID) | ||||||
|  |  | ||||||
|     def _calc_new_pos(self, pos): |     def _calc_new_pos(self, pos): | ||||||
|         x_diff, y_diff = MOVEMAP[self._identifier] |         x_diff, y_diff = MOVEMAP[self._identifier] | ||||||
|   | |||||||
| @@ -10,6 +10,7 @@ AGENT                   = 'Agent'               # Identifier of Agent-objects an | |||||||
| OTHERS                  = 'Other' | OTHERS                  = 'Other' | ||||||
| COMBINED                = 'Combined' | COMBINED                = 'Combined' | ||||||
| GLOBALPOSITIONS         = 'GlobalPositions'     # Identifier of the global position slice | GLOBALPOSITIONS         = 'GlobalPositions'     # Identifier of the global position slice | ||||||
|  | SPAWN_ENTITY_RULE       = 'SpawnEntity' | ||||||
|  |  | ||||||
| # Attributes | # Attributes | ||||||
| IS_BLOCKING_LIGHT       = 'var_is_blocking_light' | IS_BLOCKING_LIGHT       = 'var_is_blocking_light' | ||||||
| @@ -29,7 +30,7 @@ VALUE_NO_POS            = (-9999, -9999)  # Invalid Position value used in the e | |||||||
|  |  | ||||||
|  |  | ||||||
| ACTION                  = 'action'  # Identifier of Action-objects and groups (groups). | ACTION                  = 'action'  # Identifier of Action-objects and groups (groups). | ||||||
| COLLISION               = 'Collision'  # Identifier to use in the context of collitions. | COLLISION               = 'Collisions'  # Identifier to use in the context of collitions. | ||||||
| # LAST_POS                = 'LAST_POS'  # Identifiert for retrieving an enitites last pos. | # LAST_POS                = 'LAST_POS'  # Identifiert for retrieving an enitites last pos. | ||||||
| VALIDITY                = 'VALIDITY'  # Identifiert for retrieving the Validity of Action, Tick, etc. ... | VALIDITY                = 'VALIDITY'  # Identifiert for retrieving the Validity of Action, Tick, etc. ... | ||||||
|  |  | ||||||
| @@ -54,3 +55,5 @@ NOOP                    = 'Noop' | |||||||
| # Result Identifier | # Result Identifier | ||||||
| MOVEMENTS_VALID = 'motion_valid' | MOVEMENTS_VALID = 'motion_valid' | ||||||
| MOVEMENTS_FAIL  = 'motion_not_valid' | MOVEMENTS_FAIL  = 'motion_not_valid' | ||||||
|  | DEFAULT_PATH = 'environment' | ||||||
|  | MODULE_PATH = 'modules' | ||||||
|   | |||||||
| @@ -12,14 +12,6 @@ from marl_factory_grid.environment import constants as c | |||||||
|  |  | ||||||
| class Agent(Entity): | class Agent(Entity): | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def var_is_blocking_light(self): |  | ||||||
|         return False |  | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def var_can_move(self): |  | ||||||
|         return True |  | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def var_is_paralyzed(self): |     def var_is_paralyzed(self): | ||||||
|         return len(self._paralyzed) |         return len(self._paralyzed) | ||||||
| @@ -28,14 +20,6 @@ class Agent(Entity): | |||||||
|     def paralyze_reasons(self): |     def paralyze_reasons(self): | ||||||
|         return [x for x in self._paralyzed] |         return [x for x in self._paralyzed] | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def var_is_blocking_pos(self): |  | ||||||
|         return False |  | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def var_has_position(self): |  | ||||||
|         return True |  | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def obs_tag(self): |     def obs_tag(self): | ||||||
|         return self.name |         return self.name | ||||||
| @@ -48,10 +32,6 @@ class Agent(Entity): | |||||||
|     def observations(self): |     def observations(self): | ||||||
|         return self._observations |         return self._observations | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def var_can_collide(self): |  | ||||||
|         return True |  | ||||||
|  |  | ||||||
|     def step_result(self): |     def step_result(self): | ||||||
|         pass |         pass | ||||||
|  |  | ||||||
| @@ -60,16 +40,21 @@ class Agent(Entity): | |||||||
|         return self._collection |         return self._collection | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def state(self): |     def var_is_blocking_pos(self): | ||||||
|         return self._state or ActionResult(entity=self, identifier=c.NOOP, validity=c.VALID, reward=0) |         return self._is_blocking_pos | ||||||
|  |  | ||||||
|     def __init__(self, actions: List[Action], observations: List[str], *args, **kwargs): |     @property | ||||||
|  |     def state(self): | ||||||
|  |         return self._state or ActionResult(entity=self, identifier=c.NOOP, validity=c.VALID) | ||||||
|  |  | ||||||
|  |     def __init__(self, actions: List[Action], observations: List[str], *args, is_blocking_pos=False, **kwargs): | ||||||
|         super(Agent, self).__init__(*args, **kwargs) |         super(Agent, self).__init__(*args, **kwargs) | ||||||
|         self._paralyzed = set() |         self._paralyzed = set() | ||||||
|         self.step_result = dict() |         self.step_result = dict() | ||||||
|         self._actions = actions |         self._actions = actions | ||||||
|         self._observations = observations |         self._observations = observations | ||||||
|         self._state: Union[Result, None] = None |         self._state: Union[Result, None] = None | ||||||
|  |         self._is_blocking_pos = is_blocking_pos | ||||||
|  |  | ||||||
|     # noinspection PyAttributeOutsideInit |     # noinspection PyAttributeOutsideInit | ||||||
|     def clear_temp_state(self): |     def clear_temp_state(self): | ||||||
|   | |||||||
| @@ -1,20 +1,19 @@ | |||||||
| import abc | import abc | ||||||
| from collections import defaultdict |  | ||||||
|  |  | ||||||
| import numpy as np | import numpy as np | ||||||
|  |  | ||||||
| from .object import _Object | from .object import Object | ||||||
| from .. import constants as c | from .. import constants as c | ||||||
| from ...utils.results import ActionResult | from ...utils.results import ActionResult | ||||||
| from ...utils.utility_classes import RenderEntity | from ...utils.utility_classes import RenderEntity | ||||||
|  |  | ||||||
|  |  | ||||||
| class Entity(_Object, abc.ABC): | class Entity(Object, abc.ABC): | ||||||
|     """Full Env Entity that lives on the environment Grid. Doors, Items, DirtPile etc...""" |     """Full Env Entity that lives on the environment Grid. Doors, Items, DirtPile etc...""" | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def state(self): |     def state(self): | ||||||
|         return self._status or ActionResult(entity=self, identifier=c.NOOP, validity=c.VALID, reward=0) |         return self._status or ActionResult(entity=self, identifier=c.NOOP, validity=c.VALID) | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def var_has_position(self): |     def var_has_position(self): | ||||||
| @@ -60,6 +59,10 @@ class Entity(_Object, abc.ABC): | |||||||
|     def pos(self): |     def pos(self): | ||||||
|         return self._pos |         return self._pos | ||||||
|  |  | ||||||
|  |     def set_pos(self, pos): | ||||||
|  |         assert isinstance(pos, tuple) and len(pos) == 2 | ||||||
|  |         self._pos = pos | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def last_pos(self): |     def last_pos(self): | ||||||
|         try: |         try: | ||||||
| @@ -84,7 +87,7 @@ class Entity(_Object, abc.ABC): | |||||||
|                 for observer in self.observers: |                 for observer in self.observers: | ||||||
|                     observer.notify_del_entity(self) |                     observer.notify_del_entity(self) | ||||||
|                 self._view_directory = curr_pos[0] - next_pos[0], curr_pos[1] - next_pos[1] |                 self._view_directory = curr_pos[0] - next_pos[0], curr_pos[1] - next_pos[1] | ||||||
|                 self._pos = next_pos |                 self.set_pos(next_pos) | ||||||
|                 for observer in self.observers: |                 for observer in self.observers: | ||||||
|                     observer.notify_add_entity(self) |                     observer.notify_add_entity(self) | ||||||
|             return valid |             return valid | ||||||
| @@ -92,6 +95,7 @@ class Entity(_Object, abc.ABC): | |||||||
|  |  | ||||||
|     def __init__(self, pos, bind_to=None, **kwargs): |     def __init__(self, pos, bind_to=None, **kwargs): | ||||||
|         super().__init__(**kwargs) |         super().__init__(**kwargs) | ||||||
|  |         self._view_directory = c.VALUE_NO_POS | ||||||
|         self._status = None |         self._status = None | ||||||
|         self._pos = pos |         self._pos = pos | ||||||
|         self._last_pos = pos |         self._last_pos = pos | ||||||
| @@ -109,9 +113,6 @@ class Entity(_Object, abc.ABC): | |||||||
|     def render(self): |     def render(self): | ||||||
|         return RenderEntity(self.__class__.__name__.lower(), self.pos) |         return RenderEntity(self.__class__.__name__.lower(), self.pos) | ||||||
|  |  | ||||||
|     def __repr__(self): |  | ||||||
|         return super(Entity, self).__repr__() + f'(@{self.pos})' |  | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def obs_tag(self): |     def obs_tag(self): | ||||||
|         try: |         try: | ||||||
| @@ -128,25 +129,3 @@ class Entity(_Object, abc.ABC): | |||||||
|         self._collection.delete_env_object(self) |         self._collection.delete_env_object(self) | ||||||
|         self._collection = other_collection |         self._collection = other_collection | ||||||
|         return self._collection == other_collection |         return self._collection == other_collection | ||||||
|  |  | ||||||
|     @classmethod |  | ||||||
|     def from_coordinates(cls, positions: [(int, int)], *args, entity_kwargs=None, **kwargs, ): |  | ||||||
|         collection = cls(*args, **kwargs) |  | ||||||
|         collection.add_items( |  | ||||||
|             [cls._entity(tuple(pos), **entity_kwargs if entity_kwargs is not None else {}) for pos in positions]) |  | ||||||
|         return collection |  | ||||||
|  |  | ||||||
|     def notify_del_entity(self, entity): |  | ||||||
|         try: |  | ||||||
|             self.pos_dict[entity.pos].remove(entity) |  | ||||||
|         except (ValueError, AttributeError): |  | ||||||
|             pass |  | ||||||
|  |  | ||||||
|     def by_pos(self, pos: (int, int)): |  | ||||||
|         pos = tuple(pos) |  | ||||||
|         try: |  | ||||||
|             return self.state.entities.pos_dict[pos] |  | ||||||
|         except StopIteration: |  | ||||||
|             pass |  | ||||||
|         except ValueError: |  | ||||||
|             print() |  | ||||||
|   | |||||||
| @@ -1,24 +0,0 @@ | |||||||
|  |  | ||||||
|  |  | ||||||
| # noinspection PyAttributeOutsideInit |  | ||||||
| class BoundEntityMixin: |  | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def bound_entity(self): |  | ||||||
|         return self._bound_entity |  | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def name(self): |  | ||||||
|         if self.bound_entity: |  | ||||||
|             return f'{self.__class__.__name__}({self.bound_entity.name})' |  | ||||||
|         else: |  | ||||||
|             pass |  | ||||||
|  |  | ||||||
|     def belongs_to_entity(self, entity): |  | ||||||
|         return entity == self.bound_entity |  | ||||||
|  |  | ||||||
|     def bind_to(self, entity): |  | ||||||
|         self._bound_entity = entity |  | ||||||
|  |  | ||||||
|     def unbind(self): |  | ||||||
|         self._bound_entity = None |  | ||||||
| @@ -5,7 +5,7 @@ from marl_factory_grid.environment import constants as c | |||||||
| import marl_factory_grid.utils.helpers as h | import marl_factory_grid.utils.helpers as h | ||||||
|  |  | ||||||
|  |  | ||||||
| class _Object: | class Object: | ||||||
|     """Generell Objects for Organisation and Maintanance such as Actions etc...""" |     """Generell Objects for Organisation and Maintanance such as Actions etc...""" | ||||||
|  |  | ||||||
|     _u_idx = defaultdict(lambda: 0) |     _u_idx = defaultdict(lambda: 0) | ||||||
| @@ -13,10 +13,6 @@ class _Object: | |||||||
|     def __bool__(self): |     def __bool__(self): | ||||||
|         return True |         return True | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def var_has_position(self): |  | ||||||
|         return False |  | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def var_can_be_bound(self): |     def var_can_be_bound(self): | ||||||
|         try: |         try: | ||||||
| @@ -30,22 +26,14 @@ class _Object: | |||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def name(self): |     def name(self): | ||||||
|         if self._str_ident is not None: |         return f'{self.__class__.__name__}[{self.identifier}]' | ||||||
|             name = f'{self.__class__.__name__}[{self._str_ident}]' |  | ||||||
|         else: |  | ||||||
|             name = f'{self.__class__.__name__}#{self.u_int}' |  | ||||||
|         if self.bound_entity: |  | ||||||
|             name = h.add_bound_name(name, self.bound_entity) |  | ||||||
|         if self.var_has_position: |  | ||||||
|             name = h.add_pos_name(name, self) |  | ||||||
|         return name |  | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def identifier(self): |     def identifier(self): | ||||||
|         if self._str_ident is not None: |         if self._str_ident is not None: | ||||||
|             return self._str_ident |             return self._str_ident | ||||||
|         else: |         else: | ||||||
|             return self.name |             return self.u_int | ||||||
|  |  | ||||||
|     def reset_uid(self): |     def reset_uid(self): | ||||||
|         self._u_idx = defaultdict(lambda: 0) |         self._u_idx = defaultdict(lambda: 0) | ||||||
| @@ -62,7 +50,15 @@ class _Object: | |||||||
|             print(f'Following kwargs were passed, but ignored: {kwargs}') |             print(f'Following kwargs were passed, but ignored: {kwargs}') | ||||||
|  |  | ||||||
|     def __repr__(self): |     def __repr__(self): | ||||||
|         return f'{self.name}' |         name = self.name | ||||||
|  |         if self.bound_entity: | ||||||
|  |             name = h.add_bound_name(name, self.bound_entity) | ||||||
|  |         try: | ||||||
|  |             if self.var_has_position: | ||||||
|  |                 name = h.add_pos_name(name, self) | ||||||
|  |         except AttributeError: | ||||||
|  |             pass | ||||||
|  |         return name | ||||||
|  |  | ||||||
|     def __eq__(self, other) -> bool: |     def __eq__(self, other) -> bool: | ||||||
|         return other == self.identifier |         return other == self.identifier | ||||||
| @@ -71,8 +67,8 @@ class _Object: | |||||||
|         return hash(self.identifier) |         return hash(self.identifier) | ||||||
|  |  | ||||||
|     def _identify_and_count_up(self): |     def _identify_and_count_up(self): | ||||||
|         idx = _Object._u_idx[self.__class__.__name__] |         idx = Object._u_idx[self.__class__.__name__] | ||||||
|         _Object._u_idx[self.__class__.__name__] += 1 |         Object._u_idx[self.__class__.__name__] += 1 | ||||||
|         return idx |         return idx | ||||||
|  |  | ||||||
|     def set_collection(self, collection): |     def set_collection(self, collection): | ||||||
| @@ -88,7 +84,7 @@ class _Object: | |||||||
|     def summarize_state(self): |     def summarize_state(self): | ||||||
|         return dict() |         return dict() | ||||||
|  |  | ||||||
|     def bind(self, entity): |     def bind_to(self, entity): | ||||||
|         # noinspection PyAttributeOutsideInit |         # noinspection PyAttributeOutsideInit | ||||||
|         self._bound_entity = entity |         self._bound_entity = entity | ||||||
|         return c.VALID |         return c.VALID | ||||||
| @@ -100,84 +96,5 @@ class _Object: | |||||||
|     def bound_entity(self): |     def bound_entity(self): | ||||||
|         return self._bound_entity |         return self._bound_entity | ||||||
|  |  | ||||||
|     def bind_to(self, entity): |  | ||||||
|         self._bound_entity = entity |  | ||||||
|  |  | ||||||
|     def unbind(self): |     def unbind(self): | ||||||
|         self._bound_entity = None |         self._bound_entity = None | ||||||
|  |  | ||||||
|  |  | ||||||
| # class EnvObject(_Object): |  | ||||||
| #     """Objects that hold Information that are observable, but have no position on the environment grid. Inventories etc...""" |  | ||||||
| # |  | ||||||
|     # _u_idx = defaultdict(lambda: 0) |  | ||||||
| # |  | ||||||
| #     @property |  | ||||||
| #     def obs_tag(self): |  | ||||||
| #         try: |  | ||||||
| #             return self._collection.name or self.name |  | ||||||
| #         except AttributeError: |  | ||||||
| #             return self.name |  | ||||||
| # |  | ||||||
| #     @property |  | ||||||
| #     def var_is_blocking_light(self): |  | ||||||
| #         try: |  | ||||||
| #             return self._collection.var_is_blocking_light or False |  | ||||||
| #         except AttributeError: |  | ||||||
| #             return False |  | ||||||
| # |  | ||||||
| #     @property |  | ||||||
| #     def var_can_be_bound(self): |  | ||||||
| #         try: |  | ||||||
| #             return self._collection.var_can_be_bound or False |  | ||||||
| #         except AttributeError: |  | ||||||
| #             return False |  | ||||||
| # |  | ||||||
| #     @property |  | ||||||
| #     def var_can_move(self): |  | ||||||
| #         try: |  | ||||||
| #             return self._collection.var_can_move or False |  | ||||||
| #         except AttributeError: |  | ||||||
| #             return False |  | ||||||
| # |  | ||||||
| #     @property |  | ||||||
| #     def var_is_blocking_pos(self): |  | ||||||
| #         try: |  | ||||||
| #             return self._collection.var_is_blocking_pos or False |  | ||||||
| #         except AttributeError: |  | ||||||
| #             return False |  | ||||||
| # |  | ||||||
| #     @property |  | ||||||
| #     def var_has_position(self): |  | ||||||
| #         try: |  | ||||||
| #             return self._collection.var_has_position or False |  | ||||||
| #         except AttributeError: |  | ||||||
| #             return False |  | ||||||
| # |  | ||||||
| # @property |  | ||||||
| # def var_can_collide(self): |  | ||||||
| #     try: |  | ||||||
| #         return self._collection.var_can_collide or False |  | ||||||
| #     except AttributeError: |  | ||||||
| #         return False |  | ||||||
| # |  | ||||||
| # |  | ||||||
| # @property |  | ||||||
| # def encoding(self): |  | ||||||
| #     return c.VALUE_OCCUPIED_CELL |  | ||||||
| # |  | ||||||
| # |  | ||||||
| # def __init__(self, **kwargs): |  | ||||||
| #     self._bound_entity = None |  | ||||||
| #     super(EnvObject, self).__init__(**kwargs) |  | ||||||
| # |  | ||||||
| # |  | ||||||
| # def change_parent_collection(self, other_collection): |  | ||||||
| #     other_collection.add_item(self) |  | ||||||
| #     self._collection.delete_env_object(self) |  | ||||||
| #     self._collection = other_collection |  | ||||||
| #     return self._collection == other_collection |  | ||||||
| # |  | ||||||
| # |  | ||||||
| # def summarize_state(self): |  | ||||||
| #     return dict(name=str(self.name)) |  | ||||||
|   | |||||||
| @@ -1,6 +1,6 @@ | |||||||
| import numpy as np | import numpy as np | ||||||
|  |  | ||||||
| from marl_factory_grid.environment.entity.object import _Object | from marl_factory_grid.environment.entity.object import Object | ||||||
|  |  | ||||||
|  |  | ||||||
| ########################################################################## | ########################################################################## | ||||||
| @@ -8,7 +8,7 @@ from marl_factory_grid.environment.entity.object import _Object | |||||||
| ########################################################################## | ########################################################################## | ||||||
|  |  | ||||||
|  |  | ||||||
| class PlaceHolder(_Object): | class PlaceHolder(Object): | ||||||
|  |  | ||||||
|     def __init__(self, *args, fill_value=0, **kwargs): |     def __init__(self, *args, fill_value=0, **kwargs): | ||||||
|         super().__init__(*args, **kwargs) |         super().__init__(*args, **kwargs) | ||||||
| @@ -24,10 +24,10 @@ class PlaceHolder(_Object): | |||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def name(self): |     def name(self): | ||||||
|         return "PlaceHolder" |         return self.__class__.__name__ | ||||||
|  |  | ||||||
|  |  | ||||||
| class GlobalPosition(_Object): | class GlobalPosition(Object): | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def encoding(self): |     def encoding(self): | ||||||
| @@ -36,7 +36,8 @@ class GlobalPosition(_Object): | |||||||
|         else: |         else: | ||||||
|             return self.bound_entity.pos |             return self.bound_entity.pos | ||||||
|  |  | ||||||
|     def __init__(self, level_shape, *args, normalized: bool = True, **kwargs): |     def __init__(self, agent, level_shape, *args, normalized: bool = True, **kwargs): | ||||||
|         super(GlobalPosition, self).__init__(*args, **kwargs) |         super(GlobalPosition, self).__init__(*args, **kwargs) | ||||||
|  |         self.bind_to(agent) | ||||||
|         self._normalized = normalized |         self._normalized = normalized | ||||||
|         self._shape = level_shape |         self._shape = level_shape | ||||||
|   | |||||||
| @@ -5,13 +5,8 @@ from marl_factory_grid.utils.utility_classes import RenderEntity | |||||||
|  |  | ||||||
| class Wall(Entity): | class Wall(Entity): | ||||||
|  |  | ||||||
|     @property |     def __init__(self, *args, **kwargs): | ||||||
|     def var_has_position(self): |         super().__init__(*args, **kwargs) | ||||||
|         return True |  | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def var_can_collide(self): |  | ||||||
|         return True |  | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def encoding(self): |     def encoding(self): | ||||||
| @@ -19,11 +14,3 @@ class Wall(Entity): | |||||||
|  |  | ||||||
|     def render(self): |     def render(self): | ||||||
|         return RenderEntity(c.WALL, self.pos) |         return RenderEntity(c.WALL, self.pos) | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def var_is_blocking_pos(self): |  | ||||||
|         return True |  | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def var_is_blocking_light(self): |  | ||||||
|         return True |  | ||||||
|   | |||||||
| @@ -56,15 +56,18 @@ class Factory(gym.Env): | |||||||
|             self.level_filepath = Path(custom_level_path) |             self.level_filepath = Path(custom_level_path) | ||||||
|         else: |         else: | ||||||
|             self.level_filepath = Path(__file__).parent.parent / h.LEVELS_DIR / f'{self.conf.level_name}.txt' |             self.level_filepath = Path(__file__).parent.parent / h.LEVELS_DIR / f'{self.conf.level_name}.txt' | ||||||
|         self._renderer = None  # expensive - don't use; unless required ! |  | ||||||
|  |  | ||||||
|         parsed_entities = self.conf.load_entities() |         parsed_entities = self.conf.load_entities() | ||||||
|         self.map = LevelParser(self.level_filepath, parsed_entities, self.conf.pomdp_r) |         self.map = LevelParser(self.level_filepath, parsed_entities, self.conf.pomdp_r) | ||||||
|  |  | ||||||
|         # Init for later usage: |         # Init for later usage: | ||||||
|         self.state: Gamestate |         # noinspection PyTypeChecker | ||||||
|         self.map: LevelParser |         self.state: Gamestate = None | ||||||
|         self.obs_builder: OBSBuilder |         # noinspection PyTypeChecker | ||||||
|  |         self.obs_builder: OBSBuilder = None | ||||||
|  |  | ||||||
|  |         # expensive - don't use; unless required ! | ||||||
|  |         self._renderer = None | ||||||
|  |  | ||||||
|         # reset env to initial state, preparing env for new episode. |         # reset env to initial state, preparing env for new episode. | ||||||
|         # returns tuple where the first dict contains initial observation for each agent in the env |         # returns tuple where the first dict contains initial observation for each agent in the env | ||||||
| @@ -74,7 +77,7 @@ class Factory(gym.Env): | |||||||
|         return self.state.entities[item] |         return self.state.entities[item] | ||||||
|  |  | ||||||
|     def reset(self) -> (dict, dict): |     def reset(self) -> (dict, dict): | ||||||
|         if hasattr(self, 'state'): |         if self.state is not None: | ||||||
|             for entity_group in self.state.entities: |             for entity_group in self.state.entities: | ||||||
|                 try: |                 try: | ||||||
|                     entity_group[0].reset_uid() |                     entity_group[0].reset_uid() | ||||||
| @@ -87,12 +90,16 @@ class Factory(gym.Env): | |||||||
|         entities = self.map.do_init() |         entities = self.map.do_init() | ||||||
|  |  | ||||||
|         # Init rules |         # Init rules | ||||||
|         rules = self.conf.load_env_rules() |         env_rules = self.conf.load_env_rules() | ||||||
|  |         entity_rules = self.conf.load_entity_spawn_rules(entities) | ||||||
|  |         env_rules.extend(entity_rules) | ||||||
|  |  | ||||||
|         env_tests = self.conf.load_env_tests() if self.conf.tests else [] |         env_tests = self.conf.load_env_tests() if self.conf.tests else [] | ||||||
|  |  | ||||||
|         # Parse the agent conf |         # Parse the agent conf | ||||||
|         parsed_agents_conf = self.conf.parse_agents_conf() |         parsed_agents_conf = self.conf.parse_agents_conf() | ||||||
|         self.state = Gamestate(entities, parsed_agents_conf, rules, env_tests, self.conf.env_seed, self.conf.verbose) |         self.state = Gamestate(entities, parsed_agents_conf, env_rules, env_tests, self.map.level_shape, | ||||||
|  |                                self.conf.env_seed, self.conf.verbose) | ||||||
|  |  | ||||||
|         # All is set up, trigger entity init with variable pos |         # All is set up, trigger entity init with variable pos | ||||||
|         # All is set up, trigger additional init (after agent entity spawn etc) |         # All is set up, trigger additional init (after agent entity spawn etc) | ||||||
| @@ -160,7 +167,7 @@ class Factory(gym.Env): | |||||||
|         # Finalize |         # Finalize | ||||||
|         reward, reward_info, done = self.summarize_step_results(tick_result, done_results) |         reward, reward_info, done = self.summarize_step_results(tick_result, done_results) | ||||||
|  |  | ||||||
|         info = reward_info |         info = dict(reward_info) | ||||||
|  |  | ||||||
|         info.update(step_reward=sum(reward), step=self.state.curr_step) |         info.update(step_reward=sum(reward), step=self.state.curr_step) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,10 +1,15 @@ | |||||||
| from marl_factory_grid.environment.entity.agent import Agent | from marl_factory_grid.environment.entity.agent import Agent | ||||||
| from marl_factory_grid.environment.groups.collection import Collection | from marl_factory_grid.environment.groups.collection import Collection | ||||||
|  | from marl_factory_grid.environment.rules import SpawnAgents | ||||||
|  |  | ||||||
|  |  | ||||||
| class Agents(Collection): | class Agents(Collection): | ||||||
|     _entity = Agent |     _entity = Agent | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def spawn_rule(self): | ||||||
|  |         return {SpawnAgents.__name__: {}} | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def var_is_blocking_light(self): |     def var_is_blocking_light(self): | ||||||
|         return False |         return False | ||||||
|   | |||||||
| @@ -1,18 +1,25 @@ | |||||||
| from typing import List, Tuple, Union | from typing import List, Tuple, Union, Dict | ||||||
|  |  | ||||||
| from marl_factory_grid.environment.entity.entity import Entity | from marl_factory_grid.environment.entity.entity import Entity | ||||||
| from marl_factory_grid.environment.groups.objects import _Objects | from marl_factory_grid.environment.groups.objects import Objects | ||||||
| from marl_factory_grid.environment.entity.object import _Object | # noinspection PyProtectedMember | ||||||
|  | from marl_factory_grid.environment.entity.object import Object | ||||||
| import marl_factory_grid.environment.constants as c | import marl_factory_grid.environment.constants as c | ||||||
|  | from marl_factory_grid.utils.results import Result | ||||||
|  |  | ||||||
|  |  | ||||||
| class Collection(_Objects): | class Collection(Objects): | ||||||
|     _entity = _Object  # entity? |     _entity = Object  # entity? | ||||||
|  |     symbol = None | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def var_is_blocking_light(self): |     def var_is_blocking_light(self): | ||||||
|         return False |         return False | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def var_is_blocking_pos(self): | ||||||
|  |         return False | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def var_can_collide(self): |     def var_can_collide(self): | ||||||
|         return False |         return False | ||||||
| @@ -23,33 +30,65 @@ class Collection(_Objects): | |||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def var_has_position(self): |     def var_has_position(self): | ||||||
|         return False |         return True | ||||||
|  |  | ||||||
|     # @property |  | ||||||
|     # def var_has_bound(self): |  | ||||||
|     #     return False  # batteries, globalpos, inventories true |  | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def var_can_be_bound(self): |  | ||||||
|         return False |  | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def encodings(self): |     def encodings(self): | ||||||
|         return [x.encoding for x in self] |         return [x.encoding for x in self] | ||||||
|  |  | ||||||
|     def __init__(self, size, *args, **kwargs): |     @property | ||||||
|         super(Collection, self).__init__(*args, **kwargs) |     def spawn_rule(self): | ||||||
|         self.size = size |         """Prevent SpawnRule creation if Objects are spawned by map, Doors e.g.""" | ||||||
|  |         if self.symbol: | ||||||
|     def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args):  # woihn mit den args |             return None | ||||||
|         if isinstance(coords_or_quantity, int): |         elif self._spawnrule: | ||||||
|             self.add_items([self._entity() for _ in range(coords_or_quantity)]) |             return self._spawnrule | ||||||
|         else: |         else: | ||||||
|             self.add_items([self._entity(pos) for pos in coords_or_quantity]) |             return {c.SPAWN_ENTITY_RULE: dict(collection=self, coords_or_quantity=self._coords_or_quantity)} | ||||||
|  |  | ||||||
|  |     def __init__(self, size, *args, coords_or_quantity: int = None, ignore_blocking=False, | ||||||
|  |                  spawnrule: Union[None, Dict[str, dict]] = None, | ||||||
|  |                  **kwargs): | ||||||
|  |         super(Collection, self).__init__(*args, **kwargs) | ||||||
|  |         self._coords_or_quantity = coords_or_quantity | ||||||
|  |         self.size = size | ||||||
|  |         self._spawnrule = spawnrule | ||||||
|  |         self._ignore_blocking = ignore_blocking | ||||||
|  |  | ||||||
|  |     def trigger_spawn(self, state, *entity_args, coords_or_quantity=None, ignore_blocking=False,  **entity_kwargs): | ||||||
|  |         coords_or_quantity = coords_or_quantity if coords_or_quantity else self._coords_or_quantity | ||||||
|  |         if self.var_has_position: | ||||||
|  |             if self.var_has_position and isinstance(coords_or_quantity, int): | ||||||
|  |                 if ignore_blocking or self._ignore_blocking: | ||||||
|  |                     coords_or_quantity = state.entities.floorlist[:coords_or_quantity] | ||||||
|  |                 else: | ||||||
|  |                     coords_or_quantity = state.get_n_random_free_positions(coords_or_quantity) | ||||||
|  |             self.spawn(coords_or_quantity, *entity_args,  **entity_kwargs) | ||||||
|  |             state.print(f'{len(coords_or_quantity)} new {self.name} have been spawned at {coords_or_quantity}') | ||||||
|  |             return Result(identifier=f'{self.name}_spawn', validity=c.VALID, value=len(coords_or_quantity)) | ||||||
|  |         else: | ||||||
|  |             if isinstance(coords_or_quantity, int): | ||||||
|  |                 self.spawn(coords_or_quantity, *entity_args,  **entity_kwargs) | ||||||
|  |                 state.print(f'{coords_or_quantity} new {self.name} have been spawned randomly.') | ||||||
|  |                 return Result(identifier=f'{self.name}_spawn', validity=c.VALID, value=coords_or_quantity) | ||||||
|  |             else: | ||||||
|  |                 raise ValueError(f'{self._entity.__name__} has no position!') | ||||||
|  |  | ||||||
|  |     def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args, **entity_kwargs): | ||||||
|  |         if self.var_has_position: | ||||||
|  |             if isinstance(coords_or_quantity, int): | ||||||
|  |                 raise ValueError(f'{self._entity.__name__} should have a position!') | ||||||
|  |             else: | ||||||
|  |                 self.add_items([self._entity(pos, *entity_args, **entity_kwargs) for pos in coords_or_quantity]) | ||||||
|  |         else: | ||||||
|  |             if isinstance(coords_or_quantity, int): | ||||||
|  |                 self.add_items([self._entity(*entity_args, **entity_kwargs) for _ in range(coords_or_quantity)]) | ||||||
|  |             else: | ||||||
|  |                 raise ValueError(f'{self._entity.__name__} has no  position!') | ||||||
|         return c.VALID |         return c.VALID | ||||||
|  |  | ||||||
|     def despawn(self, items: List[_Object]): |     def despawn(self, items: List[Object]): | ||||||
|         items = [items] if isinstance(items, _Object) else items |         items = [items] if isinstance(items, Object) else items | ||||||
|         for item in items: |         for item in items: | ||||||
|             del self[item] |             del self[item] | ||||||
|  |  | ||||||
| @@ -115,7 +154,7 @@ class Collection(_Objects): | |||||||
|         except StopIteration: |         except StopIteration: | ||||||
|             pass |             pass | ||||||
|         except ValueError: |         except ValueError: | ||||||
|             print() |             pass | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def positions(self): |     def positions(self): | ||||||
|   | |||||||
| @@ -1,21 +1,21 @@ | |||||||
| from collections import defaultdict | from collections import defaultdict | ||||||
| from operator import itemgetter | from operator import itemgetter | ||||||
| from random import shuffle, random | from random import shuffle | ||||||
| from typing import Dict | from typing import Dict | ||||||
|  |  | ||||||
| from marl_factory_grid.environment.groups.objects import _Objects | from marl_factory_grid.environment.groups.objects import Objects | ||||||
| from marl_factory_grid.utils.helpers import POS_MASK | from marl_factory_grid.utils.helpers import POS_MASK | ||||||
|  |  | ||||||
|  |  | ||||||
| class Entities(_Objects): | class Entities(Objects): | ||||||
|     _entity = _Objects |     _entity = Objects | ||||||
|  |  | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     def neighboring_positions(pos): |     def neighboring_positions(pos): | ||||||
|         return (POS_MASK + pos).reshape(-1, 2) |         return [tuple(x) for x in (POS_MASK + pos).reshape(-1, 2)] | ||||||
|  |  | ||||||
|     def get_entities_near_pos(self, pos): |     def get_entities_near_pos(self, pos): | ||||||
|         return [y for x in itemgetter(*(tuple(x) for x in self.neighboring_positions(pos)))(self.pos_dict) for y in x] |         return [y for x in itemgetter(*self.neighboring_positions(pos))(self.pos_dict) for y in x] | ||||||
|  |  | ||||||
|     def render(self): |     def render(self): | ||||||
|         return [y for x in self for y in x.render() if x is not None] |         return [y for x in self for y in x.render() if x is not None] | ||||||
| @@ -35,8 +35,9 @@ class Entities(_Objects): | |||||||
|         super().__init__() |         super().__init__() | ||||||
|  |  | ||||||
|     def guests_that_can_collide(self, pos): |     def guests_that_can_collide(self, pos): | ||||||
|         return[x for val in self.pos_dict[pos] for x in val if x.var_can_collide] |         return [x for val in self.pos_dict[pos] for x in val if x.var_can_collide] | ||||||
|  |  | ||||||
|  |     @property | ||||||
|     def empty_positions(self): |     def empty_positions(self): | ||||||
|         empty_positions = [key for key in self.floorlist if not self.pos_dict[key]] |         empty_positions = [key for key in self.floorlist if not self.pos_dict[key]] | ||||||
|         shuffle(empty_positions) |         shuffle(empty_positions) | ||||||
| @@ -48,11 +49,23 @@ class Entities(_Objects): | |||||||
|         shuffle(empty_positions) |         shuffle(empty_positions) | ||||||
|         return empty_positions |         return empty_positions | ||||||
|  |  | ||||||
|     def is_blocked(self): |     @property | ||||||
|         return[key for key, val in self.pos_dict.items() if any([x.var_is_blocking_pos for x in val])] |     def blocked_positions(self): | ||||||
|  |         blocked_positions = [key for key, val in self.pos_dict.items() if any([x.var_is_blocking_pos for x in val])] | ||||||
|  |         shuffle(blocked_positions) | ||||||
|  |         return blocked_positions | ||||||
|  |  | ||||||
|     def is_not_blocked(self): |     @property | ||||||
|         return[key for key, val in self.pos_dict.items() if not all([x.var_is_blocking_pos for x in val])] |     def free_positions_generator(self): | ||||||
|  |         generator = ( | ||||||
|  |             key for key in self.floorlist if all(not x.var_can_collide and not x.var_is_blocking_pos | ||||||
|  |                                                  for x in self.pos_dict[key]) | ||||||
|  |                      ) | ||||||
|  |         return generator | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def free_positions_list(self): | ||||||
|  |         return [x for x in self.free_positions_generator] | ||||||
|  |  | ||||||
|     def iter_entities(self): |     def iter_entities(self): | ||||||
|         return iter((x for sublist in self.values() for x in sublist)) |         return iter((x for sublist in self.values() for x in sublist)) | ||||||
| @@ -74,7 +87,7 @@ class Entities(_Objects): | |||||||
|     def __delitem__(self, name): |     def __delitem__(self, name): | ||||||
|         assert_str = 'This group of entity does not exist in this collection!' |         assert_str = 'This group of entity does not exist in this collection!' | ||||||
|         assert any([key for key in name.keys() if key in self.keys()]), assert_str |         assert any([key for key in name.keys() if key in self.keys()]), assert_str | ||||||
|         self[name]._observers.delete(self) |         self[name].del_observer(self) | ||||||
|         for entity in self[name]: |         for entity in self[name]: | ||||||
|             entity.del_observer(self) |             entity.del_observer(self) | ||||||
|         return super(Entities, self).__delitem__(name) |         return super(Entities, self).__delitem__(name) | ||||||
| @@ -92,3 +105,6 @@ class Entities(_Objects): | |||||||
|     @property |     @property | ||||||
|     def positions(self): |     def positions(self): | ||||||
|         return [k for k, v in self.pos_dict.items() for _ in v] |         return [k for k, v in self.pos_dict.items() for _ in v] | ||||||
|  |  | ||||||
|  |     def is_occupied(self, pos): | ||||||
|  |         return len([x for x in self.pos_dict[pos] if x.var_can_collide or x.var_is_blocking_pos]) >= 1 | ||||||
|   | |||||||
| @@ -4,10 +4,6 @@ from marl_factory_grid.environment import constants as c | |||||||
| # noinspection PyUnresolvedReferences,PyTypeChecker | # noinspection PyUnresolvedReferences,PyTypeChecker | ||||||
| class IsBoundMixin: | class IsBoundMixin: | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def name(self): |  | ||||||
|         return f'{self.__class__.__name__}({self._bound_entity.name})' |  | ||||||
|  |  | ||||||
|     def __repr__(self): |     def __repr__(self): | ||||||
|         return f'{self.__class__.__name__}#{self._bound_entity.name}({self._data})' |         return f'{self.__class__.__name__}#{self._bound_entity.name}({self._data})' | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,14 +1,19 @@ | |||||||
| from collections import defaultdict | from collections import defaultdict | ||||||
| from typing import List | from typing import List, Iterator, Union | ||||||
|  |  | ||||||
| import numpy as np | import numpy as np | ||||||
|  |  | ||||||
| from marl_factory_grid.environment.entity.object import _Object | from marl_factory_grid.environment.entity.object import Object | ||||||
| import marl_factory_grid.environment.constants as c | import marl_factory_grid.environment.constants as c | ||||||
|  | from marl_factory_grid.utils import helpers as h | ||||||
|  |  | ||||||
|  |  | ||||||
| class _Objects: | class Objects: | ||||||
|     _entity = _Object |     _entity = Object | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def var_can_be_bound(self): | ||||||
|  |         return False | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def observers(self): |     def observers(self): | ||||||
| @@ -45,7 +50,7 @@ class _Objects: | |||||||
|     def __len__(self): |     def __len__(self): | ||||||
|         return len(self._data) |         return len(self._data) | ||||||
|  |  | ||||||
|     def __iter__(self): |     def __iter__(self) -> Iterator[Union[Object, None]]: | ||||||
|         return iter(self.values()) |         return iter(self.values()) | ||||||
|  |  | ||||||
|     def add_item(self, item: _entity): |     def add_item(self, item: _entity): | ||||||
| @@ -125,13 +130,14 @@ class _Objects: | |||||||
|         repr_dict = {key: val for key, val in self._data.items() if key not in [c.WALLS]} |         repr_dict = {key: val for key, val in self._data.items() if key not in [c.WALLS]} | ||||||
|         return f'{self.__class__.__name__}[{repr_dict}]' |         return f'{self.__class__.__name__}[{repr_dict}]' | ||||||
|  |  | ||||||
|     def notify_del_entity(self, entity: _Object): |     def notify_del_entity(self, entity: Object): | ||||||
|         try: |         try: | ||||||
|  |             # noinspection PyUnresolvedReferences | ||||||
|             self.pos_dict[entity.pos].remove(entity) |             self.pos_dict[entity.pos].remove(entity) | ||||||
|         except (AttributeError, ValueError, IndexError): |         except (AttributeError, ValueError, IndexError): | ||||||
|             pass |             pass | ||||||
|  |  | ||||||
|     def notify_add_entity(self, entity: _Object): |     def notify_add_entity(self, entity: Object): | ||||||
|         try: |         try: | ||||||
|             if self not in entity.observers: |             if self not in entity.observers: | ||||||
|                 entity.add_observer(self) |                 entity.add_observer(self) | ||||||
| @@ -148,12 +154,12 @@ class _Objects: | |||||||
|  |  | ||||||
|     def by_entity(self, entity): |     def by_entity(self, entity): | ||||||
|         try: |         try: | ||||||
|             return next((x for x in self if x.belongs_to_entity(entity))) |             return h.get_first(self, filter_by=lambda x: x.belongs_to_entity(entity)) | ||||||
|         except (StopIteration, AttributeError): |         except (StopIteration, AttributeError): | ||||||
|             return None |             return None | ||||||
|  |  | ||||||
|     def idx_by_entity(self, entity): |     def idx_by_entity(self, entity): | ||||||
|         try: |         try: | ||||||
|             return next((idx for idx, x in enumerate(self) if x.belongs_to_entity(entity))) |             return h.get_first_index(self, filter_by=lambda x: x.belongs_to_entity(entity)) | ||||||
|         except (StopIteration, AttributeError): |         except (StopIteration, AttributeError): | ||||||
|             return None |             return None | ||||||
|   | |||||||
| @@ -1,7 +1,10 @@ | |||||||
| from typing import List, Union | from typing import List, Union | ||||||
|  |  | ||||||
|  | from marl_factory_grid.environment import constants as c | ||||||
| from marl_factory_grid.environment.entity.util import GlobalPosition | from marl_factory_grid.environment.entity.util import GlobalPosition | ||||||
| from marl_factory_grid.environment.groups.collection import Collection | from marl_factory_grid.environment.groups.collection import Collection | ||||||
|  | from marl_factory_grid.utils.results import Result | ||||||
|  | from marl_factory_grid.utils.states import Gamestate | ||||||
|  |  | ||||||
|  |  | ||||||
| class Combined(Collection): | class Combined(Collection): | ||||||
| @@ -36,17 +39,17 @@ class GlobalPositions(Collection): | |||||||
|  |  | ||||||
|     _entity = GlobalPosition |     _entity = GlobalPosition | ||||||
|  |  | ||||||
|     @property |     var_is_blocking_light = False | ||||||
|     def var_is_blocking_light(self): |     var_can_be_bound = True | ||||||
|         return False |     var_can_collide = False | ||||||
|  |     var_has_position = False | ||||||
|     @property |  | ||||||
|     def var_can_collide(self): |  | ||||||
|         return False |  | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def var_can_be_bound(self): |  | ||||||
|         return True |  | ||||||
|  |  | ||||||
|     def __init__(self, *args, **kwargs): |     def __init__(self, *args, **kwargs): | ||||||
|         super(GlobalPositions, self).__init__(*args, **kwargs) |         super(GlobalPositions, self).__init__(*args, **kwargs) | ||||||
|  |  | ||||||
|  |     def spawn(self, agents, level_shape, *args, **kwargs): | ||||||
|  |         self.add_items([self._entity(agent, level_shape, *args, **kwargs) for agent in agents]) | ||||||
|  |         return [Result(identifier=f'{self.name}_spawn', validity=c.VALID, value=len(self))] | ||||||
|  |  | ||||||
|  |     def trigger_spawn(self, state: Gamestate, *args, **kwargs) -> [Result]: | ||||||
|  |         return self.spawn(state[c.AGENT], state.lvl_shape, *args, **kwargs) | ||||||
|   | |||||||
| @@ -7,9 +7,12 @@ class Walls(Collection): | |||||||
|     _entity = Wall |     _entity = Wall | ||||||
|     symbol = c.SYMBOL_WALL |     symbol = c.SYMBOL_WALL | ||||||
|  |  | ||||||
|     @property |     var_can_collide = True | ||||||
|     def var_has_position(self): |     var_is_blocking_light = True | ||||||
|         return True |     var_can_move = False | ||||||
|  |     var_has_position = True | ||||||
|  |     var_can_be_bound = False | ||||||
|  |     var_is_blocking_pos = True | ||||||
|  |  | ||||||
|     def __init__(self, *args, **kwargs): |     def __init__(self, *args, **kwargs): | ||||||
|         super(Walls, self).__init__(*args, **kwargs) |         super(Walls, self).__init__(*args, **kwargs) | ||||||
|   | |||||||
| @@ -2,3 +2,4 @@ MOVEMENTS_VALID: float = -0.001 | |||||||
| MOVEMENTS_FAIL: float  = -0.05 | MOVEMENTS_FAIL: float  = -0.05 | ||||||
| NOOP: float            = -0.01 | NOOP: float            = -0.01 | ||||||
| COLLISION: float       = -0.5 | COLLISION: float       = -0.5 | ||||||
|  | COLLISION_DONE: float  = -1 | ||||||
|   | |||||||
| @@ -1,11 +1,11 @@ | |||||||
| import abc | import abc | ||||||
| from random import shuffle | from random import shuffle | ||||||
| from typing import List | from typing import List, Collection | ||||||
|  |  | ||||||
|  | from marl_factory_grid.environment import rewards as r, constants as c | ||||||
| from marl_factory_grid.environment.entity.agent import Agent | from marl_factory_grid.environment.entity.agent import Agent | ||||||
| from marl_factory_grid.utils import helpers as h | from marl_factory_grid.utils import helpers as h | ||||||
| from marl_factory_grid.utils.results import TickResult, DoneResult | from marl_factory_grid.utils.results import TickResult, DoneResult | ||||||
| from marl_factory_grid.environment import rewards as r, constants as c |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class Rule(abc.ABC): | class Rule(abc.ABC): | ||||||
| @@ -39,6 +39,29 @@ class Rule(abc.ABC): | |||||||
|         return [] |         return [] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class SpawnEntity(Rule): | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def _collection(self) -> Collection: | ||||||
|  |         return Collection() | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def name(self): | ||||||
|  |         return f'{self.__class__.__name__}({self.collection.name})' | ||||||
|  |  | ||||||
|  |     def __init__(self, collection, coords_or_quantity, ignore_blocking=False): | ||||||
|  |         super().__init__() | ||||||
|  |         self.coords_or_quantity = coords_or_quantity | ||||||
|  |         self.collection = collection | ||||||
|  |         self.ignore_blocking = ignore_blocking | ||||||
|  |  | ||||||
|  |     def on_init(self, state, lvl_map) -> [TickResult]: | ||||||
|  |         results = self.collection.trigger_spawn(state, ignore_blocking=self.ignore_blocking) | ||||||
|  |         pos_str = f' on: {[x.pos for x in self.collection]}' if self.collection.var_has_position else '' | ||||||
|  |         state.print(f'Initial {self.collection.__class__.__name__} were spawned{pos_str}') | ||||||
|  |         return results | ||||||
|  |  | ||||||
|  |  | ||||||
| class SpawnAgents(Rule): | class SpawnAgents(Rule): | ||||||
|  |  | ||||||
|     def __init__(self): |     def __init__(self): | ||||||
| @@ -46,14 +69,14 @@ class SpawnAgents(Rule): | |||||||
|         pass |         pass | ||||||
|  |  | ||||||
|     def on_init(self, state, lvl_map): |     def on_init(self, state, lvl_map): | ||||||
|         agent_conf = state.agents_conf |  | ||||||
|         # agents = Agents(lvl_map.size) |         # agents = Agents(lvl_map.size) | ||||||
|         agents = state[c.AGENT] |         agents = state[c.AGENT] | ||||||
|         empty_positions = state.entities.empty_positions()[:len(agent_conf)] |         empty_positions = state.entities.empty_positions[:len(state.agents_conf)] | ||||||
|         for agent_name in agent_conf: |         for agent_name, agent_conf in state.agents_conf.items(): | ||||||
|             actions = agent_conf[agent_name]['actions'].copy() |             actions = agent_conf['actions'].copy() | ||||||
|             observations = agent_conf[agent_name]['observations'].copy() |             observations = agent_conf['observations'].copy() | ||||||
|             positions = agent_conf[agent_name]['positions'].copy() |             positions = agent_conf['positions'].copy() | ||||||
|  |             other = agent_conf['other'].copy() | ||||||
|             if positions: |             if positions: | ||||||
|                 shuffle(positions) |                 shuffle(positions) | ||||||
|                 while True: |                 while True: | ||||||
| @@ -61,18 +84,18 @@ class SpawnAgents(Rule): | |||||||
|                         pos = positions.pop() |                         pos = positions.pop() | ||||||
|                     except IndexError: |                     except IndexError: | ||||||
|                         raise ValueError(f'It was not possible to spawn an Agent on the available position: ' |                         raise ValueError(f'It was not possible to spawn an Agent on the available position: ' | ||||||
|                                          f'\n{agent_name[agent_name]["positions"].copy()}') |                                          f'\n{agent_conf["positions"].copy()}') | ||||||
|                     if agents.by_pos(pos) and state.check_pos_validity(pos): |                     if bool(agents.by_pos(pos)) or not state.check_pos_validity(pos): | ||||||
|                         continue |                         continue | ||||||
|                     else: |                     else: | ||||||
|                         agents.add_item(Agent(actions, observations, pos, str_ident=agent_name)) |                         agents.add_item(Agent(actions, observations, pos, str_ident=agent_name, **other)) | ||||||
|                     break |                     break | ||||||
|             else: |             else: | ||||||
|                 agents.add_item(Agent(actions, observations, empty_positions.pop(), str_ident=agent_name)) |                 agents.add_item(Agent(actions, observations, empty_positions.pop(), str_ident=agent_name, **other)) | ||||||
|         pass |         pass | ||||||
|  |  | ||||||
|  |  | ||||||
| class MaxStepsReached(Rule): | class DoneAtMaxStepsReached(Rule): | ||||||
|  |  | ||||||
|     def __init__(self, max_steps: int = 500): |     def __init__(self, max_steps: int = 500): | ||||||
|         super().__init__() |         super().__init__() | ||||||
| @@ -83,8 +106,8 @@ class MaxStepsReached(Rule): | |||||||
|  |  | ||||||
|     def on_check_done(self, state): |     def on_check_done(self, state): | ||||||
|         if self.max_steps <= state.curr_step: |         if self.max_steps <= state.curr_step: | ||||||
|             return [DoneResult(validity=c.VALID, identifier=self.name, reward=0)] |             return [DoneResult(validity=c.VALID, identifier=self.name)] | ||||||
|         return [DoneResult(validity=c.NOT_VALID, identifier=self.name, reward=0)] |         return [DoneResult(validity=c.NOT_VALID, identifier=self.name)] | ||||||
|  |  | ||||||
|  |  | ||||||
| class AssignGlobalPositions(Rule): | class AssignGlobalPositions(Rule): | ||||||
| @@ -95,16 +118,17 @@ class AssignGlobalPositions(Rule): | |||||||
|     def on_init(self, state, lvl_map): |     def on_init(self, state, lvl_map): | ||||||
|         from marl_factory_grid.environment.entity.util import GlobalPosition |         from marl_factory_grid.environment.entity.util import GlobalPosition | ||||||
|         for agent in state[c.AGENT]: |         for agent in state[c.AGENT]: | ||||||
|             gp = GlobalPosition(lvl_map.level_shape) |             gp = GlobalPosition(agent, lvl_map.level_shape) | ||||||
|             gp.bind_to(agent) |  | ||||||
|             state[c.GLOBALPOSITIONS].add_item(gp) |             state[c.GLOBALPOSITIONS].add_item(gp) | ||||||
|         return [] |         return [] | ||||||
|  |  | ||||||
|  |  | ||||||
| class Collision(Rule): | class WatchCollisions(Rule): | ||||||
|  |  | ||||||
|     def __init__(self, done_at_collisions: bool = False): |     def __init__(self, reward=r.COLLISION, done_at_collisions: bool = False, reward_at_done=r.COLLISION_DONE): | ||||||
|         super().__init__() |         super().__init__() | ||||||
|  |         self.reward_at_done = reward_at_done | ||||||
|  |         self.reward = reward | ||||||
|         self.done_at_collisions = done_at_collisions |         self.done_at_collisions = done_at_collisions | ||||||
|         self.curr_done = False |         self.curr_done = False | ||||||
|  |  | ||||||
| @@ -117,12 +141,12 @@ class Collision(Rule): | |||||||
|             if len(guests) >= 2: |             if len(guests) >= 2: | ||||||
|                 for i, guest in enumerate(guests): |                 for i, guest in enumerate(guests): | ||||||
|                     try: |                     try: | ||||||
|                         guest.set_state(TickResult(identifier=c.COLLISION, reward=r.COLLISION, |                         guest.set_state(TickResult(identifier=c.COLLISION, reward=self.reward, | ||||||
|                                                    validity=c.NOT_VALID, entity=self)) |                                                    validity=c.NOT_VALID, entity=self)) | ||||||
|                     except AttributeError: |                     except AttributeError: | ||||||
|                         pass |                         pass | ||||||
|                     results.append(TickResult(entity=guest, identifier=c.COLLISION, |                     results.append(TickResult(entity=guest, identifier=c.COLLISION, | ||||||
|                                               reward=r.COLLISION, validity=c.VALID)) |                                               reward=self.reward, validity=c.VALID)) | ||||||
|                 self.curr_done = True if self.done_at_collisions else False |                 self.curr_done = True if self.done_at_collisions else False | ||||||
|         return results |         return results | ||||||
|  |  | ||||||
| @@ -131,5 +155,5 @@ class Collision(Rule): | |||||||
|             inter_entity_collision_detected = self.curr_done |             inter_entity_collision_detected = self.curr_done | ||||||
|             move_failed = any(h.is_move(x.state.identifier) and not x.state.validity for x in state[c.AGENT]) |             move_failed = any(h.is_move(x.state.identifier) and not x.state.validity for x in state[c.AGENT]) | ||||||
|             if inter_entity_collision_detected or move_failed: |             if inter_entity_collision_detected or move_failed: | ||||||
|                 return [DoneResult(validity=c.VALID, identifier=c.COLLISION, reward=r.COLLISION)] |                 return [DoneResult(validity=c.VALID, identifier=c.COLLISION, reward=self.reward_at_done)] | ||||||
|         return [DoneResult(validity=c.NOT_VALID, identifier=self.name, reward=0)] |         return [] | ||||||
|   | |||||||
| @@ -6,7 +6,9 @@ from marl_factory_grid.utils.results import TickResult, DoneResult | |||||||
| class TemplateRule(Rule): | class TemplateRule(Rule): | ||||||
|  |  | ||||||
|     def __init__(self, *args, **kwargs): |     def __init__(self, *args, **kwargs): | ||||||
|         super(TemplateRule, self).__init__(*args, **kwargs) |         super(TemplateRule, self).__init__() | ||||||
|  |         self.args = args | ||||||
|  |         self.kwargs = kwargs | ||||||
|  |  | ||||||
|     def on_init(self, state, lvl_map): |     def on_init(self, state, lvl_map): | ||||||
|         pass |         pass | ||||||
|   | |||||||
| @@ -1,4 +1,4 @@ | |||||||
| from .actions import BtryCharge | from .actions import BtryCharge | ||||||
| from .entitites import Pod, Battery | from .entitites import ChargePod, Battery | ||||||
| from .groups import ChargePods, Batteries | from .groups import ChargePods, Batteries | ||||||
| from .rules import DoneAtBatteryDischarge, BatteryDecharge | from .rules import DoneAtBatteryDischarge, BatteryDecharge | ||||||
|   | |||||||
| @@ -1,11 +1,11 @@ | |||||||
| from typing import Union | from typing import Union | ||||||
|  |  | ||||||
| import marl_factory_grid.modules.batteries.constants |  | ||||||
| from marl_factory_grid.environment.actions import Action | from marl_factory_grid.environment.actions import Action | ||||||
| from marl_factory_grid.utils.results import ActionResult | from marl_factory_grid.utils.results import ActionResult | ||||||
|  |  | ||||||
| from marl_factory_grid.modules.batteries import constants as b | from marl_factory_grid.modules.batteries import constants as b | ||||||
| from marl_factory_grid.environment import constants as c | from marl_factory_grid.environment import constants as c | ||||||
|  | from marl_factory_grid.utils import helpers as h | ||||||
|  |  | ||||||
|  |  | ||||||
| class BtryCharge(Action): | class BtryCharge(Action): | ||||||
| @@ -14,8 +14,8 @@ class BtryCharge(Action): | |||||||
|         super().__init__(b.ACTION_CHARGE) |         super().__init__(b.ACTION_CHARGE) | ||||||
|  |  | ||||||
|     def do(self, entity, state) -> Union[None, ActionResult]: |     def do(self, entity, state) -> Union[None, ActionResult]: | ||||||
|         if charge_pod := state[b.CHARGE_PODS].by_pos(entity.pos): |         if charge_pod := h.get_first(state[b.CHARGE_PODS].by_pos(entity.pos)): | ||||||
|             valid = charge_pod.charge_battery(state[b.BATTERIES].by_entity(entity)) |             valid = h.get_first(charge_pod.charge_battery(state[b.BATTERIES].by_entity(entity))) | ||||||
|             if valid: |             if valid: | ||||||
|                 state.print(f'{entity.name} just charged batteries at {charge_pod.name}.') |                 state.print(f'{entity.name} just charged batteries at {charge_pod.name}.') | ||||||
|             else: |             else: | ||||||
| @@ -23,5 +23,6 @@ class BtryCharge(Action): | |||||||
|         else: |         else: | ||||||
|             valid = c.NOT_VALID |             valid = c.NOT_VALID | ||||||
|             state.print(f'{entity.name} failed to charged batteries at {entity.pos}.') |             state.print(f'{entity.name} failed to charged batteries at {entity.pos}.') | ||||||
|  |  | ||||||
|         return ActionResult(entity=entity, identifier=self._identifier, validity=valid, |         return ActionResult(entity=entity, identifier=self._identifier, validity=valid, | ||||||
|                             reward=marl_factory_grid.modules.batteries.constants.REWARD_CHARGE_VALID if valid else marl_factory_grid.modules.batteries.constants.Reward_CHARGE_FAIL) |                             reward=b.REWARD_CHARGE_VALID if valid else b.Reward_CHARGE_FAIL) | ||||||
|   | |||||||
							
								
								
									
										
											BIN
										
									
								
								marl_factory_grid/modules/batteries/chargepods.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								marl_factory_grid/modules/batteries/chargepods.png
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| After Width: | Height: | Size: 7.9 KiB | 
| @@ -1,11 +1,11 @@ | |||||||
| from marl_factory_grid.environment import constants as c | from marl_factory_grid.environment import constants as c | ||||||
| from marl_factory_grid.environment.entity.entity import Entity | from marl_factory_grid.environment.entity.entity import Entity | ||||||
| from marl_factory_grid.environment.entity.object import _Object | from marl_factory_grid.environment.entity.object import Object | ||||||
| from marl_factory_grid.modules.batteries import constants as b | from marl_factory_grid.modules.batteries import constants as b | ||||||
| from marl_factory_grid.utils.utility_classes import RenderEntity | from marl_factory_grid.utils.utility_classes import RenderEntity | ||||||
|  |  | ||||||
|  |  | ||||||
| class Battery(_Object): | class Battery(Object): | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def var_can_be_bound(self): |     def var_can_be_bound(self): | ||||||
| @@ -50,7 +50,7 @@ class Battery(_Object): | |||||||
|         return summary |         return summary | ||||||
|  |  | ||||||
|  |  | ||||||
| class Pod(Entity): | class ChargePod(Entity): | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def encoding(self): |     def encoding(self): | ||||||
| @@ -58,7 +58,7 @@ class Pod(Entity): | |||||||
|  |  | ||||||
|     def __init__(self, *args, charge_rate: float = 0.4, |     def __init__(self, *args, charge_rate: float = 0.4, | ||||||
|                  multi_charge: bool = False, **kwargs): |                  multi_charge: bool = False, **kwargs): | ||||||
|         super(Pod, self).__init__(*args, **kwargs) |         super(ChargePod, self).__init__(*args, **kwargs) | ||||||
|         self.charge_rate = charge_rate |         self.charge_rate = charge_rate | ||||||
|         self.multi_charge = multi_charge |         self.multi_charge = multi_charge | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,52 +1,36 @@ | |||||||
| from typing import Union, List, Tuple | from typing import Union, List, Tuple | ||||||
|  |  | ||||||
|  | from marl_factory_grid.environment import constants as c | ||||||
| from marl_factory_grid.environment.groups.collection import Collection | from marl_factory_grid.environment.groups.collection import Collection | ||||||
| from marl_factory_grid.modules.batteries.entitites import Pod, Battery | from marl_factory_grid.modules.batteries.entitites import ChargePod, Battery | ||||||
|  | from marl_factory_grid.utils.results import Result | ||||||
|  |  | ||||||
|  |  | ||||||
| class Batteries(Collection): | class Batteries(Collection): | ||||||
|     _entity = Battery |     _entity = Battery | ||||||
|  |  | ||||||
|     @property |     var_has_position = False | ||||||
|     def var_is_blocking_light(self): |     var_can_be_bound = True | ||||||
|         return False |  | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def var_can_collide(self): |  | ||||||
|         return False |  | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def var_can_move(self): |  | ||||||
|         return False |  | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def var_has_position(self): |  | ||||||
|         return False |  | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def var_can_be_bound(self): |  | ||||||
|         return True |  | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def obs_tag(self): |     def obs_tag(self): | ||||||
|         return self.__class__.__name__ |         return self.__class__.__name__ | ||||||
|  |  | ||||||
|     def __init__(self, *args, **kwargs): |     def __init__(self, size, initial_charge_level: float=1.0, *args, **kwargs): | ||||||
|         super(Batteries, self).__init__(*args, **kwargs) |         super(Batteries, self).__init__(size, *args, **kwargs) | ||||||
|  |         self.initial_charge_level = initial_charge_level | ||||||
|  |  | ||||||
|     def spawn(self, agents, initial_charge_level): |     def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], agents, *entity_args, **entity_kwargs): | ||||||
|         batteries = [self._entity(initial_charge_level, agent) for _, agent in enumerate(agents)] |         batteries = [self._entity(self.initial_charge_level, agent) for _, agent in enumerate(agents)] | ||||||
|         self.add_items(batteries) |         self.add_items(batteries) | ||||||
|  |  | ||||||
|     # def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args):           hat keine pos |     def trigger_spawn(self, state, *entity_args, coords_or_quantity=None,  **entity_kwargs): | ||||||
|     #     agents = entity_args[0] |         self.spawn(0, state[c.AGENT]) | ||||||
|     #     initial_charge_level = entity_args[1] |         return Result(identifier=f'{self.name}_spawn', validity=c.VALID, value=len(self)) | ||||||
|     #     batteries = [self._entity(initial_charge_level, agent) for _, agent in enumerate(agents)] |  | ||||||
|     #     self.add_items(batteries) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class ChargePods(Collection): | class ChargePods(Collection): | ||||||
|     _entity = Pod |     _entity = ChargePod | ||||||
|  |  | ||||||
|     def __init__(self, *args, **kwargs): |     def __init__(self, *args, **kwargs): | ||||||
|         super(ChargePods, self).__init__(*args, **kwargs) |         super(ChargePods, self).__init__(*args, **kwargs) | ||||||
|   | |||||||
| @@ -1,11 +1,9 @@ | |||||||
| from typing import List, Union | from typing import List, Union | ||||||
|  |  | ||||||
| import marl_factory_grid.modules.batteries.constants |  | ||||||
| from marl_factory_grid.environment.rules import Rule |  | ||||||
| from marl_factory_grid.utils.results import TickResult, DoneResult |  | ||||||
|  |  | ||||||
| from marl_factory_grid.environment import constants as c | from marl_factory_grid.environment import constants as c | ||||||
|  | from marl_factory_grid.environment.rules import Rule | ||||||
| from marl_factory_grid.modules.batteries import constants as b | from marl_factory_grid.modules.batteries import constants as b | ||||||
|  | from marl_factory_grid.utils.results import TickResult, DoneResult | ||||||
|  |  | ||||||
|  |  | ||||||
| class BatteryDecharge(Rule): | class BatteryDecharge(Rule): | ||||||
| @@ -49,10 +47,6 @@ class BatteryDecharge(Rule): | |||||||
|         self.per_action_costs = per_action_costs |         self.per_action_costs = per_action_costs | ||||||
|         self.initial_charge = initial_charge |         self.initial_charge = initial_charge | ||||||
|  |  | ||||||
|     def on_init(self, state, lvl_map):  # on reset? |  | ||||||
|         assert len(state[c.AGENT]), "There are no agents, did you already spawn them?" |  | ||||||
|         state[b.BATTERIES].spawn(state[c.AGENT], self.initial_charge) |  | ||||||
|  |  | ||||||
|     def tick_step(self, state) -> List[TickResult]: |     def tick_step(self, state) -> List[TickResult]: | ||||||
|         # Decharge |         # Decharge | ||||||
|         batteries = state[b.BATTERIES] |         batteries = state[b.BATTERIES] | ||||||
| @@ -66,7 +60,7 @@ class BatteryDecharge(Rule): | |||||||
|  |  | ||||||
|             batteries.by_entity(agent).decharge(energy_consumption) |             batteries.by_entity(agent).decharge(energy_consumption) | ||||||
|  |  | ||||||
|             results.append(TickResult(self.name, reward=0, entity=agent, validity=c.VALID)) |             results.append(TickResult(self.name, entity=agent, validity=c.VALID)) | ||||||
|  |  | ||||||
|         return results |         return results | ||||||
|  |  | ||||||
| @@ -82,13 +76,13 @@ class BatteryDecharge(Rule): | |||||||
|                 if self.paralyze_agents_on_discharge: |                 if self.paralyze_agents_on_discharge: | ||||||
|                     btry.bound_entity.paralyze(self.name) |                     btry.bound_entity.paralyze(self.name) | ||||||
|                     results.append( |                     results.append( | ||||||
|                         TickResult("Paralyzed", entity=btry.bound_entity, reward=0, validity=c.VALID) |                         TickResult("Paralyzed", entity=btry.bound_entity, validity=c.VALID) | ||||||
|                     ) |                     ) | ||||||
|                     state.print(f'{btry.bound_entity.name} has just been paralyzed!') |                     state.print(f'{btry.bound_entity.name} has just been paralyzed!') | ||||||
|             if btry.bound_entity.var_is_paralyzed and not btry.is_discharged: |             if btry.bound_entity.var_is_paralyzed and not btry.is_discharged: | ||||||
|                 btry.bound_entity.de_paralyze(self.name) |                 btry.bound_entity.de_paralyze(self.name) | ||||||
|                 results.append( |                 results.append( | ||||||
|                     TickResult("De-Paralyzed", entity=btry.bound_entity, reward=0, validity=c.VALID) |                     TickResult("De-Paralyzed", entity=btry.bound_entity, validity=c.VALID) | ||||||
|                 ) |                 ) | ||||||
|                 state.print(f'{btry.bound_entity.name} has just been de-paralyzed!') |                 state.print(f'{btry.bound_entity.name} has just been de-paralyzed!') | ||||||
|         return results |         return results | ||||||
| @@ -132,7 +126,7 @@ class DoneAtBatteryDischarge(BatteryDecharge): | |||||||
|         if any_discharged or all_discharged: |         if any_discharged or all_discharged: | ||||||
|             return [DoneResult(self.name, validity=c.VALID, reward=self.reward_discharge_done)] |             return [DoneResult(self.name, validity=c.VALID, reward=self.reward_discharge_done)] | ||||||
|         else: |         else: | ||||||
|             return [DoneResult(self.name, validity=c.NOT_VALID, reward=0)] |             return [DoneResult(self.name, validity=c.NOT_VALID)] | ||||||
|  |  | ||||||
|  |  | ||||||
| class SpawnChargePods(Rule): | class SpawnChargePods(Rule): | ||||||
| @@ -155,7 +149,7 @@ class SpawnChargePods(Rule): | |||||||
|  |  | ||||||
|     def on_init(self, state, lvl_map): |     def on_init(self, state, lvl_map): | ||||||
|         pod_collection = state[b.CHARGE_PODS] |         pod_collection = state[b.CHARGE_PODS] | ||||||
|         empty_positions = state.entities.empty_positions() |         empty_positions = state.entities.empty_positions | ||||||
|         pods = pod_collection.from_coordinates(empty_positions, entity_kwargs=dict( |         pods = pod_collection.from_coordinates(empty_positions, entity_kwargs=dict( | ||||||
|             multi_charge=self.multi_charge, charge_rate=self.charge_rate) |             multi_charge=self.multi_charge, charge_rate=self.charge_rate) | ||||||
|                                                ) |                                                ) | ||||||
|   | |||||||
| @@ -1,4 +1,4 @@ | |||||||
| from .actions import CleanUp | from .actions import CleanUp | ||||||
| from .entitites import DirtPile | from .entitites import DirtPile | ||||||
| from .groups import DirtPiles | from .groups import DirtPiles | ||||||
| from .rules import SpawnDirt, EntitiesSmearDirtOnMove, DoneOnAllDirtCleaned | from .rules import EntitiesSmearDirtOnMove, DoneOnAllDirtCleaned | ||||||
|   | |||||||
| @@ -1,5 +1,3 @@ | |||||||
| from numpy import random |  | ||||||
|  |  | ||||||
| from marl_factory_grid.environment.entity.entity import Entity | from marl_factory_grid.environment.entity.entity import Entity | ||||||
| from marl_factory_grid.utils.utility_classes import RenderEntity | from marl_factory_grid.utils.utility_classes import RenderEntity | ||||||
| from marl_factory_grid.modules.clean_up import constants as d | from marl_factory_grid.modules.clean_up import constants as d | ||||||
| @@ -7,22 +5,6 @@ from marl_factory_grid.modules.clean_up import constants as d | |||||||
|  |  | ||||||
| class DirtPile(Entity): | class DirtPile(Entity): | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def var_can_collide(self): |  | ||||||
|         return False |  | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def var_can_move(self): |  | ||||||
|         return False |  | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def var_is_blocking_light(self): |  | ||||||
|         return False |  | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def var_has_position(self): |  | ||||||
|         return True |  | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def amount(self): |     def amount(self): | ||||||
|         return self._amount |         return self._amount | ||||||
|   | |||||||
| @@ -1,76 +1,61 @@ | |||||||
| from typing import Union, List, Tuple |  | ||||||
|  |  | ||||||
| from marl_factory_grid.environment import constants as c | from marl_factory_grid.environment import constants as c | ||||||
| from marl_factory_grid.utils.results import Result |  | ||||||
| from marl_factory_grid.environment.groups.collection import Collection | from marl_factory_grid.environment.groups.collection import Collection | ||||||
| from marl_factory_grid.modules.clean_up.entitites import DirtPile | from marl_factory_grid.modules.clean_up.entitites import DirtPile | ||||||
|  | from marl_factory_grid.utils.results import Result | ||||||
|  |  | ||||||
|  |  | ||||||
| class DirtPiles(Collection): | class DirtPiles(Collection): | ||||||
|     _entity = DirtPile |     _entity = DirtPile | ||||||
|  |  | ||||||
|     @property |     var_is_blocking_light = False | ||||||
|     def var_is_blocking_light(self): |     var_can_collide = False | ||||||
|         return False |     var_can_move = False | ||||||
|  |     var_has_position = True | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def var_can_collide(self): |     def global_amount(self): | ||||||
|         return False |  | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def var_can_move(self): |  | ||||||
|         return False |  | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def var_has_position(self): |  | ||||||
|         return True |  | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def amount(self): |  | ||||||
|         return sum([dirt.amount for dirt in self]) |         return sum([dirt.amount for dirt in self]) | ||||||
|  |  | ||||||
|     def __init__(self, *args, |     def __init__(self, *args, | ||||||
|                  max_local_amount=5, |                  max_local_amount=5, | ||||||
|                  clean_amount=1, |                  clean_amount=1, | ||||||
|                  max_global_amount: int = 20, **kwargs): |                  max_global_amount: int = 20, | ||||||
|  |                  coords_or_quantity=10, | ||||||
|  |                  initial_amount=2, | ||||||
|  |                  amount_var=0.2, | ||||||
|  |                  n_var=0.2, | ||||||
|  |                  **kwargs): | ||||||
|         super(DirtPiles, self).__init__(*args, **kwargs) |         super(DirtPiles, self).__init__(*args, **kwargs) | ||||||
|  |         self.amount_var = amount_var | ||||||
|  |         self.n_var = n_var | ||||||
|         self.clean_amount = clean_amount |         self.clean_amount = clean_amount | ||||||
|         self.max_global_amount = max_global_amount |         self.max_global_amount = max_global_amount | ||||||
|         self.max_local_amount = max_local_amount |         self.max_local_amount = max_local_amount | ||||||
|  |         self.coords_or_quantity = coords_or_quantity | ||||||
|  |         self.initial_amount = initial_amount | ||||||
|  |  | ||||||
|     def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args): |     def trigger_spawn(self, state, coords_or_quantity=0, amount=0) -> [Result]: | ||||||
|         amount_s = entity_args[0] |         coords_or_quantity = coords_or_quantity if coords_or_quantity else self.coords_or_quantity | ||||||
|  |         n_new = int(abs(coords_or_quantity + (state.rng.uniform(-self.n_var, self.n_var)))) | ||||||
|  |         n_new = state.get_n_random_free_positions(n_new) | ||||||
|  |  | ||||||
|  |         amounts = [amount if amount else (self.initial_amount + state.rng.uniform(-self.amount_var, self.amount_var)) | ||||||
|  |                    for _ in range(coords_or_quantity)] | ||||||
|         spawn_counter = 0 |         spawn_counter = 0 | ||||||
|         for idx, pos in enumerate(coords_or_quantity): |         for idx, (pos, a) in enumerate(zip(n_new, amounts)): | ||||||
|             if not self.amount > self.max_global_amount: |             if not self.global_amount > self.max_global_amount: | ||||||
|                 amount = amount_s[idx] if isinstance(amount_s, list) else amount_s |  | ||||||
|                 if dirt := self.by_pos(pos): |                 if dirt := self.by_pos(pos): | ||||||
|                     dirt = next(dirt.iter()) |                     dirt = next(dirt.iter()) | ||||||
|                     new_value = dirt.amount + amount |                     new_value = dirt.amount + a | ||||||
|                     dirt.set_new_amount(new_value) |                     dirt.set_new_amount(new_value) | ||||||
|                 else: |                 else: | ||||||
|                     dirt = DirtPile(pos, amount=amount) |                     super().spawn([pos], amount=a) | ||||||
|                     self.add_item(dirt) |  | ||||||
|                     spawn_counter += 1 |                     spawn_counter += 1 | ||||||
|             else: |             else: | ||||||
|                 return Result(identifier=f'{self.name}_spawn', validity=c.NOT_VALID, reward=0, |                 return Result(identifier=f'{self.name}_spawn', validity=c.NOT_VALID, value=spawn_counter) | ||||||
|                               value=spawn_counter) |  | ||||||
|         return Result(identifier=f'{self.name}_spawn', validity=c.VALID, reward=0, value=spawn_counter) |  | ||||||
|  |  | ||||||
|     def trigger_dirt_spawn(self, n, amount, state, n_var=0.2, amount_var=0.2) -> Result: |         return Result(identifier=f'{self.name}_spawn', validity=c.VALID, value=spawn_counter) | ||||||
|         free_for_dirt = [x for x in state.entities.floorlist if len(state.entities.pos_dict[x]) == 0 or ( |  | ||||||
|                 len(state.entities.pos_dict[x]) >= 1 and isinstance(next(y for y in x), DirtPile))] |  | ||||||
|         # free_for_dirt = [x for x in state[c.FLOOR] |  | ||||||
|         #                  if len(x.guests) == 0 or ( |  | ||||||
|         #                          len(x.guests) == 1 and |  | ||||||
|         #                          isinstance(next(y for y in x.guests), DirtPile))] |  | ||||||
|         state.rng.shuffle(free_for_dirt) |  | ||||||
|  |  | ||||||
|         new_spawn = int(abs(n + (state.rng.uniform(-n_var, n_var)))) |  | ||||||
|         new_amount_s = [abs(amount + (amount*state.rng.uniform(-amount_var, amount_var))) for _ in range(new_spawn)] |  | ||||||
|         n_dirty_positions = free_for_dirt[:new_spawn] |  | ||||||
|         return self.spawn(n_dirty_positions, new_amount_s) |  | ||||||
|  |  | ||||||
|     def __repr__(self): |     def __repr__(self): | ||||||
|         s = super(DirtPiles, self).__repr__() |         s = super(DirtPiles, self).__repr__() | ||||||
|         return f'{s[:-1]}, {self.amount})' |         return f'{s[:-1]}, {self.global_amount}]' | ||||||
|   | |||||||
| @@ -22,58 +22,37 @@ class DoneOnAllDirtCleaned(Rule): | |||||||
|     def on_check_done(self, state) -> [DoneResult]: |     def on_check_done(self, state) -> [DoneResult]: | ||||||
|         if len(state[d.DIRT]) == 0 and state.curr_step: |         if len(state[d.DIRT]) == 0 and state.curr_step: | ||||||
|             return [DoneResult(validity=c.VALID, identifier=self.name, reward=self.reward)] |             return [DoneResult(validity=c.VALID, identifier=self.name, reward=self.reward)] | ||||||
|         return [DoneResult(validity=c.NOT_VALID, identifier=self.name, reward=0)] |         return [DoneResult(validity=c.NOT_VALID, identifier=self.name)] | ||||||
|  |  | ||||||
|  |  | ||||||
| class SpawnDirt(Rule): | class RespawnDirt(Rule): | ||||||
|  |  | ||||||
|     def __init__(self, initial_n: int = 5, initial_amount: float = 1.3, |     def __init__(self, respawn_freq: int = 15, respawn_n: int = 5, respawn_amount: float = 1.0): | ||||||
|                  respawn_n: int = 3, respawn_amount: float = 0.8, |  | ||||||
|                  n_var: float = 0.2, amount_var: float = 0.2, spawn_freq: int = 15): |  | ||||||
|         """ |         """ | ||||||
|         Defines the spawn pattern of intial and additional 'Dirt'-entitites. |         Defines the spawn pattern of intial and additional 'Dirt'-entitites. | ||||||
|         First chooses positions, then trys to spawn dirt until 'respawn_n' or the maximal global amount is reached. |         First chooses positions, then trys to spawn dirt until 'respawn_n' or the maximal global amount is reached. | ||||||
|         If there is allready some, it is topped up to min(max_local_amount, amount). |         If there is allready some, it is topped up to min(max_local_amount, amount). | ||||||
|  |  | ||||||
|         :type spawn_freq: int |         :type respawn_freq: int | ||||||
|         :parameter spawn_freq: In which frequency should this Rule try to spawn new 'Dirt'? |         :parameter respawn_freq: In which frequency should this Rule try to spawn new 'Dirt'? | ||||||
|         :type respawn_n: int |         :type respawn_n: int | ||||||
|         :parameter respawn_n: How many respawn positions are considered. |         :parameter respawn_n: How many respawn positions are considered. | ||||||
|         :type initial_n: int |  | ||||||
|         :parameter initial_n: How much initial positions are considered. |  | ||||||
|         :type amount_var: float |  | ||||||
|         :parameter amount_var: Variance of amount to spawn. |  | ||||||
|         :type n_var: float |  | ||||||
|         :parameter n_var: Variance of n to spawn. |  | ||||||
|         :type respawn_amount: float |         :type respawn_amount: float | ||||||
|         :parameter respawn_amount: Defines how much dirt 'amount' is placed every 'spawn_freq' ticks. |         :parameter respawn_amount: Defines how much dirt 'amount' is placed every 'spawn_freq' ticks. | ||||||
|         :type initial_amount: float |  | ||||||
|         :parameter initial_amount: Defines how much dirt 'amount' is initially placed. |  | ||||||
|  |  | ||||||
|         """ |         """ | ||||||
|         super().__init__() |         super().__init__() | ||||||
|         self.amount_var = amount_var |  | ||||||
|         self.n_var = n_var |  | ||||||
|         self.respawn_amount = respawn_amount |  | ||||||
|         self.respawn_n = respawn_n |         self.respawn_n = respawn_n | ||||||
|         self.initial_amount = initial_amount |         self.respawn_amount = respawn_amount | ||||||
|         self.initial_n = initial_n |         self.respawn_freq = respawn_freq | ||||||
|         self.spawn_freq = spawn_freq |         self._next_dirt_spawn = respawn_freq | ||||||
|         self._next_dirt_spawn = spawn_freq |  | ||||||
|  |  | ||||||
|     def on_init(self, state, lvl_map) -> str: |  | ||||||
|         result = state[d.DIRT].trigger_dirt_spawn(self.initial_n, self.initial_amount, state, |  | ||||||
|                                                   n_var=self.n_var, amount_var=self.amount_var) |  | ||||||
|         state.print(f'Initial Dirt was spawned on: {[x.pos for x in state[d.DIRT]]}') |  | ||||||
|         return result |  | ||||||
|  |  | ||||||
|     def tick_step(self, state): |     def tick_step(self, state): | ||||||
|  |         collection = state[d.DIRT] | ||||||
|         if self._next_dirt_spawn < 0: |         if self._next_dirt_spawn < 0: | ||||||
|             pass  # No DirtPile Spawn |             result = []  # No DirtPile Spawn | ||||||
|         elif not self._next_dirt_spawn: |         elif not self._next_dirt_spawn: | ||||||
|             result = [state[d.DIRT].trigger_dirt_spawn(self.respawn_n, self.respawn_amount, state, |             result = [collection.trigger_spawn(state, coords_or_quantity=self.respawn_n, amount=self.respawn_amount)] | ||||||
|                                                        n_var=self.n_var, amount_var=self.amount_var)] |             self._next_dirt_spawn = self.respawn_freq | ||||||
|             self._next_dirt_spawn = self.spawn_freq |  | ||||||
|         else: |         else: | ||||||
|             self._next_dirt_spawn -= 1 |             self._next_dirt_spawn -= 1 | ||||||
|             result = [] |             result = [] | ||||||
| @@ -99,8 +78,8 @@ class EntitiesSmearDirtOnMove(Rule): | |||||||
|         for entity in state.moving_entites: |         for entity in state.moving_entites: | ||||||
|             if is_move(entity.state.identifier) and entity.state.validity == c.VALID: |             if is_move(entity.state.identifier) and entity.state.validity == c.VALID: | ||||||
|                 if old_pos_dirt := state[d.DIRT].by_pos(entity.last_pos): |                 if old_pos_dirt := state[d.DIRT].by_pos(entity.last_pos): | ||||||
|  |                     old_pos_dirt = next(iter(old_pos_dirt)) | ||||||
|                     if smeared_dirt := round(old_pos_dirt.amount * self.smear_ratio, 2): |                     if smeared_dirt := round(old_pos_dirt.amount * self.smear_ratio, 2): | ||||||
|                         if state[d.DIRT].spawn(entity.pos, amount=smeared_dirt): |                         if state[d.DIRT].spawn(entity.pos, amount=smeared_dirt): | ||||||
|                             results.append(TickResult(identifier=self.name, entity=entity, |                             results.append(TickResult(identifier=self.name, entity=entity, validity=c.VALID)) | ||||||
|                                                       reward=0, validity=c.VALID)) |  | ||||||
|         return results |         return results | ||||||
|   | |||||||
| @@ -1,4 +1,7 @@ | |||||||
| from .actions import DestAction | from .actions import DestAction | ||||||
| from .entitites import Destination | from .entitites import Destination | ||||||
| from .groups import Destinations | from .groups import Destinations | ||||||
| from .rules import DoneAtDestinationReachAll, SpawnDestinations | from .rules import (DoneAtDestinationReachAll, | ||||||
|  |                     DoneAtDestinationReachAny, | ||||||
|  |                     SpawnDestinationsPerAgent, | ||||||
|  |                     DestinationReachReward) | ||||||
|   | |||||||
| @@ -21,4 +21,4 @@ class DestAction(Action): | |||||||
|             valid = c.NOT_VALID |             valid = c.NOT_VALID | ||||||
|             state.print(f'{entity.name} just tried to do_wait_action do_wait_action at {entity.pos} but failed') |             state.print(f'{entity.name} just tried to do_wait_action do_wait_action at {entity.pos} but failed') | ||||||
|         return ActionResult(entity=entity, identifier=self._identifier, validity=valid, |         return ActionResult(entity=entity, identifier=self._identifier, validity=valid, | ||||||
|                             reward=marl_factory_grid.modules.destinations.constants.REWARD_WAIT_VALID if valid else marl_factory_grid.modules.destinations.constants.REWARD_WAIT_FAIL) |                             reward=d.REWARD_WAIT_VALID if valid else d.REWARD_WAIT_FAIL) | ||||||
|   | |||||||
| @@ -9,30 +9,6 @@ from marl_factory_grid.utils.utility_classes import RenderEntity | |||||||
|  |  | ||||||
| class Destination(Entity): | class Destination(Entity): | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def var_can_move(self): |  | ||||||
|         return False |  | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def var_can_collide(self): |  | ||||||
|         return False |  | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def var_has_position(self): |  | ||||||
|         return True |  | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def var_is_blocking_pos(self): |  | ||||||
|         return False |  | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def var_is_blocking_light(self): |  | ||||||
|         return False |  | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def var_can_be_bound(self): |  | ||||||
|         return True |  | ||||||
|  |  | ||||||
|     def was_reached(self): |     def was_reached(self): | ||||||
|         return self._was_reached |         return self._was_reached | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,43 +1,18 @@ | |||||||
| from marl_factory_grid.environment.groups.collection import Collection | from marl_factory_grid.environment.groups.collection import Collection | ||||||
| from marl_factory_grid.modules.destinations.entitites import Destination | from marl_factory_grid.modules.destinations.entitites import Destination | ||||||
| from marl_factory_grid.environment import constants as c |  | ||||||
| from marl_factory_grid.modules.destinations import constants as d |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class Destinations(Collection): | class Destinations(Collection): | ||||||
|     _entity = Destination |     _entity = Destination | ||||||
|  |  | ||||||
|     @property |     var_is_blocking_light = False | ||||||
|     def var_is_blocking_light(self): |     var_can_collide = False | ||||||
|         return False |     var_can_move = False | ||||||
|  |     var_has_position = True | ||||||
|     @property |     var_can_be_bound = True | ||||||
|     def var_can_collide(self): |  | ||||||
|         return False |  | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def var_can_move(self): |  | ||||||
|         return False |  | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def var_has_position(self): |  | ||||||
|         return True |  | ||||||
|  |  | ||||||
|     def __init__(self, *args, **kwargs): |     def __init__(self, *args, **kwargs): | ||||||
|         super().__init__(*args, **kwargs) |         super().__init__(*args, **kwargs) | ||||||
|  |  | ||||||
|     def __repr__(self): |     def __repr__(self): | ||||||
|         return super(Destinations, self).__repr__() |         return super(Destinations, self).__repr__() | ||||||
|  |  | ||||||
|     @staticmethod |  | ||||||
|     def trigger_destination_spawn(n_dests, state): |  | ||||||
|         coordinates = state.entities.floorlist[:n_dests] |  | ||||||
|         if destinations := [Destination(pos) for pos in coordinates]: |  | ||||||
|             state[d.DESTINATION].add_items(destinations) |  | ||||||
|             state.print(f'{n_dests} new destinations have been spawned') |  | ||||||
|             return c.VALID |  | ||||||
|         else: |  | ||||||
|             state.print('No Destiantions are spawning, limit is reached.') |  | ||||||
|             return c.NOT_VALID |  | ||||||
|  |  | ||||||
|  |  | ||||||
|   | |||||||
| @@ -2,8 +2,8 @@ import ast | |||||||
| from random import shuffle | from random import shuffle | ||||||
| from typing import List, Dict, Tuple | from typing import List, Dict, Tuple | ||||||
|  |  | ||||||
| import marl_factory_grid.modules.destinations.constants |  | ||||||
| from marl_factory_grid.environment.rules import Rule | from marl_factory_grid.environment.rules import Rule | ||||||
|  | from marl_factory_grid.utils import helpers as h | ||||||
| from marl_factory_grid.utils.results import TickResult, DoneResult | from marl_factory_grid.utils.results import TickResult, DoneResult | ||||||
| from marl_factory_grid.environment import constants as c | from marl_factory_grid.environment import constants as c | ||||||
|  |  | ||||||
| @@ -54,7 +54,7 @@ class DoneAtDestinationReachAll(DestinationReachReward): | |||||||
|         """ |         """ | ||||||
|         This rule triggers and sets the done flag if ALL Destinations have been reached. |         This rule triggers and sets the done flag if ALL Destinations have been reached. | ||||||
|  |  | ||||||
|         :type reward_at_done: object |         :type reward_at_done: float | ||||||
|         :param reward_at_done: Specifies the reward, agent get, whenn all destinations are reached. |         :param reward_at_done: Specifies the reward, agent get, whenn all destinations are reached. | ||||||
|         :type dest_reach_reward: float |         :type dest_reach_reward: float | ||||||
|         :param dest_reach_reward: Specify the reward, agents get when reaching a single destination. |         :param dest_reach_reward: Specify the reward, agents get when reaching a single destination. | ||||||
| @@ -65,7 +65,7 @@ class DoneAtDestinationReachAll(DestinationReachReward): | |||||||
|     def on_check_done(self, state) -> List[DoneResult]: |     def on_check_done(self, state) -> List[DoneResult]: | ||||||
|         if all(x.was_reached() for x in state[d.DESTINATION]): |         if all(x.was_reached() for x in state[d.DESTINATION]): | ||||||
|             return [DoneResult(self.name, validity=c.VALID, reward=self.reward)] |             return [DoneResult(self.name, validity=c.VALID, reward=self.reward)] | ||||||
|         return [DoneResult(self.name, validity=c.NOT_VALID, reward=0)] |         return [DoneResult(self.name, validity=c.NOT_VALID)] | ||||||
|  |  | ||||||
|  |  | ||||||
| class DoneAtDestinationReachAny(DestinationReachReward): | class DoneAtDestinationReachAny(DestinationReachReward): | ||||||
| @@ -75,7 +75,7 @@ class DoneAtDestinationReachAny(DestinationReachReward): | |||||||
|         This rule triggers and sets the done flag if ANY Destinations has been reached. |         This rule triggers and sets the done flag if ANY Destinations has been reached. | ||||||
|         !!! IMPORTANT: 'reward_at_done' is shared between the agents; 'dest_reach_reward' is bound to a specific one. |         !!! IMPORTANT: 'reward_at_done' is shared between the agents; 'dest_reach_reward' is bound to a specific one. | ||||||
|                  |                  | ||||||
|         :type reward_at_done: object |         :type reward_at_done: float | ||||||
|         :param reward_at_done: Specifies the reward, all agent get, when any destinations has been reached.  |         :param reward_at_done: Specifies the reward, all agent get, when any destinations has been reached.  | ||||||
|                                 Default {d.REWARD_DEST_DONE} |                                 Default {d.REWARD_DEST_DONE} | ||||||
|         :type dest_reach_reward: float |         :type dest_reach_reward: float | ||||||
| @@ -87,67 +87,29 @@ class DoneAtDestinationReachAny(DestinationReachReward): | |||||||
|  |  | ||||||
|     def on_check_done(self, state) -> List[DoneResult]: |     def on_check_done(self, state) -> List[DoneResult]: | ||||||
|         if any(x.was_reached() for x in state[d.DESTINATION]): |         if any(x.was_reached() for x in state[d.DESTINATION]): | ||||||
|             return [DoneResult(self.name, validity=c.VALID, reward=marl_factory_grid.modules.destinations.constants.REWARD_DEST_REACHED)] |             return [DoneResult(self.name, validity=c.VALID, reward=d.REWARD_DEST_REACHED)] | ||||||
|         return [] |         return [] | ||||||
|  |  | ||||||
|  |  | ||||||
| class SpawnDestinations(Rule): |  | ||||||
|  |  | ||||||
|     def __init__(self, n_dests: int = 1, spawn_mode: str = d.MODE_GROUPED): |  | ||||||
|         f""" |  | ||||||
|         Defines how destinations are initially spawned and respawned in addition. |  | ||||||
|         !!! This rule introduces no kind of reward or Env.-Done condition! |  | ||||||
|                  |  | ||||||
|         :type n_dests: int |  | ||||||
|         :param n_dests: How many destiantions should be maintained (and initally spawnewd) on the map? |  | ||||||
|         :type spawn_mode: str  |  | ||||||
|         :param spawn_mode: One of {d.SPAWN_MODES}. {d.MODE_GROUPED}: Always wait for all Dstiantions do be gone,  |  | ||||||
|                            then respawn after the given time. {d.MODE_SINGLE}: Just spawn every destination,  |  | ||||||
|                            that has been reached, after the given time |  | ||||||
|                              |  | ||||||
|         """ |  | ||||||
|         super(SpawnDestinations, self).__init__() |  | ||||||
|         self.n_dests = n_dests |  | ||||||
|         self.spawn_mode = spawn_mode |  | ||||||
|  |  | ||||||
|     def on_init(self, state, lvl_map): |  | ||||||
|         # noinspection PyAttributeOutsideInit |  | ||||||
|         state[d.DESTINATION].trigger_destination_spawn(self.n_dests, state) |  | ||||||
|         pass |  | ||||||
|  |  | ||||||
|     def tick_pre_step(self, state) -> List[TickResult]: |  | ||||||
|         pass |  | ||||||
|  |  | ||||||
|     def tick_step(self, state) -> List[TickResult]: |  | ||||||
|         if n_dest_spawn := max(0, self.n_dests - len(state[d.DESTINATION])): |  | ||||||
|             if self.spawn_mode == d.MODE_GROUPED and n_dest_spawn == self.n_dests: |  | ||||||
|                 validity = state[d.DESTINATION].trigger_destination_spawn(n_dest_spawn, state) |  | ||||||
|                 return [TickResult(self.name, validity=validity, entity=None, value=n_dest_spawn)] |  | ||||||
|             elif self.spawn_mode == d.MODE_SINGLE and n_dest_spawn: |  | ||||||
|                 validity = state[d.DESTINATION].trigger_destination_spawn(n_dest_spawn, state) |  | ||||||
|                 return [TickResult(self.name, validity=validity, entity=None, value=n_dest_spawn)] |  | ||||||
|             else: |  | ||||||
|                 pass |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class SpawnDestinationsPerAgent(Rule): | class SpawnDestinationsPerAgent(Rule): | ||||||
|     def __init__(self, per_agent_positions: Dict[str, List[Tuple[int, int]]]): |     def __init__(self, coords_or_quantity: Dict[str, List[Tuple[int, int]]]): | ||||||
|         """ |         """ | ||||||
|         Special rule, that spawn distinations, that are bound to a single agent a fixed set of positions. |         Special rule, that spawn distinations, that are bound to a single agent a fixed set of positions. | ||||||
|         Usefull for introducing specialists, etc. .. |         Usefull for introducing specialists, etc. .. | ||||||
|  |  | ||||||
|         !!! This rule does not introduce any reward or done condition. |         !!! This rule does not introduce any reward or done condition. | ||||||
|  |  | ||||||
|         :type per_agent_positions:  Dict[str, List[Tuple[int, int]] |         :type coords_or_quantity:  Dict[str, List[Tuple[int, int]] | ||||||
|         :param per_agent_positions: Please provide a dictionary with agent names as keys; and a list of possible |         :param coords_or_quantity: Please provide a dictionary with agent names as keys; and a list of possible | ||||||
|                                      destiantion coords as value. Example: {Wolfgang: [(0, 0), (1, 1), ...]} |                                      destiantion coords as value. Example: {Wolfgang: [(0, 0), (1, 1), ...]} | ||||||
|         """ |         """ | ||||||
|         super(Rule, self).__init__() |         super(Rule, self).__init__() | ||||||
|         self.per_agent_positions = {key: [ast.literal_eval(x) for x in val] for key, val in per_agent_positions.items()} |         self.per_agent_positions = {key: [ast.literal_eval(x) for x in val] for key, val in coords_or_quantity.items()} | ||||||
|  |  | ||||||
|     def on_init(self, state, lvl_map): |     def on_init(self, state, lvl_map): | ||||||
|         for (agent_name, position_list) in self.per_agent_positions.items(): |         for (agent_name, position_list) in self.per_agent_positions.items(): | ||||||
|             agent = next(x for x in state[c.AGENT] if agent_name in x.name)  # Fixme: Ugly AF |             agent = h.get_first(state[c.AGENT], lambda x: agent_name in x.name) | ||||||
|  |             assert agent | ||||||
|             position_list = position_list.copy() |             position_list = position_list.copy() | ||||||
|             shuffle(position_list) |             shuffle(position_list) | ||||||
|             while True: |             while True: | ||||||
| @@ -155,7 +117,7 @@ class SpawnDestinationsPerAgent(Rule): | |||||||
|                     pos = position_list.pop() |                     pos = position_list.pop() | ||||||
|                 except IndexError: |                 except IndexError: | ||||||
|                     print(f"Could not spawn Destinations at: {self.per_agent_positions[agent_name]}") |                     print(f"Could not spawn Destinations at: {self.per_agent_positions[agent_name]}") | ||||||
|                     print(f'Check your agent palcement: {state[c.AGENT]} ... Exit ...') |                     print(f'Check your agent placement: {state[c.AGENT]} ... Exit ...') | ||||||
|                     exit(9999) |                     exit(9999) | ||||||
|                 if (not pos == agent.pos) and (not state[d.DESTINATION].by_pos(pos)): |                 if (not pos == agent.pos) and (not state[d.DESTINATION].by_pos(pos)): | ||||||
|                     destination = Destination(pos, bind_to=agent) |                     destination = Destination(pos, bind_to=agent) | ||||||
|   | |||||||
| @@ -1,4 +1,5 @@ | |||||||
| from marl_factory_grid.environment.entity.entity import Entity | from marl_factory_grid.environment.entity.entity import Entity | ||||||
|  | from marl_factory_grid.utils import Result | ||||||
| from marl_factory_grid.utils.utility_classes import RenderEntity | from marl_factory_grid.utils.utility_classes import RenderEntity | ||||||
| from marl_factory_grid.environment import constants as c | from marl_factory_grid.environment import constants as c | ||||||
|  |  | ||||||
| @@ -41,21 +42,6 @@ class Door(Entity): | |||||||
|     def str_state(self): |     def str_state(self): | ||||||
|         return 'open' if self.is_open else 'closed' |         return 'open' if self.is_open else 'closed' | ||||||
|  |  | ||||||
|     def __init__(self, *args, closed_on_init=True, auto_close_interval=10, **kwargs): |  | ||||||
|         self._status = d.STATE_CLOSED |  | ||||||
|         super(Door, self).__init__(*args, **kwargs) |  | ||||||
|         self.auto_close_interval = auto_close_interval |  | ||||||
|         self.time_to_close = 0 |  | ||||||
|         if not closed_on_init: |  | ||||||
|             self._open() |  | ||||||
|         else: |  | ||||||
|             self._close() |  | ||||||
|  |  | ||||||
|     def summarize_state(self): |  | ||||||
|         state_dict = super().summarize_state() |  | ||||||
|         state_dict.update(state=str(self.str_state), time_to_close=int(self.time_to_close)) |  | ||||||
|         return state_dict |  | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def is_closed(self): |     def is_closed(self): | ||||||
|         return self._status == d.STATE_CLOSED |         return self._status == d.STATE_CLOSED | ||||||
| @@ -68,6 +54,25 @@ class Door(Entity): | |||||||
|     def status(self): |     def status(self): | ||||||
|         return self._status |         return self._status | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def time_to_close(self): | ||||||
|  |         return self._time_to_close | ||||||
|  |  | ||||||
|  |     def __init__(self, *args, closed_on_init=True, auto_close_interval=10, **kwargs): | ||||||
|  |         self._status = d.STATE_CLOSED | ||||||
|  |         super(Door, self).__init__(*args, **kwargs) | ||||||
|  |         self._auto_close_interval = auto_close_interval | ||||||
|  |         self._time_to_close = 0 | ||||||
|  |         if not closed_on_init: | ||||||
|  |             self._open() | ||||||
|  |         else: | ||||||
|  |             self._close() | ||||||
|  |  | ||||||
|  |     def summarize_state(self): | ||||||
|  |         state_dict = super().summarize_state() | ||||||
|  |         state_dict.update(state=str(self.str_state), time_to_close=self.time_to_close) | ||||||
|  |         return state_dict | ||||||
|  |  | ||||||
|     def render(self): |     def render(self): | ||||||
|         name, state = 'door_open' if self.is_open else 'door_closed', 'blank' |         name, state = 'door_open' if self.is_open else 'door_closed', 'blank' | ||||||
|         return RenderEntity(name, self.pos, 1, 'none', state, self.u_int + 1) |         return RenderEntity(name, self.pos, 1, 'none', state, self.u_int + 1) | ||||||
| @@ -80,18 +85,35 @@ class Door(Entity): | |||||||
|         return c.VALID |         return c.VALID | ||||||
|  |  | ||||||
|     def tick(self, state): |     def tick(self, state): | ||||||
|         if self.is_open and len(state.entities.pos_dict[self.pos]) == 2 and self.time_to_close: |         # Check if no entity is standing in the door | ||||||
|             self.time_to_close -= 1 |         if len(state.entities.pos_dict[self.pos]) <= 2: | ||||||
|             return c.NOT_VALID |             if self.is_open and self.time_to_close: | ||||||
|         elif self.is_open and not self.time_to_close and len(state.entities.pos_dict[self.pos]) == 2: |                 self._decrement_timer() | ||||||
|  |                 return Result(f"{d.DOOR}_tick", c.VALID, entity=self) | ||||||
|  |             elif self.is_open and not self.time_to_close: | ||||||
|                 self.use() |                 self.use() | ||||||
|             return c.VALID |                 return Result(f"{d.DOOR}_closed", c.VALID, entity=self) | ||||||
|             else: |             else: | ||||||
|             return c.NOT_VALID |                 # No one is in door, but it is closed... Nothing to do.... | ||||||
|  |                 return None | ||||||
|  |         else: | ||||||
|  |             # Entity is standing in the door, reset timer | ||||||
|  |             self._reset_timer() | ||||||
|  |             return Result(f"{d.DOOR}_reset", c.VALID, entity=self) | ||||||
|  |  | ||||||
|     def _open(self): |     def _open(self): | ||||||
|         self._status = d.STATE_OPEN |         self._status = d.STATE_OPEN | ||||||
|         self.time_to_close = self.auto_close_interval |         self._reset_timer() | ||||||
|  |         return True | ||||||
|  |  | ||||||
|     def _close(self): |     def _close(self): | ||||||
|         self._status = d.STATE_CLOSED |         self._status = d.STATE_CLOSED | ||||||
|  |         return True | ||||||
|  |  | ||||||
|  |     def _decrement_timer(self): | ||||||
|  |         self._time_to_close -= 1 | ||||||
|  |         return True | ||||||
|  |  | ||||||
|  |     def _reset_timer(self): | ||||||
|  |         self._time_to_close = self._auto_close_interval | ||||||
|  |         return True | ||||||
|   | |||||||
| @@ -1,5 +1,3 @@ | |||||||
| from typing import Union |  | ||||||
|  |  | ||||||
| from marl_factory_grid.environment.groups.collection import Collection | from marl_factory_grid.environment.groups.collection import Collection | ||||||
| from marl_factory_grid.modules.doors import constants as d | from marl_factory_grid.modules.doors import constants as d | ||||||
| from marl_factory_grid.modules.doors.entitites import Door | from marl_factory_grid.modules.doors.entitites import Door | ||||||
| @@ -18,8 +16,10 @@ class Doors(Collection): | |||||||
|         super(Doors, self).__init__(*args, can_collide=True, **kwargs) |         super(Doors, self).__init__(*args, can_collide=True, **kwargs) | ||||||
|  |  | ||||||
|     def tick_doors(self, state): |     def tick_doors(self, state): | ||||||
|         result_dict = dict() |         results = list() | ||||||
|         for door in self: |         for door in self: | ||||||
|             did_tick = door.tick(state) |             tick_result = door.tick(state) | ||||||
|             result_dict.update({door.name: did_tick}) |             if tick_result is not None: | ||||||
|         return result_dict |                 results.append(tick_result) | ||||||
|  |         # TODO: Should return a Result object, not a random dict. | ||||||
|  |         return results | ||||||
|   | |||||||
| @@ -19,10 +19,10 @@ class DoorAutoClose(Rule): | |||||||
|  |  | ||||||
|     def tick_step(self, state): |     def tick_step(self, state): | ||||||
|         if doors := state[d.DOORS]: |         if doors := state[d.DOORS]: | ||||||
|             doors_tick_result = doors.tick_doors(state) |             doors_tick_results = doors.tick_doors(state) | ||||||
|             doors_that_ticked = [key for key, val in doors_tick_result.items() if val] |             doors_that_closed = [x.entity.name for x in doors_tick_results if 'closed' in x.identifier] | ||||||
|             state.print(f'{doors_that_ticked} were auto-closed' |             door_str = doors_that_closed if doors_that_closed else "No Doors" | ||||||
|                         if doors_that_ticked else 'No Doors were auto-closed') |             state.print(f'{door_str} were auto-closed') | ||||||
|             return [TickResult(self.name, validity=c.VALID, value=1)] |             return [TickResult(self.name, validity=c.VALID, value=1)] | ||||||
|         state.print('There are no doors, but you loaded the corresponding Module') |         state.print('There are no doors, but you loaded the corresponding Module') | ||||||
|         return [] |         return [] | ||||||
|   | |||||||
| @@ -1,8 +1,8 @@ | |||||||
| import random | import random | ||||||
| from typing import List, Union | from typing import List | ||||||
|  |  | ||||||
| from marl_factory_grid.environment.rules import Rule |  | ||||||
| from marl_factory_grid.environment import constants as c | from marl_factory_grid.environment import constants as c | ||||||
|  | from marl_factory_grid.environment.rules import Rule | ||||||
| from marl_factory_grid.utils.results import TickResult | from marl_factory_grid.utils.results import TickResult | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -14,8 +14,6 @@ class AgentSingleZonePlacementBeta(Rule): | |||||||
|         super().__init__() |         super().__init__() | ||||||
|  |  | ||||||
|     def on_init(self, state, lvl_map): |     def on_init(self, state, lvl_map): | ||||||
|         zones = state[c.ZONES] |  | ||||||
|         n_zones = state[c.ZONES] |  | ||||||
|         agents = state[c.AGENT] |         agents = state[c.AGENT] | ||||||
|         if len(self.coordinates) == len(agents): |         if len(self.coordinates) == len(agents): | ||||||
|             coordinates = self.coordinates |             coordinates = self.coordinates | ||||||
|   | |||||||
| @@ -1,4 +1,3 @@ | |||||||
| from .actions import ItemAction | from .actions import ItemAction | ||||||
| from .entitites import Item, DropOffLocation | from .entitites import Item, DropOffLocation | ||||||
| from .groups import DropOffLocations, Items, Inventory, Inventories | from .groups import DropOffLocations, Items, Inventory, Inventories | ||||||
| from .rules import ItemRules |  | ||||||
|   | |||||||
| @@ -29,7 +29,7 @@ class ItemAction(Action): | |||||||
|         elif items := state[i.ITEM].by_pos(entity.pos): |         elif items := state[i.ITEM].by_pos(entity.pos): | ||||||
|             item = items[0] |             item = items[0] | ||||||
|             item.change_parent_collection(inventory) |             item.change_parent_collection(inventory) | ||||||
|             item.set_pos_to(c.VALUE_NO_POS) |             item.set_pos(c.VALUE_NO_POS) | ||||||
|             state.print(f'{entity.name} just picked up an item at {entity.pos}') |             state.print(f'{entity.name} just picked up an item at {entity.pos}') | ||||||
|             return ActionResult(entity=entity, identifier=self._identifier, validity=c.VALID, reward=r.PICK_UP_VALID) |             return ActionResult(entity=entity, identifier=self._identifier, validity=c.VALID, reward=r.PICK_UP_VALID) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,6 +1,3 @@ | |||||||
| from typing import NamedTuple |  | ||||||
|  |  | ||||||
|  |  | ||||||
| SYMBOL_NO_ITEM      = 0 | SYMBOL_NO_ITEM      = 0 | ||||||
| SYMBOL_DROP_OFF     = 1 | SYMBOL_DROP_OFF     = 1 | ||||||
| # Item Env | # Item Env | ||||||
|   | |||||||
| @@ -8,56 +8,20 @@ from marl_factory_grid.modules.items import constants as i | |||||||
|  |  | ||||||
| class Item(Entity): | class Item(Entity): | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def var_can_collide(self): |  | ||||||
|         return False |  | ||||||
|  |  | ||||||
|     def render(self): |     def render(self): | ||||||
|         return RenderEntity(i.ITEM, self.pos) if self.pos != c.VALUE_NO_POS else None |         return RenderEntity(i.ITEM, self.pos) if self.pos != c.VALUE_NO_POS else None | ||||||
|  |  | ||||||
|     def __init__(self, *args, **kwargs): |     def __init__(self, *args, **kwargs): | ||||||
|         super().__init__(*args, **kwargs) |         super().__init__(*args, **kwargs) | ||||||
|         self._auto_despawn = -1 |  | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def auto_despawn(self): |  | ||||||
|         return self._auto_despawn |  | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def encoding(self): |     def encoding(self): | ||||||
|         # Edit this if you want items to be drawn in the ops differently |         # Edit this if you want items to be drawn in the ops differently | ||||||
|         return 1 |         return 1 | ||||||
|  |  | ||||||
|     def set_auto_despawn(self, auto_despawn): |  | ||||||
|         self._auto_despawn = auto_despawn |  | ||||||
|  |  | ||||||
|     def set_pos_to(self, no_pos): |  | ||||||
|         self._pos = no_pos |  | ||||||
|  |  | ||||||
|     def summarize_state(self) -> dict: |  | ||||||
|         super_summarization = super(Item, self).summarize_state() |  | ||||||
|         super_summarization.update(dict(auto_despawn=self.auto_despawn)) |  | ||||||
|         return super_summarization |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class DropOffLocation(Entity): | class DropOffLocation(Entity): | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def var_can_collide(self): |  | ||||||
|         return False |  | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def var_can_move(self): |  | ||||||
|         return False |  | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def var_is_blocking_light(self): |  | ||||||
|         return False |  | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def var_has_position(self): |  | ||||||
|         return True |  | ||||||
|  |  | ||||||
|     def render(self): |     def render(self): | ||||||
|         return RenderEntity(i.DROP_OFF, self.pos) |         return RenderEntity(i.DROP_OFF, self.pos) | ||||||
|  |  | ||||||
| @@ -65,18 +29,16 @@ class DropOffLocation(Entity): | |||||||
|     def encoding(self): |     def encoding(self): | ||||||
|         return i.SYMBOL_DROP_OFF |         return i.SYMBOL_DROP_OFF | ||||||
|  |  | ||||||
|     def __init__(self, *args, storage_size_until_full: int = 5, auto_item_despawn_interval: int = 5, **kwargs): |     def __init__(self, *args, storage_size_until_full: int = 5, **kwargs): | ||||||
|         super(DropOffLocation, self).__init__(*args, **kwargs) |         super(DropOffLocation, self).__init__(*args, **kwargs) | ||||||
|         self.auto_item_despawn_interval = auto_item_despawn_interval |  | ||||||
|         self.storage = deque(maxlen=storage_size_until_full or None) |         self.storage = deque(maxlen=storage_size_until_full or None) | ||||||
|  |  | ||||||
|     def place_item(self, item: Item): |     def place_item(self, item: Item): | ||||||
|         if self.is_full: |         if self.is_full: | ||||||
|             raise RuntimeWarning("There is currently no way to clear the storage or make it unfull.") |             raise RuntimeWarning("There is currently no way to clear the storage or make it unfull.") | ||||||
|             return bc.NOT_VALID  # in Zeile 81 verschieben? |             return bc.NOT_VALID | ||||||
|         else: |         else: | ||||||
|             self.storage.append(item) |             self.storage.append(item) | ||||||
|             item.set_auto_despawn(self.auto_item_despawn_interval) |  | ||||||
|             return c.VALID |             return c.VALID | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|   | |||||||
| @@ -1,13 +1,11 @@ | |||||||
| from random import shuffle |  | ||||||
|  |  | ||||||
| from marl_factory_grid.modules.items import constants as i |  | ||||||
| from marl_factory_grid.environment import constants as c | from marl_factory_grid.environment import constants as c | ||||||
|  |  | ||||||
| from marl_factory_grid.environment.groups.collection import Collection |  | ||||||
| from marl_factory_grid.environment.groups.objects import _Objects |  | ||||||
| from marl_factory_grid.environment.groups.mixins import IsBoundMixin |  | ||||||
| from marl_factory_grid.environment.entity.agent import Agent | from marl_factory_grid.environment.entity.agent import Agent | ||||||
|  | from marl_factory_grid.environment.groups.collection import Collection | ||||||
|  | from marl_factory_grid.environment.groups.mixins import IsBoundMixin | ||||||
|  | from marl_factory_grid.environment.groups.objects import Objects | ||||||
|  | from marl_factory_grid.modules.items import constants as i | ||||||
| from marl_factory_grid.modules.items.entitites import Item, DropOffLocation | from marl_factory_grid.modules.items.entitites import Item, DropOffLocation | ||||||
|  | from marl_factory_grid.utils.results import Result | ||||||
|  |  | ||||||
|  |  | ||||||
| class Items(Collection): | class Items(Collection): | ||||||
| @@ -15,7 +13,7 @@ class Items(Collection): | |||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def var_has_position(self): |     def var_has_position(self): | ||||||
|         return False |         return True | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def is_blocking_light(self): |     def is_blocking_light(self): | ||||||
| @@ -28,18 +26,18 @@ class Items(Collection): | |||||||
|     def __init__(self, *args, **kwargs): |     def __init__(self, *args, **kwargs): | ||||||
|         super().__init__(*args, **kwargs) |         super().__init__(*args, **kwargs) | ||||||
|  |  | ||||||
|     @staticmethod |     def trigger_spawn(self, state, *entity_args, coords_or_quantity=None, **entity_kwargs) -> [Result]: | ||||||
|     def trigger_item_spawn(state, n_items, spawn_frequency): |         coords_or_quantity = coords_or_quantity if coords_or_quantity else self._coords_or_quantity | ||||||
|         if item_to_spawns := max(0, (n_items - len(state[i.ITEM]))): |         assert coords_or_quantity | ||||||
|             position_list = [x for x in state.entities.floorlist] |  | ||||||
|             shuffle(position_list) |         if item_to_spawns := max(0, (coords_or_quantity - len(self))): | ||||||
|             position_list = state.entities.floorlist[:item_to_spawns] |             return super().trigger_spawn(state, | ||||||
|             state[i.ITEM].spawn(position_list) |                                          *entity_args, | ||||||
|             state.print(f'{item_to_spawns} new items have been spawned; next spawn in {spawn_frequency}') |                                          coords_or_quantity=item_to_spawns, | ||||||
|             return len(position_list) |                                          **entity_kwargs) | ||||||
|         else: |         else: | ||||||
|             state.print('No Items are spawning, limit is reached.') |             state.print('No Items are spawning, limit is reached.') | ||||||
|             return 0 |             return Result(identifier=f'{self.name}_spawn', validity=c.NOT_VALID, value=coords_or_quantity) | ||||||
|  |  | ||||||
|  |  | ||||||
| class Inventory(IsBoundMixin, Collection): | class Inventory(IsBoundMixin, Collection): | ||||||
| @@ -73,12 +71,17 @@ class Inventory(IsBoundMixin, Collection): | |||||||
|         self._collection = collection |         self._collection = collection | ||||||
|  |  | ||||||
|  |  | ||||||
| class Inventories(_Objects): | class Inventories(Objects): | ||||||
|     _entity = Inventory |     _entity = Inventory | ||||||
|  |  | ||||||
|  |     var_can_move = False | ||||||
|  |     var_has_position = False | ||||||
|  |  | ||||||
|  |     symbol = None | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def var_can_move(self): |     def spawn_rule(self): | ||||||
|         return False |         return {c.SPAWN_ENTITY_RULE: dict(collection=self, coords_or_quantity=None)} | ||||||
|  |  | ||||||
|     def __init__(self, size: int, *args, **kwargs): |     def __init__(self, size: int, *args, **kwargs): | ||||||
|         super(Inventories, self).__init__(*args, **kwargs) |         super(Inventories, self).__init__(*args, **kwargs) | ||||||
| @@ -86,10 +89,12 @@ class Inventories(_Objects): | |||||||
|         self._obs = None |         self._obs = None | ||||||
|         self._lazy_eval_transforms = [] |         self._lazy_eval_transforms = [] | ||||||
|  |  | ||||||
|     def spawn(self, agents): |     def spawn(self, agents, *args, **kwargs): | ||||||
|         inventories = [self._entity(agent, self.size, ) |         self.add_items([self._entity(agent, self.size, *args, **kwargs) for _, agent in enumerate(agents)]) | ||||||
|                        for _, agent in enumerate(agents)] |         return [Result(identifier=f'{self.name}_spawn', validity=c.VALID, value=len(self))] | ||||||
|         self.add_items(inventories) |  | ||||||
|  |     def trigger_spawn(self, state, *args, **kwargs) -> [Result]: | ||||||
|  |         return self.spawn(state[c.AGENT], *args, **kwargs) | ||||||
|  |  | ||||||
|     def idx_by_entity(self, entity): |     def idx_by_entity(self, entity): | ||||||
|         try: |         try: | ||||||
| @@ -106,10 +111,6 @@ class Inventories(_Objects): | |||||||
|     def summarize_states(self, **kwargs): |     def summarize_states(self, **kwargs): | ||||||
|         return [val.summarize_states(**kwargs) for key, val in self.items()] |         return [val.summarize_states(**kwargs) for key, val in self.items()] | ||||||
|  |  | ||||||
|     @staticmethod |  | ||||||
|     def trigger_inventory_spawn(state): |  | ||||||
|         state[i.INVENTORY].spawn(state[c.AGENT]) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class DropOffLocations(Collection): | class DropOffLocations(Collection): | ||||||
|     _entity = DropOffLocation |     _entity = DropOffLocation | ||||||
| @@ -135,7 +136,7 @@ class DropOffLocations(Collection): | |||||||
|  |  | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     def trigger_drop_off_location_spawn(state, n_locations): |     def trigger_drop_off_location_spawn(state, n_locations): | ||||||
|         empty_positions = state.entities.empty_positions()[:n_locations] |         empty_positions = state.entities.empty_positions[:n_locations] | ||||||
|         do_entites = state[i.DROP_OFF] |         do_entites = state[i.DROP_OFF] | ||||||
|         drop_offs = [DropOffLocation(pos) for pos in empty_positions] |         drop_offs = [DropOffLocation(pos) for pos in empty_positions] | ||||||
|         do_entites.add_items(drop_offs) |         do_entites.add_items(drop_offs) | ||||||
|   | |||||||
| @@ -6,52 +6,28 @@ from marl_factory_grid.utils.results import TickResult | |||||||
| from marl_factory_grid.modules.items import constants as i | from marl_factory_grid.modules.items import constants as i | ||||||
|  |  | ||||||
|  |  | ||||||
| class ItemRules(Rule): | class RespawnItems(Rule): | ||||||
|  |  | ||||||
|     def __init__(self, n_items: int = 5, spawn_frequency: int = 15, |     def __init__(self, n_items: int = 5, respawn_freq: int = 15, n_locations: int = 5): | ||||||
|                  n_locations: int = 5, max_dropoff_storage_size: int = 0): |  | ||||||
|         super().__init__() |         super().__init__() | ||||||
|         self.spawn_frequency = spawn_frequency |         self.spawn_frequency = respawn_freq | ||||||
|         self._next_item_spawn = spawn_frequency |         self._next_item_spawn = respawn_freq | ||||||
|         self.n_items = n_items |         self.n_items = n_items | ||||||
|         self.max_dropoff_storage_size = max_dropoff_storage_size |  | ||||||
|         self.n_locations = n_locations |         self.n_locations = n_locations | ||||||
|  |  | ||||||
|     def on_init(self, state, lvl_map): |  | ||||||
|         state[i.DROP_OFF].trigger_drop_off_location_spawn(state, self.n_locations) |  | ||||||
|         self._next_item_spawn = self.spawn_frequency |  | ||||||
|         state[i.INVENTORY].trigger_inventory_spawn(state) |  | ||||||
|         state[i.ITEM].trigger_item_spawn(state, self.n_items, self.spawn_frequency) |  | ||||||
|  |  | ||||||
|     def tick_step(self, state): |     def tick_step(self, state): | ||||||
|         for item in list(state[i.ITEM].values()): |  | ||||||
|             if item.auto_despawn >= 1: |  | ||||||
|                 item.set_auto_despawn(item.auto_despawn - 1) |  | ||||||
|             elif not item.auto_despawn: |  | ||||||
|                 state[i.ITEM].delete_env_object(item) |  | ||||||
|             else: |  | ||||||
|                 pass |  | ||||||
|  |  | ||||||
|         if not self._next_item_spawn: |         if not self._next_item_spawn: | ||||||
|             state[i.ITEM].trigger_item_spawn(state, self.n_items, self.spawn_frequency) |             state[i.ITEM].trigger_spawn(state, self.n_items, self.spawn_frequency) | ||||||
|         else: |         else: | ||||||
|             self._next_item_spawn = max(0, self._next_item_spawn - 1) |             self._next_item_spawn = max(0, self._next_item_spawn - 1) | ||||||
|         return [] |         return [] | ||||||
|  |  | ||||||
|     def tick_post_step(self, state) -> List[TickResult]: |     def tick_post_step(self, state) -> List[TickResult]: | ||||||
|         for item in list(state[i.ITEM].values()): |  | ||||||
|             if item.auto_despawn >= 1: |  | ||||||
|                 item.set_auto_despawn(item.auto_despawn-1) |  | ||||||
|             elif not item.auto_despawn: |  | ||||||
|                 state[i.ITEM].delete_env_object(item) |  | ||||||
|             else: |  | ||||||
|                 pass |  | ||||||
|  |  | ||||||
|         if not self._next_item_spawn: |         if not self._next_item_spawn: | ||||||
|             if spawned_items := state[i.ITEM].trigger_item_spawn(state, self.n_items, self.spawn_frequency): |             if spawned_items := state[i.ITEM].trigger_spawn(state, self.n_items, self.spawn_frequency): | ||||||
|                 return [TickResult(self.name, validity=c.VALID, value=spawned_items, entity=None)] |                 return [TickResult(self.name, validity=c.VALID, value=spawned_items.value)] | ||||||
|             else: |             else: | ||||||
|                 return [TickResult(self.name, validity=c.NOT_VALID, value=0, entity=None)] |                 return [TickResult(self.name, validity=c.NOT_VALID, value=0)] | ||||||
|         else: |         else: | ||||||
|             self._next_item_spawn = max(0, self._next_item_spawn-1) |             self._next_item_spawn = max(0, self._next_item_spawn-1) | ||||||
|             return [] |             return [] | ||||||
|   | |||||||
| @@ -1,3 +1,2 @@ | |||||||
| from .entitites import Machine | from .entitites import Machine | ||||||
| from .groups import Machines | from .groups import Machines | ||||||
| from .rules import MachineRule |  | ||||||
|   | |||||||
| @@ -1,10 +1,12 @@ | |||||||
| from typing import Union | from typing import Union | ||||||
|  |  | ||||||
|  | import marl_factory_grid.modules.machines.constants | ||||||
| from marl_factory_grid.environment.actions import Action | from marl_factory_grid.environment.actions import Action | ||||||
| from marl_factory_grid.utils.results import ActionResult | from marl_factory_grid.utils.results import ActionResult | ||||||
|  |  | ||||||
| from marl_factory_grid.modules.machines import constants as m, rewards as r | from marl_factory_grid.modules.machines import constants as m | ||||||
| from marl_factory_grid.environment import constants as c | from marl_factory_grid.environment import constants as c | ||||||
|  | from marl_factory_grid.utils import helpers as h | ||||||
|  |  | ||||||
|  |  | ||||||
| class MachineAction(Action): | class MachineAction(Action): | ||||||
| @@ -13,13 +15,12 @@ class MachineAction(Action): | |||||||
|         super().__init__(m.MACHINE_ACTION) |         super().__init__(m.MACHINE_ACTION) | ||||||
|  |  | ||||||
|     def do(self, entity, state) -> Union[None, ActionResult]: |     def do(self, entity, state) -> Union[None, ActionResult]: | ||||||
|         if machine := state[m.MACHINES].by_pos(entity.pos): |         if machine := h.get_first(state[m.MACHINES].by_pos(entity.pos)): | ||||||
|             if valid := machine.maintain(): |             if valid := machine.maintain(): | ||||||
|                 return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=r.MAINTAIN_VALID) |                 return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=marl_factory_grid.modules.machines.constants.MAINTAIN_VALID) | ||||||
|             else: |             else: | ||||||
|                 return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=r.MAINTAIN_FAIL) |                 return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=marl_factory_grid.modules.machines.constants.MAINTAIN_FAIL) | ||||||
|         else: |         else: | ||||||
|             return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=r.MAINTAIN_FAIL) |             return ActionResult(entity=entity, identifier=self._identifier, | ||||||
|  |                                 validity=c.NOT_VALID, reward=marl_factory_grid.modules.machines.constants.MAINTAIN_FAIL | ||||||
|  |                                 ) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -11,3 +11,7 @@ STATE_MAINTAIN = 'maintenance' | |||||||
| SYMBOL_WORK = 1 | SYMBOL_WORK = 1 | ||||||
| SYMBOL_IDLE = 0.6 | SYMBOL_IDLE = 0.6 | ||||||
| SYMBOL_MAINTAIN = 0.3 | SYMBOL_MAINTAIN = 0.3 | ||||||
|  | MAINTAIN_VALID: float = 0.5 | ||||||
|  | MAINTAIN_FAIL: float = -0.1 | ||||||
|  | FAIL_MISSING_MAINTENANCE: float = -0.5 | ||||||
|  | NONE: float = 0 | ||||||
|   | |||||||
| @@ -8,22 +8,6 @@ from . import constants as m | |||||||
|  |  | ||||||
| class Machine(Entity): | class Machine(Entity): | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def var_can_collide(self): |  | ||||||
|         return False |  | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def var_can_move(self): |  | ||||||
|         return False |  | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def var_is_blocking_light(self): |  | ||||||
|         return False |  | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def var_has_position(self): |  | ||||||
|         return True |  | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def encoding(self): |     def encoding(self): | ||||||
|         return self._encodings[self.status] |         return self._encodings[self.status] | ||||||
| @@ -46,12 +30,11 @@ class Machine(Entity): | |||||||
|         else: |         else: | ||||||
|             return c.NOT_VALID |             return c.NOT_VALID | ||||||
|  |  | ||||||
|     def tick(self): |     def tick(self, state): | ||||||
|         # if self.status == m.STATE_MAINTAIN and any([c.AGENT in x.name for x in self.tile.guests]): |         others = state.entities.pos_dict[self.pos] | ||||||
|         if self.status == m.STATE_MAINTAIN and any([c.AGENT in x.name for x in self.state.entities.pos_dict[self.pos]]): |         if self.status == m.STATE_MAINTAIN and any([c.AGENT in x.name for x in others]): | ||||||
|             return TickResult(identifier=self.name, validity=c.VALID, reward=0, entity=self) |             return TickResult(identifier=self.name, validity=c.VALID, entity=self) | ||||||
|         # elif self.status == m.STATE_MAINTAIN and not any([c.AGENT in x.name for x in self.tile.guests]): |         elif self.status == m.STATE_MAINTAIN and not any([c.AGENT in x.name for x in others]): | ||||||
|         elif self.status == m.STATE_MAINTAIN and not any([c.AGENT in x.name for x in self.state.entities.pos_dict[self.pos]]): |  | ||||||
|             self.status = m.STATE_WORK |             self.status = m.STATE_WORK | ||||||
|             self.reset_counter() |             self.reset_counter() | ||||||
|             return None |             return None | ||||||
|   | |||||||
| @@ -1,5 +1,3 @@ | |||||||
| from typing import Union, List, Tuple |  | ||||||
|  |  | ||||||
| from marl_factory_grid.environment.groups.collection import Collection | from marl_factory_grid.environment.groups.collection import Collection | ||||||
|  |  | ||||||
| from .entitites import Machine | from .entitites import Machine | ||||||
|   | |||||||
| @@ -1,5 +0,0 @@ | |||||||
| MAINTAIN_VALID: float = 0.5 |  | ||||||
| MAINTAIN_FAIL: float = -0.1 |  | ||||||
| FAIL_MISSING_MAINTENANCE: float = -0.5 |  | ||||||
|  |  | ||||||
| NONE: float = 0 |  | ||||||
| @@ -1,28 +0,0 @@ | |||||||
| from typing import List |  | ||||||
| from marl_factory_grid.environment.rules import Rule |  | ||||||
| from marl_factory_grid.utils.results import TickResult, DoneResult |  | ||||||
| from marl_factory_grid.environment import constants as c |  | ||||||
| from marl_factory_grid.modules.machines import constants as m |  | ||||||
| from marl_factory_grid.modules.machines.entitites import Machine |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class MachineRule(Rule): |  | ||||||
|  |  | ||||||
|     def __init__(self, n_machines: int = 2): |  | ||||||
|         super(MachineRule, self).__init__() |  | ||||||
|         self.n_machines = n_machines |  | ||||||
|  |  | ||||||
|     def on_init(self, state, lvl_map): |  | ||||||
|         state[m.MACHINES].spawn(state.entities.empty_positions()) |  | ||||||
|  |  | ||||||
|     def tick_pre_step(self, state) -> List[TickResult]: |  | ||||||
|         pass |  | ||||||
|  |  | ||||||
|     def tick_step(self, state) -> List[TickResult]: |  | ||||||
|         pass |  | ||||||
|  |  | ||||||
|     def tick_post_step(self, state) -> List[TickResult]: |  | ||||||
|         pass |  | ||||||
|  |  | ||||||
|     def on_check_done(self, state) -> List[DoneResult]: |  | ||||||
|         pass |  | ||||||
|   | |||||||
| @@ -1,3 +1,4 @@ | |||||||
| MAINTAINER         = 'Maintainer'                    # TEMPLATE _identifier. Define your own! | MAINTAINER         = 'Maintainer'                    # TEMPLATE _identifier. Define your own! | ||||||
| MAINTAINERS        = 'Maintainers'                   # TEMPLATE _identifier. Define your own! | MAINTAINERS        = 'Maintainers'                   # TEMPLATE _identifier. Define your own! | ||||||
|  |  | ||||||
|  | MAINTAINER_COLLISION_REWARD = -5 | ||||||
|   | |||||||
| @@ -1,48 +1,35 @@ | |||||||
|  | from random import shuffle | ||||||
|  |  | ||||||
| import networkx as nx | import networkx as nx | ||||||
| import numpy as np | import numpy as np | ||||||
|  |  | ||||||
|  |  | ||||||
| from ...algorithms.static.utils import points_to_graph | from ...algorithms.static.utils import points_to_graph | ||||||
| from ...environment import constants as c | from ...environment import constants as c | ||||||
| from ...environment.actions import Action, ALL_BASEACTIONS | from ...environment.actions import Action, ALL_BASEACTIONS | ||||||
| from ...environment.entity.entity import Entity | from ...environment.entity.entity import Entity | ||||||
| from ..doors import constants as do | from ..doors import constants as do | ||||||
| from ..maintenance import constants as mi | from ..maintenance import constants as mi | ||||||
| from ...utils.helpers import MOVEMAP | from ...utils import helpers as h | ||||||
| from ...utils.utility_classes import RenderEntity | from ...utils.utility_classes import RenderEntity, Floor | ||||||
| from ...utils.states import Gamestate | from ..doors import DoorUse | ||||||
|  |  | ||||||
|  |  | ||||||
| class Maintainer(Entity): | class Maintainer(Entity): | ||||||
|  |  | ||||||
|     @property |     def __init__(self, objective: str, action: Action, *args, **kwargs): | ||||||
|     def var_can_collide(self): |  | ||||||
|         return True |  | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def var_can_move(self): |  | ||||||
|         return False |  | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def var_is_blocking_light(self): |  | ||||||
|         return False |  | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def var_has_position(self): |  | ||||||
|         return True |  | ||||||
|  |  | ||||||
|     def __init__(self, state: Gamestate, objective: str, action: Action, *args, **kwargs): |  | ||||||
|         super().__init__(*args, **kwargs) |         super().__init__(*args, **kwargs) | ||||||
|         self.action = action |         self.action = action | ||||||
|         self.actions = [x() for x in ALL_BASEACTIONS] |         self.actions = [x() for x in ALL_BASEACTIONS] + [DoorUse()] | ||||||
|         self.objective = objective |         self.objective = objective | ||||||
|         self._path = None |         self._path = None | ||||||
|         self._next = [] |         self._next = [] | ||||||
|         self._last = [] |         self._last = [] | ||||||
|         self._last_serviced = 'None' |         self._last_serviced = 'None' | ||||||
|         self._floortile_graph = points_to_graph(state.entities.floorlist) |         self._floortile_graph = None | ||||||
|  |  | ||||||
|     def tick(self, state): |     def tick(self, state): | ||||||
|         if found_objective := state[self.objective].by_pos(self.pos): |         if found_objective := h.get_first(state[self.objective].by_pos(self.pos)): | ||||||
|             if found_objective.name != self._last_serviced: |             if found_objective.name != self._last_serviced: | ||||||
|                 self.action.do(self, state) |                 self.action.do(self, state) | ||||||
|                 self._last_serviced = found_objective.name |                 self._last_serviced = found_objective.name | ||||||
| @@ -54,24 +41,27 @@ class Maintainer(Entity): | |||||||
|             return action.do(self, state) |             return action.do(self, state) | ||||||
|  |  | ||||||
|     def get_move_action(self, state) -> Action: |     def get_move_action(self, state) -> Action: | ||||||
|  |         if not self._floortile_graph: | ||||||
|  |             state.print("Generating Floorgraph....") | ||||||
|  |             self._floortile_graph = points_to_graph(state.entities.floorlist) | ||||||
|         if self._path is None or not self._path: |         if self._path is None or not self._path: | ||||||
|             if not self._next: |             if not self._next: | ||||||
|                 self._next = list(state[self.objective].values()) |                 self._next = list(state[self.objective].values()) + [Floor(*state.random_free_position)] | ||||||
|  |                 shuffle(self._next) | ||||||
|                 self._last = [] |                 self._last = [] | ||||||
|             self._last.append(self._next.pop()) |             self._last.append(self._next.pop()) | ||||||
|  |             state.print("Calculating shortest path....") | ||||||
|             self._path = self.calculate_route(self._last[-1]) |             self._path = self.calculate_route(self._last[-1]) | ||||||
|  |  | ||||||
|         if door := self._door_is_close(state): |         if door := self._closed_door_in_path(state): | ||||||
|             if door.is_closed: |             state.print(f"{self} found {door} that is closed. Attempt to open.") | ||||||
|             # Translate the action_object to an integer to have the same output as any other model |             # Translate the action_object to an integer to have the same output as any other model | ||||||
|             action = do.ACTION_DOOR_USE |             action = do.ACTION_DOOR_USE | ||||||
|         else: |         else: | ||||||
|             action = self._predict_move(state) |             action = self._predict_move(state) | ||||||
|         else: |  | ||||||
|             action = self._predict_move(state) |  | ||||||
|         # Translate the action_object to an integer to have the same output as any other model |         # Translate the action_object to an integer to have the same output as any other model | ||||||
|         try: |         try: | ||||||
|             action_obj = next(x for x in self.actions if x.name == action) |             action_obj = h.get_first(self.actions, lambda x: x.name == action) | ||||||
|         except (StopIteration, UnboundLocalError): |         except (StopIteration, UnboundLocalError): | ||||||
|             print('Will not happen') |             print('Will not happen') | ||||||
|             raise EnvironmentError |             raise EnvironmentError | ||||||
| @@ -81,11 +71,10 @@ class Maintainer(Entity): | |||||||
|         route = nx.shortest_path(self._floortile_graph, self.pos, entity.pos) |         route = nx.shortest_path(self._floortile_graph, self.pos, entity.pos) | ||||||
|         return route[1:] |         return route[1:] | ||||||
|  |  | ||||||
|     def _door_is_close(self, state): |     def _closed_door_in_path(self, state): | ||||||
|         state.print("Found a door that is close.") |         if self._path: | ||||||
|         try: |             return h.get_first(state[do.DOORS].by_pos(self._path[0]), lambda x: x.is_closed) | ||||||
|             return next(y for x in state.entities.neighboring_positions(self.state.pos) for y in state.entities.pos_dict[x] if do.DOOR in y.name) |         else: | ||||||
|         except StopIteration: |  | ||||||
|             return None |             return None | ||||||
|  |  | ||||||
|     def _predict_move(self, state): |     def _predict_move(self, state): | ||||||
| @@ -96,7 +85,7 @@ class Maintainer(Entity): | |||||||
|             next_pos = self._path.pop(0) |             next_pos = self._path.pop(0) | ||||||
|             diff = np.subtract(next_pos, self.pos) |             diff = np.subtract(next_pos, self.pos) | ||||||
|             # Retrieve action based on the pos dif (like in: What do I have to do to get there?) |             # Retrieve action based on the pos dif (like in: What do I have to do to get there?) | ||||||
|             action = next(action for action, pos_diff in MOVEMAP.items() if np.all(diff == pos_diff)) |             action = next(action for action, pos_diff in h.MOVEMAP.items() if np.all(diff == pos_diff)) | ||||||
|         return action |         return action | ||||||
|  |  | ||||||
|     def render(self): |     def render(self): | ||||||
|   | |||||||
| @@ -1,34 +1,27 @@ | |||||||
| from typing import Union, List, Tuple | from typing import Union, List, Tuple, Dict | ||||||
|  |  | ||||||
| from marl_factory_grid.environment.groups.collection import Collection | from marl_factory_grid.environment.groups.collection import Collection | ||||||
| from .entities import Maintainer | from .entities import Maintainer | ||||||
| from ..machines import constants as mc | from ..machines import constants as mc | ||||||
| from ..machines.actions import MachineAction | from ..machines.actions import MachineAction | ||||||
| from ...utils.states import Gamestate |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class Maintainers(Collection): | class Maintainers(Collection): | ||||||
|     _entity = Maintainer |     _entity = Maintainer | ||||||
|  |  | ||||||
|     @property |     var_can_collide = True | ||||||
|     def var_can_collide(self): |     var_can_move = True | ||||||
|         return True |     var_is_blocking_light = False | ||||||
|  |     var_has_position = True | ||||||
|  |  | ||||||
|     @property |     def __init__(self, size, *args, coords_or_quantity: int = None, | ||||||
|     def var_can_move(self): |                  spawnrule: Union[None, Dict[str, dict]] = None, | ||||||
|         return True |                  **kwargs): | ||||||
|  |         super(Collection, self).__init__(*args, **kwargs) | ||||||
|  |         self._coords_or_quantity = coords_or_quantity | ||||||
|  |         self.size = size | ||||||
|  |         self._spawnrule = spawnrule | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def var_is_blocking_light(self): |  | ||||||
|         return False |  | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def var_has_position(self): |  | ||||||
|         return True |  | ||||||
|  |  | ||||||
|     def __init__(self, *args, **kwargs): |  | ||||||
|         super().__init__(*args, **kwargs) |  | ||||||
|  |  | ||||||
|     def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args): |     def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args): | ||||||
|         state = entity_args[0] |         self.add_items([self._entity(mc.MACHINES, MachineAction(), pos) for pos in coords_or_quantity]) | ||||||
|         self.add_items([self._entity(state, mc.MACHINES, MachineAction(), pos) for pos in coords_or_quantity]) |  | ||||||
|   | |||||||
| @@ -1 +0,0 @@ | |||||||
| MAINTAINER_COLLISION_REWARD = -5 |  | ||||||
| @@ -1,32 +1,28 @@ | |||||||
| from typing import List | from typing import List | ||||||
|  |  | ||||||
|  | import marl_factory_grid.modules.maintenance.constants | ||||||
| from marl_factory_grid.environment.rules import Rule | from marl_factory_grid.environment.rules import Rule | ||||||
| from marl_factory_grid.utils.results import TickResult, DoneResult | from marl_factory_grid.utils.results import TickResult, DoneResult | ||||||
| from marl_factory_grid.environment import constants as c | from marl_factory_grid.environment import constants as c | ||||||
| from . import rewards as r |  | ||||||
| from . import constants as M | from . import constants as M | ||||||
| from marl_factory_grid.utils.states import Gamestate |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class MaintenanceRule(Rule): | class MoveMaintainers(Rule): | ||||||
|  |  | ||||||
|     def __init__(self, n_maintainer: int = 1, *args, **kwargs): |     def __init__(self): | ||||||
|         super(MaintenanceRule, self).__init__(*args, **kwargs) |         super().__init__() | ||||||
|         self.n_maintainer = n_maintainer |  | ||||||
|  |  | ||||||
|     def on_init(self, state: Gamestate, lvl_map): |  | ||||||
|         state[M.MAINTAINERS].spawn(state.entities.empty_positions[:self.n_maintainer], state) |  | ||||||
|         pass |  | ||||||
|  |  | ||||||
|     def tick_pre_step(self, state) -> List[TickResult]: |  | ||||||
|         pass |  | ||||||
|  |  | ||||||
|     def tick_step(self, state) -> List[TickResult]: |     def tick_step(self, state) -> List[TickResult]: | ||||||
|         for maintainer in state[M.MAINTAINERS]: |         for maintainer in state[M.MAINTAINERS]: | ||||||
|             maintainer.tick(state) |             maintainer.tick(state) | ||||||
|  |         # Todo: Return a Result Object. | ||||||
|         return [] |         return [] | ||||||
|  |  | ||||||
|     def tick_post_step(self, state) -> List[TickResult]: |  | ||||||
|         pass | class DoneAtMaintainerCollision(Rule): | ||||||
|  |  | ||||||
|  |     def __init__(self): | ||||||
|  |         super().__init__() | ||||||
|  |  | ||||||
|     def on_check_done(self, state) -> List[DoneResult]: |     def on_check_done(self, state) -> List[DoneResult]: | ||||||
|         agents = list(state[c.AGENT].values()) |         agents = list(state[c.AGENT].values()) | ||||||
| @@ -35,5 +31,5 @@ class MaintenanceRule(Rule): | |||||||
|         for agent in agents: |         for agent in agents: | ||||||
|             if agent.pos in m_pos: |             if agent.pos in m_pos: | ||||||
|                 done_results.append(DoneResult(entity=agent, validity=c.VALID, identifier=self.name, |                 done_results.append(DoneResult(entity=agent, validity=c.VALID, identifier=self.name, | ||||||
|                                                reward=r.MAINTAINER_COLLISION_REWARD)) |                                                reward=marl_factory_grid.modules.maintenance.constants.MAINTAINER_COLLISION_REWARD)) | ||||||
|         return done_results |         return done_results | ||||||
|   | |||||||
| @@ -1,10 +1,10 @@ | |||||||
| import random | import random | ||||||
| from typing import List, Tuple | from typing import List, Tuple | ||||||
|  |  | ||||||
| from marl_factory_grid.environment.entity.object import _Object | from marl_factory_grid.environment.entity.object import Object | ||||||
|  |  | ||||||
|  |  | ||||||
| class Zone(_Object): | class Zone(Object): | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def positions(self): |     def positions(self): | ||||||
|   | |||||||
| @@ -1,8 +1,8 @@ | |||||||
| from marl_factory_grid.environment.groups.objects import _Objects | from marl_factory_grid.environment.groups.objects import Objects | ||||||
| from marl_factory_grid.modules.zones import Zone | from marl_factory_grid.modules.zones import Zone | ||||||
|  |  | ||||||
|  |  | ||||||
| class Zones(_Objects): | class Zones(Objects): | ||||||
|     symbol = None |     symbol = None | ||||||
|     _entity = Zone |     _entity = Zone | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,8 +1,8 @@ | |||||||
| from random import choices, choice | from random import choices, choice | ||||||
|  |  | ||||||
| from . import constants as z, Zone | from . import constants as z, Zone | ||||||
|  | from .. import Destination | ||||||
| from ..destinations import constants as d | from ..destinations import constants as d | ||||||
| from ... import Destination |  | ||||||
| from ...environment.rules import Rule | from ...environment.rules import Rule | ||||||
| from ...environment import constants as c | from ...environment import constants as c | ||||||
|  |  | ||||||
|   | |||||||
| @@ -0,0 +1,3 @@ | |||||||
|  | from . import helpers as h | ||||||
|  | from . import helpers | ||||||
|  | from .results import Result, DoneResult, ActionResult, TickResult | ||||||
|   | |||||||
| @@ -1,4 +1,5 @@ | |||||||
| import ast | import ast | ||||||
|  |  | ||||||
| from os import PathLike | from os import PathLike | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| from typing import Union, List | from typing import Union, List | ||||||
| @@ -9,18 +10,17 @@ from marl_factory_grid.environment import constants as c | |||||||
| from marl_factory_grid.environment.rules import Rule | from marl_factory_grid.environment.rules import Rule | ||||||
| from marl_factory_grid.environment.tests import Test | from marl_factory_grid.environment.tests import Test | ||||||
| from marl_factory_grid.utils.helpers import locate_and_import_class | from marl_factory_grid.utils.helpers import locate_and_import_class | ||||||
|  | from marl_factory_grid.environment.constants import DEFAULT_PATH, MODULE_PATH | ||||||
| DEFAULT_PATH = 'environment' | from marl_factory_grid.environment import constants as c | ||||||
| MODULE_PATH = 'modules' |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class FactoryConfigParser(object): | class FactoryConfigParser(object): | ||||||
|     default_entites = [] |     default_entites = [] | ||||||
|     default_rules = ['MaxStepsReached', 'Collision'] |     default_rules = ['DoneAtMaxStepsReached', 'WatchCollision'] | ||||||
|     default_actions = [c.MOVE8, c.NOOP] |     default_actions = [c.MOVE8, c.NOOP] | ||||||
|     default_observations = [c.WALLS, c.AGENT] |     default_observations = [c.WALLS, c.AGENT] | ||||||
|  |  | ||||||
|     def __init__(self, config_path, custom_modules_path: Union[None, PathLike] = None): |     def __init__(self, config_path, custom_modules_path: Union[PathLike] = None): | ||||||
|         self.config_path = Path(config_path) |         self.config_path = Path(config_path) | ||||||
|         self.custom_modules_path = Path(custom_modules_path) if custom_modules_path is not None else custom_modules_path |         self.custom_modules_path = Path(custom_modules_path) if custom_modules_path is not None else custom_modules_path | ||||||
|         self.config = yaml.safe_load(self.config_path.open()) |         self.config = yaml.safe_load(self.config_path.open()) | ||||||
| @@ -44,6 +44,10 @@ class FactoryConfigParser(object): | |||||||
|     def rules(self): |     def rules(self): | ||||||
|         return self.config['Rules'] |         return self.config['Rules'] | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def tests(self): | ||||||
|  |         return self.config.get('Tests', []) | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def agents(self): |     def agents(self): | ||||||
|         return self.config['Agents'] |         return self.config['Agents'] | ||||||
| @@ -56,10 +60,12 @@ class FactoryConfigParser(object): | |||||||
|         return str(self.config) |         return str(self.config) | ||||||
|  |  | ||||||
|     def __getitem__(self, item): |     def __getitem__(self, item): | ||||||
|  |         try: | ||||||
|             return self.config[item] |             return self.config[item] | ||||||
|  |         except KeyError: | ||||||
|  |             print(f'The mandatory {item} section could not be found in your .config gile. Check Spelling!') | ||||||
|  |  | ||||||
|     def load_entities(self): |     def load_entities(self): | ||||||
|         # entites = Entities() |  | ||||||
|         entity_classes = dict() |         entity_classes = dict() | ||||||
|         entities = [] |         entities = [] | ||||||
|         if c.DEFAULTS in self.entities: |         if c.DEFAULTS in self.entities: | ||||||
| @@ -67,28 +73,40 @@ class FactoryConfigParser(object): | |||||||
|         entities.extend(x for x in self.entities if x != c.DEFAULTS) |         entities.extend(x for x in self.entities if x != c.DEFAULTS) | ||||||
|  |  | ||||||
|         for entity in entities: |         for entity in entities: | ||||||
|  |             e1 = e2 = e3 = None | ||||||
|             try: |             try: | ||||||
|                 folder_path = Path(__file__).parent.parent / DEFAULT_PATH |                 folder_path = Path(__file__).parent.parent / DEFAULT_PATH | ||||||
|                 entity_class = locate_and_import_class(entity, folder_path) |                 entity_class = locate_and_import_class(entity, folder_path) | ||||||
|             except AttributeError as e1: |             except AttributeError as e: | ||||||
|  |                 e1 = e | ||||||
|                 try: |                 try: | ||||||
|                     folder_path = Path(__file__).parent.parent / MODULE_PATH |                     module_path = Path(__file__).parent.parent / MODULE_PATH | ||||||
|                     entity_class = locate_and_import_class(entity, folder_path) |                     entity_class = locate_and_import_class(entity, module_path) | ||||||
|                 except AttributeError as e2: |                 except AttributeError as e: | ||||||
|  |                     e2 = e | ||||||
|  |                     if self.custom_modules_path: | ||||||
|                         try: |                         try: | ||||||
|                         folder_path = self.custom_modules_path |                             entity_class = locate_and_import_class(entity, self.custom_modules_path) | ||||||
|                         entity_class = locate_and_import_class(entity, folder_path) |                         except AttributeError as e: | ||||||
|                     except AttributeError as e3: |                             e3 = e | ||||||
|                         ents = [y for x in [e1.argss[1], e2.argss[1], e3.argss[1]] for y in x] |                             pass | ||||||
|  |             if (e1 and e2) or e3: | ||||||
|  |                 ents = [y for x in [e1, e2, e3] if x is not None for y in x.args[1]] | ||||||
|  |                 print('##############################################################') | ||||||
|                 print('### Error  ###  Error  ###  Error  ###  Error  ###  Error  ###') |                 print('### Error  ###  Error  ###  Error  ###  Error  ###  Error  ###') | ||||||
|                         print() |                 print('##############################################################') | ||||||
|  |                 print(f'Class "{entity}" was not found in "{module_path.name}"') | ||||||
|                 print(f'Class "{entity}" was not found in "{folder_path.name}"') |                 print(f'Class "{entity}" was not found in "{folder_path.name}"') | ||||||
|  |                 print('##############################################################') | ||||||
|  |                 if self.custom_modules_path: | ||||||
|  |                     print(f'Class "{entity}" was not found in "{self.custom_modules_path}"') | ||||||
|                 print('Possible Entitys are:', str(ents)) |                 print('Possible Entitys are:', str(ents)) | ||||||
|                         print() |                 print('##############################################################') | ||||||
|                 print('Goodbye') |                 print('Goodbye') | ||||||
|                         print() |                 print('##############################################################') | ||||||
|                         exit() |                 print('### Error  ###  Error  ###  Error  ###  Error  ###  Error  ###') | ||||||
|                         # raise AttributeError(e1.argss[0], e2.argss[0], e3.argss[0], 'Possible Entitys are:', str(ents)) |                 print('##############################################################') | ||||||
|  |                 exit(-99999) | ||||||
|  |  | ||||||
|             entity_kwargs = self.entities.get(entity, {}) |             entity_kwargs = self.entities.get(entity, {}) | ||||||
|             entity_symbol = entity_class.symbol if hasattr(entity_class, 'symbol') else None |             entity_symbol = entity_class.symbol if hasattr(entity_class, 'symbol') else None | ||||||
| @@ -126,7 +144,12 @@ class FactoryConfigParser(object): | |||||||
|                 observations.extend(self.default_observations) |                 observations.extend(self.default_observations) | ||||||
|             observations.extend(x for x in self.agents[name]['Observations'] if x != c.DEFAULTS) |             observations.extend(x for x in self.agents[name]['Observations'] if x != c.DEFAULTS) | ||||||
|             positions = [ast.literal_eval(x) for x in self.agents[name].get('Positions', [])] |             positions = [ast.literal_eval(x) for x in self.agents[name].get('Positions', [])] | ||||||
|             parsed_agents_conf[name] = dict(actions=parsed_actions, observations=observations, positions=positions) |             other_kwargs = {k: v for k, v in self.agents[name].items() if k not in | ||||||
|  |                             ['Actions', 'Observations', 'Positions']} | ||||||
|  |             parsed_agents_conf[name] = dict( | ||||||
|  |                 actions=parsed_actions, observations=observations, positions=positions, other=other_kwargs | ||||||
|  |                                             ) | ||||||
|  |  | ||||||
|         return parsed_agents_conf |         return parsed_agents_conf | ||||||
|  |  | ||||||
|     def load_env_rules(self) -> List[Rule]: |     def load_env_rules(self) -> List[Rule]: | ||||||
| @@ -137,28 +160,69 @@ class FactoryConfigParser(object): | |||||||
|                     rules.append({rule: {}}) |                     rules.append({rule: {}}) | ||||||
|  |  | ||||||
|         return self._load_smth(rules, Rule) |         return self._load_smth(rules, Rule) | ||||||
|         pass |  | ||||||
|  |  | ||||||
|     def load_env_tests(self) -> List[Test]: |     def load_env_tests(self) -> List[Rule]: | ||||||
|         return self._load_smth(self.tests, None)  # Test |         return self._load_smth(self.tests, None)  # Test | ||||||
|         pass |  | ||||||
|  |  | ||||||
|     def _load_smth(self, config, class_obj): |     def _load_smth(self, config, class_obj): | ||||||
|         rules = list() |         rules = list() | ||||||
|         rules_names = list() |         for rule in config: | ||||||
|  |             e1 = e2 = e3 = None | ||||||
|         for rule in rules_names: |  | ||||||
|             try: |             try: | ||||||
|                 folder_path = (Path(__file__).parent.parent / DEFAULT_PATH) |                 folder_path = (Path(__file__).parent.parent / DEFAULT_PATH) | ||||||
|                 rule_class = locate_and_import_class(rule, folder_path) |                 rule_class = locate_and_import_class(rule, folder_path) | ||||||
|  |             except AttributeError as e: | ||||||
|  |                 e1 = e | ||||||
|  |                 try: | ||||||
|  |                     module_path = (Path(__file__).parent.parent / MODULE_PATH) | ||||||
|  |                     rule_class = locate_and_import_class(rule, module_path) | ||||||
|  |                 except AttributeError as e: | ||||||
|  |                     e2 = e | ||||||
|  |                     if self.custom_modules_path: | ||||||
|  |                         try: | ||||||
|  |                             rule_class = locate_and_import_class(rule, self.custom_modules_path) | ||||||
|  |                         except AttributeError as e: | ||||||
|  |                             e3 = e | ||||||
|  |                             pass | ||||||
|  |             if (e1 and e2) or e3: | ||||||
|  |                 ents = [y for x in [e1, e2, e3] if x is not None for y in x.args[1]] | ||||||
|  |                 print('### Error  ###  Error  ###  Error  ###  Error  ###  Error  ###') | ||||||
|  |                 print('') | ||||||
|  |                 print(f'Class "{rule}" was not found in "{module_path.name}"') | ||||||
|  |                 print(f'Class "{rule}" was not found in "{folder_path.name}"') | ||||||
|  |                 if self.custom_modules_path: | ||||||
|  |                     print(f'Class "{rule}" was not found in "{self.custom_modules_path}"') | ||||||
|  |                 print('Possible Entitys are:', str(ents)) | ||||||
|  |                 print('') | ||||||
|  |                 print('Goodbye') | ||||||
|  |                 print('') | ||||||
|  |                 exit(-99999) | ||||||
|  |  | ||||||
|  |             if issubclass(rule_class, class_obj): | ||||||
|  |                 rule_kwargs = config.get(rule, {}) | ||||||
|  |                 rules.append(rule_class(**(rule_kwargs or {}))) | ||||||
|  |         return rules | ||||||
|  |  | ||||||
|  |     def load_entity_spawn_rules(self, entities) -> List[Rule]: | ||||||
|  |         rules = list() | ||||||
|  |         rules_dicts = list() | ||||||
|  |         for e in entities: | ||||||
|  |             try: | ||||||
|  |                 if spawn_rule := e.spawn_rule: | ||||||
|  |                     rules_dicts.append(spawn_rule) | ||||||
|  |             except AttributeError: | ||||||
|  |                 pass | ||||||
|  |  | ||||||
|  |         for rule_dict in rules_dicts: | ||||||
|  |             for rule_name, rule_kwargs in rule_dict.items(): | ||||||
|  |                 try: | ||||||
|  |                     folder_path = (Path(__file__).parent.parent / DEFAULT_PATH) | ||||||
|  |                     rule_class = locate_and_import_class(rule_name, folder_path) | ||||||
|                 except AttributeError: |                 except AttributeError: | ||||||
|                     try: |                     try: | ||||||
|                         folder_path = (Path(__file__).parent.parent / MODULE_PATH) |                         folder_path = (Path(__file__).parent.parent / MODULE_PATH) | ||||||
|                     rule_class = locate_and_import_class(rule, folder_path) |                         rule_class = locate_and_import_class(rule_name, folder_path) | ||||||
|                     except AttributeError: |                     except AttributeError: | ||||||
|                     rule_class = locate_and_import_class(rule, self.custom_modules_path) |                         rule_class = locate_and_import_class(rule_name, self.custom_modules_path) | ||||||
|             # Fixme This check does not work! |  | ||||||
|             #  assert isinstance(rule_class, class_obj), f'{rule_class.__name__} is no valid "class_obj.__name__".' |  | ||||||
|             rule_kwargs = config.get(rule, {}) |  | ||||||
|                 rules.append(rule_class(**rule_kwargs)) |                 rules.append(rule_class(**rule_kwargs)) | ||||||
|         return rules |         return rules | ||||||
|   | |||||||
| @@ -2,7 +2,7 @@ import importlib | |||||||
|  |  | ||||||
| from collections import defaultdict | from collections import defaultdict | ||||||
| from pathlib import PurePath, Path | from pathlib import PurePath, Path | ||||||
| from typing import Union, Dict, List | from typing import Union, Dict, List, Iterable, Callable | ||||||
|  |  | ||||||
| import numpy as np | import numpy as np | ||||||
| from numpy.typing import ArrayLike | from numpy.typing import ArrayLike | ||||||
| @@ -61,8 +61,8 @@ class ObservationTranslator: | |||||||
|         :param per_agent_named_obs_spaces: `Named observation space` one for each agent. Overloaded. |         :param per_agent_named_obs_spaces: `Named observation space` one for each agent. Overloaded. | ||||||
|         type  per_agent_named_obs_spaces: Dict[str, dict] |         type  per_agent_named_obs_spaces: Dict[str, dict] | ||||||
|  |  | ||||||
|         :param placeholder_fill_value: Currently not fully implemented!!! |         :param placeholder_fill_value: Currently, not fully implemented!!! | ||||||
|         :type  placeholder_fill_value: Union[int, str] = 'N') |         :type  placeholder_fill_value: Union[int, str] = 'N' | ||||||
|         """ |         """ | ||||||
|  |  | ||||||
|         if isinstance(placeholder_fill_value, str): |         if isinstance(placeholder_fill_value, str): | ||||||
| @@ -222,7 +222,7 @@ def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''): | |||||||
|         mod = importlib.import_module('.'.join(module_parts)) |         mod = importlib.import_module('.'.join(module_parts)) | ||||||
|         all_found_modules.extend([x for x in dir(mod) if (not(x.startswith('__') or len(x) <= 2) and x.istitle()) |         all_found_modules.extend([x for x in dir(mod) if (not(x.startswith('__') or len(x) <= 2) and x.istitle()) | ||||||
|                                   and x not in ['Entity',  'NamedTuple', 'List', 'Rule', 'Union', |                                   and x not in ['Entity',  'NamedTuple', 'List', 'Rule', 'Union', | ||||||
|                                                 'TickResult', 'ActionResult', 'Action', 'Agent', 'BoundEntityMixin', |                                                 'TickResult', 'ActionResult', 'Action', 'Agent', | ||||||
|                                                 'RenderEntity', 'TemplateRule', 'Objects', 'PositionMixin', |                                                 'RenderEntity', 'TemplateRule', 'Objects', 'PositionMixin', | ||||||
|                                                 'IsBoundMixin', 'EnvObject', 'EnvObjects', 'Dict', 'Any' |                                                 'IsBoundMixin', 'EnvObject', 'EnvObjects', 'Dict', 'Any' | ||||||
|                                                 ]]) |                                                 ]]) | ||||||
| @@ -240,7 +240,13 @@ def add_bound_name(name_str, bound_e): | |||||||
|  |  | ||||||
| def add_pos_name(name_str, bound_e): | def add_pos_name(name_str, bound_e): | ||||||
|     if bound_e.var_has_position: |     if bound_e.var_has_position: | ||||||
|         return f'{name_str}({bound_e.pos})' |         return f'{name_str}@{bound_e.pos}' | ||||||
|     return name_str |     return name_str | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def get_first(iterable: Iterable, filter_by: Callable[[any], bool] = lambda _: True): | ||||||
|  |     return next((x for x in iterable if filter_by(x)), None) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def get_first_index(iterable: Iterable, filter_by: Callable[[any], bool] = lambda _: True): | ||||||
|  |     return next((idx for idx, x in enumerate(iterable) if filter_by(x)), None) | ||||||
|   | |||||||
| @@ -47,6 +47,7 @@ class LevelParser(object): | |||||||
|         # All other |         # All other | ||||||
|         for es_name in self.e_p_dict: |         for es_name in self.e_p_dict: | ||||||
|             e_class, e_kwargs = self.e_p_dict[es_name]['class'], self.e_p_dict[es_name]['kwargs'] |             e_class, e_kwargs = self.e_p_dict[es_name]['class'], self.e_p_dict[es_name]['kwargs'] | ||||||
|  |             e_kwargs = e_kwargs if e_kwargs else {} | ||||||
|  |  | ||||||
|             if hasattr(e_class, 'symbol') and e_class.symbol is not None: |             if hasattr(e_class, 'symbol') and e_class.symbol is not None: | ||||||
|                 symbols = e_class.symbol |                 symbols = e_class.symbol | ||||||
|   | |||||||
| @@ -9,7 +9,7 @@ from marl_factory_grid.utils.helpers import IGNORED_DF_COLUMNS | |||||||
|  |  | ||||||
| import pandas as pd | import pandas as pd | ||||||
|  |  | ||||||
| from marl_factory_grid.utils.plotting.compare_runs import plot_single_run | from marl_factory_grid.utils.plotting.plot_single_runs import plot_single_run | ||||||
|  |  | ||||||
|  |  | ||||||
| class EnvMonitor(Wrapper): | class EnvMonitor(Wrapper): | ||||||
| @@ -22,7 +22,6 @@ class EnvMonitor(Wrapper): | |||||||
|         self._monitor_df = pd.DataFrame() |         self._monitor_df = pd.DataFrame() | ||||||
|         self._monitor_dict = dict() |         self._monitor_dict = dict() | ||||||
|  |  | ||||||
|  |  | ||||||
|     def step(self, action): |     def step(self, action): | ||||||
|         obs_type, obs, reward, done, info = self.env.step(action) |         obs_type, obs, reward, done, info = self.env.step(action) | ||||||
|         self._read_info(info) |         self._read_info(info) | ||||||
|   | |||||||
| @@ -2,11 +2,9 @@ from os import PathLike | |||||||
| from pathlib import Path | from pathlib import Path | ||||||
| from typing import Union, List | from typing import Union, List | ||||||
|  |  | ||||||
| import yaml |  | ||||||
| from gymnasium import Wrapper |  | ||||||
|  |  | ||||||
| import numpy as np | import numpy as np | ||||||
| import pandas as pd | import pandas as pd | ||||||
|  | from gymnasium import Wrapper | ||||||
|  |  | ||||||
|  |  | ||||||
| class EnvRecorder(Wrapper): | class EnvRecorder(Wrapper): | ||||||
| @@ -106,7 +104,7 @@ class EnvRecorder(Wrapper): | |||||||
|                 out_dict = {'episodes': self._recorder_out_list} |                 out_dict = {'episodes': self._recorder_out_list} | ||||||
|             out_dict.update( |             out_dict.update( | ||||||
|                 {'n_episodes': self._curr_episode, |                 {'n_episodes': self._curr_episode, | ||||||
|                  'metadata':dict( |                  'metadata': dict( | ||||||
|                      level_name=self.env.params['General']['level_name'], |                      level_name=self.env.params['General']['level_name'], | ||||||
|                      verbose=False, |                      verbose=False, | ||||||
|                      n_agents=len(self.env.params['Agents']), |                      n_agents=len(self.env.params['Agents']), | ||||||
|   | |||||||
| @@ -1,17 +1,16 @@ | |||||||
| import math |  | ||||||
| import re | import re | ||||||
| from collections import defaultdict | from collections import defaultdict | ||||||
| from itertools import product |  | ||||||
| from typing import Dict, List | from typing import Dict, List | ||||||
|  |  | ||||||
| import numpy as np | import numpy as np | ||||||
| from numba import njit |  | ||||||
|  |  | ||||||
| from marl_factory_grid.environment import constants as c | from marl_factory_grid.environment import constants as c | ||||||
|  | from marl_factory_grid.environment.entity.object import Object | ||||||
| from marl_factory_grid.environment.groups.utils import Combined | from marl_factory_grid.environment.groups.utils import Combined | ||||||
| import marl_factory_grid.utils.helpers as h |  | ||||||
| from marl_factory_grid.utils.states import Gamestate |  | ||||||
| from marl_factory_grid.utils.utility_classes import Floor | from marl_factory_grid.utils.utility_classes import Floor | ||||||
|  | from marl_factory_grid.utils.ray_caster import RayCaster | ||||||
|  | from marl_factory_grid.utils.states import Gamestate | ||||||
|  | from marl_factory_grid.utils import helpers as h | ||||||
|  |  | ||||||
|  |  | ||||||
| class OBSBuilder(object): | class OBSBuilder(object): | ||||||
| @@ -77,11 +76,13 @@ class OBSBuilder(object): | |||||||
|  |  | ||||||
|     def place_entity_in_observation(self, obs_array, agent, e): |     def place_entity_in_observation(self, obs_array, agent, e): | ||||||
|         x, y = (e.x - agent.x) + self.pomdp_r, (e.y - agent.y) + self.pomdp_r |         x, y = (e.x - agent.x) + self.pomdp_r, (e.y - agent.y) + self.pomdp_r | ||||||
|  |         if not min([y, x]) < 0: | ||||||
|             try: |             try: | ||||||
|                 obs_array[x, y] += e.encoding |                 obs_array[x, y] += e.encoding | ||||||
|             except IndexError: |             except IndexError: | ||||||
|                 # Seemded to be visible but is out of range |                 # Seemded to be visible but is out of range | ||||||
|                 pass |                 pass | ||||||
|  |         pass | ||||||
|  |  | ||||||
|     def build_for_agent(self, agent, state) -> (List[str], np.ndarray): |     def build_for_agent(self, agent, state) -> (List[str], np.ndarray): | ||||||
|         assert self._curr_env_step == state.curr_step, ( |         assert self._curr_env_step == state.curr_step, ( | ||||||
| @@ -121,18 +122,24 @@ class OBSBuilder(object): | |||||||
|                         e = self.all_obs[l_name] |                         e = self.all_obs[l_name] | ||||||
|                     except KeyError: |                     except KeyError: | ||||||
|                         try: |                         try: | ||||||
|                             # Look for bound entity names! |                             # Look for bound entity REPRs! | ||||||
|                             pattern = re.compile(f'{re.escape(l_name)}(.*){re.escape(agent.name)}') |                             pattern = re.compile(f'{re.escape(l_name)}' | ||||||
|                             name = next((x for x in self.all_obs if pattern.search(x)), None) |                                                  f'{re.escape("[")}(.*){re.escape("]")}' | ||||||
|  |                                                  f'{re.escape("(")}{re.escape(agent.name)}{re.escape(")")}') | ||||||
|  |                             name = next((key for key, val in self.all_obs.items() | ||||||
|  |                                          if pattern.search(str(val)) and isinstance(val, Object)), None) | ||||||
|                             e = self.all_obs[name] |                             e = self.all_obs[name] | ||||||
|                         except KeyError: |                         except KeyError: | ||||||
|                             try: |                             try: | ||||||
|                                 e = next(v for k, v in self.all_obs.items() if l_name in k and agent.name in k) |                                 e = next(v for k, v in self.all_obs.items() if l_name in k and agent.name in k) | ||||||
|                             except StopIteration: |                             except StopIteration: | ||||||
|                                 raise KeyError( |                                 print(f'# Check for spelling errors!') | ||||||
|                                     f'Check for spelling errors! \n ' |                                 print(f'# No combination of "{l_name}" and "{agent.name}" could not be found in:') | ||||||
|                                     f'No combination of "{l_name} and {agent.name}" could not be found in:\n ' |                                 print(f'# {list(dict(self.all_obs).keys())}') | ||||||
|                                     f'{list(dict(self.all_obs).keys())}') |                                 print('#') | ||||||
|  |                                 print('# exiting...') | ||||||
|  |                                 print('#') | ||||||
|  |                                 exit(-99999) | ||||||
|  |  | ||||||
|                     try: |                     try: | ||||||
|                         positional = e.var_has_position |                         positional = e.var_has_position | ||||||
| @@ -161,31 +168,30 @@ class OBSBuilder(object): | |||||||
|             try: |             try: | ||||||
|                 light_map = np.zeros(self.obs_shape) |                 light_map = np.zeros(self.obs_shape) | ||||||
|                 visible_floor = self.ray_caster[agent.name].visible_entities(self._floortiles, reset_cache=False) |                 visible_floor = self.ray_caster[agent.name].visible_entities(self._floortiles, reset_cache=False) | ||||||
|                 if self.pomdp_r: |  | ||||||
|                 for f in set(visible_floor): |                 for f in set(visible_floor): | ||||||
|                     self.place_entity_in_observation(light_map, agent, f) |                     self.place_entity_in_observation(light_map, agent, f) | ||||||
|                 else: |                 # else: | ||||||
|                     for f in set(visible_floor): |                 #     for f in set(visible_floor): | ||||||
|                         light_map[f.x, f.y] += f.encoding |                 #         light_map[f.x, f.y] += f.encoding | ||||||
|                 self.curr_lightmaps[agent.name] = light_map |                 self.curr_lightmaps[agent.name] = light_map | ||||||
|             except (KeyError, ValueError): |             except (KeyError, ValueError): | ||||||
|                 print() |  | ||||||
|                 pass |                 pass | ||||||
|         return obs, self.obs_layers[agent.name] |         return obs, self.obs_layers[agent.name] | ||||||
|  |  | ||||||
|     def _sort_and_name_observation_conf(self, agent): |     def _sort_and_name_observation_conf(self, agent): | ||||||
|         ''' |         """ | ||||||
|         Builds the useable observation scheme per agent from conf.yaml. |         Builds the useable observation scheme per agent from conf.yaml. | ||||||
|         :param agent: |         :param agent: | ||||||
|         :return: |         :return: | ||||||
|         ''' |         """ | ||||||
|         # Fixme: no asymetric shapes possible. |         # Fixme: no asymetric shapes possible. | ||||||
|         self.ray_caster[agent.name] = RayCaster(agent, min(self.obs_shape)) |         self.ray_caster[agent.name] = RayCaster(agent, min(self.obs_shape)) | ||||||
|         obs_layers = [] |         obs_layers = [] | ||||||
|  |  | ||||||
|         for obs_str in agent.observations: |         for obs_str in agent.observations: | ||||||
|             if isinstance(obs_str, dict): |             if isinstance(obs_str, dict): | ||||||
|                 obs_str, vals = next(obs_str.items().__iter__()) |                 obs_str, vals = h.get_first(obs_str.items()) | ||||||
|             else: |             else: | ||||||
|                 vals = None |                 vals = None | ||||||
|             if obs_str == c.SELF: |             if obs_str == c.SELF: | ||||||
| @@ -214,129 +220,3 @@ class OBSBuilder(object): | |||||||
|                 obs_layers.append(obs_str) |                 obs_layers.append(obs_str) | ||||||
|         self.obs_layers[agent.name] = obs_layers |         self.obs_layers[agent.name] = obs_layers | ||||||
|         self.curr_lightmaps[agent.name] = np.zeros(self.obs_shape) |         self.curr_lightmaps[agent.name] = np.zeros(self.obs_shape) | ||||||
|  |  | ||||||
|  |  | ||||||
| class RayCaster: |  | ||||||
|     def __init__(self, agent, pomdp_r, degs=360): |  | ||||||
|         self.agent = agent |  | ||||||
|         self.pomdp_r = pomdp_r |  | ||||||
|         self.n_rays = (self.pomdp_r + 1) * 8 |  | ||||||
|         self.degs = degs |  | ||||||
|         self.ray_targets = self.build_ray_targets() |  | ||||||
|         self.obs_shape_cube = np.array([self.pomdp_r, self.pomdp_r]) |  | ||||||
|         self._cache_dict = {} |  | ||||||
|  |  | ||||||
|     def __repr__(self): |  | ||||||
|         return f'{self.__class__.__name__}({self.agent.name})' |  | ||||||
|  |  | ||||||
|     def build_ray_targets(self): |  | ||||||
|         north = np.array([0, -1]) * self.pomdp_r |  | ||||||
|         thetas = [np.deg2rad(deg) for deg in np.linspace(-self.degs // 2, self.degs // 2, self.n_rays)[::-1]] |  | ||||||
|         rot_M = [ |  | ||||||
|             [[math.cos(theta), -math.sin(theta)], |  | ||||||
|              [math.sin(theta), math.cos(theta)]] for theta in thetas |  | ||||||
|         ] |  | ||||||
|         rot_M = np.stack(rot_M, 0) |  | ||||||
|         rot_M = np.unique(np.round(rot_M @ north), axis=0) |  | ||||||
|         return rot_M.astype(int) |  | ||||||
|  |  | ||||||
|     def ray_block_cache(self, key, callback): |  | ||||||
|         if key not in self._cache_dict: |  | ||||||
|             self._cache_dict[key] = callback() |  | ||||||
|         return self._cache_dict[key] |  | ||||||
|  |  | ||||||
|     def visible_entities(self, pos_dict, reset_cache=True): |  | ||||||
|         visible = list() |  | ||||||
|         if reset_cache: |  | ||||||
|             self._cache_dict = {} |  | ||||||
|  |  | ||||||
|         for ray in self.get_rays(): |  | ||||||
|             rx, ry = ray[0] |  | ||||||
|             for x, y in ray: |  | ||||||
|                 cx, cy = x - rx, y - ry |  | ||||||
|  |  | ||||||
|                 entities_hit = pos_dict[(x, y)] |  | ||||||
|                 hits = self.ray_block_cache((x, y), |  | ||||||
|                                             lambda: any(True for e in entities_hit if e.var_is_blocking_light) |  | ||||||
|                                             ) |  | ||||||
|  |  | ||||||
|                 diag_hits = all([ |  | ||||||
|                     self.ray_block_cache( |  | ||||||
|                         key, |  | ||||||
|                         lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light) and bool( |  | ||||||
|                             pos_dict[key])) |  | ||||||
|                     for key in ((x, y - cy), (x - cx, y)) |  | ||||||
|                 ]) if (cx != 0 and cy != 0) else False |  | ||||||
|  |  | ||||||
|                 visible += entities_hit if not diag_hits else [] |  | ||||||
|                 if hits or diag_hits: |  | ||||||
|                     break |  | ||||||
|                 rx, ry = x, y |  | ||||||
|         return visible |  | ||||||
|  |  | ||||||
|     def get_rays(self): |  | ||||||
|         a_pos = self.agent.pos |  | ||||||
|         outline = self.ray_targets + a_pos |  | ||||||
|         return self.bresenham_loop(a_pos, outline) |  | ||||||
|  |  | ||||||
|     # todo do this once and cache the points! |  | ||||||
|     def get_fov_outline(self) -> np.ndarray: |  | ||||||
|         return self.ray_targets + self.agent.pos |  | ||||||
|  |  | ||||||
|     def get_square_outline(self): |  | ||||||
|         agent = self.agent |  | ||||||
|         x_coords = range(agent.x - self.pomdp_r, agent.x + self.pomdp_r + 1) |  | ||||||
|         y_coords = range(agent.y - self.pomdp_r, agent.y + self.pomdp_r + 1) |  | ||||||
|         outline = list(product(x_coords, [agent.y - self.pomdp_r, agent.y + self.pomdp_r])) \ |  | ||||||
|                   + list(product([agent.x - self.pomdp_r, agent.x + self.pomdp_r], y_coords)) |  | ||||||
|         return outline |  | ||||||
|  |  | ||||||
|     @staticmethod |  | ||||||
|     @njit |  | ||||||
|     def bresenham_loop(a_pos, points): |  | ||||||
|         results = [] |  | ||||||
|         for end in points: |  | ||||||
|             x1, y1 = a_pos |  | ||||||
|             x2, y2 = end |  | ||||||
|             dx = x2 - x1 |  | ||||||
|             dy = y2 - y1 |  | ||||||
|  |  | ||||||
|             # Determine how steep the line is |  | ||||||
|             is_steep = abs(dy) > abs(dx) |  | ||||||
|  |  | ||||||
|             # Rotate line |  | ||||||
|             if is_steep: |  | ||||||
|                 x1, y1 = y1, x1 |  | ||||||
|                 x2, y2 = y2, x2 |  | ||||||
|  |  | ||||||
|             # Swap start and end points if necessary and store swap state |  | ||||||
|             swapped = False |  | ||||||
|             if x1 > x2: |  | ||||||
|                 x1, x2 = x2, x1 |  | ||||||
|                 y1, y2 = y2, y1 |  | ||||||
|                 swapped = True |  | ||||||
|  |  | ||||||
|             # Recalculate differentials |  | ||||||
|             dx = x2 - x1 |  | ||||||
|             dy = y2 - y1 |  | ||||||
|  |  | ||||||
|             # Calculate error |  | ||||||
|             error = int(dx / 2.0) |  | ||||||
|             ystep = 1 if y1 < y2 else -1 |  | ||||||
|  |  | ||||||
|             # Iterate over bounding box generating points between start and end |  | ||||||
|             y = y1 |  | ||||||
|             points = [] |  | ||||||
|             for x in range(int(x1), int(x2) + 1): |  | ||||||
|                 coord = [y, x] if is_steep else [x, y] |  | ||||||
|                 points.append(coord) |  | ||||||
|                 error -= abs(dy) |  | ||||||
|                 if error < 0: |  | ||||||
|                     y += ystep |  | ||||||
|                     error += dx |  | ||||||
|  |  | ||||||
|             # Reverse the list if the coordinates were swapped |  | ||||||
|             if swapped: |  | ||||||
|                 points.reverse() |  | ||||||
|             results.append(points) |  | ||||||
|         return results |  | ||||||
|   | |||||||
| @@ -7,50 +7,11 @@ from typing import Union, List | |||||||
| import pandas as pd | import pandas as pd | ||||||
| 
 | 
 | ||||||
| from marl_factory_grid.utils.helpers import IGNORED_DF_COLUMNS | from marl_factory_grid.utils.helpers import IGNORED_DF_COLUMNS | ||||||
| from marl_factory_grid.utils.plotting.plotting import prepare_plot | from marl_factory_grid.utils.plotting.plotting_utils import prepare_plot | ||||||
| 
 | 
 | ||||||
| MODEL_MAP = None | MODEL_MAP = None | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def plot_single_run(run_path: Union[str, PathLike], use_tex: bool = False, column_keys=None): |  | ||||||
|     run_path = Path(run_path) |  | ||||||
|     df_list = list() |  | ||||||
|     if run_path.is_dir(): |  | ||||||
|         monitor_file = next(run_path.glob('*monitor*.pick')) |  | ||||||
|     elif run_path.exists() and run_path.is_file(): |  | ||||||
|         monitor_file = run_path |  | ||||||
|     else: |  | ||||||
|         raise ValueError |  | ||||||
| 
 |  | ||||||
|     with monitor_file.open('rb') as f: |  | ||||||
|         monitor_df = pickle.load(f) |  | ||||||
| 
 |  | ||||||
|         monitor_df = monitor_df.fillna(0) |  | ||||||
|         df_list.append(monitor_df) |  | ||||||
| 
 |  | ||||||
|     df = pd.concat(df_list,  ignore_index=True) |  | ||||||
|     df = df.fillna(0).rename(columns={'episode': 'Episode'}).sort_values(['Episode']) |  | ||||||
|     if column_keys is not None: |  | ||||||
|         columns = [col for col in column_keys if col in df.columns] |  | ||||||
|     else: |  | ||||||
|         columns = [col for col in df.columns if col not in IGNORED_DF_COLUMNS] |  | ||||||
| 
 |  | ||||||
|     roll_n = 50 |  | ||||||
| 
 |  | ||||||
|     non_overlapp_window = df.groupby(['Episode']).rolling(roll_n, min_periods=1).mean() |  | ||||||
| 
 |  | ||||||
|     df_melted = df[columns + ['Episode']].reset_index().melt( |  | ||||||
|         id_vars=['Episode'], value_vars=columns, var_name="Measurement", value_name="Score" |  | ||||||
|     ) |  | ||||||
| 
 |  | ||||||
|     if df_melted['Episode'].max() > 800: |  | ||||||
|         skip_n = round(df_melted['Episode'].max() * 0.02) |  | ||||||
|         df_melted = df_melted[df_melted['Episode'] % skip_n == 0] |  | ||||||
| 
 |  | ||||||
|     prepare_plot(run_path.parent / f'{run_path.parent.name}_monitor_lineplot.png', df_melted, use_tex=use_tex) |  | ||||||
|     print('Plotting done.') |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| def compare_seed_runs(run_path: Union[str, PathLike], use_tex: bool = False): | def compare_seed_runs(run_path: Union[str, PathLike], use_tex: bool = False): | ||||||
|     run_path = Path(run_path) |     run_path = Path(run_path) | ||||||
|     df_list = list() |     df_list = list() | ||||||
							
								
								
									
										48
									
								
								marl_factory_grid/utils/plotting/plot_single_runs.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										48
									
								
								marl_factory_grid/utils/plotting/plot_single_runs.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,48 @@ | |||||||
|  | import pickle | ||||||
|  | from os import PathLike | ||||||
|  | from pathlib import Path | ||||||
|  | from typing import Union | ||||||
|  |  | ||||||
|  | import pandas as pd | ||||||
|  |  | ||||||
|  | from marl_factory_grid.utils.helpers import IGNORED_DF_COLUMNS | ||||||
|  | from marl_factory_grid.utils.plotting.plotting_utils import prepare_plot | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def plot_single_run(run_path: Union[str, PathLike], use_tex: bool = False, column_keys=None, | ||||||
|  |                     file_key: str ='monitor', file_ext: str ='pkl'): | ||||||
|  |     run_path = Path(run_path) | ||||||
|  |     df_list = list() | ||||||
|  |     if run_path.is_dir(): | ||||||
|  |         monitor_file = next(run_path.glob(f'*{file_key}*.{file_ext}')) | ||||||
|  |     elif run_path.exists() and run_path.is_file(): | ||||||
|  |         monitor_file = run_path | ||||||
|  |     else: | ||||||
|  |         raise ValueError | ||||||
|  |  | ||||||
|  |     with monitor_file.open('rb') as f: | ||||||
|  |         monitor_df = pickle.load(f) | ||||||
|  |  | ||||||
|  |         monitor_df = monitor_df.fillna(0) | ||||||
|  |         df_list.append(monitor_df) | ||||||
|  |  | ||||||
|  |     df = pd.concat(df_list,  ignore_index=True) | ||||||
|  |     df = df.fillna(0).rename(columns={'episode': 'Episode'}).sort_values(['Episode']) | ||||||
|  |     if column_keys is not None: | ||||||
|  |         columns = [col for col in column_keys if col in df.columns] | ||||||
|  |     else: | ||||||
|  |         columns = [col for col in df.columns if col not in IGNORED_DF_COLUMNS] | ||||||
|  |  | ||||||
|  |     # roll_n = 50 | ||||||
|  |     # non_overlapp_window = df.groupby(['Episode']).rolling(roll_n, min_periods=1).mean() | ||||||
|  |  | ||||||
|  |     df_melted = df[columns + ['Episode']].reset_index().melt( | ||||||
|  |         id_vars=['Episode'], value_vars=columns, var_name="Measurement", value_name="Score" | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     if df_melted['Episode'].max() > 800: | ||||||
|  |         skip_n = round(df_melted['Episode'].max() * 0.02) | ||||||
|  |         df_melted = df_melted[df_melted['Episode'] % skip_n == 0] | ||||||
|  |  | ||||||
|  |     prepare_plot(run_path.parent / f'{run_path.parent.name}_monitor_lineplot.png', df_melted, use_tex=use_tex) | ||||||
|  |     print('Plotting done.') | ||||||
| @@ -60,7 +60,7 @@ def prepare_center_double_column_legend(df, hue, style, hue_order): | |||||||
|     print('Struggling to plot Figure using LaTeX - going back to normal.') |     print('Struggling to plot Figure using LaTeX - going back to normal.') | ||||||
|     plt.close('all') |     plt.close('all') | ||||||
|     sns.set(rc={'text.usetex': False}, style='whitegrid') |     sns.set(rc={'text.usetex': False}, style='whitegrid') | ||||||
|     fig = plt.figure(figsize=(10, 11)) |     _ = plt.figure(figsize=(10, 11)) | ||||||
|     lineplot = sns.lineplot(data=df, x='Episode', y='Score', hue=hue, style=style, |     lineplot = sns.lineplot(data=df, x='Episode', y='Score', hue=hue, style=style, | ||||||
|                             ci=95, palette=PALETTE, hue_order=hue_order, legend=False) |                             ci=95, palette=PALETTE, hue_order=hue_order, legend=False) | ||||||
|     # plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0) |     # plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0) | ||||||
| @@ -19,7 +19,7 @@ class RayCaster: | |||||||
|         return f'{self.__class__.__name__}({self.agent.name})' |         return f'{self.__class__.__name__}({self.agent.name})' | ||||||
|  |  | ||||||
|     def build_ray_targets(self): |     def build_ray_targets(self): | ||||||
|         north = np.array([0, -1])*self.pomdp_r |         north = np.array([0, -1]) * self.pomdp_r | ||||||
|         thetas = [np.deg2rad(deg) for deg in np.linspace(-self.degs // 2, self.degs // 2, self.n_rays)[::-1]] |         thetas = [np.deg2rad(deg) for deg in np.linspace(-self.degs // 2, self.degs // 2, self.n_rays)[::-1]] | ||||||
|         rot_M = [ |         rot_M = [ | ||||||
|             [[math.cos(theta), -math.sin(theta)], |             [[math.cos(theta), -math.sin(theta)], | ||||||
| @@ -39,8 +39,9 @@ class RayCaster: | |||||||
|         if reset_cache: |         if reset_cache: | ||||||
|             self._cache_dict = dict() |             self._cache_dict = dict() | ||||||
|  |  | ||||||
|         for ray in self.get_rays(): |         for ray in self.get_rays():  # Do not check, just trust. | ||||||
|             rx, ry = ray[0] |             rx, ry = ray[0] | ||||||
|  |             # self.ray_block_cache(ray[0], lambda: False) We do not do that, because of doors etc... | ||||||
|             for x, y in ray: |             for x, y in ray: | ||||||
|                 cx, cy = x - rx, y - ry |                 cx, cy = x - rx, y - ry | ||||||
|  |  | ||||||
| @@ -52,8 +53,9 @@ class RayCaster: | |||||||
|                 diag_hits = all([ |                 diag_hits = all([ | ||||||
|                     self.ray_block_cache( |                     self.ray_block_cache( | ||||||
|                         key, |                         key, | ||||||
|                         lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light)) |                         # lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light) | ||||||
|                     for key in ((x, y-cy), (x-cx, y)) |                         lambda: any(True for e in pos_dict[key] if e.var_is_blocking_light)) | ||||||
|  |                     for key in ((x, y - cy), (x - cx, y)) | ||||||
|                 ]) if (cx != 0 and cy != 0) else False |                 ]) if (cx != 0 and cy != 0) else False | ||||||
|  |  | ||||||
|                 visible += entities_hit if not diag_hits else [] |                 visible += entities_hit if not diag_hits else [] | ||||||
| @@ -75,8 +77,8 @@ class RayCaster: | |||||||
|         agent = self.agent |         agent = self.agent | ||||||
|         x_coords = range(agent.x - self.pomdp_r, agent.x + self.pomdp_r + 1) |         x_coords = range(agent.x - self.pomdp_r, agent.x + self.pomdp_r + 1) | ||||||
|         y_coords = range(agent.y - self.pomdp_r, agent.y + self.pomdp_r + 1) |         y_coords = range(agent.y - self.pomdp_r, agent.y + self.pomdp_r + 1) | ||||||
|         outline = list(product(x_coords, [agent.y - self.pomdp_r, agent.y + self.pomdp_r])) \ |         outline = list(product(x_coords, [agent.y - self.pomdp_r, agent.y + self.pomdp_r])) | ||||||
|                   + list(product([agent.x - self.pomdp_r, agent.x + self.pomdp_r], y_coords)) |         outline += list(product([agent.x - self.pomdp_r, agent.x + self.pomdp_r], y_coords)) | ||||||
|         return outline |         return outline | ||||||
|  |  | ||||||
|     @staticmethod |     @staticmethod | ||||||
|   | |||||||
| @@ -31,7 +31,7 @@ class Renderer: | |||||||
|  |  | ||||||
|     def __init__(self, lvl_shape: Tuple[int, int] = (16, 16), |     def __init__(self, lvl_shape: Tuple[int, int] = (16, 16), | ||||||
|                  lvl_padded_shape: Union[Tuple[int, int], None] = None, |                  lvl_padded_shape: Union[Tuple[int, int], None] = None, | ||||||
|                  cell_size: int = 40, fps: int = 7, |                  cell_size: int = 40, fps: int = 7, factor: float = 0.9, | ||||||
|                  grid_lines: bool = True, view_radius: int = 2): |                  grid_lines: bool = True, view_radius: int = 2): | ||||||
|         # TODO: Customn_assets paths |         # TODO: Customn_assets paths | ||||||
|         self.grid_h, self.grid_w = lvl_shape |         self.grid_h, self.grid_w = lvl_shape | ||||||
| @@ -45,7 +45,7 @@ class Renderer: | |||||||
|         self.screen = pygame.display.set_mode(self.screen_size) |         self.screen = pygame.display.set_mode(self.screen_size) | ||||||
|         self.clock = pygame.time.Clock() |         self.clock = pygame.time.Clock() | ||||||
|         assets = list(self.ASSETS.rglob('*.png')) |         assets = list(self.ASSETS.rglob('*.png')) | ||||||
|         self.assets = {path.stem: self.load_asset(str(path), 1) for path in assets} |         self.assets = {path.stem: self.load_asset(str(path), factor) for path in assets} | ||||||
|         self.fill_bg() |         self.fill_bg() | ||||||
|  |  | ||||||
|         now = time.time() |         now = time.time() | ||||||
| @@ -110,22 +110,22 @@ class Renderer: | |||||||
|                 pygame.quit() |                 pygame.quit() | ||||||
|                 sys.exit() |                 sys.exit() | ||||||
|         self.fill_bg() |         self.fill_bg() | ||||||
|         blits = deque() |         # First all others | ||||||
|         for entity in [x for x in entities]: |         blits = deque(self.blit_params(x) for x in entities if not x.name.lower() == AGENT) | ||||||
|             bp = self.blit_params(entity) |         # Then Agents, so that agents are rendered on top. | ||||||
|             blits.append(bp) |         for agent in (x for x in entities if x.name.lower() == AGENT): | ||||||
|             if entity.name.lower() == AGENT: |             agent_blit = self.blit_params(agent) | ||||||
|             if self.view_radius > 0: |             if self.view_radius > 0: | ||||||
|                     vis_rects = self.visibility_rects(bp, entity.aux) |                 vis_rects = self.visibility_rects(agent_blit, agent.aux) | ||||||
|                 blits.extendleft(vis_rects) |                 blits.extendleft(vis_rects) | ||||||
|                 if entity.state != BLANK: |             if agent.state != BLANK: | ||||||
|                     agent_state_blits = self.blit_params( |                 state_blit = self.blit_params( | ||||||
|                         RenderEntity(entity.state, (entity.pos[0] + 0.12, entity.pos[1]), 0.48, SCALE) |                     RenderEntity(agent.state, (agent.pos[0] + 0.12, agent.pos[1]), 0.48, SCALE) | ||||||
|                 ) |                 ) | ||||||
|                     textsurface = self.font.render(str(entity.id), False, (0, 0, 0)) |                 textsurface = self.font.render(str(agent.id), False, (0, 0, 0)) | ||||||
|                     text_blit = dict(source=textsurface, dest=(bp['dest'].center[0]-.07*self.cell_size, |                 text_blit = dict(source=textsurface, dest=(agent_blit['dest'].center[0]-.07*self.cell_size, | ||||||
|                                                                bp['dest'].center[1])) |                                                            agent_blit['dest'].center[1])) | ||||||
|                     blits += [agent_state_blits, text_blit] |                 blits += [agent_blit, state_blit, text_blit] | ||||||
|  |  | ||||||
|         for blit in blits: |         for blit in blits: | ||||||
|             self.screen.blit(**blit) |             self.screen.blit(**blit) | ||||||
|   | |||||||
| @@ -1,9 +1,12 @@ | |||||||
| from typing import Union | from typing import Union | ||||||
| from dataclasses import dataclass | from dataclasses import dataclass | ||||||
|  |  | ||||||
|  | from marl_factory_grid.environment.entity.object import Object | ||||||
|  |  | ||||||
| TYPE_VALUE  = 'value' | TYPE_VALUE  = 'value' | ||||||
| TYPE_REWARD = 'reward' | TYPE_REWARD = 'reward' | ||||||
| types = [TYPE_VALUE, TYPE_REWARD] | TYPES = [TYPE_VALUE, TYPE_REWARD] | ||||||
|  |  | ||||||
|  |  | ||||||
| @dataclass | @dataclass | ||||||
| class InfoObject: | class InfoObject: | ||||||
| @@ -18,17 +21,21 @@ class Result: | |||||||
|     validity: bool |     validity: bool | ||||||
|     reward: Union[float, None] = None |     reward: Union[float, None] = None | ||||||
|     value: Union[float, None] = None |     value: Union[float, None] = None | ||||||
|     entity: None = None |     entity: Object = None | ||||||
|  |  | ||||||
|     def get_infos(self): |     def get_infos(self): | ||||||
|         n = self.entity.name if self.entity is not None else "Global" |         n = self.entity.name if self.entity is not None else "Global" | ||||||
|         return [InfoObject(identifier=f'{n}_{self.identifier}_{t}', |         # Return multiple Info Dicts | ||||||
|                            val_type=t, value=self.__getattribute__(t)) for t in types |         return [InfoObject(identifier=f'{n}_{self.identifier}', | ||||||
|  |                            val_type=t, value=self.__getattribute__(t)) for t in TYPES | ||||||
|                 if self.__getattribute__(t) is not None] |                 if self.__getattribute__(t) is not None] | ||||||
|  |  | ||||||
|     def __repr__(self): |     def __repr__(self): | ||||||
|         valid = "not " if not self.validity else "" |         valid = "not " if not self.validity else "" | ||||||
|         return f'{self.__class__.__name__}({self.identifier.capitalize()} {valid}valid: {self.reward})' |         reward = f" | Reward: {self.reward}" if self.reward is not None else "" | ||||||
|  |         value = f" | Value: {self.value}" if self.value is not None else "" | ||||||
|  |         entity = f" | by: {self.entity.name}" if self.entity is not None else "" | ||||||
|  |         return f'{self.__class__.__name__}({self.identifier.capitalize()} {valid}valid{reward}{value}{entity})' | ||||||
|  |  | ||||||
|  |  | ||||||
| @dataclass | @dataclass | ||||||
|   | |||||||
| @@ -1,9 +1,12 @@ | |||||||
| from typing import List, Dict, Tuple | from itertools import islice | ||||||
|  | from typing import List, Tuple | ||||||
|  |  | ||||||
| import numpy as np | import numpy as np | ||||||
|  |  | ||||||
| from marl_factory_grid.environment import constants as c | from marl_factory_grid.environment import constants as c | ||||||
|  | from marl_factory_grid.environment.entity.entity import Entity | ||||||
| from marl_factory_grid.environment.rules import Rule | from marl_factory_grid.environment.rules import Rule | ||||||
|  | from marl_factory_grid.utils.results import Result, DoneResult | ||||||
| from marl_factory_grid.environment.tests import Test | from marl_factory_grid.environment.tests import Test | ||||||
| from marl_factory_grid.utils.results import Result | from marl_factory_grid.utils.results import Result | ||||||
|  |  | ||||||
| @@ -60,7 +63,8 @@ class Gamestate(object): | |||||||
|     def moving_entites(self): |     def moving_entites(self): | ||||||
|         return [y for x in self.entities for y in x if x.var_can_move] |         return [y for x in self.entities for y in x if x.var_can_move] | ||||||
|  |  | ||||||
|     def __init__(self, entities, agents_conf, rules: [Rule], tests: [Test], env_seed=69, verbose=False): |     def __init__(self, entities, agents_conf, rules: List[Rule], tests: [Test], lvl_shape, env_seed=69, verbose=False): | ||||||
|  |         self.lvl_shape = lvl_shape | ||||||
|         self.entities = entities |         self.entities = entities | ||||||
|         self.curr_step = 0 |         self.curr_step = 0 | ||||||
|         self.curr_actions = None |         self.curr_actions = None | ||||||
| @@ -82,7 +86,52 @@ class Gamestate(object): | |||||||
|     def __repr__(self): |     def __repr__(self): | ||||||
|         return f'{self.__class__.__name__}({len(self.entities)} Entitites @ Step {self.curr_step})' |         return f'{self.__class__.__name__}({len(self.entities)} Entitites @ Step {self.curr_step})' | ||||||
|  |  | ||||||
|     def tick(self, actions) -> List[Result]: |     @property | ||||||
|  |     def random_free_position(self) -> (int, int): | ||||||
|  |         """ | ||||||
|  |         Returns a single **free** position (x, y), which is **free** for spawning or walking. | ||||||
|  |         No Entity at this position posses *var_is_blocking_pos* or *var_can_collide*. | ||||||
|  |  | ||||||
|  |         :return:    Single **free** position. | ||||||
|  |         """ | ||||||
|  |         return self.get_n_random_free_positions(1)[0] | ||||||
|  |  | ||||||
|  |     def get_n_random_free_positions(self, n) -> list[tuple[int, int]]: | ||||||
|  |         """ | ||||||
|  |         Returns a list of *n* **free** positions [(x, y), ... ], which are **free** for spawning or walking. | ||||||
|  |         No Entity at this position posses *var_is_blocking_pos* or *var_can_collide*. | ||||||
|  |  | ||||||
|  |         :return:    List of n **free** position. | ||||||
|  |         """ | ||||||
|  |         return list(islice(self.entities.free_positions_generator, n)) | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def random_position(self) -> (int, int): | ||||||
|  |         """ | ||||||
|  |         Returns a single available position (x, y), ignores all entity attributes. | ||||||
|  |  | ||||||
|  |         :return:    Single random position. | ||||||
|  |         """ | ||||||
|  |         return self.get_n_random_positions(1)[0] | ||||||
|  |  | ||||||
|  |     def get_n_random_positions(self, n) -> list[tuple[int, int]]: | ||||||
|  |         """ | ||||||
|  |         Returns a list of *n* available positions [(x, y), ... ], ignores all entity attributes. | ||||||
|  |  | ||||||
|  |         :return:    List of n random positions. | ||||||
|  |         """ | ||||||
|  |         return list(islice(self.entities.floorlist, n)) | ||||||
|  |  | ||||||
|  |     def tick(self, actions) -> list[Result]: | ||||||
|  |         """ | ||||||
|  |         Performs a single **Gamestate Tick**by calling the inner rule hooks in sequential order. | ||||||
|  |         - tick_pre_step_all:    Things to do before the agents do their actions. Statechange, Moving, Spawning etc... | ||||||
|  |         - agent tick:           Agents do their actions. | ||||||
|  |         - tick_step_all:        Things to do after the agents did their actions. Statechange, Moving, Spawning etc... | ||||||
|  |         - tick_post_step_all:   Things to do at the very end of each step. Counting, Reward calculations etc... | ||||||
|  |  | ||||||
|  |         :return:    List of *Result*-objects. | ||||||
|  |         """ | ||||||
|         results = list() |         results = list() | ||||||
|         test_results = list() |         test_results = list() | ||||||
|         self.curr_step += 1 |         self.curr_step += 1 | ||||||
| @@ -112,11 +161,23 @@ class Gamestate(object): | |||||||
|  |  | ||||||
|         return results |         return results | ||||||
|  |  | ||||||
|     def print(self, string): |     def print(self, string) -> None: | ||||||
|  |         """ | ||||||
|  |         When *verbose* is active, print stuff. | ||||||
|  |  | ||||||
|  |         :param string:      *String* to print. | ||||||
|  |         :type string:       str | ||||||
|  |         :return: Nothing | ||||||
|  |         """ | ||||||
|         if self.verbose: |         if self.verbose: | ||||||
|             print(string) |             print(string) | ||||||
|  |  | ||||||
|     def check_done(self): |     def check_done(self) -> List[DoneResult]: | ||||||
|  |         """ | ||||||
|  |         Iterate all **Rules** that override tehe *on_ckeck_done* hook. | ||||||
|  |  | ||||||
|  |         :return:    List of Results | ||||||
|  |         """ | ||||||
|         results = list() |         results = list() | ||||||
|         for rule in self.rules: |         for rule in self.rules: | ||||||
|             if on_check_done_result := rule.on_check_done(self): |             if on_check_done_result := rule.on_check_done(self): | ||||||
| @@ -124,24 +185,47 @@ class Gamestate(object): | |||||||
|         return results |         return results | ||||||
|  |  | ||||||
|     def get_all_pos_with_collisions(self) -> List[Tuple[(int, int)]]: |     def get_all_pos_with_collisions(self) -> List[Tuple[(int, int)]]: | ||||||
|         positions = [pos for pos, entity_list_for_position in self.entities.pos_dict.items() |         """ | ||||||
|                      if any([e.var_can_collide for e in entity_list_for_position])] |         Returns a list positions [(x, y), ... ] on which collisions occur. This does not include agents, | ||||||
|  |         that were unable to move because their target direction was blocked, also a form of collision. | ||||||
|  |  | ||||||
|  |         :return:    List of positions. | ||||||
|  |         """ | ||||||
|  |         positions = [pos for pos, entities in self.entities.pos_dict.items() if | ||||||
|  |                      len(entities) >= 2 and (len([e for e in entities if e.var_can_collide]) >= 2) | ||||||
|  |                      ] | ||||||
|         return positions |         return positions | ||||||
|  |  | ||||||
|     def check_move_validity(self, moving_entity, position): |     def check_move_validity(self, moving_entity: Entity, target_position: (int, int)) -> bool: | ||||||
|         if moving_entity.pos != position and not any( |         """ | ||||||
|                 entity.var_is_blocking_pos for entity in self.entities.pos_dict[position]) and not ( |         Whether it is safe to move to the target positions and moving entity does not introduce a blocking attribute, | ||||||
|                 moving_entity.var_is_blocking_pos and self.entities.is_occupied(position)): |         when position is allready occupied. | ||||||
|             return True |  | ||||||
|         else: |  | ||||||
|             return False |  | ||||||
|  |  | ||||||
|     def check_pos_validity(self, position): |         :param moving_entity: Entity | ||||||
|         if not any(entity.var_is_blocking_pos for entity in self.entities.pos_dict[position]): |         :param target_position: pos | ||||||
|             return True |         :return:    Safe to move to | ||||||
|         else: |         """ | ||||||
|             return False |  | ||||||
|  |  | ||||||
|  |         is_not_blocked = self.check_pos_validity(target_position) | ||||||
|  |         will_not_block_others = moving_entity.var_is_blocking_pos and self.entities.is_occupied(target_position) | ||||||
|  |  | ||||||
|  |         if moving_entity.pos != target_position and is_not_blocked and not will_not_block_others: | ||||||
|  |             return c.VALID | ||||||
|  |         else: | ||||||
|  |             return c.NOT_VALID | ||||||
|  |  | ||||||
|  |     def check_pos_validity(self, pos: (int, int)) -> bool: | ||||||
|  |         """ | ||||||
|  |         Check if *pos* is a valid position to move or spawn to. | ||||||
|  |  | ||||||
|  |         :param pos: position to check | ||||||
|  |         :return: Wheter pos is a valid target. | ||||||
|  |         """ | ||||||
|  |  | ||||||
|  |         if not any(e.var_is_blocking_pos for e in self.entities.pos_dict[pos]) and pos in self.entities.floorlist: | ||||||
|  |             return c.VALID | ||||||
|  |         else: | ||||||
|  |             return c.NOT_VALID | ||||||
|  |  | ||||||
| class StepTests: | class StepTests: | ||||||
|     def __init__(self, *args): |     def __init__(self, *args): | ||||||
|   | |||||||
| @@ -28,7 +28,9 @@ class ConfigExplainer: | |||||||
|  |  | ||||||
|     def explain_module(self, class_to_explain): |     def explain_module(self, class_to_explain): | ||||||
|         parameters = inspect.signature(class_to_explain).parameters |         parameters = inspect.signature(class_to_explain).parameters | ||||||
|         explained = {class_to_explain.__name__: {key: val.default for key, val in parameters.items() if key not in EXCLUDED}} |         explained = {class_to_explain.__name__: | ||||||
|  |                          {key: val.default for key, val in parameters.items() if key not in EXCLUDED} | ||||||
|  |                      } | ||||||
|         return explained |         return explained | ||||||
|  |  | ||||||
|     def _load_and_compare(self, compare_class, paths): |     def _load_and_compare(self, compare_class, paths): | ||||||
| @@ -135,4 +137,3 @@ if __name__ == '__main__': | |||||||
|     ce.get_observations() |     ce.get_observations() | ||||||
|     ce.get_assets() |     ce.get_assets() | ||||||
|     all_conf = ce.get_all() |     all_conf = ce.get_all() | ||||||
|     print() |  | ||||||
|   | |||||||
| @@ -52,3 +52,6 @@ class Floor: | |||||||
|  |  | ||||||
|     def __hash__(self): |     def __hash__(self): | ||||||
|         return hash(self.name) |         return hash(self.name) | ||||||
|  |  | ||||||
|  |     def __repr__(self): | ||||||
|  |         return f"Floor{self.pos}" | ||||||
|   | |||||||
| @@ -6,18 +6,21 @@ from marl_factory_grid.environment.factory import Factory | |||||||
|  |  | ||||||
| from marl_factory_grid.utils.logging.envmonitor import EnvMonitor | from marl_factory_grid.utils.logging.envmonitor import EnvMonitor | ||||||
| from marl_factory_grid.utils.logging.recorder import EnvRecorder | from marl_factory_grid.utils.logging.recorder import EnvRecorder | ||||||
|  | from marl_factory_grid.utils.plotting.plot_single_runs import plot_single_run | ||||||
| from marl_factory_grid.utils.tools import ConfigExplainer | from marl_factory_grid.utils.tools import ConfigExplainer | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|     # Render at each step? |     # Render at each step? | ||||||
|     render = True |     render = False | ||||||
|     # Reveal all possible Modules (Entities, Rules, Agents[Actions, Observations], etc.) |     # Reveal all possible Modules (Entities, Rules, Agents[Actions, Observations], etc.) | ||||||
|     explain_config = False |     explain_config = False | ||||||
|     # Collect statistics? |     # Collect statistics? | ||||||
|     monitor = False |     monitor = True | ||||||
|     # Record as Protobuf? |     # Record as Protobuf? | ||||||
|     record = False |     record = False | ||||||
|  |     # Plot Results? | ||||||
|  |     plotting = True | ||||||
|  |  | ||||||
|     run_path = Path('study_out') |     run_path = Path('study_out') | ||||||
|  |  | ||||||
| @@ -38,7 +41,7 @@ if __name__ == '__main__': | |||||||
|         factory = EnvRecorder(factory) |         factory = EnvRecorder(factory) | ||||||
|  |  | ||||||
|     # RL learn Loop |     # RL learn Loop | ||||||
|     for episode in trange(500): |     for episode in trange(10): | ||||||
|         _ = factory.reset() |         _ = factory.reset() | ||||||
|         done = False |         done = False | ||||||
|         if render: |         if render: | ||||||
| @@ -54,7 +57,10 @@ if __name__ == '__main__': | |||||||
|                 break |                 break | ||||||
|  |  | ||||||
|     if monitor: |     if monitor: | ||||||
|         factory.save_run(run_path / 'test.pkl') |         factory.save_run(run_path / 'test_monitor.pkl') | ||||||
|     if record: |     if record: | ||||||
|         factory.save_records(run_path / 'test.pb') |         factory.save_records(run_path / 'test.pb') | ||||||
|  |     if plotting: | ||||||
|  |         plot_single_run(run_path) | ||||||
|  |  | ||||||
|     print('Done!!! Goodbye....') |     print('Done!!! Goodbye....') | ||||||
|   | |||||||
| @@ -6,6 +6,7 @@ import yaml | |||||||
| from marl_factory_grid.environment.factory import Factory | from marl_factory_grid.environment.factory import Factory | ||||||
| from marl_factory_grid.utils.logging.envmonitor import EnvMonitor | from marl_factory_grid.utils.logging.envmonitor import EnvMonitor | ||||||
| from marl_factory_grid.utils.logging.recorder import EnvRecorder | from marl_factory_grid.utils.logging.recorder import EnvRecorder | ||||||
|  | from marl_factory_grid.utils import helpers as h | ||||||
|  |  | ||||||
| from marl_factory_grid.modules.doors import constants as d | from marl_factory_grid.modules.doors import constants as d | ||||||
|  |  | ||||||
| @@ -55,13 +56,14 @@ if __name__ == '__main__': | |||||||
|                                for model_idx, model in enumerate(models)] |                                for model_idx, model in enumerate(models)] | ||||||
|                 else: |                 else: | ||||||
|                     actions = models[0].predict(env_state, deterministic=determin)[0] |                     actions = models[0].predict(env_state, deterministic=determin)[0] | ||||||
|  |                 # noinspection PyTupleAssignmentBalance | ||||||
|                 env_state, step_r, done_bool, info_obj = env.step(actions) |                 env_state, step_r, done_bool, info_obj = env.step(actions) | ||||||
|  |  | ||||||
|                 rew += step_r |                 rew += step_r | ||||||
|                 if render: |                 if render: | ||||||
|                     env.render() |                     env.render() | ||||||
|                 try: |                 try: | ||||||
|                     door = next(x for x in env.unwrapped.unwrapped[d.DOORS] if x.is_open) |                     door = h.get_first([x for x in env.unwrapped.unwrapped[d.DOORS] if x.is_open]) | ||||||
|                     print('openDoor found') |                     print('openDoor found') | ||||||
|                 except StopIteration: |                 except StopIteration: | ||||||
|                     pass |                     pass | ||||||
|   | |||||||
| @@ -1,8 +1,8 @@ | |||||||
| from algorithms.utils import Checkpointer | from algorithms.utils import Checkpointer | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| from algorithms.utils import load_yaml_file, add_env_props, instantiate_class, load_class | from algorithms.utils import load_yaml_file, add_env_props, instantiate_class, load_class | ||||||
| #from algorithms.marl import LoopSNAC, LoopIAC, LoopSEAC |  | ||||||
|  |  | ||||||
|  | # from algorithms.marl import LoopSNAC, LoopIAC, LoopSEAC | ||||||
|  |  | ||||||
|  |  | ||||||
| for i in range(0, 5): | for i in range(0, 5): | ||||||
|   | |||||||
							
								
								
									
										41
									
								
								transform_wg_to_json_no_priv.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										41
									
								
								transform_wg_to_json_no_priv.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,41 @@ | |||||||
|  | import configparser | ||||||
|  | import json | ||||||
|  | from datetime import datetime | ||||||
|  | from pathlib import Path | ||||||
|  |  | ||||||
|  | if __name__ == '__main__': | ||||||
|  |  | ||||||
|  |     conf_path = Path('wg0') | ||||||
|  |     wg0_conf = configparser.ConfigParser() | ||||||
|  |     wg0_conf.read(conf_path/'wg0.conf') | ||||||
|  |     interface = wg0_conf['Interface'] | ||||||
|  |     # Iterate all pears | ||||||
|  |     for client_name in wg0_conf.sections(): | ||||||
|  |         if client_name == 'Interface': | ||||||
|  |             continue | ||||||
|  |         # Delete any old conf.json for the current peer | ||||||
|  |         (conf_path / f'{client_name}.json').unlink(missing_ok=True) | ||||||
|  |  | ||||||
|  |         peer = wg0_conf[client_name] | ||||||
|  |  | ||||||
|  |         date_time = datetime.now().strftime('%Y-%m-%dT%H:%M:%S.%f000Z') | ||||||
|  |  | ||||||
|  |         jdict = dict( | ||||||
|  |             id=client_name, | ||||||
|  |             private_key=peer['PublicKey'], | ||||||
|  |             public_key=peer['PublicKey'], | ||||||
|  |             # preshared_key=wg0_conf[client_name_wg0]['PresharedKey'], | ||||||
|  |             name=client_name, | ||||||
|  |             email=f"sysadmin@mobile.ifi.lmu.de", | ||||||
|  |             allocated_ips=[interface['Address'].replace('/24', '')], | ||||||
|  |             allowed_ips=['10.4.0.0/24', '10.153.199.0/24'], | ||||||
|  |             extra_allowed_ips=[], | ||||||
|  |             use_server_dns=True, | ||||||
|  |             enabled=True, | ||||||
|  |             created_at=date_time, | ||||||
|  |             updated_at=date_time | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         with (conf_path / f'{client_name}.json').open('w+') as f: | ||||||
|  |             json.dump(jdict, f, indent='\t', separators=(',', ': ')) | ||||||
|  |         print(client_name, ' written...') | ||||||
		Reference in New Issue
	
	Block a user
	 Chanumask
					Chanumask