2021-11-25 14:48:34 +01:00

67 lines
2.9 KiB
Python

import numpy as np
from networkx.algorithms.approximation import traveling_salesman as tsp
from environments.factory.base.objects import Agent
from environments.factory.base.registers import FloorTiles, Actions
from environments.helpers import points_to_graph
from environments import helpers as h
class TSPDirtAgent(Agent):
def __init__(self, floortiles: FloorTiles, dirt_register, actions: Actions, *args,
static_problem: bool = True, **kwargs):
super().__init__(*args, **kwargs)
self.static_problem = static_problem
self._floortiles = floortiles
self._actions = actions
self._dirt_register = dirt_register
self._floortile_graph = points_to_graph(self._floortiles.positions,
allow_euclidean_connections=self._actions.allow_diagonal_movement,
allow_manhattan_connections=self._actions.allow_square_movement)
self._static_route = None
def predict(self, *_, **__):
if self._dirt_register.by_pos(self.pos) is not None:
# Translate the action_object to an integer to have the same output as any other model
action = h.EnvActions.CLEAN_UP
elif any('door' in x.name.lower() for x in self.tile.guests):
door = next(x for x in self.tile.guests if 'door' in x.name.lower())
if door.is_closed:
# Translate the action_object to an integer to have the same output as any other model
action = h.EnvActions.USE_DOOR
else:
action = self._predict_move()
else:
action = self._predict_move()
# Translate the action_object to an integer to have the same output as any other model
action_obj = next(action_i for action_i, action_obj in enumerate(self._actions) if action_obj == action)
return action_obj
def _predict_move(self):
if self.static_problem:
if self._static_route is None:
self._static_route = self.calculate_tsp_route()
else:
pass
next_pos = self._static_route.pop(0)
while next_pos == self.pos:
next_pos = self._static_route.pop(0)
else:
raise NotImplementedError
diff = np.subtract(next_pos, self.pos)
# Retrieve action based on the pos dif (like in: What do i have to do to get there?)
try:
action = next(action for action, pos_diff in h.ACTIONMAP.items()
if (diff == pos_diff).all())
except StopIteration:
print('This Should not happen!')
return action
def calculate_tsp_route(self):
route = tsp.traveling_salesman_problem(self._floortile_graph,
nodes=[self.pos] + [x for x in self._dirt_register.positions])
return route