mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-23 07:16:44 +02:00
Destinations implemented and debugged
This commit is contained in:
parent
3d81b7577d
commit
7f7a3d9a3b
@ -3,27 +3,28 @@ import numpy as np
|
|||||||
from networkx.algorithms.approximation import traveling_salesman as tsp
|
from networkx.algorithms.approximation import traveling_salesman as tsp
|
||||||
|
|
||||||
from environments.factory.base.objects import Agent
|
from environments.factory.base.objects import Agent
|
||||||
from environments.factory.base.registers import FloorTiles, Actions
|
|
||||||
from environments.helpers import points_to_graph
|
from environments.helpers import points_to_graph
|
||||||
from environments import helpers as h
|
from environments import helpers as h
|
||||||
|
from environments.helpers import Constants as c
|
||||||
|
|
||||||
|
|
||||||
|
future_planning = 7
|
||||||
|
|
||||||
class TSPDirtAgent(Agent):
|
class TSPDirtAgent(Agent):
|
||||||
|
|
||||||
def __init__(self, floortiles: FloorTiles, dirt_register, actions: Actions, *args,
|
def __init__(self, env, *args,
|
||||||
static_problem: bool = True, **kwargs):
|
static_problem: bool = True, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.static_problem = static_problem
|
self.static_problem = static_problem
|
||||||
self._floortiles = floortiles
|
self.local_optimization = True
|
||||||
self._actions = actions
|
self._env = env
|
||||||
self._dirt_register = dirt_register
|
self._floortile_graph = points_to_graph(self._env[c.FLOOR].positions,
|
||||||
self._floortile_graph = points_to_graph(self._floortiles.positions,
|
allow_euclidean_connections=self._env._actions.allow_diagonal_movement,
|
||||||
allow_euclidean_connections=self._actions.allow_diagonal_movement,
|
allow_manhattan_connections=self._env._actions.allow_square_movement)
|
||||||
allow_manhattan_connections=self._actions.allow_square_movement)
|
|
||||||
self._static_route = None
|
self._static_route = None
|
||||||
|
|
||||||
def predict(self, *_, **__):
|
def predict(self, *_, **__):
|
||||||
if self._dirt_register.by_pos(self.pos) is not None:
|
if self._env[c.DIRT].by_pos(self.pos) is not None:
|
||||||
# Translate the action_object to an integer to have the same output as any other model
|
# Translate the action_object to an integer to have the same output as any other model
|
||||||
action = h.EnvActions.CLEAN_UP
|
action = h.EnvActions.CLEAN_UP
|
||||||
elif any('door' in x.name.lower() for x in self.tile.guests):
|
elif any('door' in x.name.lower() for x in self.tile.guests):
|
||||||
@ -36,12 +37,13 @@ class TSPDirtAgent(Agent):
|
|||||||
else:
|
else:
|
||||||
action = self._predict_move()
|
action = self._predict_move()
|
||||||
# Translate the action_object to an integer to have the same output as any other model
|
# Translate the action_object to an integer to have the same output as any other model
|
||||||
action_obj = next(action_i for action_i, action_obj in enumerate(self._actions) if action_obj == action)
|
action_obj = next(action_i for action_i, action_obj in enumerate(self._env._actions) if action_obj == action)
|
||||||
return action_obj
|
return action_obj
|
||||||
|
|
||||||
def _predict_move(self):
|
def _predict_move(self):
|
||||||
|
if len(self._env[c.DIRT]) >= 1:
|
||||||
if self.static_problem:
|
if self.static_problem:
|
||||||
if self._static_route is None:
|
if not self._static_route:
|
||||||
self._static_route = self.calculate_tsp_route()
|
self._static_route = self.calculate_tsp_route()
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
@ -49,7 +51,11 @@ class TSPDirtAgent(Agent):
|
|||||||
while next_pos == self.pos:
|
while next_pos == self.pos:
|
||||||
next_pos = self._static_route.pop(0)
|
next_pos = self._static_route.pop(0)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
if not self._static_route:
|
||||||
|
self._static_route = self.calculate_tsp_route()[:7]
|
||||||
|
next_pos = self._static_route.pop(0)
|
||||||
|
while next_pos == self.pos:
|
||||||
|
next_pos = self._static_route.pop(0)
|
||||||
|
|
||||||
diff = np.subtract(next_pos, self.pos)
|
diff = np.subtract(next_pos, self.pos)
|
||||||
# Retrieve action based on the pos dif (like in: What do i have to do to get there?)
|
# Retrieve action based on the pos dif (like in: What do i have to do to get there?)
|
||||||
@ -58,9 +64,23 @@ class TSPDirtAgent(Agent):
|
|||||||
if (diff == pos_diff).all())
|
if (diff == pos_diff).all())
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
print('This Should not happen!')
|
print('This Should not happen!')
|
||||||
|
else:
|
||||||
|
action = int(np.random.randint(self._env.action_space.n))
|
||||||
return action
|
return action
|
||||||
|
|
||||||
def calculate_tsp_route(self):
|
def calculate_tsp_route(self):
|
||||||
|
if self.local_optimization:
|
||||||
|
nodes = \
|
||||||
|
[self.pos] + \
|
||||||
|
[x for x in self._env[c.DIRT].positions if max(abs(np.subtract(x, self.pos))) < 3]
|
||||||
|
try:
|
||||||
|
while len(nodes) < 7:
|
||||||
|
nodes += [next(x for x in self._env[c.DIRT].positions if x not in nodes)]
|
||||||
|
except StopIteration:
|
||||||
|
nodes = [self.pos] + self._env[c.DIRT].positions
|
||||||
|
|
||||||
|
else:
|
||||||
|
nodes = [self.pos] + self._env[c.DIRT].positions
|
||||||
route = tsp.traveling_salesman_problem(self._floortile_graph,
|
route = tsp.traveling_salesman_problem(self._floortile_graph,
|
||||||
nodes=[self.pos] + [x for x in self._dirt_register.positions])
|
nodes=nodes, cycle=True, method=tsp.greedy_tsp)
|
||||||
return route
|
return route
|
||||||
|
BIN
environments/factory/assets/charge_pod.png
Normal file
BIN
environments/factory/assets/charge_pod.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 6.5 KiB |
BIN
environments/factory/assets/destination.png
Normal file
BIN
environments/factory/assets/destination.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 6.9 KiB |
@ -64,7 +64,7 @@ class BaseFactory(gym.Env):
|
|||||||
def __init__(self, level_name='simple', n_agents=1, max_steps=int(5e2),
|
def __init__(self, level_name='simple', n_agents=1, max_steps=int(5e2),
|
||||||
mv_prop: MovementProperties = MovementProperties(),
|
mv_prop: MovementProperties = MovementProperties(),
|
||||||
obs_prop: ObservationProperties = ObservationProperties(),
|
obs_prop: ObservationProperties = ObservationProperties(),
|
||||||
parse_doors=False, done_at_collision=False,
|
parse_doors=False, done_at_collision=False, inject_agents: Union[None, List] = None,
|
||||||
verbose=False, doors_have_area=True, env_seed=time.time_ns(), individual_rewards=False,
|
verbose=False, doors_have_area=True, env_seed=time.time_ns(), individual_rewards=False,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
|
|
||||||
@ -98,6 +98,7 @@ class BaseFactory(gym.Env):
|
|||||||
self.done_at_collision = done_at_collision
|
self.done_at_collision = done_at_collision
|
||||||
self._record_episodes = False
|
self._record_episodes = False
|
||||||
self.parse_doors = parse_doors
|
self.parse_doors = parse_doors
|
||||||
|
self._injected_agents = inject_agents or []
|
||||||
self.doors_have_area = doors_have_area
|
self.doors_have_area = doors_have_area
|
||||||
self.individual_rewards = individual_rewards
|
self.individual_rewards = individual_rewards
|
||||||
|
|
||||||
@ -108,8 +109,10 @@ class BaseFactory(gym.Env):
|
|||||||
return self._entities[item]
|
return self._entities[item]
|
||||||
|
|
||||||
def _base_init_env(self):
|
def _base_init_env(self):
|
||||||
|
|
||||||
|
# All entities
|
||||||
# Objects
|
# Objects
|
||||||
entities = {}
|
self._entities = Entities()
|
||||||
# Level
|
# Level
|
||||||
level_filepath = Path(__file__).parent.parent / h.LEVELS_DIR / f'{self.level_name}.txt'
|
level_filepath = Path(__file__).parent.parent / h.LEVELS_DIR / f'{self.level_name}.txt'
|
||||||
parsed_level = h.parse_level(level_filepath)
|
parsed_level = h.parse_level(level_filepath)
|
||||||
@ -121,14 +124,14 @@ class BaseFactory(gym.Env):
|
|||||||
np.argwhere(level_array == c.OCCUPIED_CELL.value),
|
np.argwhere(level_array == c.OCCUPIED_CELL.value),
|
||||||
self._level_shape
|
self._level_shape
|
||||||
)
|
)
|
||||||
entities.update({c.WALLS: walls})
|
self._entities.register_additional_items({c.WALLS: walls})
|
||||||
|
|
||||||
# Floor
|
# Floor
|
||||||
floor = FloorTiles.from_argwhere_coordinates(
|
floor = FloorTiles.from_argwhere_coordinates(
|
||||||
np.argwhere(level_array == c.FREE_CELL.value),
|
np.argwhere(level_array == c.FREE_CELL.value),
|
||||||
self._level_shape
|
self._level_shape
|
||||||
)
|
)
|
||||||
entities.update({c.FLOOR: floor})
|
self._entities.register_additional_items({c.FLOOR: floor})
|
||||||
|
|
||||||
# NOPOS
|
# NOPOS
|
||||||
self._NO_POS_TILE = Tile(c.NO_POS.value)
|
self._NO_POS_TILE = Tile(c.NO_POS.value)
|
||||||
@ -141,7 +144,7 @@ class BaseFactory(gym.Env):
|
|||||||
doors = Doors.from_tiles(door_tiles, self._level_shape,
|
doors = Doors.from_tiles(door_tiles, self._level_shape,
|
||||||
entity_kwargs=dict(context=floor)
|
entity_kwargs=dict(context=floor)
|
||||||
)
|
)
|
||||||
entities.update({c.DOORS: doors})
|
self._entities.register_additional_items({c.DOORS: doors})
|
||||||
|
|
||||||
# Actions
|
# Actions
|
||||||
self._actions = Actions(self.mv_prop, can_use_doors=self.parse_doors)
|
self._actions = Actions(self.mv_prop, can_use_doors=self.parse_doors)
|
||||||
@ -149,12 +152,22 @@ class BaseFactory(gym.Env):
|
|||||||
self._actions.register_additional_items(additional_actions)
|
self._actions.register_additional_items(additional_actions)
|
||||||
|
|
||||||
# Agents
|
# Agents
|
||||||
agents = Agents.from_tiles(floor.empty_tiles[:self.n_agents], self._level_shape,
|
agents_to_spawn = self.n_agents-len(self._injected_agents)
|
||||||
|
agents_kwargs = dict(level_shape=self._level_shape,
|
||||||
individual_slices=self.obs_prop.render_agents == a_obs.SEPERATE,
|
individual_slices=self.obs_prop.render_agents == a_obs.SEPERATE,
|
||||||
hide_from_obs_builder=self.obs_prop.render_agents == a_obs.LEVEL,
|
hide_from_obs_builder=self.obs_prop.render_agents == a_obs.LEVEL,
|
||||||
is_observable=self.obs_prop.render_agents != a_obs.NOT
|
is_observable=self.obs_prop.render_agents != a_obs.NOT)
|
||||||
)
|
if agents_to_spawn:
|
||||||
entities.update({c.AGENT: agents})
|
agents = Agents.from_tiles(floor.empty_tiles[:agents_to_spawn], **agents_kwargs)
|
||||||
|
else:
|
||||||
|
agents = Agents(**agents_kwargs)
|
||||||
|
if self._injected_agents:
|
||||||
|
initialized_injections = list()
|
||||||
|
for i, injection in enumerate(self._injected_agents):
|
||||||
|
agents.register_item(injection(self, floor.empty_tiles[agents_to_spawn+i+1], static_problem=False))
|
||||||
|
initialized_injections.append(agents[-1])
|
||||||
|
self._initialized_injections = initialized_injections
|
||||||
|
self._entities.register_additional_items({c.AGENT: agents})
|
||||||
|
|
||||||
if self.obs_prop.additional_agent_placeholder is not None:
|
if self.obs_prop.additional_agent_placeholder is not None:
|
||||||
# TODO: Make this accept Lists for multiple placeholders
|
# TODO: Make this accept Lists for multiple placeholders
|
||||||
@ -165,11 +178,7 @@ class BaseFactory(gym.Env):
|
|||||||
fill_value=self.obs_prop.additional_agent_placeholder)
|
fill_value=self.obs_prop.additional_agent_placeholder)
|
||||||
)
|
)
|
||||||
|
|
||||||
entities.update({c.AGENT_PLACEHOLDER: placeholder})
|
self._entities.register_additional_items({c.AGENT_PLACEHOLDER: placeholder})
|
||||||
|
|
||||||
# All entities
|
|
||||||
self._entities = Entities()
|
|
||||||
self._entities.register_additional_items(entities)
|
|
||||||
|
|
||||||
# Additional Entitites from SubEnvs
|
# Additional Entitites from SubEnvs
|
||||||
if additional_entities := self.additional_entities:
|
if additional_entities := self.additional_entities:
|
||||||
@ -182,6 +191,7 @@ class BaseFactory(gym.Env):
|
|||||||
arrays = self._entities.obs_arrays
|
arrays = self._entities.obs_arrays
|
||||||
|
|
||||||
obs_cube_z = sum([a.shape[0] if not self[key].is_per_agent else 1 for key, a in arrays.items()])
|
obs_cube_z = sum([a.shape[0] if not self[key].is_per_agent else 1 for key, a in arrays.items()])
|
||||||
|
obs_cube_z += 1 if self.obs_prop.show_global_position_info else 0
|
||||||
self._obs_cube = np.zeros((obs_cube_z, *self._level_shape), dtype=np.float32)
|
self._obs_cube = np.zeros((obs_cube_z, *self._level_shape), dtype=np.float32)
|
||||||
|
|
||||||
def reset(self) -> (np.ndarray, int, bool, dict):
|
def reset(self) -> (np.ndarray, int, bool, dict):
|
||||||
@ -279,7 +289,7 @@ class BaseFactory(gym.Env):
|
|||||||
if self.n_agents == 1:
|
if self.n_agents == 1:
|
||||||
obs = self._build_per_agent_obs(self[c.AGENT][0], state_array_dict)
|
obs = self._build_per_agent_obs(self[c.AGENT][0], state_array_dict)
|
||||||
elif self.n_agents >= 2:
|
elif self.n_agents >= 2:
|
||||||
obs = np.stack(self._build_per_agent_obs(agent, state_array_dict) for agent in self[c.AGENT])
|
obs = np.stack([self._build_per_agent_obs(agent, state_array_dict) for agent in self[c.AGENT]])
|
||||||
else:
|
else:
|
||||||
raise ValueError('n_agents cannot be smaller than 1!!')
|
raise ValueError('n_agents cannot be smaller than 1!!')
|
||||||
return obs
|
return obs
|
||||||
@ -384,6 +394,7 @@ class BaseFactory(gym.Env):
|
|||||||
|
|
||||||
if self.obs_prop.pomdp_r:
|
if self.obs_prop.pomdp_r:
|
||||||
oobs = self._do_pomdp_obs_cutout(agent, other_agent_obs)[0]
|
oobs = self._do_pomdp_obs_cutout(agent, other_agent_obs)[0]
|
||||||
|
# noinspection PyUnresolvedReferences
|
||||||
mask = (oobs != c.SHADOWED_CELL.value).astype(int)
|
mask = (oobs != c.SHADOWED_CELL.value).astype(int)
|
||||||
obs[0] += oobs * mask
|
obs[0] += oobs * mask
|
||||||
|
|
||||||
@ -497,7 +508,7 @@ class BaseFactory(gym.Env):
|
|||||||
if self._actions.is_moving_action(agent.temp_action):
|
if self._actions.is_moving_action(agent.temp_action):
|
||||||
if agent.temp_valid:
|
if agent.temp_valid:
|
||||||
# info_dict.update(movement=1)
|
# info_dict.update(movement=1)
|
||||||
per_agent_reward -= 0.01
|
per_agent_reward -= 0.001
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
per_agent_reward -= 0.05
|
per_agent_reward -= 0.05
|
||||||
@ -553,6 +564,7 @@ class BaseFactory(gym.Env):
|
|||||||
self.print(f"reward is {reward}")
|
self.print(f"reward is {reward}")
|
||||||
return reward, combined_info_dict
|
return reward, combined_info_dict
|
||||||
|
|
||||||
|
# noinspection PyGlobalUndefined
|
||||||
def render(self, mode='human'):
|
def render(self, mode='human'):
|
||||||
if not self._renderer: # lazy init
|
if not self._renderer: # lazy init
|
||||||
from environments.factory.base.renderer import Renderer, RenderEntity
|
from environments.factory.base.renderer import Renderer, RenderEntity
|
||||||
@ -560,6 +572,7 @@ class BaseFactory(gym.Env):
|
|||||||
height, width = self._obs_cube.shape[1:]
|
height, width = self._obs_cube.shape[1:]
|
||||||
self._renderer = Renderer(width, height, view_radius=self._pomdp_r, fps=5)
|
self._renderer = Renderer(width, height, view_radius=self._pomdp_r, fps=5)
|
||||||
|
|
||||||
|
# noinspection PyUnboundLocalVariable
|
||||||
walls = [RenderEntity('wall', wall.pos) for wall in self[c.WALLS]]
|
walls = [RenderEntity('wall', wall.pos) for wall in self[c.WALLS]]
|
||||||
|
|
||||||
agents = []
|
agents = []
|
||||||
@ -582,6 +595,12 @@ class BaseFactory(gym.Env):
|
|||||||
with filepath.open('w') as f:
|
with filepath.open('w') as f:
|
||||||
simplejson.dump(d, f, indent=4, namedtuple_as_object=True)
|
simplejson.dump(d, f, indent=4, namedtuple_as_object=True)
|
||||||
|
|
||||||
|
def get_injected_agents(self) -> list:
|
||||||
|
if hasattr(self, '_initialized_injections'):
|
||||||
|
return self._initialized_injections
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
|
||||||
def _summarize_state(self):
|
def _summarize_state(self):
|
||||||
summary = {f'{REC_TAC}step': self._steps}
|
summary = {f'{REC_TAC}step': self._steps}
|
||||||
|
|
||||||
@ -621,9 +640,15 @@ class BaseFactory(gym.Env):
|
|||||||
def additional_obs_build(self) -> List[np.ndarray]:
|
def additional_obs_build(self) -> List[np.ndarray]:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def additional_per_agent_obs_build(self, agent) -> List[np.ndarray]:
|
def additional_per_agent_obs_build(self, agent) -> List[np.ndarray]:
|
||||||
return []
|
additional_per_agent_obs = []
|
||||||
|
if self.obs_prop.show_global_position_info:
|
||||||
|
pos_array = np.zeros(self.observation_space.shape[1:])
|
||||||
|
for xy in range(1):
|
||||||
|
pos_array[0, xy] = agent.pos[xy] / self._level_shape[xy]
|
||||||
|
additional_per_agent_obs.append(pos_array)
|
||||||
|
|
||||||
|
return additional_per_agent_obs
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def do_additional_reset(self) -> None:
|
def do_additional_reset(self) -> None:
|
||||||
|
@ -50,6 +50,8 @@ class Register:
|
|||||||
|
|
||||||
def __getitem__(self, item):
|
def __getitem__(self, item):
|
||||||
if isinstance(item, (int, np.int64, np.int32)):
|
if isinstance(item, (int, np.int64, np.int32)):
|
||||||
|
if item < 0:
|
||||||
|
item = len(self._register) - abs(item)
|
||||||
try:
|
try:
|
||||||
return next(v for i, v in enumerate(self._register.values()) if i == item)
|
return next(v for i, v in enumerate(self._register.values()) if i == item)
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
@ -147,10 +149,10 @@ class MovingEntityObjectRegister(EntityObjectRegister, ABC):
|
|||||||
if self.individual_slices:
|
if self.individual_slices:
|
||||||
self._array = np.delete(self._array, idx, axis=0)
|
self._array = np.delete(self._array, idx, axis=0)
|
||||||
|
|
||||||
def delete_item(self, item):
|
def delete_entity(self, item):
|
||||||
self.delete_item_by_name(item.name)
|
self.delete_entity_by_name(item.name)
|
||||||
|
|
||||||
def delete_item_by_name(self, name):
|
def delete_entity_by_name(self, name):
|
||||||
del self[name]
|
del self[name]
|
||||||
|
|
||||||
|
|
||||||
@ -320,8 +322,11 @@ class Agents(MovingEntityObjectRegister):
|
|||||||
def positions(self):
|
def positions(self):
|
||||||
return [agent.pos for agent in self]
|
return [agent.pos for agent in self]
|
||||||
|
|
||||||
def __setitem__(self, key, value):
|
def replace_agent(self, key, agent):
|
||||||
self._register[self[key].name] = value
|
old_agent = self[key]
|
||||||
|
self[key].tile.leave(self[key])
|
||||||
|
agent._name = old_agent.name
|
||||||
|
self._register[agent.name] = agent
|
||||||
|
|
||||||
|
|
||||||
class Doors(EntityObjectRegister):
|
class Doors(EntityObjectRegister):
|
||||||
|
292
environments/factory/factory_destination.py
Normal file
292
environments/factory/factory_destination.py
Normal file
@ -0,0 +1,292 @@
|
|||||||
|
import time
|
||||||
|
from collections import defaultdict
|
||||||
|
from enum import Enum
|
||||||
|
from typing import List, Union, NamedTuple, Dict
|
||||||
|
import numpy as np
|
||||||
|
import random
|
||||||
|
|
||||||
|
from environments.factory.base.base_factory import BaseFactory
|
||||||
|
from environments.helpers import Constants as c
|
||||||
|
from environments import helpers as h
|
||||||
|
from environments.factory.base.objects import Agent, Entity, Action, Tile
|
||||||
|
from environments.factory.base.registers import Entities, MovingEntityObjectRegister
|
||||||
|
|
||||||
|
from environments.factory.base.renderer import RenderEntity
|
||||||
|
|
||||||
|
|
||||||
|
DESTINATION = 1
|
||||||
|
DESTINATION_DONE = 0.5
|
||||||
|
|
||||||
|
|
||||||
|
class Destination(Entity):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def any_agent_has_dwelled(self):
|
||||||
|
return bool(len(self._per_agent_times))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def currently_dwelling_names(self):
|
||||||
|
return self._per_agent_times.keys()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def can_collide(self):
|
||||||
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def encoding(self):
|
||||||
|
return DESTINATION
|
||||||
|
|
||||||
|
def __init__(self, *args, dwell_time: int = 0, **kwargs):
|
||||||
|
super(Destination, self).__init__(*args, **kwargs)
|
||||||
|
self.dwell_time = dwell_time
|
||||||
|
self._per_agent_times = defaultdict(lambda: dwell_time)
|
||||||
|
|
||||||
|
def wait(self, agent: Agent):
|
||||||
|
self._per_agent_times[agent.name] -= 1
|
||||||
|
return c.VALID
|
||||||
|
|
||||||
|
def leave(self, agent: Agent):
|
||||||
|
del self._per_agent_times[agent.name]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_considered_reached(self):
|
||||||
|
agent_at_position = any(c.AGENT.name.lower() in x.name.lower() for x in self.tile.guests_that_can_collide)
|
||||||
|
return (agent_at_position and not self.dwell_time) or any(x == 0 for x in self._per_agent_times.values())
|
||||||
|
|
||||||
|
def agent_is_dwelling(self, agent: Agent):
|
||||||
|
return self._per_agent_times[agent.name] < self.dwell_time
|
||||||
|
|
||||||
|
def summarize_state(self, n_steps=None) -> dict:
|
||||||
|
state_summary = super().summarize_state(n_steps=n_steps)
|
||||||
|
state_summary.update(per_agent_times=self._per_agent_times)
|
||||||
|
return state_summary
|
||||||
|
|
||||||
|
|
||||||
|
class Destinations(MovingEntityObjectRegister):
|
||||||
|
|
||||||
|
_accepted_objects = Destination
|
||||||
|
_light_blocking = False
|
||||||
|
|
||||||
|
def as_array(self):
|
||||||
|
self._array[:] = c.FREE_CELL.value
|
||||||
|
for item in self:
|
||||||
|
if item.pos != c.NO_POS.value:
|
||||||
|
self._array[0, item.x, item.y] = item.encoding
|
||||||
|
return self._array
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
super(Destinations, self).__repr__()
|
||||||
|
|
||||||
|
|
||||||
|
class ReachedDestinations(Destinations):
|
||||||
|
_accepted_objects = Destination
|
||||||
|
_light_blocking = False
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(ReachedDestinations, self).__init__(*args, is_observable=False, **kwargs)
|
||||||
|
|
||||||
|
def summarize_states(self, n_steps=None):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
class DestSpawnMode(object):
|
||||||
|
DONE = 'DONE'
|
||||||
|
GROUPED = 'GROUPED'
|
||||||
|
PER_DEST = 'PER_DEST'
|
||||||
|
|
||||||
|
|
||||||
|
class DestinationProperties(NamedTuple):
|
||||||
|
n_dests: int = 1 # How many destinations are there
|
||||||
|
dwell_time: int = 0 # How long does the agent need to "wait" on a destination
|
||||||
|
spawn_frequency: int = 0
|
||||||
|
spawn_in_other_zone: bool = True #
|
||||||
|
spawn_mode: str = DestSpawnMode.DONE
|
||||||
|
|
||||||
|
assert dwell_time >= 0, 'dwell_time cannot be < 0!'
|
||||||
|
assert spawn_frequency >= 0, 'spawn_frequency cannot be < 0!'
|
||||||
|
assert n_dests >= 0, 'n_destinations cannot be < 0!'
|
||||||
|
assert (spawn_mode == DestSpawnMode.DONE) != bool(spawn_frequency)
|
||||||
|
|
||||||
|
|
||||||
|
# noinspection PyAttributeOutsideInit, PyAbstractClass
|
||||||
|
class DestinationFactory(BaseFactory):
|
||||||
|
# noinspection PyMissingConstructor
|
||||||
|
|
||||||
|
def __init__(self, *args, dest_prop: DestinationProperties = DestinationProperties(),
|
||||||
|
env_seed=time.time_ns(), **kwargs):
|
||||||
|
if isinstance(dest_prop, dict):
|
||||||
|
dest_prop = DestinationProperties(**dest_prop)
|
||||||
|
self.dest_prop = dest_prop
|
||||||
|
kwargs.update(env_seed=env_seed)
|
||||||
|
self._dest_rng = np.random.default_rng(env_seed)
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def additional_actions(self) -> Union[Action, List[Action]]:
|
||||||
|
# noinspection PyUnresolvedReferences
|
||||||
|
super_actions = super().additional_actions
|
||||||
|
if self.dest_prop.dwell_time:
|
||||||
|
super_actions.append(Action(enum_ident=h.EnvActions.WAIT_ON_DEST))
|
||||||
|
return super_actions
|
||||||
|
|
||||||
|
@property
|
||||||
|
def additional_entities(self) -> Dict[(Enum, Entities)]:
|
||||||
|
# noinspection PyUnresolvedReferences
|
||||||
|
super_entities = super().additional_entities
|
||||||
|
|
||||||
|
empty_tiles = self[c.FLOOR].empty_tiles[:self.dest_prop.n_dests]
|
||||||
|
destinations = Destinations.from_tiles(
|
||||||
|
empty_tiles, self._level_shape,
|
||||||
|
entity_kwargs=dict(
|
||||||
|
dwell_time=self.dest_prop.dwell_time)
|
||||||
|
)
|
||||||
|
reached_destinations = ReachedDestinations(level_shape=self._level_shape)
|
||||||
|
|
||||||
|
super_entities.update({c.DESTINATION: destinations, c.REACHEDDESTINATION: reached_destinations})
|
||||||
|
return super_entities
|
||||||
|
|
||||||
|
def additional_per_agent_obs_build(self, agent) -> List[np.ndarray]:
|
||||||
|
additional_per_agent_obs_build = super().additional_per_agent_obs_build(agent)
|
||||||
|
return additional_per_agent_obs_build
|
||||||
|
|
||||||
|
def wait(self, agent: Agent):
|
||||||
|
if destiantion := self[c.DESTINATION].by_pos(agent.pos):
|
||||||
|
valid = destiantion.wait(agent)
|
||||||
|
return valid
|
||||||
|
else:
|
||||||
|
return c.NOT_VALID
|
||||||
|
|
||||||
|
def do_additional_actions(self, agent: Agent, action: Action) -> Union[None, c]:
|
||||||
|
# noinspection PyUnresolvedReferences
|
||||||
|
valid = super().do_additional_actions(agent, action)
|
||||||
|
if valid is None:
|
||||||
|
if action == h.EnvActions.WAIT_ON_DEST:
|
||||||
|
valid = self.wait(agent)
|
||||||
|
return valid
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return valid
|
||||||
|
|
||||||
|
def do_additional_reset(self) -> None:
|
||||||
|
# noinspection PyUnresolvedReferences
|
||||||
|
super().do_additional_reset()
|
||||||
|
self._dest_spawn_timer = dict()
|
||||||
|
|
||||||
|
def trigger_destination_spawn(self):
|
||||||
|
destinations_to_spawn = [key for key, val in self._dest_spawn_timer.items()
|
||||||
|
if val == self.dest_prop.spawn_frequency]
|
||||||
|
if destinations_to_spawn:
|
||||||
|
n_dest_to_spawn = len(destinations_to_spawn)
|
||||||
|
if self.dest_prop.spawn_mode != DestSpawnMode.GROUPED:
|
||||||
|
destinations = [Destination(tile) for tile in self[c.FLOOR].empty_tiles[:n_dest_to_spawn]]
|
||||||
|
self[c.DESTINATION].register_additional_items(destinations)
|
||||||
|
for dest in destinations_to_spawn:
|
||||||
|
del self._dest_spawn_timer[dest]
|
||||||
|
self.print(f'{n_dest_to_spawn} new destinations have been spawned')
|
||||||
|
elif self.dest_prop.spawn_mode == DestSpawnMode.GROUPED and n_dest_to_spawn == self.dest_prop.n_dests:
|
||||||
|
destinations = [Destination(tile) for tile in self[c.FLOOR].empty_tiles[:n_dest_to_spawn]]
|
||||||
|
self[c.DESTINATION].register_additional_items(destinations)
|
||||||
|
for dest in destinations_to_spawn:
|
||||||
|
del self._dest_spawn_timer[dest]
|
||||||
|
self.print(f'{n_dest_to_spawn} new destinations have been spawned')
|
||||||
|
else:
|
||||||
|
self.print(f'{n_dest_to_spawn} new destinations could be spawned, but waiting for all.')
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
self.print('No Items are spawning, limit is reached.')
|
||||||
|
|
||||||
|
def do_additional_step(self) -> dict:
|
||||||
|
# noinspection PyUnresolvedReferences
|
||||||
|
info_dict = super().do_additional_step()
|
||||||
|
for key, val in self._dest_spawn_timer.items():
|
||||||
|
self._dest_spawn_timer[key] = min(self.dest_prop.spawn_frequency, self._dest_spawn_timer[key] + 1)
|
||||||
|
for dest in list(self[c.DESTINATION].values()):
|
||||||
|
if dest.is_considered_reached:
|
||||||
|
self[c.REACHEDDESTINATION].register_item(dest)
|
||||||
|
self[c.DESTINATION].delete_entity(dest)
|
||||||
|
self._dest_spawn_timer[dest.name] = 0
|
||||||
|
self.print(f'{dest.name} is reached now, removing...')
|
||||||
|
else:
|
||||||
|
for agent_name in dest.currently_dwelling_names:
|
||||||
|
agent = self[c.AGENT].by_name(agent_name)
|
||||||
|
if agent.pos == dest.pos:
|
||||||
|
self.print(f'{agent.name} is still waiting.')
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
dest.leave(agent)
|
||||||
|
self.print(f'{agent.name} left the destination early.')
|
||||||
|
self.trigger_destination_spawn()
|
||||||
|
return info_dict
|
||||||
|
|
||||||
|
def calculate_additional_reward(self, agent: Agent) -> (int, dict):
|
||||||
|
# noinspection PyUnresolvedReferences
|
||||||
|
reward, info_dict = super().calculate_additional_reward(agent)
|
||||||
|
if h.EnvActions.WAIT_ON_DEST == agent.temp_action:
|
||||||
|
if agent.temp_valid:
|
||||||
|
info_dict.update({f'{agent.name}_waiting_at_dest': 1})
|
||||||
|
info_dict.update(agent_waiting_at_dest=1)
|
||||||
|
self.print(f'{agent.name} just waited at {agent.pos}')
|
||||||
|
reward += 0.1
|
||||||
|
else:
|
||||||
|
info_dict.update({f'{agent.name}_tried_failed': 1})
|
||||||
|
info_dict.update(agent_waiting_failed=1)
|
||||||
|
self.print(f'{agent.name} just tried to wait wait at {agent.pos} but failed')
|
||||||
|
reward -= 0.1
|
||||||
|
if len(self[c.REACHEDDESTINATION]):
|
||||||
|
for reached_dest in list(self[c.REACHEDDESTINATION]):
|
||||||
|
if agent.pos == reached_dest.pos:
|
||||||
|
info_dict.update({f'{agent.name}_reached_destination': 1})
|
||||||
|
info_dict.update(agent_reached_destination=1)
|
||||||
|
self.print(f'{agent.name} just reached destination at {agent.pos}')
|
||||||
|
reward += 0.5
|
||||||
|
self[c.REACHEDDESTINATION].delete_entity(reached_dest)
|
||||||
|
return reward, info_dict
|
||||||
|
|
||||||
|
def render_additional_assets(self, mode='human'):
|
||||||
|
# noinspection PyUnresolvedReferences
|
||||||
|
additional_assets = super().render_additional_assets()
|
||||||
|
destinations = [RenderEntity(c.DESTINATION.value, dest.pos) for dest in self[c.DESTINATION]]
|
||||||
|
additional_assets.extend(destinations)
|
||||||
|
return additional_assets
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
from environments.utility_classes import AgentRenderOptions as ARO, ObservationProperties
|
||||||
|
|
||||||
|
render = True
|
||||||
|
|
||||||
|
dest_probs = DestinationProperties(n_dests=2, spawn_frequency=5, spawn_mode=DestSpawnMode.GROUPED)
|
||||||
|
|
||||||
|
obs_props = ObservationProperties(render_agents=ARO.LEVEL, omit_agent_self=True, pomdp_r=2)
|
||||||
|
|
||||||
|
move_props = {'allow_square_movement': True,
|
||||||
|
'allow_diagonal_movement': False,
|
||||||
|
'allow_no_op': False}
|
||||||
|
|
||||||
|
factory = DestinationFactory(n_agents=10, done_at_collision=False,
|
||||||
|
level_name='rooms', max_steps=400,
|
||||||
|
obs_prop=obs_props, parse_doors=True,
|
||||||
|
verbose=True,
|
||||||
|
mv_prop=move_props, dest_prop=dest_probs
|
||||||
|
)
|
||||||
|
|
||||||
|
# noinspection DuplicatedCode
|
||||||
|
n_actions = factory.action_space.n - 1
|
||||||
|
_ = factory.observation_space
|
||||||
|
|
||||||
|
for epoch in range(4):
|
||||||
|
random_actions = [[random.randint(0, n_actions) for _
|
||||||
|
in range(factory.n_agents)] for _
|
||||||
|
in range(factory.max_steps + 1)]
|
||||||
|
env_state = factory.reset()
|
||||||
|
r = 0
|
||||||
|
for agent_i_action in random_actions:
|
||||||
|
env_state, step_r, done_bool, info_obj = factory.step(agent_i_action)
|
||||||
|
r += step_r
|
||||||
|
if render:
|
||||||
|
factory.render()
|
||||||
|
if done_bool:
|
||||||
|
break
|
||||||
|
print(f'Factory run {epoch} done, reward is:\n {r}')
|
||||||
|
pass
|
@ -66,7 +66,7 @@ class DirtRegister(MovingEntityObjectRegister):
|
|||||||
self._array[:] = c.FREE_CELL.value
|
self._array[:] = c.FREE_CELL.value
|
||||||
for dirt in list(self.values()):
|
for dirt in list(self.values()):
|
||||||
if dirt.amount == 0:
|
if dirt.amount == 0:
|
||||||
self.delete_item(dirt)
|
self.delete_entity(dirt)
|
||||||
self._array[0, dirt.x, dirt.y] = dirt.amount
|
self._array[0, dirt.x, dirt.y] = dirt.amount
|
||||||
else:
|
else:
|
||||||
self._array = np.zeros((1, *self._level_shape))
|
self._array = np.zeros((1, *self._level_shape))
|
||||||
@ -155,7 +155,7 @@ class DirtFactory(BaseFactory):
|
|||||||
new_dirt_amount = dirt.amount - self.dirt_prop.clean_amount
|
new_dirt_amount = dirt.amount - self.dirt_prop.clean_amount
|
||||||
|
|
||||||
if new_dirt_amount <= 0:
|
if new_dirt_amount <= 0:
|
||||||
self[c.DIRT].delete_item(dirt)
|
self[c.DIRT].delete_entity(dirt)
|
||||||
else:
|
else:
|
||||||
dirt.set_new_amount(max(new_dirt_amount, c.FREE_CELL.value))
|
dirt.set_new_amount(max(new_dirt_amount, c.FREE_CELL.value))
|
||||||
return c.VALID
|
return c.VALID
|
||||||
@ -243,11 +243,12 @@ class DirtFactory(BaseFactory):
|
|||||||
# Reward if pickup succeds,
|
# Reward if pickup succeds,
|
||||||
# 0.5 on every pickup
|
# 0.5 on every pickup
|
||||||
reward += 0.5
|
reward += 0.5
|
||||||
|
info_dict.update(dirt_cleaned=1)
|
||||||
if self.dirt_prop.done_when_clean and (len(self[c.DIRT]) == 0):
|
if self.dirt_prop.done_when_clean and (len(self[c.DIRT]) == 0):
|
||||||
# 0.5 additional reward for the very last pickup
|
# 0.5 additional reward for the very last pickup
|
||||||
reward += 0.5
|
reward += 4.5
|
||||||
|
info_dict.update(done_clean=1)
|
||||||
self.print(f'{agent.name} did just clean up some dirt at {agent.pos}.')
|
self.print(f'{agent.name} did just clean up some dirt at {agent.pos}.')
|
||||||
info_dict.update(dirt_cleaned=1)
|
|
||||||
else:
|
else:
|
||||||
reward -= 0.01
|
reward -= 0.01
|
||||||
self.print(f'{agent.name} just tried to clean up some dirt at {agent.pos}, but failed.')
|
self.print(f'{agent.name} just tried to clean up some dirt at {agent.pos}, but failed.')
|
||||||
@ -288,23 +289,22 @@ if __name__ == '__main__':
|
|||||||
doors_have_area=False,
|
doors_have_area=False,
|
||||||
obs_prop=obs_props, parse_doors=True,
|
obs_prop=obs_props, parse_doors=True,
|
||||||
record_episodes=True, verbose=True,
|
record_episodes=True, verbose=True,
|
||||||
mv_prop=move_props, dirt_prop=dirt_props
|
mv_prop=move_props, dirt_prop=dirt_props,
|
||||||
|
inject_agents=[TSPDirtAgent]
|
||||||
)
|
)
|
||||||
|
|
||||||
# noinspection DuplicatedCode
|
# noinspection DuplicatedCode
|
||||||
n_actions = factory.action_space.n - 1
|
n_actions = factory.action_space.n - 1
|
||||||
_ = factory.observation_space
|
_ = factory.observation_space
|
||||||
|
|
||||||
for epoch in range(4):
|
for epoch in range(10):
|
||||||
random_actions = [[random.randint(0, n_actions) for _
|
random_actions = [[random.randint(0, n_actions) for _
|
||||||
in range(factory.n_agents)] for _
|
in range(factory.n_agents)] for _
|
||||||
in range(factory.max_steps+1)]
|
in range(factory.max_steps+1)]
|
||||||
env_state = factory.reset()
|
env_state = factory.reset()
|
||||||
if render:
|
if render:
|
||||||
factory.render()
|
factory.render()
|
||||||
random_start_position = factory[c.AGENT][0].tile
|
tsp_agent = factory.get_injected_agents()[0]
|
||||||
factory[c.AGENT][0] = tsp_agent = TSPDirtAgent(factory[c.FLOOR], factory[c.DIRT],
|
|
||||||
factory._actions, random_start_position)
|
|
||||||
|
|
||||||
r = 0
|
r = 0
|
||||||
for agent_i_action in random_actions:
|
for agent_i_action in random_actions:
|
||||||
|
@ -308,7 +308,7 @@ class ItemFactory(BaseFactory):
|
|||||||
if item.auto_despawn >= 1:
|
if item.auto_despawn >= 1:
|
||||||
item.set_auto_despawn(item.auto_despawn-1)
|
item.set_auto_despawn(item.auto_despawn-1)
|
||||||
elif not item.auto_despawn:
|
elif not item.auto_despawn:
|
||||||
self[c.ITEM].delete_item(item)
|
self[c.ITEM].delete_entity(item)
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -55,6 +55,10 @@ class Constants(Enum):
|
|||||||
CHARGE_POD = 'Charge_Pod'
|
CHARGE_POD = 'Charge_Pod'
|
||||||
BATTERIES = 'BATTERIES'
|
BATTERIES = 'BATTERIES'
|
||||||
|
|
||||||
|
# Destination Env
|
||||||
|
DESTINATION = 'Destination'
|
||||||
|
REACHEDDESTINATION = 'ReachedDestination'
|
||||||
|
|
||||||
def __bool__(self):
|
def __bool__(self):
|
||||||
if 'not_' in self.value:
|
if 'not_' in self.value:
|
||||||
return False
|
return False
|
||||||
@ -91,6 +95,7 @@ class EnvActions(Enum):
|
|||||||
CLEAN_UP = 'clean_up'
|
CLEAN_UP = 'clean_up'
|
||||||
ITEM_ACTION = 'item_action'
|
ITEM_ACTION = 'item_action'
|
||||||
CHARGE = 'charge'
|
CHARGE = 'charge'
|
||||||
|
WAIT_ON_DEST = 'wait'
|
||||||
|
|
||||||
|
|
||||||
m = MovingAction
|
m = MovingAction
|
||||||
|
@ -23,6 +23,7 @@ class ObservationProperties(NamedTuple):
|
|||||||
cast_shadows = True
|
cast_shadows = True
|
||||||
frames_to_stack: int = 0
|
frames_to_stack: int = 0
|
||||||
pomdp_r: int = 0
|
pomdp_r: int = 0
|
||||||
|
show_global_position_info: bool = True
|
||||||
|
|
||||||
|
|
||||||
class MarlFrameStack(gym.ObservationWrapper):
|
class MarlFrameStack(gym.ObservationWrapper):
|
||||||
|
@ -31,6 +31,8 @@ def prepare_tex(df, hue, style, hue_order):
|
|||||||
lineplot = sns.lineplot(data=df, x='Episode', y='Score', ci=95, palette=PALETTE,
|
lineplot = sns.lineplot(data=df, x='Episode', y='Score', ci=95, palette=PALETTE,
|
||||||
hue_order=hue_order, hue=hue, style=style)
|
hue_order=hue_order, hue=hue, style=style)
|
||||||
# lineplot.set_title(f'{sorted(list(df["Measurement"].unique()))}')
|
# lineplot.set_title(f'{sorted(list(df["Measurement"].unique()))}')
|
||||||
|
plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0)
|
||||||
|
plt.tight_layout()
|
||||||
return lineplot
|
return lineplot
|
||||||
|
|
||||||
|
|
||||||
|
@ -20,8 +20,8 @@ if __name__ == '__main__':
|
|||||||
render = True
|
render = True
|
||||||
record = True
|
record = True
|
||||||
seed = 67
|
seed = 67
|
||||||
n_agents = 2
|
n_agents = 1
|
||||||
out_path = Path('study_out/e_1_obs_stack_3_gae_0.25_n_steps_16/seperate_N/dirt/A2C_obs_stack_3_gae_0.25_n_steps_16/0_A2C_obs_stack_3_gae_0.25_n_steps_16')
|
out_path = Path('study_out/e_1_new_reward/no_obs/dirt/A2C_new_reward/0_A2C_new_reward')
|
||||||
out_path_2 = Path('study_out/e_1_obs_stack_3_gae_0.25_n_steps_16/seperate_N/dirt/A2C_obs_stack_3_gae_0.25_n_steps_16/1_A2C_obs_stack_3_gae_0.25_n_steps_16')
|
out_path_2 = Path('study_out/e_1_obs_stack_3_gae_0.25_n_steps_16/seperate_N/dirt/A2C_obs_stack_3_gae_0.25_n_steps_16/1_A2C_obs_stack_3_gae_0.25_n_steps_16')
|
||||||
model_path = out_path
|
model_path = out_path
|
||||||
|
|
||||||
@ -38,7 +38,7 @@ if __name__ == '__main__':
|
|||||||
other_model = out_path / 'model.zip'
|
other_model = out_path / 'model.zip'
|
||||||
|
|
||||||
model_cls = next(val for key, val in h.MODEL_MAP.items() if key in out_path.parent.name)
|
model_cls = next(val for key, val in h.MODEL_MAP.items() if key in out_path.parent.name)
|
||||||
models = [model_cls.load(this_model), model_cls.load(other_model)]
|
models = [model_cls.load(this_model)] # , model_cls.load(other_model)]
|
||||||
|
|
||||||
# Init Env
|
# Init Env
|
||||||
with DirtFactory(**env_kwargs) as env:
|
with DirtFactory(**env_kwargs) as env:
|
||||||
|
@ -75,7 +75,7 @@ baseline_monitor_file = 'e_1_baseline'
|
|||||||
from stable_baselines3 import A2C
|
from stable_baselines3 import A2C
|
||||||
|
|
||||||
def policy_model_kwargs():
|
def policy_model_kwargs():
|
||||||
return dict(gae_lambda=0.25, n_steps=16, max_grad_norm=0, use_rms_prop=True)
|
return dict() # gae_lambda=0.25, n_steps=16, max_grad_norm=0.25, use_rms_prop=True)
|
||||||
|
|
||||||
|
|
||||||
def dqn_model_kwargs():
|
def dqn_model_kwargs():
|
||||||
@ -198,12 +198,12 @@ if __name__ == '__main__':
|
|||||||
ood_run = True
|
ood_run = True
|
||||||
plotting = True
|
plotting = True
|
||||||
|
|
||||||
train_steps = 5e6
|
train_steps = 1e7
|
||||||
n_seeds = 3
|
n_seeds = 3
|
||||||
frames_to_stack = 3
|
frames_to_stack = 3
|
||||||
|
|
||||||
# Define a global studi save path
|
# Define a global studi save path
|
||||||
start_time = 'rms_weight_decay_0' # int(time.time())
|
start_time = 'new_reward' # int(time.time())
|
||||||
study_root_path = Path(__file__).parent.parent / 'study_out' / f'{Path(__file__).stem}_{start_time}'
|
study_root_path = Path(__file__).parent.parent / 'study_out' / f'{Path(__file__).stem}_{start_time}'
|
||||||
|
|
||||||
# Define Global Env Parameters
|
# Define Global Env Parameters
|
||||||
@ -516,7 +516,7 @@ if __name__ == '__main__':
|
|||||||
# df_melted["Measurements"] = df_melted["Measurement"] + " " + df_melted["monitor"]
|
# df_melted["Measurements"] = df_melted["Measurement"] + " " + df_melted["monitor"]
|
||||||
|
|
||||||
# Plotting
|
# Plotting
|
||||||
# fig, ax = plt.subplots(figsize=(11.7, 8.27))
|
fig, ax = plt.subplots(figsize=(11.7, 8.27))
|
||||||
|
|
||||||
c = sns.catplot(data=df_melted[df_melted['env'] == env_name],
|
c = sns.catplot(data=df_melted[df_melted['env'] == env_name],
|
||||||
x='Measurement', hue='monitor', row='model', col='obs_mode', y='Score',
|
x='Measurement', hue='monitor', row='model', col='obs_mode', y='Score',
|
||||||
@ -525,7 +525,7 @@ if __name__ == '__main__':
|
|||||||
c.set_xticklabels(rotation=65, horizontalalignment='right')
|
c.set_xticklabels(rotation=65, horizontalalignment='right')
|
||||||
# c.fig.subplots_adjust(top=0.9) # adjust the Figure in rp
|
# c.fig.subplots_adjust(top=0.9) # adjust the Figure in rp
|
||||||
c.fig.suptitle(f"Cat plot for {env_name}")
|
c.fig.suptitle(f"Cat plot for {env_name}")
|
||||||
# plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
|
plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
plt.savefig(study_root_path / f'results_{n_agents}_agents_{env_name}.png')
|
plt.savefig(study_root_path / f'results_{n_agents}_agents_{env_name}.png')
|
||||||
pass
|
pass
|
||||||
|
Loading…
x
Reference in New Issue
Block a user