TSP Single Agent
This commit is contained in:
@ -3,7 +3,7 @@ from enum import Enum
|
||||
from typing import Union
|
||||
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
from environments import helpers as h
|
||||
from environments.helpers import Constants as c
|
||||
import itertools
|
||||
|
||||
@ -267,11 +267,7 @@ class Door(Entity):
|
||||
neighbor_pos = list(itertools.product([-1, 1, 0], repeat=2))[:-1]
|
||||
neighbor_tiles = [context.by_pos(tuple([sum(x) for x in zip(self.pos, diff)])) for diff in neighbor_pos]
|
||||
neighbor_pos = [x.pos for x in neighbor_tiles if x]
|
||||
possible_connections = itertools.combinations(neighbor_pos, 2)
|
||||
self.connectivity = nx.Graph()
|
||||
for a, b in possible_connections:
|
||||
if not max(abs(np.subtract(a, b))) > 1:
|
||||
self.connectivity.add_edge(a, b)
|
||||
self.connectivity = h.points_to_graph(neighbor_pos)
|
||||
self.connectivity_subgroups = list(nx.algorithms.components.connected_components(self.connectivity))
|
||||
for idx, group in enumerate(self.connectivity_subgroups):
|
||||
for tile_pos in group:
|
||||
|
@ -320,6 +320,9 @@ class Agents(MovingEntityObjectRegister):
|
||||
def positions(self):
|
||||
return [agent.pos for agent in self]
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self._register[self[key].name] = value
|
||||
|
||||
|
||||
class Doors(EntityObjectRegister):
|
||||
|
||||
|
@ -5,6 +5,7 @@ import random
|
||||
|
||||
import numpy as np
|
||||
|
||||
from algorithms.TSP_dirt_agent import TSPDirtAgent
|
||||
from environments.helpers import Constants as c
|
||||
from environments import helpers as h
|
||||
from environments.factory.base.base_factory import BaseFactory
|
||||
@ -262,17 +263,29 @@ if __name__ == '__main__':
|
||||
from environments.utility_classes import AgentRenderOptions as ARO
|
||||
render = True
|
||||
|
||||
dirt_props = DirtProperties(1, 0.05, 0.1, 3, 1, 20, 0)
|
||||
dirt_props = DirtProperties(
|
||||
initial_dirt_ratio=0.35,
|
||||
initial_dirt_spawn_r_var=0.1,
|
||||
clean_amount=0.34,
|
||||
max_spawn_amount=0.1,
|
||||
max_global_amount=20,
|
||||
max_local_amount=1,
|
||||
spawn_frequency=0,
|
||||
max_spawn_ratio=0.05,
|
||||
dirt_smear_amount=0.0,
|
||||
agent_can_interact=True
|
||||
)
|
||||
|
||||
obs_props = ObservationProperties(render_agents=ARO.COMBINED, omit_agent_self=True,
|
||||
pomdp_r=15, additional_agent_placeholder=None)
|
||||
pomdp_r=2, additional_agent_placeholder=None)
|
||||
|
||||
move_props = {'allow_square_movement': True,
|
||||
'allow_diagonal_movement': False,
|
||||
'allow_no_op': False}
|
||||
|
||||
factory = DirtFactory(n_agents=5, done_at_collision=False,
|
||||
factory = DirtFactory(n_agents=1, done_at_collision=False,
|
||||
level_name='rooms', max_steps=400,
|
||||
doors_have_area=False,
|
||||
obs_prop=obs_props, parse_doors=True,
|
||||
record_episodes=True, verbose=True,
|
||||
mv_prop=move_props, dirt_prop=dirt_props
|
||||
@ -287,9 +300,15 @@ if __name__ == '__main__':
|
||||
in range(factory.n_agents)] for _
|
||||
in range(factory.max_steps+1)]
|
||||
env_state = factory.reset()
|
||||
if render:
|
||||
factory.render()
|
||||
random_start_position = factory[c.AGENT][0].tile
|
||||
factory[c.AGENT][0] = tsp_agent = TSPDirtAgent(factory[c.FLOOR], factory[c.DIRT],
|
||||
factory._actions, random_start_position)
|
||||
|
||||
r = 0
|
||||
for agent_i_action in random_actions:
|
||||
env_state, step_r, done_bool, info_obj = factory.step(agent_i_action)
|
||||
env_state, step_r, done_bool, info_obj = factory.step(tsp_agent.predict())
|
||||
r += step_r
|
||||
if render:
|
||||
factory.render()
|
||||
|
@ -1,7 +1,9 @@
|
||||
import itertools
|
||||
from collections import defaultdict
|
||||
from enum import Enum, auto
|
||||
from typing import Tuple, Union
|
||||
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
@ -153,6 +155,23 @@ def asset_str(agent):
|
||||
return c.AGENT.value, 'idle'
|
||||
|
||||
|
||||
def points_to_graph(coordiniates_or_tiles, allow_euclidean_connections=True, allow_manhattan_connections=True):
|
||||
assert allow_euclidean_connections or allow_manhattan_connections
|
||||
if hasattr(coordiniates_or_tiles, 'positions'):
|
||||
coordiniates_or_tiles = coordiniates_or_tiles.positions
|
||||
possible_connections = itertools.combinations(coordiniates_or_tiles, 2)
|
||||
graph = nx.Graph()
|
||||
for a, b in possible_connections:
|
||||
diff = abs(np.subtract(a, b))
|
||||
if not max(diff) > 1:
|
||||
if allow_manhattan_connections and allow_euclidean_connections:
|
||||
graph.add_edge(a, b)
|
||||
elif not allow_manhattan_connections and allow_euclidean_connections and all(diff):
|
||||
graph.add_edge(a, b)
|
||||
elif allow_manhattan_connections and not allow_euclidean_connections and not all(diff) and any(diff):
|
||||
graph.add_edge(a, b)
|
||||
return graph
|
||||
|
||||
if __name__ == '__main__':
|
||||
parsed_level = parse_level(Path(__file__).parent / 'factory' / 'levels' / 'simple.txt')
|
||||
y = one_hot_level(parsed_level)
|
||||
|
Reference in New Issue
Block a user