TSP Single Agent

This commit is contained in:
Steffen Illium 2021-11-25 14:48:34 +01:00
parent 3c84ba483b
commit 3d81b7577d
6 changed files with 145 additions and 42 deletions

View File

@ -0,0 +1,66 @@
import numpy as np
from networkx.algorithms.approximation import traveling_salesman as tsp
from environments.factory.base.objects import Agent
from environments.factory.base.registers import FloorTiles, Actions
from environments.helpers import points_to_graph
from environments import helpers as h
class TSPDirtAgent(Agent):
def __init__(self, floortiles: FloorTiles, dirt_register, actions: Actions, *args,
static_problem: bool = True, **kwargs):
super().__init__(*args, **kwargs)
self.static_problem = static_problem
self._floortiles = floortiles
self._actions = actions
self._dirt_register = dirt_register
self._floortile_graph = points_to_graph(self._floortiles.positions,
allow_euclidean_connections=self._actions.allow_diagonal_movement,
allow_manhattan_connections=self._actions.allow_square_movement)
self._static_route = None
def predict(self, *_, **__):
if self._dirt_register.by_pos(self.pos) is not None:
# Translate the action_object to an integer to have the same output as any other model
action = h.EnvActions.CLEAN_UP
elif any('door' in x.name.lower() for x in self.tile.guests):
door = next(x for x in self.tile.guests if 'door' in x.name.lower())
if door.is_closed:
# Translate the action_object to an integer to have the same output as any other model
action = h.EnvActions.USE_DOOR
else:
action = self._predict_move()
else:
action = self._predict_move()
# Translate the action_object to an integer to have the same output as any other model
action_obj = next(action_i for action_i, action_obj in enumerate(self._actions) if action_obj == action)
return action_obj
def _predict_move(self):
if self.static_problem:
if self._static_route is None:
self._static_route = self.calculate_tsp_route()
else:
pass
next_pos = self._static_route.pop(0)
while next_pos == self.pos:
next_pos = self._static_route.pop(0)
else:
raise NotImplementedError
diff = np.subtract(next_pos, self.pos)
# Retrieve action based on the pos dif (like in: What do i have to do to get there?)
try:
action = next(action for action, pos_diff in h.ACTIONMAP.items()
if (diff == pos_diff).all())
except StopIteration:
print('This Should not happen!')
return action
def calculate_tsp_route(self):
route = tsp.traveling_salesman_problem(self._floortile_graph,
nodes=[self.pos] + [x for x in self._dirt_register.positions])
return route

View File

@ -3,7 +3,7 @@ from enum import Enum
from typing import Union from typing import Union
import networkx as nx import networkx as nx
import numpy as np from environments import helpers as h
from environments.helpers import Constants as c from environments.helpers import Constants as c
import itertools import itertools
@ -267,11 +267,7 @@ class Door(Entity):
neighbor_pos = list(itertools.product([-1, 1, 0], repeat=2))[:-1] neighbor_pos = list(itertools.product([-1, 1, 0], repeat=2))[:-1]
neighbor_tiles = [context.by_pos(tuple([sum(x) for x in zip(self.pos, diff)])) for diff in neighbor_pos] neighbor_tiles = [context.by_pos(tuple([sum(x) for x in zip(self.pos, diff)])) for diff in neighbor_pos]
neighbor_pos = [x.pos for x in neighbor_tiles if x] neighbor_pos = [x.pos for x in neighbor_tiles if x]
possible_connections = itertools.combinations(neighbor_pos, 2) self.connectivity = h.points_to_graph(neighbor_pos)
self.connectivity = nx.Graph()
for a, b in possible_connections:
if not max(abs(np.subtract(a, b))) > 1:
self.connectivity.add_edge(a, b)
self.connectivity_subgroups = list(nx.algorithms.components.connected_components(self.connectivity)) self.connectivity_subgroups = list(nx.algorithms.components.connected_components(self.connectivity))
for idx, group in enumerate(self.connectivity_subgroups): for idx, group in enumerate(self.connectivity_subgroups):
for tile_pos in group: for tile_pos in group:

View File

@ -320,6 +320,9 @@ class Agents(MovingEntityObjectRegister):
def positions(self): def positions(self):
return [agent.pos for agent in self] return [agent.pos for agent in self]
def __setitem__(self, key, value):
self._register[self[key].name] = value
class Doors(EntityObjectRegister): class Doors(EntityObjectRegister):

View File

@ -5,6 +5,7 @@ import random
import numpy as np import numpy as np
from algorithms.TSP_dirt_agent import TSPDirtAgent
from environments.helpers import Constants as c from environments.helpers import Constants as c
from environments import helpers as h from environments import helpers as h
from environments.factory.base.base_factory import BaseFactory from environments.factory.base.base_factory import BaseFactory
@ -262,17 +263,29 @@ if __name__ == '__main__':
from environments.utility_classes import AgentRenderOptions as ARO from environments.utility_classes import AgentRenderOptions as ARO
render = True render = True
dirt_props = DirtProperties(1, 0.05, 0.1, 3, 1, 20, 0) dirt_props = DirtProperties(
initial_dirt_ratio=0.35,
initial_dirt_spawn_r_var=0.1,
clean_amount=0.34,
max_spawn_amount=0.1,
max_global_amount=20,
max_local_amount=1,
spawn_frequency=0,
max_spawn_ratio=0.05,
dirt_smear_amount=0.0,
agent_can_interact=True
)
obs_props = ObservationProperties(render_agents=ARO.COMBINED, omit_agent_self=True, obs_props = ObservationProperties(render_agents=ARO.COMBINED, omit_agent_self=True,
pomdp_r=15, additional_agent_placeholder=None) pomdp_r=2, additional_agent_placeholder=None)
move_props = {'allow_square_movement': True, move_props = {'allow_square_movement': True,
'allow_diagonal_movement': False, 'allow_diagonal_movement': False,
'allow_no_op': False} 'allow_no_op': False}
factory = DirtFactory(n_agents=5, done_at_collision=False, factory = DirtFactory(n_agents=1, done_at_collision=False,
level_name='rooms', max_steps=400, level_name='rooms', max_steps=400,
doors_have_area=False,
obs_prop=obs_props, parse_doors=True, obs_prop=obs_props, parse_doors=True,
record_episodes=True, verbose=True, record_episodes=True, verbose=True,
mv_prop=move_props, dirt_prop=dirt_props mv_prop=move_props, dirt_prop=dirt_props
@ -287,9 +300,15 @@ if __name__ == '__main__':
in range(factory.n_agents)] for _ in range(factory.n_agents)] for _
in range(factory.max_steps+1)] in range(factory.max_steps+1)]
env_state = factory.reset() env_state = factory.reset()
if render:
factory.render()
random_start_position = factory[c.AGENT][0].tile
factory[c.AGENT][0] = tsp_agent = TSPDirtAgent(factory[c.FLOOR], factory[c.DIRT],
factory._actions, random_start_position)
r = 0 r = 0
for agent_i_action in random_actions: for agent_i_action in random_actions:
env_state, step_r, done_bool, info_obj = factory.step(agent_i_action) env_state, step_r, done_bool, info_obj = factory.step(tsp_agent.predict())
r += step_r r += step_r
if render: if render:
factory.render() factory.render()

View File

@ -1,7 +1,9 @@
import itertools
from collections import defaultdict from collections import defaultdict
from enum import Enum, auto from enum import Enum, auto
from typing import Tuple, Union from typing import Tuple, Union
import networkx as nx
import numpy as np import numpy as np
from pathlib import Path from pathlib import Path
@ -153,6 +155,23 @@ def asset_str(agent):
return c.AGENT.value, 'idle' return c.AGENT.value, 'idle'
def points_to_graph(coordiniates_or_tiles, allow_euclidean_connections=True, allow_manhattan_connections=True):
assert allow_euclidean_connections or allow_manhattan_connections
if hasattr(coordiniates_or_tiles, 'positions'):
coordiniates_or_tiles = coordiniates_or_tiles.positions
possible_connections = itertools.combinations(coordiniates_or_tiles, 2)
graph = nx.Graph()
for a, b in possible_connections:
diff = abs(np.subtract(a, b))
if not max(diff) > 1:
if allow_manhattan_connections and allow_euclidean_connections:
graph.add_edge(a, b)
elif not allow_manhattan_connections and allow_euclidean_connections and all(diff):
graph.add_edge(a, b)
elif allow_manhattan_connections and not allow_euclidean_connections and not all(diff) and any(diff):
graph.add_edge(a, b)
return graph
if __name__ == '__main__': if __name__ == '__main__':
parsed_level = parse_level(Path(__file__).parent / 'factory' / 'levels' / 'simple.txt') parsed_level = parse_level(Path(__file__).parent / 'factory' / 'levels' / 'simple.txt')
y = one_hot_level(parsed_level) y = one_hot_level(parsed_level)

View File

@ -75,7 +75,7 @@ baseline_monitor_file = 'e_1_baseline'
from stable_baselines3 import A2C from stable_baselines3 import A2C
def policy_model_kwargs(): def policy_model_kwargs():
return dict(gae_lambda=0.25, n_steps=16, max_grad_norm=0, use_rms_prop=False) return dict(gae_lambda=0.25, n_steps=16, max_grad_norm=0, use_rms_prop=True)
def dqn_model_kwargs(): def dqn_model_kwargs():
@ -203,7 +203,7 @@ if __name__ == '__main__':
frames_to_stack = 3 frames_to_stack = 3
# Define a global studi save path # Define a global studi save path
start_time = 'adam_no_weight_decay' # int(time.time()) start_time = 'rms_weight_decay_0' # int(time.time())
study_root_path = Path(__file__).parent.parent / 'study_out' / f'{Path(__file__).stem}_{start_time}' study_root_path = Path(__file__).parent.parent / 'study_out' / f'{Path(__file__).stem}_{start_time}'
# Define Global Env Parameters # Define Global Env Parameters
@ -285,36 +285,36 @@ if __name__ == '__main__':
pomdp_r=2) pomdp_r=2)
) )
)}) )})
observation_modes.update({ observation_modes.update({
'seperate_N': dict( 'seperate_N': dict(
post_training_kwargs= post_training_kwargs=
dict(obs_prop=ObservationProperties( dict(obs_prop=ObservationProperties(
render_agents=AgentRenderOptions.COMBINED, render_agents=AgentRenderOptions.COMBINED,
additional_agent_placeholder=None, additional_agent_placeholder=None,
omit_agent_self=True, omit_agent_self=True,
frames_to_stack=frames_to_stack, frames_to_stack=frames_to_stack,
pomdp_r=2) pomdp_r=2)
), ),
additional_env_kwargs= additional_env_kwargs=
dict(obs_prop=ObservationProperties( dict(obs_prop=ObservationProperties(
render_agents=AgentRenderOptions.NOT, render_agents=AgentRenderOptions.NOT,
additional_agent_placeholder='N', additional_agent_placeholder='N',
omit_agent_self=True, omit_agent_self=True,
frames_to_stack=frames_to_stack, frames_to_stack=frames_to_stack,
pomdp_r=2) pomdp_r=2)
) )
)}) )})
observation_modes.update({ observation_modes.update({
'in_lvl_obs': dict( 'in_lvl_obs': dict(
post_training_kwargs= post_training_kwargs=
dict(obs_prop=ObservationProperties( dict(obs_prop=ObservationProperties(
render_agents=AgentRenderOptions.LEVEL, render_agents=AgentRenderOptions.LEVEL,
omit_agent_self=True, omit_agent_self=True,
additional_agent_placeholder=None, additional_agent_placeholder=None,
frames_to_stack=frames_to_stack, frames_to_stack=frames_to_stack,
pomdp_r=2) pomdp_r=2)
) )
)}) )})
observation_modes.update({ observation_modes.update({
# No further adjustment needed # No further adjustment needed
'no_obs': dict( 'no_obs': dict(