mirror of
				https://github.com/illiumst/marl-factory-grid.git
				synced 2025-10-31 04:37:25 +01:00 
			
		
		
		
	TSP Single Agent
This commit is contained in:
		
							
								
								
									
										66
									
								
								algorithms/TSP_dirt_agent.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										66
									
								
								algorithms/TSP_dirt_agent.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,66 @@ | ||||
| import numpy as np | ||||
|  | ||||
| from networkx.algorithms.approximation import traveling_salesman as tsp | ||||
|  | ||||
| from environments.factory.base.objects import Agent | ||||
| from environments.factory.base.registers import FloorTiles, Actions | ||||
| from environments.helpers import points_to_graph | ||||
| from environments import helpers as h | ||||
|  | ||||
|  | ||||
| class TSPDirtAgent(Agent): | ||||
|  | ||||
|     def __init__(self, floortiles: FloorTiles, dirt_register, actions: Actions, *args, | ||||
|                  static_problem: bool = True, **kwargs): | ||||
|         super().__init__(*args, **kwargs) | ||||
|         self.static_problem = static_problem | ||||
|         self._floortiles = floortiles | ||||
|         self._actions = actions | ||||
|         self._dirt_register = dirt_register | ||||
|         self._floortile_graph = points_to_graph(self._floortiles.positions, | ||||
|                                                 allow_euclidean_connections=self._actions.allow_diagonal_movement, | ||||
|                                                 allow_manhattan_connections=self._actions.allow_square_movement) | ||||
|         self._static_route = None | ||||
|  | ||||
|     def predict(self, *_, **__): | ||||
|         if self._dirt_register.by_pos(self.pos) is not None: | ||||
|             # Translate the action_object to an integer to have the same output as any other model | ||||
|             action = h.EnvActions.CLEAN_UP | ||||
|         elif any('door' in x.name.lower() for x in self.tile.guests): | ||||
|             door = next(x for x in self.tile.guests if 'door' in x.name.lower()) | ||||
|             if door.is_closed: | ||||
|                 # Translate the action_object to an integer to have the same output as any other model | ||||
|                 action = h.EnvActions.USE_DOOR | ||||
|             else: | ||||
|                 action = self._predict_move() | ||||
|         else: | ||||
|             action = self._predict_move() | ||||
|         # Translate the action_object to an integer to have the same output as any other model | ||||
|         action_obj = next(action_i for action_i, action_obj in enumerate(self._actions) if action_obj == action) | ||||
|         return action_obj | ||||
|  | ||||
|     def _predict_move(self): | ||||
|         if self.static_problem: | ||||
|             if self._static_route is None: | ||||
|                 self._static_route = self.calculate_tsp_route() | ||||
|             else: | ||||
|                 pass | ||||
|             next_pos = self._static_route.pop(0) | ||||
|             while next_pos == self.pos: | ||||
|                 next_pos = self._static_route.pop(0) | ||||
|         else: | ||||
|             raise NotImplementedError | ||||
|  | ||||
|         diff = np.subtract(next_pos, self.pos) | ||||
|         # Retrieve action based on the pos dif (like in: What do i have to do to get there?) | ||||
|         try: | ||||
|             action = next(action for action, pos_diff in h.ACTIONMAP.items() | ||||
|                           if (diff == pos_diff).all()) | ||||
|         except StopIteration: | ||||
|             print('This Should not happen!') | ||||
|         return action | ||||
|  | ||||
|     def calculate_tsp_route(self): | ||||
|         route = tsp.traveling_salesman_problem(self._floortile_graph, | ||||
|                                                nodes=[self.pos] + [x for x in self._dirt_register.positions]) | ||||
|         return route | ||||
| @@ -3,7 +3,7 @@ from enum import Enum | ||||
| from typing import Union | ||||
|  | ||||
| import networkx as nx | ||||
| import numpy as np | ||||
| from environments import helpers as h | ||||
| from environments.helpers import Constants as c | ||||
| import itertools | ||||
|  | ||||
| @@ -267,11 +267,7 @@ class Door(Entity): | ||||
|         neighbor_pos = list(itertools.product([-1, 1, 0], repeat=2))[:-1] | ||||
|         neighbor_tiles = [context.by_pos(tuple([sum(x) for x in zip(self.pos, diff)])) for diff in neighbor_pos] | ||||
|         neighbor_pos = [x.pos for x in neighbor_tiles if x] | ||||
|         possible_connections = itertools.combinations(neighbor_pos, 2) | ||||
|         self.connectivity = nx.Graph() | ||||
|         for a, b in possible_connections: | ||||
|             if not max(abs(np.subtract(a, b))) > 1: | ||||
|                 self.connectivity.add_edge(a, b) | ||||
|         self.connectivity = h.points_to_graph(neighbor_pos) | ||||
|         self.connectivity_subgroups = list(nx.algorithms.components.connected_components(self.connectivity)) | ||||
|         for idx, group in enumerate(self.connectivity_subgroups): | ||||
|             for tile_pos in group: | ||||
|   | ||||
| @@ -320,6 +320,9 @@ class Agents(MovingEntityObjectRegister): | ||||
|     def positions(self): | ||||
|         return [agent.pos for agent in self] | ||||
|  | ||||
|     def __setitem__(self, key, value): | ||||
|         self._register[self[key].name] = value | ||||
|  | ||||
|  | ||||
| class Doors(EntityObjectRegister): | ||||
|  | ||||
|   | ||||
| @@ -5,6 +5,7 @@ import random | ||||
|  | ||||
| import numpy as np | ||||
|  | ||||
| from algorithms.TSP_dirt_agent import TSPDirtAgent | ||||
| from environments.helpers import Constants as c | ||||
| from environments import helpers as h | ||||
| from environments.factory.base.base_factory import BaseFactory | ||||
| @@ -262,17 +263,29 @@ if __name__ == '__main__': | ||||
|     from environments.utility_classes import AgentRenderOptions as ARO | ||||
|     render = True | ||||
|  | ||||
|     dirt_props = DirtProperties(1, 0.05, 0.1, 3, 1, 20, 0) | ||||
|     dirt_props = DirtProperties( | ||||
|         initial_dirt_ratio=0.35, | ||||
|         initial_dirt_spawn_r_var=0.1, | ||||
|         clean_amount=0.34, | ||||
|         max_spawn_amount=0.1, | ||||
|         max_global_amount=20, | ||||
|         max_local_amount=1, | ||||
|         spawn_frequency=0, | ||||
|         max_spawn_ratio=0.05, | ||||
|         dirt_smear_amount=0.0, | ||||
|         agent_can_interact=True | ||||
|     ) | ||||
|  | ||||
|     obs_props = ObservationProperties(render_agents=ARO.COMBINED, omit_agent_self=True, | ||||
|                                       pomdp_r=15, additional_agent_placeholder=None) | ||||
|                                       pomdp_r=2, additional_agent_placeholder=None) | ||||
|  | ||||
|     move_props = {'allow_square_movement': True, | ||||
|                   'allow_diagonal_movement': False, | ||||
|                   'allow_no_op': False} | ||||
|  | ||||
|     factory = DirtFactory(n_agents=5, done_at_collision=False, | ||||
|     factory = DirtFactory(n_agents=1, done_at_collision=False, | ||||
|                           level_name='rooms', max_steps=400, | ||||
|                           doors_have_area=False, | ||||
|                           obs_prop=obs_props, parse_doors=True, | ||||
|                           record_episodes=True, verbose=True, | ||||
|                           mv_prop=move_props, dirt_prop=dirt_props | ||||
| @@ -287,9 +300,15 @@ if __name__ == '__main__': | ||||
|                            in range(factory.n_agents)] for _ | ||||
|                           in range(factory.max_steps+1)] | ||||
|         env_state = factory.reset() | ||||
|         if render: | ||||
|             factory.render() | ||||
|         random_start_position = factory[c.AGENT][0].tile | ||||
|         factory[c.AGENT][0] = tsp_agent = TSPDirtAgent(factory[c.FLOOR], factory[c.DIRT], | ||||
|                                                        factory._actions, random_start_position) | ||||
|  | ||||
|         r = 0 | ||||
|         for agent_i_action in random_actions: | ||||
|             env_state, step_r, done_bool, info_obj = factory.step(agent_i_action) | ||||
|             env_state, step_r, done_bool, info_obj = factory.step(tsp_agent.predict()) | ||||
|             r += step_r | ||||
|             if render: | ||||
|                 factory.render() | ||||
|   | ||||
| @@ -1,7 +1,9 @@ | ||||
| import itertools | ||||
| from collections import defaultdict | ||||
| from enum import Enum, auto | ||||
| from typing import Tuple, Union | ||||
|  | ||||
| import networkx as nx | ||||
| import numpy as np | ||||
| from pathlib import Path | ||||
|  | ||||
| @@ -153,6 +155,23 @@ def asset_str(agent): | ||||
|         return c.AGENT.value, 'idle' | ||||
|  | ||||
|  | ||||
| def points_to_graph(coordiniates_or_tiles, allow_euclidean_connections=True, allow_manhattan_connections=True): | ||||
|     assert allow_euclidean_connections or allow_manhattan_connections | ||||
|     if hasattr(coordiniates_or_tiles, 'positions'): | ||||
|         coordiniates_or_tiles = coordiniates_or_tiles.positions | ||||
|     possible_connections = itertools.combinations(coordiniates_or_tiles, 2) | ||||
|     graph = nx.Graph() | ||||
|     for a, b in possible_connections: | ||||
|         diff = abs(np.subtract(a, b)) | ||||
|         if not max(diff) > 1: | ||||
|             if allow_manhattan_connections and allow_euclidean_connections: | ||||
|                 graph.add_edge(a, b) | ||||
|             elif not allow_manhattan_connections and allow_euclidean_connections and all(diff): | ||||
|                 graph.add_edge(a, b) | ||||
|             elif allow_manhattan_connections and not allow_euclidean_connections and not all(diff) and any(diff): | ||||
|                 graph.add_edge(a, b) | ||||
|     return graph | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     parsed_level = parse_level(Path(__file__).parent / 'factory' / 'levels' / 'simple.txt') | ||||
|     y = one_hot_level(parsed_level) | ||||
|   | ||||
| @@ -75,7 +75,7 @@ baseline_monitor_file = 'e_1_baseline' | ||||
| from stable_baselines3 import A2C | ||||
|  | ||||
| def policy_model_kwargs(): | ||||
|     return dict(gae_lambda=0.25, n_steps=16, max_grad_norm=0, use_rms_prop=False) | ||||
|     return dict(gae_lambda=0.25, n_steps=16, max_grad_norm=0, use_rms_prop=True) | ||||
|  | ||||
|  | ||||
| def dqn_model_kwargs(): | ||||
| @@ -203,7 +203,7 @@ if __name__ == '__main__': | ||||
|     frames_to_stack = 3 | ||||
|  | ||||
|     # Define a global studi save path | ||||
|     start_time = 'adam_no_weight_decay'  # int(time.time()) | ||||
|     start_time = 'rms_weight_decay_0'  # int(time.time()) | ||||
|     study_root_path = Path(__file__).parent.parent / 'study_out' / f'{Path(__file__).stem}_{start_time}' | ||||
|  | ||||
|     # Define Global Env Parameters | ||||
| @@ -285,36 +285,36 @@ if __name__ == '__main__': | ||||
|                     pomdp_r=2) | ||||
|                 ) | ||||
|             )}) | ||||
|     observation_modes.update({ | ||||
|         'seperate_N': dict( | ||||
|             post_training_kwargs= | ||||
|             dict(obs_prop=ObservationProperties( | ||||
|                 render_agents=AgentRenderOptions.COMBINED, | ||||
|                 additional_agent_placeholder=None, | ||||
|                 omit_agent_self=True, | ||||
|                 frames_to_stack=frames_to_stack, | ||||
|                 pomdp_r=2) | ||||
|             ), | ||||
|             additional_env_kwargs= | ||||
|             dict(obs_prop=ObservationProperties( | ||||
|                 render_agents=AgentRenderOptions.NOT, | ||||
|                 additional_agent_placeholder='N', | ||||
|                 omit_agent_self=True, | ||||
|                 frames_to_stack=frames_to_stack, | ||||
|                 pomdp_r=2) | ||||
|             ) | ||||
|         )}) | ||||
|     observation_modes.update({ | ||||
|         'in_lvl_obs': dict( | ||||
|             post_training_kwargs= | ||||
|             dict(obs_prop=ObservationProperties( | ||||
|                 render_agents=AgentRenderOptions.LEVEL, | ||||
|                 omit_agent_self=True, | ||||
|                 additional_agent_placeholder=None, | ||||
|                 frames_to_stack=frames_to_stack, | ||||
|                 pomdp_r=2) | ||||
|             ) | ||||
|         )}) | ||||
|         observation_modes.update({ | ||||
|             'seperate_N': dict( | ||||
|                 post_training_kwargs= | ||||
|                 dict(obs_prop=ObservationProperties( | ||||
|                     render_agents=AgentRenderOptions.COMBINED, | ||||
|                     additional_agent_placeholder=None, | ||||
|                     omit_agent_self=True, | ||||
|                     frames_to_stack=frames_to_stack, | ||||
|                     pomdp_r=2) | ||||
|                 ), | ||||
|                 additional_env_kwargs= | ||||
|                 dict(obs_prop=ObservationProperties( | ||||
|                     render_agents=AgentRenderOptions.NOT, | ||||
|                     additional_agent_placeholder='N', | ||||
|                     omit_agent_self=True, | ||||
|                     frames_to_stack=frames_to_stack, | ||||
|                     pomdp_r=2) | ||||
|                 ) | ||||
|             )}) | ||||
|         observation_modes.update({ | ||||
|             'in_lvl_obs': dict( | ||||
|                 post_training_kwargs= | ||||
|                 dict(obs_prop=ObservationProperties( | ||||
|                     render_agents=AgentRenderOptions.LEVEL, | ||||
|                     omit_agent_self=True, | ||||
|                     additional_agent_placeholder=None, | ||||
|                     frames_to_stack=frames_to_stack, | ||||
|                     pomdp_r=2) | ||||
|                 ) | ||||
|             )}) | ||||
|     observation_modes.update({ | ||||
|         #  No further adjustment needed | ||||
|         'no_obs': dict( | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Steffen Illium
					Steffen Illium