TSP Single Agent
This commit is contained in:
parent
3c84ba483b
commit
3d81b7577d
66
algorithms/TSP_dirt_agent.py
Normal file
66
algorithms/TSP_dirt_agent.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from networkx.algorithms.approximation import traveling_salesman as tsp
|
||||||
|
|
||||||
|
from environments.factory.base.objects import Agent
|
||||||
|
from environments.factory.base.registers import FloorTiles, Actions
|
||||||
|
from environments.helpers import points_to_graph
|
||||||
|
from environments import helpers as h
|
||||||
|
|
||||||
|
|
||||||
|
class TSPDirtAgent(Agent):
|
||||||
|
|
||||||
|
def __init__(self, floortiles: FloorTiles, dirt_register, actions: Actions, *args,
|
||||||
|
static_problem: bool = True, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.static_problem = static_problem
|
||||||
|
self._floortiles = floortiles
|
||||||
|
self._actions = actions
|
||||||
|
self._dirt_register = dirt_register
|
||||||
|
self._floortile_graph = points_to_graph(self._floortiles.positions,
|
||||||
|
allow_euclidean_connections=self._actions.allow_diagonal_movement,
|
||||||
|
allow_manhattan_connections=self._actions.allow_square_movement)
|
||||||
|
self._static_route = None
|
||||||
|
|
||||||
|
def predict(self, *_, **__):
|
||||||
|
if self._dirt_register.by_pos(self.pos) is not None:
|
||||||
|
# Translate the action_object to an integer to have the same output as any other model
|
||||||
|
action = h.EnvActions.CLEAN_UP
|
||||||
|
elif any('door' in x.name.lower() for x in self.tile.guests):
|
||||||
|
door = next(x for x in self.tile.guests if 'door' in x.name.lower())
|
||||||
|
if door.is_closed:
|
||||||
|
# Translate the action_object to an integer to have the same output as any other model
|
||||||
|
action = h.EnvActions.USE_DOOR
|
||||||
|
else:
|
||||||
|
action = self._predict_move()
|
||||||
|
else:
|
||||||
|
action = self._predict_move()
|
||||||
|
# Translate the action_object to an integer to have the same output as any other model
|
||||||
|
action_obj = next(action_i for action_i, action_obj in enumerate(self._actions) if action_obj == action)
|
||||||
|
return action_obj
|
||||||
|
|
||||||
|
def _predict_move(self):
|
||||||
|
if self.static_problem:
|
||||||
|
if self._static_route is None:
|
||||||
|
self._static_route = self.calculate_tsp_route()
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
next_pos = self._static_route.pop(0)
|
||||||
|
while next_pos == self.pos:
|
||||||
|
next_pos = self._static_route.pop(0)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
diff = np.subtract(next_pos, self.pos)
|
||||||
|
# Retrieve action based on the pos dif (like in: What do i have to do to get there?)
|
||||||
|
try:
|
||||||
|
action = next(action for action, pos_diff in h.ACTIONMAP.items()
|
||||||
|
if (diff == pos_diff).all())
|
||||||
|
except StopIteration:
|
||||||
|
print('This Should not happen!')
|
||||||
|
return action
|
||||||
|
|
||||||
|
def calculate_tsp_route(self):
|
||||||
|
route = tsp.traveling_salesman_problem(self._floortile_graph,
|
||||||
|
nodes=[self.pos] + [x for x in self._dirt_register.positions])
|
||||||
|
return route
|
@ -3,7 +3,7 @@ from enum import Enum
|
|||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
import numpy as np
|
from environments import helpers as h
|
||||||
from environments.helpers import Constants as c
|
from environments.helpers import Constants as c
|
||||||
import itertools
|
import itertools
|
||||||
|
|
||||||
@ -267,11 +267,7 @@ class Door(Entity):
|
|||||||
neighbor_pos = list(itertools.product([-1, 1, 0], repeat=2))[:-1]
|
neighbor_pos = list(itertools.product([-1, 1, 0], repeat=2))[:-1]
|
||||||
neighbor_tiles = [context.by_pos(tuple([sum(x) for x in zip(self.pos, diff)])) for diff in neighbor_pos]
|
neighbor_tiles = [context.by_pos(tuple([sum(x) for x in zip(self.pos, diff)])) for diff in neighbor_pos]
|
||||||
neighbor_pos = [x.pos for x in neighbor_tiles if x]
|
neighbor_pos = [x.pos for x in neighbor_tiles if x]
|
||||||
possible_connections = itertools.combinations(neighbor_pos, 2)
|
self.connectivity = h.points_to_graph(neighbor_pos)
|
||||||
self.connectivity = nx.Graph()
|
|
||||||
for a, b in possible_connections:
|
|
||||||
if not max(abs(np.subtract(a, b))) > 1:
|
|
||||||
self.connectivity.add_edge(a, b)
|
|
||||||
self.connectivity_subgroups = list(nx.algorithms.components.connected_components(self.connectivity))
|
self.connectivity_subgroups = list(nx.algorithms.components.connected_components(self.connectivity))
|
||||||
for idx, group in enumerate(self.connectivity_subgroups):
|
for idx, group in enumerate(self.connectivity_subgroups):
|
||||||
for tile_pos in group:
|
for tile_pos in group:
|
||||||
|
@ -320,6 +320,9 @@ class Agents(MovingEntityObjectRegister):
|
|||||||
def positions(self):
|
def positions(self):
|
||||||
return [agent.pos for agent in self]
|
return [agent.pos for agent in self]
|
||||||
|
|
||||||
|
def __setitem__(self, key, value):
|
||||||
|
self._register[self[key].name] = value
|
||||||
|
|
||||||
|
|
||||||
class Doors(EntityObjectRegister):
|
class Doors(EntityObjectRegister):
|
||||||
|
|
||||||
|
@ -5,6 +5,7 @@ import random
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from algorithms.TSP_dirt_agent import TSPDirtAgent
|
||||||
from environments.helpers import Constants as c
|
from environments.helpers import Constants as c
|
||||||
from environments import helpers as h
|
from environments import helpers as h
|
||||||
from environments.factory.base.base_factory import BaseFactory
|
from environments.factory.base.base_factory import BaseFactory
|
||||||
@ -262,17 +263,29 @@ if __name__ == '__main__':
|
|||||||
from environments.utility_classes import AgentRenderOptions as ARO
|
from environments.utility_classes import AgentRenderOptions as ARO
|
||||||
render = True
|
render = True
|
||||||
|
|
||||||
dirt_props = DirtProperties(1, 0.05, 0.1, 3, 1, 20, 0)
|
dirt_props = DirtProperties(
|
||||||
|
initial_dirt_ratio=0.35,
|
||||||
|
initial_dirt_spawn_r_var=0.1,
|
||||||
|
clean_amount=0.34,
|
||||||
|
max_spawn_amount=0.1,
|
||||||
|
max_global_amount=20,
|
||||||
|
max_local_amount=1,
|
||||||
|
spawn_frequency=0,
|
||||||
|
max_spawn_ratio=0.05,
|
||||||
|
dirt_smear_amount=0.0,
|
||||||
|
agent_can_interact=True
|
||||||
|
)
|
||||||
|
|
||||||
obs_props = ObservationProperties(render_agents=ARO.COMBINED, omit_agent_self=True,
|
obs_props = ObservationProperties(render_agents=ARO.COMBINED, omit_agent_self=True,
|
||||||
pomdp_r=15, additional_agent_placeholder=None)
|
pomdp_r=2, additional_agent_placeholder=None)
|
||||||
|
|
||||||
move_props = {'allow_square_movement': True,
|
move_props = {'allow_square_movement': True,
|
||||||
'allow_diagonal_movement': False,
|
'allow_diagonal_movement': False,
|
||||||
'allow_no_op': False}
|
'allow_no_op': False}
|
||||||
|
|
||||||
factory = DirtFactory(n_agents=5, done_at_collision=False,
|
factory = DirtFactory(n_agents=1, done_at_collision=False,
|
||||||
level_name='rooms', max_steps=400,
|
level_name='rooms', max_steps=400,
|
||||||
|
doors_have_area=False,
|
||||||
obs_prop=obs_props, parse_doors=True,
|
obs_prop=obs_props, parse_doors=True,
|
||||||
record_episodes=True, verbose=True,
|
record_episodes=True, verbose=True,
|
||||||
mv_prop=move_props, dirt_prop=dirt_props
|
mv_prop=move_props, dirt_prop=dirt_props
|
||||||
@ -287,9 +300,15 @@ if __name__ == '__main__':
|
|||||||
in range(factory.n_agents)] for _
|
in range(factory.n_agents)] for _
|
||||||
in range(factory.max_steps+1)]
|
in range(factory.max_steps+1)]
|
||||||
env_state = factory.reset()
|
env_state = factory.reset()
|
||||||
|
if render:
|
||||||
|
factory.render()
|
||||||
|
random_start_position = factory[c.AGENT][0].tile
|
||||||
|
factory[c.AGENT][0] = tsp_agent = TSPDirtAgent(factory[c.FLOOR], factory[c.DIRT],
|
||||||
|
factory._actions, random_start_position)
|
||||||
|
|
||||||
r = 0
|
r = 0
|
||||||
for agent_i_action in random_actions:
|
for agent_i_action in random_actions:
|
||||||
env_state, step_r, done_bool, info_obj = factory.step(agent_i_action)
|
env_state, step_r, done_bool, info_obj = factory.step(tsp_agent.predict())
|
||||||
r += step_r
|
r += step_r
|
||||||
if render:
|
if render:
|
||||||
factory.render()
|
factory.render()
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
|
import itertools
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
from typing import Tuple, Union
|
from typing import Tuple, Union
|
||||||
|
|
||||||
|
import networkx as nx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@ -153,6 +155,23 @@ def asset_str(agent):
|
|||||||
return c.AGENT.value, 'idle'
|
return c.AGENT.value, 'idle'
|
||||||
|
|
||||||
|
|
||||||
|
def points_to_graph(coordiniates_or_tiles, allow_euclidean_connections=True, allow_manhattan_connections=True):
|
||||||
|
assert allow_euclidean_connections or allow_manhattan_connections
|
||||||
|
if hasattr(coordiniates_or_tiles, 'positions'):
|
||||||
|
coordiniates_or_tiles = coordiniates_or_tiles.positions
|
||||||
|
possible_connections = itertools.combinations(coordiniates_or_tiles, 2)
|
||||||
|
graph = nx.Graph()
|
||||||
|
for a, b in possible_connections:
|
||||||
|
diff = abs(np.subtract(a, b))
|
||||||
|
if not max(diff) > 1:
|
||||||
|
if allow_manhattan_connections and allow_euclidean_connections:
|
||||||
|
graph.add_edge(a, b)
|
||||||
|
elif not allow_manhattan_connections and allow_euclidean_connections and all(diff):
|
||||||
|
graph.add_edge(a, b)
|
||||||
|
elif allow_manhattan_connections and not allow_euclidean_connections and not all(diff) and any(diff):
|
||||||
|
graph.add_edge(a, b)
|
||||||
|
return graph
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parsed_level = parse_level(Path(__file__).parent / 'factory' / 'levels' / 'simple.txt')
|
parsed_level = parse_level(Path(__file__).parent / 'factory' / 'levels' / 'simple.txt')
|
||||||
y = one_hot_level(parsed_level)
|
y = one_hot_level(parsed_level)
|
||||||
|
@ -75,7 +75,7 @@ baseline_monitor_file = 'e_1_baseline'
|
|||||||
from stable_baselines3 import A2C
|
from stable_baselines3 import A2C
|
||||||
|
|
||||||
def policy_model_kwargs():
|
def policy_model_kwargs():
|
||||||
return dict(gae_lambda=0.25, n_steps=16, max_grad_norm=0, use_rms_prop=False)
|
return dict(gae_lambda=0.25, n_steps=16, max_grad_norm=0, use_rms_prop=True)
|
||||||
|
|
||||||
|
|
||||||
def dqn_model_kwargs():
|
def dqn_model_kwargs():
|
||||||
@ -203,7 +203,7 @@ if __name__ == '__main__':
|
|||||||
frames_to_stack = 3
|
frames_to_stack = 3
|
||||||
|
|
||||||
# Define a global studi save path
|
# Define a global studi save path
|
||||||
start_time = 'adam_no_weight_decay' # int(time.time())
|
start_time = 'rms_weight_decay_0' # int(time.time())
|
||||||
study_root_path = Path(__file__).parent.parent / 'study_out' / f'{Path(__file__).stem}_{start_time}'
|
study_root_path = Path(__file__).parent.parent / 'study_out' / f'{Path(__file__).stem}_{start_time}'
|
||||||
|
|
||||||
# Define Global Env Parameters
|
# Define Global Env Parameters
|
||||||
@ -285,36 +285,36 @@ if __name__ == '__main__':
|
|||||||
pomdp_r=2)
|
pomdp_r=2)
|
||||||
)
|
)
|
||||||
)})
|
)})
|
||||||
observation_modes.update({
|
observation_modes.update({
|
||||||
'seperate_N': dict(
|
'seperate_N': dict(
|
||||||
post_training_kwargs=
|
post_training_kwargs=
|
||||||
dict(obs_prop=ObservationProperties(
|
dict(obs_prop=ObservationProperties(
|
||||||
render_agents=AgentRenderOptions.COMBINED,
|
render_agents=AgentRenderOptions.COMBINED,
|
||||||
additional_agent_placeholder=None,
|
additional_agent_placeholder=None,
|
||||||
omit_agent_self=True,
|
omit_agent_self=True,
|
||||||
frames_to_stack=frames_to_stack,
|
frames_to_stack=frames_to_stack,
|
||||||
pomdp_r=2)
|
pomdp_r=2)
|
||||||
),
|
),
|
||||||
additional_env_kwargs=
|
additional_env_kwargs=
|
||||||
dict(obs_prop=ObservationProperties(
|
dict(obs_prop=ObservationProperties(
|
||||||
render_agents=AgentRenderOptions.NOT,
|
render_agents=AgentRenderOptions.NOT,
|
||||||
additional_agent_placeholder='N',
|
additional_agent_placeholder='N',
|
||||||
omit_agent_self=True,
|
omit_agent_self=True,
|
||||||
frames_to_stack=frames_to_stack,
|
frames_to_stack=frames_to_stack,
|
||||||
pomdp_r=2)
|
pomdp_r=2)
|
||||||
)
|
)
|
||||||
)})
|
)})
|
||||||
observation_modes.update({
|
observation_modes.update({
|
||||||
'in_lvl_obs': dict(
|
'in_lvl_obs': dict(
|
||||||
post_training_kwargs=
|
post_training_kwargs=
|
||||||
dict(obs_prop=ObservationProperties(
|
dict(obs_prop=ObservationProperties(
|
||||||
render_agents=AgentRenderOptions.LEVEL,
|
render_agents=AgentRenderOptions.LEVEL,
|
||||||
omit_agent_self=True,
|
omit_agent_self=True,
|
||||||
additional_agent_placeholder=None,
|
additional_agent_placeholder=None,
|
||||||
frames_to_stack=frames_to_stack,
|
frames_to_stack=frames_to_stack,
|
||||||
pomdp_r=2)
|
pomdp_r=2)
|
||||||
)
|
)
|
||||||
)})
|
)})
|
||||||
observation_modes.update({
|
observation_modes.update({
|
||||||
# No further adjustment needed
|
# No further adjustment needed
|
||||||
'no_obs': dict(
|
'no_obs': dict(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user