mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-22 23:06:43 +02:00
Rework of Observations and Entity Differentiation, lazy obs build by notification
This commit is contained in:
parent
7f7a3d9a3b
commit
b43f595207
@ -1,7 +1,9 @@
|
|||||||
import abc
|
import abc
|
||||||
|
import enum
|
||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from itertools import chain
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Union, Iterable, Dict
|
from typing import List, Union, Iterable, Dict
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -14,7 +16,8 @@ from environments.factory.base.shadow_casting import Map
|
|||||||
from environments.helpers import Constants as c, Constants
|
from environments.helpers import Constants as c, Constants
|
||||||
from environments import helpers as h
|
from environments import helpers as h
|
||||||
from environments.factory.base.objects import Agent, Tile, Action
|
from environments.factory.base.objects import Agent, Tile, Action
|
||||||
from environments.factory.base.registers import Actions, Entities, Agents, Doors, FloorTiles, WallTiles, PlaceHolders
|
from environments.factory.base.registers import Actions, Entities, Agents, Doors, FloorTiles, WallTiles, PlaceHolders, \
|
||||||
|
GlobalPositions
|
||||||
from environments.utility_classes import MovementProperties, ObservationProperties, MarlFrameStack
|
from environments.utility_classes import MovementProperties, ObservationProperties, MarlFrameStack
|
||||||
from environments.utility_classes import AgentRenderOptions as a_obs
|
from environments.utility_classes import AgentRenderOptions as a_obs
|
||||||
|
|
||||||
@ -30,17 +33,31 @@ class BaseFactory(gym.Env):
|
|||||||
def action_space(self):
|
def action_space(self):
|
||||||
return spaces.Discrete(len(self._actions))
|
return spaces.Discrete(len(self._actions))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def named_action_space(self):
|
||||||
|
return {x.identifier.value: idx for idx, x in enumerate(self._actions.values())}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def observation_space(self):
|
def observation_space(self):
|
||||||
if r := self._pomdp_r:
|
obs, _ = self._build_observations()
|
||||||
z = self._obs_cube.shape[0]
|
if self.n_agents > 1:
|
||||||
xy = r*2 + 1
|
shape = obs[0].shape
|
||||||
level_shape = (z, xy, xy)
|
|
||||||
else:
|
else:
|
||||||
level_shape = self._obs_cube.shape
|
shape = obs.shape
|
||||||
space = spaces.Box(low=0, high=1, shape=level_shape, dtype=np.float32)
|
space = spaces.Box(low=0, high=1, shape=shape, dtype=np.float32)
|
||||||
return space
|
return space
|
||||||
|
|
||||||
|
@property
|
||||||
|
def named_observation_space(self):
|
||||||
|
# Build it
|
||||||
|
_, named_obs = self._build_observations()
|
||||||
|
if self.n_agents > 1:
|
||||||
|
# Only return the first named obs space, as their structure at the moment is same.
|
||||||
|
return [{key.name: val for key, val in named_ob.items()} for named_ob in named_obs.values()][0]
|
||||||
|
else:
|
||||||
|
return {key.name: val for key, val in named_obs.items()}
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def pomdp_diameter(self):
|
def pomdp_diameter(self):
|
||||||
return self._pomdp_r * 2 + 1
|
return self._pomdp_r * 2 + 1
|
||||||
@ -86,11 +103,14 @@ class BaseFactory(gym.Env):
|
|||||||
self.obs_prop = obs_prop
|
self.obs_prop = obs_prop
|
||||||
self.level_name = level_name
|
self.level_name = level_name
|
||||||
self._level_shape = None
|
self._level_shape = None
|
||||||
|
self._obs_shape = None
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
self._renderer = None # expensive - don't use it when not required !
|
self._renderer = None # expensive - don't use it when not required !
|
||||||
self._entities = Entities()
|
self._entities = Entities()
|
||||||
|
|
||||||
self.n_agents = n_agents
|
self.n_agents = n_agents
|
||||||
|
level_filepath = Path(__file__).parent.parent / h.LEVELS_DIR / f'{self.level_name}.txt'
|
||||||
|
self._parsed_level = h.parse_level(level_filepath)
|
||||||
|
|
||||||
self.max_steps = max_steps
|
self.max_steps = max_steps
|
||||||
self._pomdp_r = self.obs_prop.pomdp_r
|
self._pomdp_r = self.obs_prop.pomdp_r
|
||||||
@ -114,10 +134,12 @@ class BaseFactory(gym.Env):
|
|||||||
# Objects
|
# Objects
|
||||||
self._entities = Entities()
|
self._entities = Entities()
|
||||||
# Level
|
# Level
|
||||||
level_filepath = Path(__file__).parent.parent / h.LEVELS_DIR / f'{self.level_name}.txt'
|
|
||||||
parsed_level = h.parse_level(level_filepath)
|
level_array = h.one_hot_level(self._parsed_level)
|
||||||
level_array = h.one_hot_level(parsed_level)
|
level_array = np.pad(level_array, self.obs_prop.pomdp_r, 'constant', constant_values=1)
|
||||||
|
|
||||||
self._level_shape = level_array.shape
|
self._level_shape = level_array.shape
|
||||||
|
self._obs_shape = self._level_shape if not self.obs_prop.pomdp_r else (self.pomdp_diameter, ) * 2
|
||||||
|
|
||||||
# Walls
|
# Walls
|
||||||
walls = WallTiles.from_argwhere_coordinates(
|
walls = WallTiles.from_argwhere_coordinates(
|
||||||
@ -134,13 +156,14 @@ class BaseFactory(gym.Env):
|
|||||||
self._entities.register_additional_items({c.FLOOR: floor})
|
self._entities.register_additional_items({c.FLOOR: floor})
|
||||||
|
|
||||||
# NOPOS
|
# NOPOS
|
||||||
self._NO_POS_TILE = Tile(c.NO_POS.value)
|
self._NO_POS_TILE = Tile(c.NO_POS.value, None)
|
||||||
|
|
||||||
# Doors
|
# Doors
|
||||||
if self.parse_doors:
|
if self.parse_doors:
|
||||||
parsed_doors = h.one_hot_level(parsed_level, c.DOOR)
|
parsed_doors = h.one_hot_level(self._parsed_level, c.DOOR)
|
||||||
|
parsed_doors = np.pad(parsed_doors, self.obs_prop.pomdp_r, 'constant', constant_values=0)
|
||||||
if np.any(parsed_doors):
|
if np.any(parsed_doors):
|
||||||
door_tiles = [floor.by_pos(pos) for pos in np.argwhere(parsed_doors == c.OCCUPIED_CELL.value)]
|
door_tiles = [floor.by_pos(tuple(pos)) for pos in np.argwhere(parsed_doors == c.OCCUPIED_CELL.value)]
|
||||||
doors = Doors.from_tiles(door_tiles, self._level_shape,
|
doors = Doors.from_tiles(door_tiles, self._level_shape,
|
||||||
entity_kwargs=dict(context=floor)
|
entity_kwargs=dict(context=floor)
|
||||||
)
|
)
|
||||||
@ -153,12 +176,11 @@ class BaseFactory(gym.Env):
|
|||||||
|
|
||||||
# Agents
|
# Agents
|
||||||
agents_to_spawn = self.n_agents-len(self._injected_agents)
|
agents_to_spawn = self.n_agents-len(self._injected_agents)
|
||||||
agents_kwargs = dict(level_shape=self._level_shape,
|
agents_kwargs = dict(individual_slices=self.obs_prop.render_agents == a_obs.SEPERATE,
|
||||||
individual_slices=self.obs_prop.render_agents == a_obs.SEPERATE,
|
hide_from_obs_builder=self.obs_prop.render_agents in [a_obs.NOT, a_obs.LEVEL],
|
||||||
hide_from_obs_builder=self.obs_prop.render_agents == a_obs.LEVEL,
|
)
|
||||||
is_observable=self.obs_prop.render_agents != a_obs.NOT)
|
|
||||||
if agents_to_spawn:
|
if agents_to_spawn:
|
||||||
agents = Agents.from_tiles(floor.empty_tiles[:agents_to_spawn], **agents_kwargs)
|
agents = Agents.from_tiles(floor.empty_tiles[:agents_to_spawn], self._level_shape, **agents_kwargs)
|
||||||
else:
|
else:
|
||||||
agents = Agents(**agents_kwargs)
|
agents = Agents(**agents_kwargs)
|
||||||
if self._injected_agents:
|
if self._injected_agents:
|
||||||
@ -173,10 +195,10 @@ class BaseFactory(gym.Env):
|
|||||||
# TODO: Make this accept Lists for multiple placeholders
|
# TODO: Make this accept Lists for multiple placeholders
|
||||||
|
|
||||||
# Empty Observations with either [0, 1, N(0, 1)]
|
# Empty Observations with either [0, 1, N(0, 1)]
|
||||||
placeholder = PlaceHolders.from_tiles([self._NO_POS_TILE], self._level_shape,
|
placeholder = PlaceHolders.from_values(self.obs_prop.additional_agent_placeholder, self._level_shape,
|
||||||
entity_kwargs=dict(
|
entity_kwargs=dict(
|
||||||
fill_value=self.obs_prop.additional_agent_placeholder)
|
fill_value=self.obs_prop.additional_agent_placeholder)
|
||||||
)
|
)
|
||||||
|
|
||||||
self._entities.register_additional_items({c.AGENT_PLACEHOLDER: placeholder})
|
self._entities.register_additional_items({c.AGENT_PLACEHOLDER: placeholder})
|
||||||
|
|
||||||
@ -184,24 +206,22 @@ class BaseFactory(gym.Env):
|
|||||||
if additional_entities := self.additional_entities:
|
if additional_entities := self.additional_entities:
|
||||||
self._entities.register_additional_items(additional_entities)
|
self._entities.register_additional_items(additional_entities)
|
||||||
|
|
||||||
|
if self.obs_prop.show_global_position_info:
|
||||||
|
global_positions = GlobalPositions(self._level_shape)
|
||||||
|
obs_shape_2d = self._level_shape if not self._pomdp_r else ((self.pomdp_diameter,) * 2)
|
||||||
|
global_positions.spawn_GlobalPositionObjects(obs_shape_2d, self[c.AGENT])
|
||||||
|
self._entities.register_additional_items({c.GLOBAL_POSITION: global_positions})
|
||||||
|
|
||||||
# Return
|
# Return
|
||||||
return self._entities
|
return self._entities
|
||||||
|
|
||||||
def _init_obs_cube(self):
|
def reset(self) -> (np.typing.ArrayLike, int, bool, dict):
|
||||||
arrays = self._entities.obs_arrays
|
|
||||||
|
|
||||||
obs_cube_z = sum([a.shape[0] if not self[key].is_per_agent else 1 for key, a in arrays.items()])
|
|
||||||
obs_cube_z += 1 if self.obs_prop.show_global_position_info else 0
|
|
||||||
self._obs_cube = np.zeros((obs_cube_z, *self._level_shape), dtype=np.float32)
|
|
||||||
|
|
||||||
def reset(self) -> (np.ndarray, int, bool, dict):
|
|
||||||
_ = self._base_init_env()
|
_ = self._base_init_env()
|
||||||
self._init_obs_cube()
|
|
||||||
self.do_additional_reset()
|
self.do_additional_reset()
|
||||||
|
|
||||||
self._steps = 0
|
self._steps = 0
|
||||||
|
|
||||||
obs = self._get_observations()
|
obs, _ = self._build_observations()
|
||||||
return obs
|
return obs
|
||||||
|
|
||||||
def step(self, actions):
|
def step(self, actions):
|
||||||
@ -264,7 +284,7 @@ class BaseFactory(gym.Env):
|
|||||||
# Post step Hook for later use
|
# Post step Hook for later use
|
||||||
info.update(self.hook_post_step())
|
info.update(self.hook_post_step())
|
||||||
|
|
||||||
obs = self._get_observations()
|
obs, _ = self._build_observations()
|
||||||
|
|
||||||
return obs, reward, done, info
|
return obs, reward, done, info
|
||||||
|
|
||||||
@ -284,141 +304,120 @@ class BaseFactory(gym.Env):
|
|||||||
else:
|
else:
|
||||||
return c.NOT_VALID
|
return c.NOT_VALID
|
||||||
|
|
||||||
def _get_observations(self) -> np.ndarray:
|
def _build_observations(self) -> np.typing.ArrayLike:
|
||||||
state_array_dict = self._entities.obs_arrays
|
# Observation dict:
|
||||||
if self.n_agents == 1:
|
per_agent_expl_idx = dict()
|
||||||
obs = self._build_per_agent_obs(self[c.AGENT][0], state_array_dict)
|
per_agent_obsn = dict()
|
||||||
elif self.n_agents >= 2:
|
# Generel Observations
|
||||||
obs = np.stack([self._build_per_agent_obs(agent, state_array_dict) for agent in self[c.AGENT]])
|
lvl_obs = self[c.WALLS].as_array()
|
||||||
else:
|
door_obs = self[c.DOORS].as_array()
|
||||||
raise ValueError('n_agents cannot be smaller than 1!!')
|
agent_obs = self[c.AGENT].as_array() if self.obs_prop.render_agents != a_obs.NOT else None
|
||||||
return obs
|
placeholder_obs = self[c.AGENT_PLACEHOLDER].as_array() if self[c.AGENT_PLACEHOLDER] else None
|
||||||
|
add_obs_dict = self._additional_observations()
|
||||||
|
|
||||||
def _build_per_agent_obs(self, agent: Agent, state_array_dict) -> np.ndarray:
|
for agent_idx, agent in enumerate(self[c.AGENT]):
|
||||||
agent_pos_is_omitted = False
|
obs_dict = dict()
|
||||||
agent_omit_idx = None
|
# Build Agent Observations
|
||||||
|
if self.obs_prop.render_agents != a_obs.NOT:
|
||||||
if self.obs_prop.omit_agent_self and self.n_agents == 1:
|
if self.obs_prop.omit_agent_self:
|
||||||
pass
|
if self.obs_prop.render_agents == a_obs.SEPERATE:
|
||||||
elif self.obs_prop.omit_agent_self and self.obs_prop.render_agents in [a_obs.COMBINED, ] and self.n_agents > 1:
|
agent_obs = np.take(agent_obs, [x for x in range(self.n_agents) if x != agent_idx], axis=0)
|
||||||
state_array_dict[c.AGENT][0, agent.x, agent.y] -= agent.encoding
|
|
||||||
agent_pos_is_omitted = True
|
|
||||||
elif self.obs_prop.omit_agent_self and self.obs_prop.render_agents == a_obs.SEPERATE and self.n_agents > 1:
|
|
||||||
agent_omit_idx = next((i for i, a in enumerate(self[c.AGENT]) if a == agent))
|
|
||||||
|
|
||||||
running_idx, shadowing_idxs, can_be_shadowed_idxs = 0, [], []
|
|
||||||
self._obs_cube[:] = 0
|
|
||||||
|
|
||||||
# FIXME: Refactor this! Make a globally build observation, then add individual per-agent-obs
|
|
||||||
for key, array in state_array_dict.items():
|
|
||||||
# Flush state array object representation to obs cube
|
|
||||||
if not self[key].hide_from_obs_builder:
|
|
||||||
if self[key].is_per_agent:
|
|
||||||
per_agent_idx = self[key].idx_by_entity(agent)
|
|
||||||
z = 1
|
|
||||||
self._obs_cube[running_idx: running_idx+z] = array[per_agent_idx]
|
|
||||||
else:
|
|
||||||
if key == c.AGENT and agent_omit_idx is not None:
|
|
||||||
z = array.shape[0] - 1
|
|
||||||
for array_idx in range(array.shape[0]):
|
|
||||||
self._obs_cube[running_idx: running_idx+z] = array[[x for x in range(array.shape[0])
|
|
||||||
if x != agent_omit_idx]]
|
|
||||||
# Agent OBS are combined
|
|
||||||
elif key == c.AGENT and self.obs_prop.omit_agent_self \
|
|
||||||
and self.obs_prop.render_agents == a_obs.COMBINED:
|
|
||||||
z = 1
|
|
||||||
self._obs_cube[running_idx: running_idx + z] = array
|
|
||||||
# Each Agent is rendered on a seperate array slice
|
|
||||||
else:
|
else:
|
||||||
z = array.shape[0]
|
agent_obs = agent_obs.copy()
|
||||||
self._obs_cube[running_idx: running_idx + z] = array
|
agent_obs[(0, *agent.pos)] -= agent.encoding
|
||||||
# Define which OBS SLices cast a Shadow
|
|
||||||
if self[key].is_blocking_light:
|
|
||||||
for i in range(z):
|
|
||||||
shadowing_idxs.append(running_idx + i)
|
|
||||||
# Define which OBS SLices are effected by shadows
|
|
||||||
if self[key].can_be_shadowed:
|
|
||||||
for i in range(z):
|
|
||||||
can_be_shadowed_idxs.append(running_idx + i)
|
|
||||||
running_idx += z
|
|
||||||
|
|
||||||
if agent_pos_is_omitted:
|
# Build Level Observations
|
||||||
state_array_dict[c.AGENT][0, agent.x, agent.y] += agent.encoding
|
if self.obs_prop.render_agents == a_obs.LEVEL:
|
||||||
|
lvl_obs = lvl_obs.copy()
|
||||||
if self._pomdp_r:
|
lvl_obs += agent_obs
|
||||||
obs = self._do_pomdp_obs_cutout(agent, self._obs_cube)
|
|
||||||
else:
|
|
||||||
obs = self._obs_cube
|
|
||||||
|
|
||||||
obs = obs.copy()
|
|
||||||
|
|
||||||
if self.obs_prop.cast_shadows:
|
|
||||||
obs_block_light = [obs[idx] != c.OCCUPIED_CELL.value for idx in shadowing_idxs]
|
|
||||||
door_shadowing = False
|
|
||||||
if self.parse_doors:
|
|
||||||
if doors := self[c.DOORS]:
|
|
||||||
if door := doors.by_pos(agent.pos):
|
|
||||||
if door.is_closed:
|
|
||||||
for group in door.connectivity_subgroups:
|
|
||||||
if agent.last_pos not in group:
|
|
||||||
door_shadowing = True
|
|
||||||
if self._pomdp_r:
|
|
||||||
blocking = [tuple(np.subtract(x, agent.pos) + (self._pomdp_r, self._pomdp_r))
|
|
||||||
for x in group]
|
|
||||||
xs, ys = zip(*blocking)
|
|
||||||
else:
|
|
||||||
xs, ys = zip(*group)
|
|
||||||
|
|
||||||
# noinspection PyUnresolvedReferences
|
|
||||||
obs_block_light[0][xs, ys] = False
|
|
||||||
|
|
||||||
light_block_map = Map((np.prod(obs_block_light, axis=0) != True).astype(int))
|
|
||||||
if self._pomdp_r:
|
|
||||||
light_block_map = light_block_map.do_fov(self._pomdp_r, self._pomdp_r, max(self._level_shape))
|
|
||||||
else:
|
|
||||||
light_block_map = light_block_map.do_fov(*agent.pos, max(self._level_shape))
|
|
||||||
if door_shadowing:
|
|
||||||
# noinspection PyUnboundLocalVariable
|
|
||||||
light_block_map[xs, ys] = 0
|
|
||||||
agent.temp_light_map = light_block_map
|
|
||||||
for obs_idx in can_be_shadowed_idxs:
|
|
||||||
obs[obs_idx] = ((obs[obs_idx] * light_block_map) + 0.) - (1 - light_block_map) # * obs[0])
|
|
||||||
else:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Agents observe other agents as wall
|
|
||||||
if self.obs_prop.render_agents == a_obs.LEVEL and self.n_agents > 1:
|
|
||||||
other_agent_obs = self[c.AGENT].as_array()
|
|
||||||
if self.obs_prop.omit_agent_self:
|
|
||||||
other_agent_obs[:, agent.x, agent.y] -= agent.encoding
|
|
||||||
|
|
||||||
|
obs_dict[c.WALLS] = lvl_obs
|
||||||
|
if self.obs_prop.render_agents in [a_obs.SEPERATE, a_obs.COMBINED]:
|
||||||
|
obs_dict[c.AGENT] = agent_obs
|
||||||
|
if self[c.AGENT_PLACEHOLDER]:
|
||||||
|
obs_dict[c.AGENT_PLACEHOLDER] = placeholder_obs
|
||||||
|
obs_dict[c.DOORS] = door_obs
|
||||||
|
obs_dict.update(add_obs_dict)
|
||||||
|
observations = np.vstack(list(obs_dict.values()))
|
||||||
if self.obs_prop.pomdp_r:
|
if self.obs_prop.pomdp_r:
|
||||||
oobs = self._do_pomdp_obs_cutout(agent, other_agent_obs)[0]
|
observations = self._do_pomdp_cutout(agent, observations)
|
||||||
# noinspection PyUnresolvedReferences
|
|
||||||
mask = (oobs != c.SHADOWED_CELL.value).astype(int)
|
|
||||||
obs[0] += oobs * mask
|
|
||||||
|
|
||||||
|
raw_obs = self._additional_raw_observations(agent)
|
||||||
|
observations = np.vstack((observations, *list(raw_obs.values())))
|
||||||
|
|
||||||
|
keys = list(chain(obs_dict.keys(), raw_obs.keys()))
|
||||||
|
idxs = np.cumsum([x.shape[0] for x in chain(obs_dict.values(), raw_obs.values())]) - 1
|
||||||
|
per_agent_expl_idx[agent.name] = {key: list(range(a, b)) for key, a, b in
|
||||||
|
zip(keys, idxs, list(idxs[1:]) + [idxs[-1]+1, ])}
|
||||||
|
|
||||||
|
# Shadow Casting
|
||||||
|
try:
|
||||||
|
light_block_obs = [obs_idx for key, obs_idx in per_agent_expl_idx[agent.name].items()
|
||||||
|
if self[key].is_blocking_light]
|
||||||
|
# Flatten
|
||||||
|
light_block_obs = [x for y in light_block_obs for x in y]
|
||||||
|
shadowed_obs = [obs_idx for key, obs_idx in per_agent_expl_idx[agent.name].items()
|
||||||
|
if self[key].can_be_shadowed]
|
||||||
|
# Flatten
|
||||||
|
shadowed_obs = [x for y in shadowed_obs for x in y]
|
||||||
|
except AttributeError as e:
|
||||||
|
print('Check your Keys! Only use Constants as Keys!')
|
||||||
|
print(e)
|
||||||
|
raise e
|
||||||
|
if self.obs_prop.cast_shadows:
|
||||||
|
obs_block_light = observations[light_block_obs] != c.OCCUPIED_CELL.value
|
||||||
|
door_shadowing = False
|
||||||
|
if self.parse_doors:
|
||||||
|
if doors := self[c.DOORS]:
|
||||||
|
if door := doors.by_pos(agent.pos):
|
||||||
|
if door.is_closed:
|
||||||
|
for group in door.connectivity_subgroups:
|
||||||
|
if agent.last_pos not in group:
|
||||||
|
door_shadowing = True
|
||||||
|
if self._pomdp_r:
|
||||||
|
blocking = [
|
||||||
|
tuple(np.subtract(x, agent.pos) + (self._pomdp_r, self._pomdp_r))
|
||||||
|
for x in group]
|
||||||
|
xs, ys = zip(*blocking)
|
||||||
|
else:
|
||||||
|
xs, ys = zip(*group)
|
||||||
|
|
||||||
|
# noinspection PyUnresolvedReferences
|
||||||
|
obs_block_light[:, xs, ys] = False
|
||||||
|
|
||||||
|
light_block_map = Map((np.prod(obs_block_light, axis=0) != True).astype(int).squeeze())
|
||||||
|
if self._pomdp_r:
|
||||||
|
light_block_map = light_block_map.do_fov(self._pomdp_r, self._pomdp_r, max(self._level_shape))
|
||||||
|
else:
|
||||||
|
light_block_map = light_block_map.do_fov(*agent.pos, max(self._level_shape))
|
||||||
|
if door_shadowing:
|
||||||
|
# noinspection PyUnboundLocalVariable
|
||||||
|
light_block_map[xs, ys] = 0
|
||||||
|
agent.temp_light_map = light_block_map.copy()
|
||||||
|
|
||||||
|
observations[shadowed_obs] = ((observations[shadowed_obs] * light_block_map) + 0.) - (1 - light_block_map)
|
||||||
else:
|
else:
|
||||||
obs[0] += other_agent_obs
|
pass
|
||||||
|
|
||||||
# Additional Observation:
|
per_agent_obsn[agent.name] = observations
|
||||||
for additional_obs in self.additional_obs_build():
|
|
||||||
obs[running_idx:running_idx+additional_obs.shape[0]] = additional_obs
|
|
||||||
running_idx += additional_obs.shape[0]
|
|
||||||
for additional_per_agent_obs in self.additional_per_agent_obs_build(agent):
|
|
||||||
obs[running_idx:running_idx + additional_per_agent_obs.shape[0]] = additional_per_agent_obs
|
|
||||||
running_idx += additional_per_agent_obs.shape[0]
|
|
||||||
|
|
||||||
return obs
|
if self.n_agents == 1:
|
||||||
|
agent_name = self[c.AGENT][0].name
|
||||||
|
obs, explained_idx = per_agent_obsn[agent_name], per_agent_expl_idx[agent_name]
|
||||||
|
elif self.n_agents >= 2:
|
||||||
|
obs, explained_idx = np.stack(list(per_agent_obsn.values())), per_agent_expl_idx
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
def _do_pomdp_obs_cutout(self, agent, obs_to_be_padded):
|
return obs, explained_idx
|
||||||
|
|
||||||
|
def _do_pomdp_cutout(self, agent, obs_to_be_padded):
|
||||||
assert obs_to_be_padded.ndim == 3
|
assert obs_to_be_padded.ndim == 3
|
||||||
r, d = self._pomdp_r, self.pomdp_diameter
|
r, d = self._pomdp_r, self.pomdp_diameter
|
||||||
x0, x1 = max(0, agent.x - r), min(agent.x + r + 1, self._level_shape[0])
|
x0, x1 = max(0, agent.x - r), min(agent.x + r + 1, self._level_shape[0])
|
||||||
y0, y1 = max(0, agent.y - r), min(agent.y + r + 1, self._level_shape[1])
|
y0, y1 = max(0, agent.y - r), min(agent.y + r + 1, self._level_shape[1])
|
||||||
# Other Agent Obs = oobs
|
|
||||||
oobs = obs_to_be_padded[:, x0:x1, y0:y1]
|
oobs = obs_to_be_padded[:, x0:x1, y0:y1]
|
||||||
if oobs.shape[0:] != (d, d):
|
if oobs.shape[1:] != (d, d):
|
||||||
if xd := oobs.shape[1] % d:
|
if xd := oobs.shape[1] % d:
|
||||||
if agent.x > r:
|
if agent.x > r:
|
||||||
x0_pad = 0
|
x0_pad = 0
|
||||||
@ -478,7 +477,7 @@ class BaseFactory(gym.Env):
|
|||||||
if doors := self[c.DOORS]:
|
if doors := self[c.DOORS]:
|
||||||
if self.doors_have_area:
|
if self.doors_have_area:
|
||||||
if door := doors.by_pos(new_tile.pos):
|
if door := doors.by_pos(new_tile.pos):
|
||||||
if door.can_collide:
|
if door.is_open:
|
||||||
return agent.tile, c.NOT_VALID
|
return agent.tile, c.NOT_VALID
|
||||||
else: # door.is_closed:
|
else: # door.is_closed:
|
||||||
pass
|
pass
|
||||||
@ -569,7 +568,7 @@ class BaseFactory(gym.Env):
|
|||||||
if not self._renderer: # lazy init
|
if not self._renderer: # lazy init
|
||||||
from environments.factory.base.renderer import Renderer, RenderEntity
|
from environments.factory.base.renderer import Renderer, RenderEntity
|
||||||
global Renderer, RenderEntity
|
global Renderer, RenderEntity
|
||||||
height, width = self._obs_cube.shape[1:]
|
height, width = self._level_shape
|
||||||
self._renderer = Renderer(width, height, view_radius=self._pomdp_r, fps=5)
|
self._renderer = Renderer(width, height, view_radius=self._pomdp_r, fps=5)
|
||||||
|
|
||||||
# noinspection PyUnboundLocalVariable
|
# noinspection PyUnboundLocalVariable
|
||||||
@ -636,20 +635,6 @@ class BaseFactory(gym.Env):
|
|||||||
|
|
||||||
# Functions which provide additions to functions of the base class
|
# Functions which provide additions to functions of the base class
|
||||||
# Always call super!!!!!!
|
# Always call super!!!!!!
|
||||||
@abc.abstractmethod
|
|
||||||
def additional_obs_build(self) -> List[np.ndarray]:
|
|
||||||
return []
|
|
||||||
|
|
||||||
def additional_per_agent_obs_build(self, agent) -> List[np.ndarray]:
|
|
||||||
additional_per_agent_obs = []
|
|
||||||
if self.obs_prop.show_global_position_info:
|
|
||||||
pos_array = np.zeros(self.observation_space.shape[1:])
|
|
||||||
for xy in range(1):
|
|
||||||
pos_array[0, xy] = agent.pos[xy] / self._level_shape[xy]
|
|
||||||
additional_per_agent_obs.append(pos_array)
|
|
||||||
|
|
||||||
return additional_per_agent_obs
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def do_additional_reset(self) -> None:
|
def do_additional_reset(self) -> None:
|
||||||
pass
|
pass
|
||||||
@ -666,6 +651,17 @@ class BaseFactory(gym.Env):
|
|||||||
def check_additional_done(self) -> bool:
|
def check_additional_done(self) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def _additional_observations(self) -> Dict[Constants, np.typing.ArrayLike]:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def _additional_raw_observations(self, agent) -> Dict[Constants, np.typing.ArrayLike]:
|
||||||
|
additional_raw_observations = {}
|
||||||
|
if self.obs_prop.show_global_position_info:
|
||||||
|
additional_raw_observations.update({c.GLOBAL_POSITION: self[c.GLOBAL_POSITION].by_entity(agent).as_array()})
|
||||||
|
return additional_raw_observations
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def calculate_additional_reward(self, agent: Agent) -> (int, dict):
|
def calculate_additional_reward(self, agent: Agent) -> (int, dict):
|
||||||
return 0, {}
|
return 0, {}
|
||||||
|
@ -3,22 +3,26 @@ from enum import Enum
|
|||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from environments import helpers as h
|
from environments import helpers as h
|
||||||
from environments.helpers import Constants as c
|
from environments.helpers import Constants as c
|
||||||
import itertools
|
import itertools
|
||||||
|
|
||||||
|
##########################################################################
|
||||||
|
# ##################### Base Object Definition ######################### #
|
||||||
|
##########################################################################
|
||||||
|
|
||||||
|
|
||||||
class Object:
|
class Object:
|
||||||
|
|
||||||
|
"""Generell Objects for Organisation and Maintanance such as Actions etc..."""
|
||||||
|
|
||||||
_u_idx = defaultdict(lambda: 0)
|
_u_idx = defaultdict(lambda: 0)
|
||||||
|
|
||||||
def __bool__(self):
|
def __bool__(self):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@property
|
|
||||||
def is_blocking_light(self):
|
|
||||||
return self._is_blocking_light
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self):
|
||||||
return self._name
|
return self._name
|
||||||
@ -43,7 +47,7 @@ class Object:
|
|||||||
elif self._str_ident is not None and self._enum_ident is None:
|
elif self._str_ident is not None and self._enum_ident is None:
|
||||||
self._name = f'{self.__class__.__name__}[{self._str_ident}]'
|
self._name = f'{self.__class__.__name__}[{self._str_ident}]'
|
||||||
elif self._str_ident is None and self._enum_ident is None:
|
elif self._str_ident is None and self._enum_ident is None:
|
||||||
self._name = f'{self.__class__.__name__}#{self._u_idx[self.__class__.__name__]}'
|
self._name = f'{self.__class__.__name__}#{Object._u_idx[self.__class__.__name__]}'
|
||||||
Object._u_idx[self.__class__.__name__] += 1
|
Object._u_idx[self.__class__.__name__] += 1
|
||||||
else:
|
else:
|
||||||
raise ValueError('Please use either of the idents.')
|
raise ValueError('Please use either of the idents.')
|
||||||
@ -68,16 +72,56 @@ class Object:
|
|||||||
return other.name == self.name
|
return other.name == self.name
|
||||||
|
|
||||||
|
|
||||||
class Entity(Object):
|
class EnvObject(Object):
|
||||||
|
|
||||||
@property
|
"""Objects that hold Information that are observable, but have no position on the env grid. Inventories etc..."""
|
||||||
def can_collide(self):
|
|
||||||
return True
|
_u_idx = defaultdict(lambda: 0)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def encoding(self):
|
def encoding(self):
|
||||||
return c.OCCUPIED_CELL.value
|
return c.OCCUPIED_CELL.value
|
||||||
|
|
||||||
|
def __init__(self, register, **kwargs):
|
||||||
|
super(EnvObject, self).__init__(**kwargs)
|
||||||
|
self._register = register
|
||||||
|
|
||||||
|
|
||||||
|
class BoundingMixin:
|
||||||
|
|
||||||
|
@property
|
||||||
|
def bound_entity(self):
|
||||||
|
return self._bound_entity
|
||||||
|
|
||||||
|
def __init__(self, entity_to_be_bound, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
assert entity_to_be_bound is not None
|
||||||
|
self._bound_entity = entity_to_be_bound
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
s = super(BoundingMixin, self).__repr__()
|
||||||
|
i = s[:s.find('(')]
|
||||||
|
return f'{s[:i]}[{self.bound_entity.name}]{s[i:]}'
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self):
|
||||||
|
return f'{super(BoundingMixin, self).name}({self._bound_entity.name})'
|
||||||
|
|
||||||
|
def belongs_to_entity(self, entity):
|
||||||
|
return entity == self.bound_entity
|
||||||
|
|
||||||
|
|
||||||
|
class Entity(EnvObject):
|
||||||
|
"""Full Env Entity that lives on the env Grid. Doors, Items, Dirt etc..."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_blocking_light(self):
|
||||||
|
return self._is_blocking_light
|
||||||
|
|
||||||
|
@property
|
||||||
|
def can_collide(self):
|
||||||
|
return True
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def x(self):
|
def x(self):
|
||||||
return self.pos[0]
|
return self.pos[0]
|
||||||
@ -94,9 +138,10 @@ class Entity(Object):
|
|||||||
def tile(self):
|
def tile(self):
|
||||||
return self._tile
|
return self._tile
|
||||||
|
|
||||||
def __init__(self, tile, **kwargs):
|
def __init__(self, tile, *args, is_blocking_light=True, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self._tile = tile
|
self._tile = tile
|
||||||
|
self._is_blocking_light = is_blocking_light
|
||||||
tile.enter(self)
|
tile.enter(self)
|
||||||
|
|
||||||
def summarize_state(self, **_) -> dict:
|
def summarize_state(self, **_) -> dict:
|
||||||
@ -104,7 +149,7 @@ class Entity(Object):
|
|||||||
tile=str(self.tile.name), can_collide=bool(self.can_collide))
|
tile=str(self.tile.name), can_collide=bool(self.can_collide))
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f'{self.name}(@{self.pos})'
|
return super(Entity, self).__repr__() + f'(@{self.pos})'
|
||||||
|
|
||||||
|
|
||||||
class MoveableEntity(Entity):
|
class MoveableEntity(Entity):
|
||||||
@ -118,7 +163,7 @@ class MoveableEntity(Entity):
|
|||||||
if self._last_tile:
|
if self._last_tile:
|
||||||
return self._last_tile.pos
|
return self._last_tile.pos
|
||||||
else:
|
else:
|
||||||
return c.NO_POS
|
return c.NO_POS.value
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def direction_of_view(self):
|
def direction_of_view(self):
|
||||||
@ -137,45 +182,66 @@ class MoveableEntity(Entity):
|
|||||||
curr_tile.leave(self)
|
curr_tile.leave(self)
|
||||||
self._tile = next_tile
|
self._tile = next_tile
|
||||||
self._last_tile = curr_tile
|
self._last_tile = curr_tile
|
||||||
|
self._register.notify_change_to_value(self)
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
##########################################################################
|
||||||
|
# ####################### Objects and Entitys ########################## #
|
||||||
|
##########################################################################
|
||||||
|
|
||||||
|
|
||||||
class Action(Object):
|
class Action(Object):
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class PlaceHolder(MoveableEntity):
|
class PlaceHolder(Object):
|
||||||
|
|
||||||
def __init__(self, *args, fill_value=0, **kwargs):
|
def __init__(self, *args, fill_value=0, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self._fill_value = fill_value
|
self._fill_value = fill_value
|
||||||
|
|
||||||
@property
|
|
||||||
def last_tile(self):
|
|
||||||
return self.tile
|
|
||||||
|
|
||||||
@property
|
|
||||||
def direction_of_view(self):
|
|
||||||
return self.pos
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def can_collide(self):
|
def can_collide(self):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def encoding(self):
|
def encoding(self):
|
||||||
return c.NO_POS.value[0]
|
return self._fill_value
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self):
|
||||||
return "PlaceHolder"
|
return "PlaceHolder"
|
||||||
|
|
||||||
|
|
||||||
class Tile(Object):
|
class GlobalPosition(EnvObject):
|
||||||
|
|
||||||
|
def belongs_to_entity(self, entity):
|
||||||
|
return self._agent == entity
|
||||||
|
|
||||||
|
def __init__(self, level_shape, obs_shape, agent, normalized: bool = True):
|
||||||
|
super(GlobalPosition, self).__init__(self)
|
||||||
|
self._obs_shape = (1, *obs_shape) if len(obs_shape) == 2 else obs_shape
|
||||||
|
self._agent = agent
|
||||||
|
self._level_shape = level_shape
|
||||||
|
self._normalized = normalized
|
||||||
|
|
||||||
|
def as_array(self):
|
||||||
|
pos_array = np.zeros(self._obs_shape)
|
||||||
|
for xy in range(1):
|
||||||
|
pos_array[0, 0, xy] = self._agent.pos[xy] / self._level_shape[xy]
|
||||||
|
return pos_array
|
||||||
|
|
||||||
|
|
||||||
|
class Tile(EnvObject):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def encoding(self):
|
||||||
|
return c.FREE_CELL.value
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def guests_that_can_collide(self):
|
def guests_that_can_collide(self):
|
||||||
@ -197,8 +263,8 @@ class Tile(Object):
|
|||||||
def pos(self):
|
def pos(self):
|
||||||
return self._pos
|
return self._pos
|
||||||
|
|
||||||
def __init__(self, pos, **kwargs):
|
def __init__(self, pos, *args, **kwargs):
|
||||||
super(Tile, self).__init__(**kwargs)
|
super(Tile, self).__init__(*args, **kwargs)
|
||||||
self._guests = dict()
|
self._guests = dict()
|
||||||
self._pos = tuple(pos)
|
self._pos = tuple(pos)
|
||||||
|
|
||||||
@ -233,6 +299,11 @@ class Tile(Object):
|
|||||||
|
|
||||||
|
|
||||||
class Wall(Tile):
|
class Wall(Tile):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def encoding(self):
|
||||||
|
return c.OCCUPIED_CELL.value
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@ -247,7 +318,8 @@ class Door(Entity):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def encoding(self):
|
def encoding(self):
|
||||||
return 1 if self.is_closed else 2
|
# This is important as it shadow is checked by occupation value
|
||||||
|
return c.OCCUPIED_CELL.value if self.is_closed else 2
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def str_state(self):
|
def str_state(self):
|
||||||
@ -307,11 +379,13 @@ class Door(Entity):
|
|||||||
def _open(self):
|
def _open(self):
|
||||||
self.connectivity.add_edges_from([(self.pos, x) for x in range(len(self.connectivity_subgroups))])
|
self.connectivity.add_edges_from([(self.pos, x) for x in range(len(self.connectivity_subgroups))])
|
||||||
self._state = c.OPEN_DOOR
|
self._state = c.OPEN_DOOR
|
||||||
|
self._register.notify_change_to_value(self)
|
||||||
self.time_to_close = self.auto_close_interval
|
self.time_to_close = self.auto_close_interval
|
||||||
|
|
||||||
def _close(self):
|
def _close(self):
|
||||||
self.connectivity.remove_node(self.pos)
|
self.connectivity.remove_node(self.pos)
|
||||||
self._state = c.CLOSED_DOOR
|
self._state = c.CLOSED_DOOR
|
||||||
|
self._register.notify_change_to_value(self)
|
||||||
|
|
||||||
def is_linked(self, old_pos, new_pos):
|
def is_linked(self, old_pos, new_pos):
|
||||||
try:
|
try:
|
||||||
|
@ -1,18 +1,23 @@
|
|||||||
import numbers
|
import numbers
|
||||||
import random
|
import random
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from typing import List, Union, Dict
|
from typing import List, Union, Dict, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from environments.factory.base.objects import Entity, Tile, Agent, Door, Action, Wall, Object, PlaceHolder
|
from environments.factory.base.objects import Entity, Tile, Agent, Door, Action, Wall, PlaceHolder, GlobalPosition, \
|
||||||
|
Object, EnvObject
|
||||||
from environments.utility_classes import MovementProperties
|
from environments.utility_classes import MovementProperties
|
||||||
from environments import helpers as h
|
from environments import helpers as h
|
||||||
from environments.helpers import Constants as c
|
from environments.helpers import Constants as c
|
||||||
|
|
||||||
|
##########################################################################
|
||||||
|
# ##################### Base Register Definition ####################### #
|
||||||
|
##########################################################################
|
||||||
|
|
||||||
class Register:
|
|
||||||
_accepted_objects = Entity
|
class ObjectRegister:
|
||||||
|
_accepted_objects = Object
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self):
|
||||||
@ -48,6 +53,12 @@ class Register:
|
|||||||
def items(self):
|
def items(self):
|
||||||
return self._register.items()
|
return self._register.items()
|
||||||
|
|
||||||
|
def _get_index(self, item):
|
||||||
|
try:
|
||||||
|
return next(i for i, v in enumerate(self._register.values()) if v == item)
|
||||||
|
except (StopIteration, AssertionError):
|
||||||
|
return None
|
||||||
|
|
||||||
def __getitem__(self, item):
|
def __getitem__(self, item):
|
||||||
if isinstance(item, (int, np.int64, np.int32)):
|
if isinstance(item, (int, np.int64, np.int32)):
|
||||||
if item < 0:
|
if item < 0:
|
||||||
@ -65,39 +76,66 @@ class Register:
|
|||||||
return f'{self.__class__.__name__}({self._register})'
|
return f'{self.__class__.__name__}({self._register})'
|
||||||
|
|
||||||
|
|
||||||
class ObjectRegister(Register):
|
class EnvObjectRegister(ObjectRegister):
|
||||||
|
|
||||||
hide_from_obs_builder = False
|
_accepted_objects = EnvObject
|
||||||
|
|
||||||
def __init__(self, level_shape: (int, int), *args, individual_slices=False, is_per_agent=False, **kwargs):
|
def __init__(self, obs_shape: (int, int), *args, **kwargs):
|
||||||
super(ObjectRegister, self).__init__(*args, **kwargs)
|
super(EnvObjectRegister, self).__init__(*args, **kwargs)
|
||||||
self.is_per_agent = is_per_agent
|
self._shape = obs_shape
|
||||||
self.individual_slices = individual_slices
|
|
||||||
self._level_shape = level_shape
|
|
||||||
self._array = None
|
self._array = None
|
||||||
|
self.hide_from_obs_builder = False
|
||||||
|
self._lazy_eval_transforms = []
|
||||||
|
|
||||||
def register_item(self, other):
|
def register_item(self, other: EnvObject):
|
||||||
super(ObjectRegister, self).register_item(other)
|
super(EnvObjectRegister, self).register_item(other)
|
||||||
if self._array is None:
|
if self._array is None:
|
||||||
self._array = np.zeros((1, *self._level_shape))
|
self._array = np.zeros((1, *self._shape))
|
||||||
else:
|
self.notify_change_to_value(other)
|
||||||
if self.individual_slices:
|
|
||||||
self._array = np.concatenate((self._array, np.zeros((1, *self._array.shape[1:]))))
|
def as_array(self):
|
||||||
|
if self._lazy_eval_transforms:
|
||||||
|
idxs, values = zip(*self._lazy_eval_transforms)
|
||||||
|
# nuumpy put repects the ordering so that
|
||||||
|
np.put(self._array, idxs, values)
|
||||||
|
self._lazy_eval_transforms = []
|
||||||
|
return self._array
|
||||||
|
|
||||||
def summarize_states(self, n_steps=None):
|
def summarize_states(self, n_steps=None):
|
||||||
return [val.summarize_state(n_steps=n_steps) for val in self.values()]
|
return [val.summarize_state(n_steps=n_steps) for val in self.values()]
|
||||||
|
|
||||||
|
def notify_change_to_free(self, env_object: EnvObject):
|
||||||
|
self._array_change_notifyer(env_object, value=c.FREE_CELL.value)
|
||||||
|
|
||||||
class EntityObjectRegister(ObjectRegister, ABC):
|
def notify_change_to_value(self, env_object: EnvObject):
|
||||||
|
self._array_change_notifyer(env_object)
|
||||||
|
|
||||||
def as_array(self):
|
def _array_change_notifyer(self, env_object: EnvObject, value=None):
|
||||||
raise NotImplementedError
|
pos = self._get_index(env_object)
|
||||||
|
value = value if value is not None else env_object.encoding
|
||||||
|
self._lazy_eval_transforms.append((pos, value))
|
||||||
|
|
||||||
|
def __delitem__(self, name):
|
||||||
|
self.notify_change_to_free(self._register[name])
|
||||||
|
del self._register[name]
|
||||||
|
|
||||||
|
def delete_env_object(self, env_object: EnvObject):
|
||||||
|
del self[env_object.name]
|
||||||
|
|
||||||
|
def delete_env_object_by_name(self, name):
|
||||||
|
del self[name]
|
||||||
|
|
||||||
|
|
||||||
|
class EntityRegister(EnvObjectRegister, ABC):
|
||||||
|
|
||||||
|
_accepted_objects = Entity
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_tiles(cls, tiles, *args, entity_kwargs=None, **kwargs):
|
def from_tiles(cls, tiles, *args, entity_kwargs=None, **kwargs):
|
||||||
# objects_name = cls._accepted_objects.__name__
|
# objects_name = cls._accepted_objects.__name__
|
||||||
register_obj = cls(*args, **kwargs)
|
register_obj = cls(*args, **kwargs)
|
||||||
entities = [cls._accepted_objects(tile, str_ident=i, **entity_kwargs if entity_kwargs is not None else {})
|
entities = [cls._accepted_objects(tile, register_obj, str_ident=i,
|
||||||
|
**entity_kwargs if entity_kwargs is not None else {})
|
||||||
for i, tile in enumerate(tiles)]
|
for i, tile in enumerate(tiles)]
|
||||||
register_obj.register_additional_items(entities)
|
register_obj.register_additional_items(entities)
|
||||||
return register_obj
|
return register_obj
|
||||||
@ -115,86 +153,172 @@ class EntityObjectRegister(ObjectRegister, ABC):
|
|||||||
def tiles(self):
|
def tiles(self):
|
||||||
return [entity.tile for entity in self]
|
return [entity.tile for entity in self]
|
||||||
|
|
||||||
def __init__(self, *args, is_blocking_light=False, is_observable=True, can_be_shadowed=True, **kwargs):
|
@property
|
||||||
super(EntityObjectRegister, self).__init__(*args, **kwargs)
|
def encodings(self):
|
||||||
self.can_be_shadowed = can_be_shadowed
|
return [x.encoding for x in self]
|
||||||
self.is_blocking_light = is_blocking_light
|
|
||||||
self.is_observable = is_observable
|
|
||||||
|
|
||||||
def by_pos(self, pos):
|
def __init__(self, level_shape, *args,
|
||||||
if isinstance(pos, np.ndarray):
|
is_blocking_light: bool = False,
|
||||||
pos = tuple(pos)
|
can_be_shadowed: bool = True,
|
||||||
|
individual_slices: bool = False, **kwargs):
|
||||||
|
super(EntityRegister, self).__init__(level_shape, *args, **kwargs)
|
||||||
|
self._lazy_eval_transforms = []
|
||||||
|
self.can_be_shadowed = can_be_shadowed
|
||||||
|
self.individual_slices = individual_slices
|
||||||
|
self.is_blocking_light = is_blocking_light
|
||||||
|
|
||||||
|
def __delitem__(self, name):
|
||||||
|
idx, obj = next((i, obj) for i, obj in enumerate(self) if obj.name == name)
|
||||||
|
obj.tile.leave(obj)
|
||||||
|
super(EntityRegister, self).__delitem__(name)
|
||||||
|
if self.individual_slices:
|
||||||
|
self._array = np.delete(self._array, idx, axis=0)
|
||||||
|
|
||||||
|
def as_array(self):
|
||||||
|
if self._lazy_eval_transforms:
|
||||||
|
idxs, values = zip(*self._lazy_eval_transforms)
|
||||||
|
# numpy put repects the ordering so that
|
||||||
|
# Todo: Export the index building in a seperate function
|
||||||
|
np.put(self._array, [np.ravel_multi_index(idx, self._array.shape) for idx in idxs], values)
|
||||||
|
self._lazy_eval_transforms = []
|
||||||
|
return self._array
|
||||||
|
|
||||||
|
def _array_change_notifyer(self, entity, pos=None, value=None):
|
||||||
|
# Todo: Export the contruction in a seperate function
|
||||||
|
pos = pos if pos is not None else entity.pos
|
||||||
|
value = value if value is not None else entity.encoding
|
||||||
|
x, y = pos
|
||||||
|
if self.individual_slices:
|
||||||
|
idx = (self._get_index(entity), x, y)
|
||||||
|
else:
|
||||||
|
idx = (0, x, y)
|
||||||
|
self._lazy_eval_transforms.append((idx, value))
|
||||||
|
|
||||||
|
def by_pos(self, pos: Tuple[int, int]):
|
||||||
try:
|
try:
|
||||||
return next(item for item in self.values() if item.pos == pos)
|
return next(item for item in self if item.pos == tuple(pos))
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
class MovingEntityObjectRegister(EntityObjectRegister, ABC):
|
class BoundRegisterMixin(EnvObjectRegister, ABC):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_entities_to_bind(self, entitites):
|
||||||
|
def from_values(cls, values: Union[str, numbers.Number, List[Union[str, numbers.Number]]],
|
||||||
|
*args, object_kwargs=None, **kwargs):
|
||||||
|
# objects_name = cls._accepted_objects.__name__
|
||||||
|
if isinstance(values, (str, numbers.Number)):
|
||||||
|
values = [values]
|
||||||
|
register_obj = cls(*args, **kwargs)
|
||||||
|
objects = [cls._accepted_objects(register_obj, str_ident=i, fill_value=value,
|
||||||
|
**object_kwargs if object_kwargs is not None else {})
|
||||||
|
for i, value in enumerate(values)]
|
||||||
|
register_obj.register_additional_items(objects)
|
||||||
|
return register_obj
|
||||||
|
|
||||||
|
|
||||||
|
class MovingEntityObjectRegister(EntityRegister, ABC):
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(MovingEntityObjectRegister, self).__init__(*args, **kwargs)
|
super(MovingEntityObjectRegister, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
def by_pos(self, pos):
|
def notify_change_to_value(self, entity):
|
||||||
if isinstance(pos, np.ndarray):
|
super(MovingEntityObjectRegister, self).notify_change_to_value(entity)
|
||||||
pos = tuple(pos)
|
if entity.last_pos != c.NO_POS.value:
|
||||||
|
try:
|
||||||
|
self._array_change_notifyer(entity, entity.last_pos, value=c.FREE_CELL.value)
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
##########################################################################
|
||||||
|
# ################# Objects and Entity Registers ####################### #
|
||||||
|
##########################################################################
|
||||||
|
|
||||||
|
|
||||||
|
class GlobalPositions(EnvObjectRegister):
|
||||||
|
_accepted_objects = GlobalPosition
|
||||||
|
is_blocking_light = False
|
||||||
|
can_be_shadowed = False
|
||||||
|
hide_from_obs_builder = True
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(GlobalPositions, self).__init__(*args, is_per_agent=True, individual_slices=True, **kwargs)
|
||||||
|
|
||||||
|
def as_array(self):
|
||||||
|
# Todo make this lazy?
|
||||||
|
return np.stack([gp.as_array() for inv_idx, gp in enumerate(self)])
|
||||||
|
|
||||||
|
def spawn_GlobalPositionObjects(self, obs_shape, agents):
|
||||||
|
global_positions = [self._accepted_objects(self._shape, obs_shape, agent)
|
||||||
|
for _, agent in enumerate(agents)]
|
||||||
|
# noinspection PyTypeChecker
|
||||||
|
self.register_additional_items(global_positions)
|
||||||
|
|
||||||
|
def summarize_states(self, n_steps=None):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def idx_by_entity(self, entity):
|
||||||
try:
|
try:
|
||||||
return next(x for x in self if x.pos == pos)
|
return next((idx for idx, inv in enumerate(self) if inv.belongs_to_entity(entity)))
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def __delitem__(self, name):
|
def by_entity(self, entity):
|
||||||
idx = next(i for i, entity in enumerate(self) if entity.name == name)
|
try:
|
||||||
del self._register[name]
|
return next((inv for inv in self if inv.belongs_to_entity(entity)))
|
||||||
if self.individual_slices:
|
except StopIteration:
|
||||||
self._array = np.delete(self._array, idx, axis=0)
|
return None
|
||||||
|
|
||||||
def delete_entity(self, item):
|
|
||||||
self.delete_entity_by_name(item.name)
|
|
||||||
|
|
||||||
def delete_entity_by_name(self, name):
|
|
||||||
del self[name]
|
|
||||||
|
|
||||||
|
|
||||||
class PlaceHolders(MovingEntityObjectRegister):
|
class PlaceHolders(EnvObjectRegister):
|
||||||
|
|
||||||
_accepted_objects = PlaceHolder
|
_accepted_objects = PlaceHolder
|
||||||
|
|
||||||
def __init__(self, *args, fill_value: Union[str, int] = 0, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
|
assert not 'individual_slices' in kwargs, 'Keyword - "individual_slices": "True" and must not be altered'
|
||||||
|
kwargs.update(individual_slices=False)
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.fill_value = fill_value
|
|
||||||
|
@classmethod
|
||||||
|
def from_values(cls, values: Union[str, numbers.Number, List[Union[str, numbers.Number]]],
|
||||||
|
*args, object_kwargs=None, **kwargs):
|
||||||
|
# objects_name = cls._accepted_objects.__name__
|
||||||
|
if isinstance(values, (str, numbers.Number)):
|
||||||
|
values = [values]
|
||||||
|
register_obj = cls(*args, **kwargs)
|
||||||
|
objects = [cls._accepted_objects(register_obj, str_ident=i, fill_value=value,
|
||||||
|
**object_kwargs if object_kwargs is not None else {})
|
||||||
|
for i, value in enumerate(values)]
|
||||||
|
register_obj.register_additional_items(objects)
|
||||||
|
return register_obj
|
||||||
|
|
||||||
# noinspection DuplicatedCode
|
# noinspection DuplicatedCode
|
||||||
def as_array(self):
|
def as_array(self):
|
||||||
if isinstance(self.fill_value, numbers.Number):
|
for idx, placeholder in enumerate(self):
|
||||||
self._array[:] = self.fill_value
|
if isinstance(placeholder.encoding, numbers.Number):
|
||||||
elif isinstance(self.fill_value, str):
|
self._array[idx][:] = placeholder.fill_value
|
||||||
if self.fill_value.lower() in ['normal', 'n']:
|
elif isinstance(placeholder.fill_value, str):
|
||||||
self._array = np.random.normal(size=self._array.shape)
|
if placeholder.fill_value.lower() in ['normal', 'n']:
|
||||||
|
self._array[:] = np.random.normal(size=self._array.shape)
|
||||||
|
else:
|
||||||
|
raise ValueError('Choose one of: ["normal", "N"]')
|
||||||
else:
|
else:
|
||||||
raise ValueError('Choose one of: ["normal", "N"]')
|
raise TypeError('Objects of type "str" or "number" is required here.')
|
||||||
else:
|
|
||||||
raise TypeError('Objects of type "str" or "number" is required here.')
|
|
||||||
|
|
||||||
if self.individual_slices:
|
return self._array
|
||||||
return self._array
|
|
||||||
else:
|
|
||||||
return self._array[None, 0]
|
|
||||||
|
|
||||||
|
|
||||||
class Entities(Register):
|
class Entities(ObjectRegister):
|
||||||
|
_accepted_objects = EntityRegister
|
||||||
_accepted_objects = EntityObjectRegister
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def observable_arrays(self):
|
def arrays(self):
|
||||||
# FIXME: Find a better name
|
return {key: val.as_array() for key, val in self.items()}
|
||||||
return {key: val.as_array() for key, val in self.items() if val.is_observable}
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def obs_arrays(self):
|
def obs_arrays(self):
|
||||||
# FIXME: Find a better name
|
return {key: val.as_array() for key, val in self.items() if not val.hide_from_obs_builder}
|
||||||
return {key: val.as_array() for key, val in self.items() if val.is_observable and not val.hide_from_obs_builder}
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def names(self):
|
def names(self):
|
||||||
@ -220,34 +344,34 @@ class Entities(Register):
|
|||||||
return found_entities
|
return found_entities
|
||||||
|
|
||||||
|
|
||||||
class WallTiles(EntityObjectRegister):
|
class WallTiles(EntityRegister):
|
||||||
_accepted_objects = Wall
|
_accepted_objects = Wall
|
||||||
_light_blocking = True
|
_light_blocking = True
|
||||||
|
hide_from_obs_builder = True
|
||||||
|
|
||||||
def as_array(self):
|
def as_array(self):
|
||||||
if not np.any(self._array):
|
if not np.any(self._array):
|
||||||
|
# Which is Faster?
|
||||||
|
# indices = [x.pos for x in self]
|
||||||
|
# np.put(self._array, [np.ravel_multi_index((0, *x), self._array.shape) for x in indices], self.encodings)
|
||||||
x, y = zip(*[x.pos for x in self])
|
x, y = zip(*[x.pos for x in self])
|
||||||
self._array[0, x, y] = self.encoding
|
self._array[0, x, y] = self.encoding
|
||||||
return self._array
|
return self._array
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(WallTiles, self).__init__(*args, individual_slices=False,
|
super(WallTiles, self).__init__(*args, is_blocking_light=self._light_blocking, individual_slices=False,
|
||||||
is_blocking_light=self._light_blocking, **kwargs)
|
**kwargs)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def encoding(self):
|
def encoding(self):
|
||||||
return c.OCCUPIED_CELL.value
|
return c.OCCUPIED_CELL.value
|
||||||
|
|
||||||
@property
|
|
||||||
def array(self):
|
|
||||||
return self._array
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_argwhere_coordinates(cls, argwhere_coordinates, *args, **kwargs):
|
def from_argwhere_coordinates(cls, argwhere_coordinates, *args, **kwargs):
|
||||||
tiles = cls(*args, **kwargs)
|
tiles = cls(*args, **kwargs)
|
||||||
# noinspection PyTypeChecker
|
# noinspection PyTypeChecker
|
||||||
tiles.register_additional_items(
|
tiles.register_additional_items(
|
||||||
[cls._accepted_objects(pos, is_blocking_light=cls._light_blocking)
|
[cls._accepted_objects(pos, tiles, is_blocking_light=cls._light_blocking)
|
||||||
for pos in argwhere_coordinates]
|
for pos in argwhere_coordinates]
|
||||||
)
|
)
|
||||||
return tiles
|
return tiles
|
||||||
@ -264,12 +388,11 @@ class WallTiles(EntityObjectRegister):
|
|||||||
|
|
||||||
|
|
||||||
class FloorTiles(WallTiles):
|
class FloorTiles(WallTiles):
|
||||||
|
|
||||||
_accepted_objects = Tile
|
_accepted_objects = Tile
|
||||||
_light_blocking = False
|
_light_blocking = False
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(FloorTiles, self).__init__(*args, is_observable=False, **kwargs)
|
super(FloorTiles, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def encoding(self):
|
def encoding(self):
|
||||||
@ -297,22 +420,21 @@ class FloorTiles(WallTiles):
|
|||||||
|
|
||||||
|
|
||||||
class Agents(MovingEntityObjectRegister):
|
class Agents(MovingEntityObjectRegister):
|
||||||
|
|
||||||
_accepted_objects = Agent
|
_accepted_objects = Agent
|
||||||
|
|
||||||
def __init__(self, *args, hide_from_obs_builder=False, **kwargs):
|
def __init__(self, *args, hide_from_obs_builder=False, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.hide_from_obs_builder = hide_from_obs_builder
|
self.hide_from_obs_builder = hide_from_obs_builder
|
||||||
|
|
||||||
# noinspection DuplicatedCode
|
@DeprecationWarning
|
||||||
def as_array(self):
|
def Xas_array(self):
|
||||||
self._array[:] = c.FREE_CELL.value
|
# Super Safe Version
|
||||||
# noinspection PyTupleAssignmentBalance
|
# self._array[:] = c.FREE_CELL.value
|
||||||
for z, x, y, v in zip(range(len(self)), *zip(*[x.pos for x in self]), [x.encoding for x in self]):
|
indices = list(zip(range(len(self)), *zip(*[x.last_pos for x in self])))
|
||||||
if self.individual_slices:
|
np.put(self._array, [np.ravel_multi_index(x, self._array.shape) for x in indices], c.FREE_CELL.value)
|
||||||
self._array[z, x, y] += v
|
indices = list(zip(range(len(self)), *zip(*[x.pos for x in self])))
|
||||||
else:
|
np.put(self._array, [np.ravel_multi_index(x, self._array.shape) for x in indices], self.encodings)
|
||||||
self._array[0, x, y] += v
|
|
||||||
if self.individual_slices:
|
if self.individual_slices:
|
||||||
return self._array
|
return self._array
|
||||||
else:
|
else:
|
||||||
@ -329,17 +451,11 @@ class Agents(MovingEntityObjectRegister):
|
|||||||
self._register[agent.name] = agent
|
self._register[agent.name] = agent
|
||||||
|
|
||||||
|
|
||||||
class Doors(EntityObjectRegister):
|
class Doors(EntityRegister):
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(Doors, self).__init__(*args, is_blocking_light=True, **kwargs)
|
super(Doors, self).__init__(*args, is_blocking_light=True, **kwargs)
|
||||||
|
|
||||||
def as_array(self):
|
|
||||||
self._array[:] = 0
|
|
||||||
for door in self:
|
|
||||||
self._array[0, door.x, door.y] = door.encoding
|
|
||||||
return self._array
|
|
||||||
|
|
||||||
_accepted_objects = Door
|
_accepted_objects = Door
|
||||||
|
|
||||||
def get_near_position(self, position: (int, int)) -> Union[None, Door]:
|
def get_near_position(self, position: (int, int)) -> Union[None, Door]:
|
||||||
@ -353,8 +469,7 @@ class Doors(EntityObjectRegister):
|
|||||||
door.tick()
|
door.tick()
|
||||||
|
|
||||||
|
|
||||||
class Actions(Register):
|
class Actions(ObjectRegister):
|
||||||
|
|
||||||
_accepted_objects = Action
|
_accepted_objects = Action
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -385,7 +500,7 @@ class Actions(Register):
|
|||||||
return action in self.movement_actions.values()
|
return action in self.movement_actions.values()
|
||||||
|
|
||||||
|
|
||||||
class Zones(Register):
|
class Zones(ObjectRegister):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def accounting_zones(self):
|
def accounting_zones(self):
|
||||||
|
@ -13,7 +13,7 @@ mult_array = np.asarray([
|
|||||||
class Map(object):
|
class Map(object):
|
||||||
# Multipliers for transforming coordinates to other octants:
|
# Multipliers for transforming coordinates to other octants:
|
||||||
|
|
||||||
def __init__(self, map_array: np.ndarray, diamond_slope: float = 0.9):
|
def __init__(self, map_array: np.typing.ArrayLike, diamond_slope: float = 0.9):
|
||||||
self.data = map_array
|
self.data = map_array
|
||||||
self.width, self.height = map_array.shape
|
self.width, self.height = map_array.shape
|
||||||
self.light = np.full_like(self.data, c.FREE_CELL.value)
|
self.light = np.full_like(self.data, c.FREE_CELL.value)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import random
|
import random
|
||||||
|
|
||||||
from environments.factory.factory_battery import BatteryFactory, BatteryProperties
|
from environments.factory.factory_battery import BatteryFactory, BatteryProperties
|
||||||
|
from environments.factory.factory_dest import DestFactory
|
||||||
from environments.factory.factory_dirt import DirtFactory, DirtProperties
|
from environments.factory.factory_dirt import DirtFactory, DirtProperties
|
||||||
from environments.factory.factory_item import ItemFactory
|
from environments.factory.factory_item import ItemFactory
|
||||||
|
|
||||||
@ -17,6 +18,12 @@ class DirtBatteryFactory(DirtFactory, BatteryFactory):
|
|||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
# noinspection PyAbstractClass
|
||||||
|
class DirtDestItemFactory(ItemFactory, DirtFactory, DestFactory):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
from environments.utility_classes import AgentRenderOptions as ARO, ObservationProperties
|
from environments.utility_classes import AgentRenderOptions as ARO, ObservationProperties
|
||||||
|
|
||||||
|
@ -1,18 +1,18 @@
|
|||||||
from typing import Union, NamedTuple
|
from typing import Union, NamedTuple, Dict
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from environments.factory.base.base_factory import BaseFactory
|
from environments.factory.base.base_factory import BaseFactory
|
||||||
from environments.factory.base.objects import Agent, Action, Entity
|
from environments.factory.base.objects import Agent, Action, Entity, EnvObject, BoundingMixin
|
||||||
from environments.factory.base.registers import EntityObjectRegister, ObjectRegister
|
from environments.factory.base.registers import EntityRegister, EnvObjectRegister
|
||||||
from environments.factory.base.renderer import RenderEntity
|
from environments.factory.base.renderer import RenderEntity
|
||||||
from environments.helpers import Constants as c
|
from environments.helpers import Constants as c, Constants
|
||||||
|
|
||||||
from environments import helpers as h
|
from environments import helpers as h
|
||||||
|
|
||||||
|
|
||||||
CHARGE_ACTION = h.EnvActions.CHARGE
|
CHARGE_ACTION = h.EnvActions.CHARGE
|
||||||
ITEM_DROP_OFF = 1
|
CHARGE_POD = 1
|
||||||
|
|
||||||
|
|
||||||
class BatteryProperties(NamedTuple):
|
class BatteryProperties(NamedTuple):
|
||||||
@ -24,42 +24,18 @@ class BatteryProperties(NamedTuple):
|
|||||||
multi_charge: bool = False
|
multi_charge: bool = False
|
||||||
|
|
||||||
|
|
||||||
class Battery(object):
|
class Battery(EnvObject, BoundingMixin):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_discharged(self):
|
def is_discharged(self):
|
||||||
return self.charge_level == 0
|
return self.charge_level == 0
|
||||||
|
|
||||||
@property
|
def __init__(self, initial_charge_level: float, *args, **kwargs):
|
||||||
def is_blocking_light(self):
|
super(Battery, self).__init__(*args, **kwargs)
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
|
||||||
def can_collide(self):
|
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self):
|
|
||||||
return f'{self.__class__.__name__}({self.agent.name})'
|
|
||||||
|
|
||||||
def __init__(self, pomdp_r: int, level_shape: (int, int), agent: Agent, initial_charge_level: float):
|
|
||||||
super().__init__()
|
|
||||||
self.agent = agent
|
|
||||||
self._pomdp_r = pomdp_r
|
|
||||||
self._level_shape = level_shape
|
|
||||||
if self._pomdp_r:
|
|
||||||
self._array = np.zeros((1, pomdp_r * 2 + 1, pomdp_r * 2 + 1))
|
|
||||||
else:
|
|
||||||
self._array = np.zeros((1, *self._level_shape))
|
|
||||||
self.charge_level = initial_charge_level
|
self.charge_level = initial_charge_level
|
||||||
|
|
||||||
def as_array(self):
|
def encoding(self):
|
||||||
self._array[:] = c.FREE_CELL.value
|
return self.charge_level
|
||||||
self._array[0, 0] = self.charge_level
|
|
||||||
return self._array
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f'{self.__class__.__name__}[{self.agent.name}]({self.charge_level})'
|
|
||||||
|
|
||||||
def charge(self, amount) -> c:
|
def charge(self, amount) -> c:
|
||||||
if self.charge_level < 1:
|
if self.charge_level < 1:
|
||||||
@ -73,12 +49,10 @@ class Battery(object):
|
|||||||
if self.charge_level != 0:
|
if self.charge_level != 0:
|
||||||
# noinspection PyTypeChecker
|
# noinspection PyTypeChecker
|
||||||
self.charge_level = max(0, amount + self.charge_level)
|
self.charge_level = max(0, amount + self.charge_level)
|
||||||
|
self._register.notify_change_to_value(self)
|
||||||
return c.VALID
|
return c.VALID
|
||||||
else:
|
else:
|
||||||
return c.NOT_VALID
|
return c.NOT_VALID
|
||||||
|
|
||||||
def belongs_to_entity(self, entity):
|
|
||||||
return self.agent == entity
|
|
||||||
|
|
||||||
def summarize_state(self, **kwargs):
|
def summarize_state(self, **kwargs):
|
||||||
attr_dict = {key: str(val) for key, val in self.__dict__.items() if not key.startswith('_') and key != 'data'}
|
attr_dict = {key: str(val) for key, val in self.__dict__.items() if not key.startswith('_') and key != 'data'}
|
||||||
@ -86,7 +60,7 @@ class Battery(object):
|
|||||||
return attr_dict
|
return attr_dict
|
||||||
|
|
||||||
|
|
||||||
class BatteriesRegister(ObjectRegister):
|
class BatteriesRegister(EnvObjectRegister):
|
||||||
|
|
||||||
_accepted_objects = Battery
|
_accepted_objects = Battery
|
||||||
is_blocking_light = False
|
is_blocking_light = False
|
||||||
@ -98,16 +72,17 @@ class BatteriesRegister(ObjectRegister):
|
|||||||
self.is_observable = True
|
self.is_observable = True
|
||||||
|
|
||||||
def as_array(self):
|
def as_array(self):
|
||||||
# self._array[:] = c.FREE_CELL.value
|
# ToDO: Make this Lazy
|
||||||
|
self._array[:] = c.FREE_CELL.value
|
||||||
for inv_idx, battery in enumerate(self):
|
for inv_idx, battery in enumerate(self):
|
||||||
self._array[inv_idx] = battery.as_array()
|
self._array[inv_idx] = battery.as_array()
|
||||||
return self._array
|
return self._array
|
||||||
|
|
||||||
def spawn_batteries(self, agents, pomdp_r, initial_charge_level):
|
def spawn_batteries(self, agents, pomdp_r, initial_charge_level):
|
||||||
inventories = [self._accepted_objects(pomdp_r, self._level_shape, agent,
|
batteries = [self._accepted_objects(pomdp_r, self._shape, agent,
|
||||||
initial_charge_level)
|
initial_charge_level)
|
||||||
for _, agent in enumerate(agents)]
|
for _, agent in enumerate(agents)]
|
||||||
self.register_additional_items(inventories)
|
self.register_additional_items(batteries)
|
||||||
|
|
||||||
def idx_by_entity(self, entity):
|
def idx_by_entity(self, entity):
|
||||||
try:
|
try:
|
||||||
@ -135,7 +110,7 @@ class ChargePod(Entity):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def encoding(self):
|
def encoding(self):
|
||||||
return ITEM_DROP_OFF
|
return CHARGE_POD
|
||||||
|
|
||||||
def __init__(self, *args, charge_rate: float = 0.4,
|
def __init__(self, *args, charge_rate: float = 0.4,
|
||||||
multi_charge: bool = False, **kwargs):
|
multi_charge: bool = False, **kwargs):
|
||||||
@ -157,11 +132,12 @@ class ChargePod(Entity):
|
|||||||
return summary
|
return summary
|
||||||
|
|
||||||
|
|
||||||
class ChargePods(EntityObjectRegister):
|
class ChargePods(EntityRegister):
|
||||||
|
|
||||||
_accepted_objects = ChargePod
|
_accepted_objects = ChargePod
|
||||||
|
|
||||||
def as_array(self):
|
@DeprecationWarning
|
||||||
|
def Xas_array(self):
|
||||||
self._array[:] = c.FREE_CELL.value
|
self._array[:] = c.FREE_CELL.value
|
||||||
for item in self:
|
for item in self:
|
||||||
if item.pos != c.NO_POS.value:
|
if item.pos != c.NO_POS.value:
|
||||||
@ -180,6 +156,16 @@ class BatteryFactory(BaseFactory):
|
|||||||
self.btry_prop = btry_prop
|
self.btry_prop = btry_prop
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def _additional_raw_observations(self, agent) -> Dict[Constants, np.typing.ArrayLike]:
|
||||||
|
additional_raw_observations = super()._additional_raw_observations(agent)
|
||||||
|
additional_raw_observations.update({c.BATTERIES: self[c.BATTERIES].by_entity(agent).as_array()})
|
||||||
|
return additional_raw_observations
|
||||||
|
|
||||||
|
def _additional_observations(self) -> Dict[Constants, np.typing.ArrayLike]:
|
||||||
|
additional_observations = super()._additional_observations()
|
||||||
|
additional_observations.update({c.CHARGE_POD: self[c.CHARGE_POD].as_array()})
|
||||||
|
return additional_observations
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def additional_entities(self):
|
def additional_entities(self):
|
||||||
super_entities = super().additional_entities
|
super_entities = super().additional_entities
|
||||||
|
@ -6,10 +6,10 @@ import numpy as np
|
|||||||
import random
|
import random
|
||||||
|
|
||||||
from environments.factory.base.base_factory import BaseFactory
|
from environments.factory.base.base_factory import BaseFactory
|
||||||
from environments.helpers import Constants as c
|
from environments.helpers import Constants as c, Constants
|
||||||
from environments import helpers as h
|
from environments import helpers as h
|
||||||
from environments.factory.base.objects import Agent, Entity, Action, Tile
|
from environments.factory.base.objects import Agent, Entity, Action
|
||||||
from environments.factory.base.registers import Entities, MovingEntityObjectRegister
|
from environments.factory.base.registers import Entities, EntityRegister
|
||||||
|
|
||||||
from environments.factory.base.renderer import RenderEntity
|
from environments.factory.base.renderer import RenderEntity
|
||||||
|
|
||||||
@ -62,13 +62,16 @@ class Destination(Entity):
|
|||||||
return state_summary
|
return state_summary
|
||||||
|
|
||||||
|
|
||||||
class Destinations(MovingEntityObjectRegister):
|
class Destinations(EntityRegister):
|
||||||
|
|
||||||
_accepted_objects = Destination
|
_accepted_objects = Destination
|
||||||
_light_blocking = False
|
_light_blocking = False
|
||||||
|
|
||||||
def as_array(self):
|
def as_array(self):
|
||||||
self._array[:] = c.FREE_CELL.value
|
self._array[:] = c.FREE_CELL.value
|
||||||
|
# ToDo: Switch to new Style Array Put
|
||||||
|
# indices = list(zip(range(len(self)), *zip(*[x.pos for x in self])))
|
||||||
|
# np.put(self._array, [np.ravel_multi_index(x, self._array.shape) for x in indices], self.encodings)
|
||||||
for item in self:
|
for item in self:
|
||||||
if item.pos != c.NO_POS.value:
|
if item.pos != c.NO_POS.value:
|
||||||
self._array[0, item.x, item.y] = item.encoding
|
self._array[0, item.x, item.y] = item.encoding
|
||||||
@ -83,39 +86,39 @@ class ReachedDestinations(Destinations):
|
|||||||
_light_blocking = False
|
_light_blocking = False
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(ReachedDestinations, self).__init__(*args, is_observable=False, **kwargs)
|
super(ReachedDestinations, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
def summarize_states(self, n_steps=None):
|
def summarize_states(self, n_steps=None):
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
class DestSpawnMode(object):
|
class DestModeOptions(object):
|
||||||
DONE = 'DONE'
|
DONE = 'DONE'
|
||||||
GROUPED = 'GROUPED'
|
GROUPED = 'GROUPED'
|
||||||
PER_DEST = 'PER_DEST'
|
PER_DEST = 'PER_DEST'
|
||||||
|
|
||||||
|
|
||||||
class DestinationProperties(NamedTuple):
|
class DestProperties(NamedTuple):
|
||||||
n_dests: int = 1 # How many destinations are there
|
n_dests: int = 1 # How many destinations are there
|
||||||
dwell_time: int = 0 # How long does the agent need to "wait" on a destination
|
dwell_time: int = 0 # How long does the agent need to "wait" on a destination
|
||||||
spawn_frequency: int = 0
|
spawn_frequency: int = 0
|
||||||
spawn_in_other_zone: bool = True #
|
spawn_in_other_zone: bool = True #
|
||||||
spawn_mode: str = DestSpawnMode.DONE
|
spawn_mode: str = DestModeOptions.DONE
|
||||||
|
|
||||||
assert dwell_time >= 0, 'dwell_time cannot be < 0!'
|
assert dwell_time >= 0, 'dwell_time cannot be < 0!'
|
||||||
assert spawn_frequency >= 0, 'spawn_frequency cannot be < 0!'
|
assert spawn_frequency >= 0, 'spawn_frequency cannot be < 0!'
|
||||||
assert n_dests >= 0, 'n_destinations cannot be < 0!'
|
assert n_dests >= 0, 'n_destinations cannot be < 0!'
|
||||||
assert (spawn_mode == DestSpawnMode.DONE) != bool(spawn_frequency)
|
assert (spawn_mode == DestModeOptions.DONE) != bool(spawn_frequency)
|
||||||
|
|
||||||
|
|
||||||
# noinspection PyAttributeOutsideInit, PyAbstractClass
|
# noinspection PyAttributeOutsideInit, PyAbstractClass
|
||||||
class DestinationFactory(BaseFactory):
|
class DestFactory(BaseFactory):
|
||||||
# noinspection PyMissingConstructor
|
# noinspection PyMissingConstructor
|
||||||
|
|
||||||
def __init__(self, *args, dest_prop: DestinationProperties = DestinationProperties(),
|
def __init__(self, *args, dest_prop: DestProperties = DestProperties(),
|
||||||
env_seed=time.time_ns(), **kwargs):
|
env_seed=time.time_ns(), **kwargs):
|
||||||
if isinstance(dest_prop, dict):
|
if isinstance(dest_prop, dict):
|
||||||
dest_prop = DestinationProperties(**dest_prop)
|
dest_prop = DestProperties(**dest_prop)
|
||||||
self.dest_prop = dest_prop
|
self.dest_prop = dest_prop
|
||||||
kwargs.update(env_seed=env_seed)
|
kwargs.update(env_seed=env_seed)
|
||||||
self._dest_rng = np.random.default_rng(env_seed)
|
self._dest_rng = np.random.default_rng(env_seed)
|
||||||
@ -145,10 +148,6 @@ class DestinationFactory(BaseFactory):
|
|||||||
super_entities.update({c.DESTINATION: destinations, c.REACHEDDESTINATION: reached_destinations})
|
super_entities.update({c.DESTINATION: destinations, c.REACHEDDESTINATION: reached_destinations})
|
||||||
return super_entities
|
return super_entities
|
||||||
|
|
||||||
def additional_per_agent_obs_build(self, agent) -> List[np.ndarray]:
|
|
||||||
additional_per_agent_obs_build = super().additional_per_agent_obs_build(agent)
|
|
||||||
return additional_per_agent_obs_build
|
|
||||||
|
|
||||||
def wait(self, agent: Agent):
|
def wait(self, agent: Agent):
|
||||||
if destiantion := self[c.DESTINATION].by_pos(agent.pos):
|
if destiantion := self[c.DESTINATION].by_pos(agent.pos):
|
||||||
valid = destiantion.wait(agent)
|
valid = destiantion.wait(agent)
|
||||||
@ -178,13 +177,13 @@ class DestinationFactory(BaseFactory):
|
|||||||
if val == self.dest_prop.spawn_frequency]
|
if val == self.dest_prop.spawn_frequency]
|
||||||
if destinations_to_spawn:
|
if destinations_to_spawn:
|
||||||
n_dest_to_spawn = len(destinations_to_spawn)
|
n_dest_to_spawn = len(destinations_to_spawn)
|
||||||
if self.dest_prop.spawn_mode != DestSpawnMode.GROUPED:
|
if self.dest_prop.spawn_mode != DestModeOptions.GROUPED:
|
||||||
destinations = [Destination(tile) for tile in self[c.FLOOR].empty_tiles[:n_dest_to_spawn]]
|
destinations = [Destination(tile) for tile in self[c.FLOOR].empty_tiles[:n_dest_to_spawn]]
|
||||||
self[c.DESTINATION].register_additional_items(destinations)
|
self[c.DESTINATION].register_additional_items(destinations)
|
||||||
for dest in destinations_to_spawn:
|
for dest in destinations_to_spawn:
|
||||||
del self._dest_spawn_timer[dest]
|
del self._dest_spawn_timer[dest]
|
||||||
self.print(f'{n_dest_to_spawn} new destinations have been spawned')
|
self.print(f'{n_dest_to_spawn} new destinations have been spawned')
|
||||||
elif self.dest_prop.spawn_mode == DestSpawnMode.GROUPED and n_dest_to_spawn == self.dest_prop.n_dests:
|
elif self.dest_prop.spawn_mode == DestModeOptions.GROUPED and n_dest_to_spawn == self.dest_prop.n_dests:
|
||||||
destinations = [Destination(tile) for tile in self[c.FLOOR].empty_tiles[:n_dest_to_spawn]]
|
destinations = [Destination(tile) for tile in self[c.FLOOR].empty_tiles[:n_dest_to_spawn]]
|
||||||
self[c.DESTINATION].register_additional_items(destinations)
|
self[c.DESTINATION].register_additional_items(destinations)
|
||||||
for dest in destinations_to_spawn:
|
for dest in destinations_to_spawn:
|
||||||
@ -204,7 +203,7 @@ class DestinationFactory(BaseFactory):
|
|||||||
for dest in list(self[c.DESTINATION].values()):
|
for dest in list(self[c.DESTINATION].values()):
|
||||||
if dest.is_considered_reached:
|
if dest.is_considered_reached:
|
||||||
self[c.REACHEDDESTINATION].register_item(dest)
|
self[c.REACHEDDESTINATION].register_item(dest)
|
||||||
self[c.DESTINATION].delete_entity(dest)
|
self[c.DESTINATION].delete_env_object(dest)
|
||||||
self._dest_spawn_timer[dest.name] = 0
|
self._dest_spawn_timer[dest.name] = 0
|
||||||
self.print(f'{dest.name} is reached now, removing...')
|
self.print(f'{dest.name} is reached now, removing...')
|
||||||
else:
|
else:
|
||||||
@ -219,6 +218,11 @@ class DestinationFactory(BaseFactory):
|
|||||||
self.trigger_destination_spawn()
|
self.trigger_destination_spawn()
|
||||||
return info_dict
|
return info_dict
|
||||||
|
|
||||||
|
def _additional_observations(self) -> Dict[Constants, np.typing.ArrayLike]:
|
||||||
|
additional_observations = super()._additional_observations()
|
||||||
|
additional_observations.update({c.DESTINATION: self[c.DESTINATION].as_array()})
|
||||||
|
return additional_observations
|
||||||
|
|
||||||
def calculate_additional_reward(self, agent: Agent) -> (int, dict):
|
def calculate_additional_reward(self, agent: Agent) -> (int, dict):
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
reward, info_dict = super().calculate_additional_reward(agent)
|
reward, info_dict = super().calculate_additional_reward(agent)
|
||||||
@ -240,7 +244,7 @@ class DestinationFactory(BaseFactory):
|
|||||||
info_dict.update(agent_reached_destination=1)
|
info_dict.update(agent_reached_destination=1)
|
||||||
self.print(f'{agent.name} just reached destination at {agent.pos}')
|
self.print(f'{agent.name} just reached destination at {agent.pos}')
|
||||||
reward += 0.5
|
reward += 0.5
|
||||||
self[c.REACHEDDESTINATION].delete_entity(reached_dest)
|
self[c.REACHEDDESTINATION].delete_env_object(reached_dest)
|
||||||
return reward, info_dict
|
return reward, info_dict
|
||||||
|
|
||||||
def render_additional_assets(self, mode='human'):
|
def render_additional_assets(self, mode='human'):
|
||||||
@ -256,7 +260,7 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
render = True
|
render = True
|
||||||
|
|
||||||
dest_probs = DestinationProperties(n_dests=2, spawn_frequency=5, spawn_mode=DestSpawnMode.GROUPED)
|
dest_probs = DestProperties(n_dests=2, spawn_frequency=5, spawn_mode=DestModeOptions.GROUPED)
|
||||||
|
|
||||||
obs_props = ObservationProperties(render_agents=ARO.LEVEL, omit_agent_self=True, pomdp_r=2)
|
obs_props = ObservationProperties(render_agents=ARO.LEVEL, omit_agent_self=True, pomdp_r=2)
|
||||||
|
|
||||||
@ -264,12 +268,12 @@ if __name__ == '__main__':
|
|||||||
'allow_diagonal_movement': False,
|
'allow_diagonal_movement': False,
|
||||||
'allow_no_op': False}
|
'allow_no_op': False}
|
||||||
|
|
||||||
factory = DestinationFactory(n_agents=10, done_at_collision=False,
|
factory = DestFactory(n_agents=10, done_at_collision=False,
|
||||||
level_name='rooms', max_steps=400,
|
level_name='rooms', max_steps=400,
|
||||||
obs_prop=obs_props, parse_doors=True,
|
obs_prop=obs_props, parse_doors=True,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
mv_prop=move_props, dest_prop=dest_probs
|
mv_prop=move_props, dest_prop=dest_probs
|
||||||
)
|
)
|
||||||
|
|
||||||
# noinspection DuplicatedCode
|
# noinspection DuplicatedCode
|
||||||
n_actions = factory.action_space.n - 1
|
n_actions = factory.action_space.n - 1
|
@ -6,11 +6,11 @@ import random
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from algorithms.TSP_dirt_agent import TSPDirtAgent
|
from algorithms.TSP_dirt_agent import TSPDirtAgent
|
||||||
from environments.helpers import Constants as c
|
from environments.helpers import Constants as c, Constants
|
||||||
from environments import helpers as h
|
from environments import helpers as h
|
||||||
from environments.factory.base.base_factory import BaseFactory
|
from environments.factory.base.base_factory import BaseFactory
|
||||||
from environments.factory.base.objects import Agent, Action, Entity, Tile
|
from environments.factory.base.objects import Agent, Action, Entity, Tile
|
||||||
from environments.factory.base.registers import Entities, MovingEntityObjectRegister
|
from environments.factory.base.registers import Entities, MovingEntityObjectRegister, EntityRegister
|
||||||
|
|
||||||
from environments.factory.base.renderer import RenderEntity
|
from environments.factory.base.renderer import RenderEntity
|
||||||
from environments.utility_classes import ObservationProperties
|
from environments.utility_classes import ObservationProperties
|
||||||
@ -42,6 +42,7 @@ class Dirt(Entity):
|
|||||||
def amount(self):
|
def amount(self):
|
||||||
return self._amount
|
return self._amount
|
||||||
|
|
||||||
|
@property
|
||||||
def encoding(self):
|
def encoding(self):
|
||||||
# Edit this if you want items to be drawn in the ops differntly
|
# Edit this if you want items to be drawn in the ops differntly
|
||||||
return self._amount
|
return self._amount
|
||||||
@ -52,6 +53,7 @@ class Dirt(Entity):
|
|||||||
|
|
||||||
def set_new_amount(self, amount):
|
def set_new_amount(self, amount):
|
||||||
self._amount = amount
|
self._amount = amount
|
||||||
|
self._register.notify_change_to_value(self)
|
||||||
|
|
||||||
def summarize_state(self, **kwargs):
|
def summarize_state(self, **kwargs):
|
||||||
state_dict = super().summarize_state(**kwargs)
|
state_dict = super().summarize_state(**kwargs)
|
||||||
@ -59,18 +61,7 @@ class Dirt(Entity):
|
|||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
class DirtRegister(MovingEntityObjectRegister):
|
class DirtRegister(EntityRegister):
|
||||||
|
|
||||||
def as_array(self):
|
|
||||||
if self._array is not None:
|
|
||||||
self._array[:] = c.FREE_CELL.value
|
|
||||||
for dirt in list(self.values()):
|
|
||||||
if dirt.amount == 0:
|
|
||||||
self.delete_entity(dirt)
|
|
||||||
self._array[0, dirt.x, dirt.y] = dirt.amount
|
|
||||||
else:
|
|
||||||
self._array = np.zeros((1, *self._level_shape))
|
|
||||||
return self._array
|
|
||||||
|
|
||||||
_accepted_objects = Dirt
|
_accepted_objects = Dirt
|
||||||
|
|
||||||
@ -93,7 +84,7 @@ class DirtRegister(MovingEntityObjectRegister):
|
|||||||
if not self.amount > self.dirt_properties.max_global_amount:
|
if not self.amount > self.dirt_properties.max_global_amount:
|
||||||
dirt = self.by_pos(tile.pos)
|
dirt = self.by_pos(tile.pos)
|
||||||
if dirt is None:
|
if dirt is None:
|
||||||
dirt = Dirt(tile, amount=self.dirt_properties.max_spawn_amount)
|
dirt = Dirt(tile, self, amount=self.dirt_properties.max_spawn_amount)
|
||||||
self.register_item(dirt)
|
self.register_item(dirt)
|
||||||
else:
|
else:
|
||||||
new_value = dirt.amount + self.dirt_properties.max_spawn_amount
|
new_value = dirt.amount + self.dirt_properties.max_spawn_amount
|
||||||
@ -155,7 +146,7 @@ class DirtFactory(BaseFactory):
|
|||||||
new_dirt_amount = dirt.amount - self.dirt_prop.clean_amount
|
new_dirt_amount = dirt.amount - self.dirt_prop.clean_amount
|
||||||
|
|
||||||
if new_dirt_amount <= 0:
|
if new_dirt_amount <= 0:
|
||||||
self[c.DIRT].delete_entity(dirt)
|
self[c.DIRT].delete_env_object(dirt)
|
||||||
else:
|
else:
|
||||||
dirt.set_new_amount(max(new_dirt_amount, c.FREE_CELL.value))
|
dirt.set_new_amount(max(new_dirt_amount, c.FREE_CELL.value))
|
||||||
return c.VALID
|
return c.VALID
|
||||||
@ -224,6 +215,11 @@ class DirtFactory(BaseFactory):
|
|||||||
done = self.dirt_prop.done_when_clean and (len(self[c.DIRT]) == 0)
|
done = self.dirt_prop.done_when_clean and (len(self[c.DIRT]) == 0)
|
||||||
return super_done or done
|
return super_done or done
|
||||||
|
|
||||||
|
def _additional_observations(self) -> Dict[Constants, np.typing.ArrayLike]:
|
||||||
|
additional_observations = super()._additional_observations()
|
||||||
|
additional_observations.update({c.DIRT: self[c.DIRT].as_array()})
|
||||||
|
return additional_observations
|
||||||
|
|
||||||
def calculate_additional_reward(self, agent: Agent) -> (int, dict):
|
def calculate_additional_reward(self, agent: Agent) -> (int, dict):
|
||||||
reward, info_dict = super().calculate_additional_reward(agent)
|
reward, info_dict = super().calculate_additional_reward(agent)
|
||||||
dirt = [dirt.amount for dirt in self[c.DIRT]]
|
dirt = [dirt.amount for dirt in self[c.DIRT]]
|
||||||
@ -278,41 +274,52 @@ if __name__ == '__main__':
|
|||||||
)
|
)
|
||||||
|
|
||||||
obs_props = ObservationProperties(render_agents=ARO.COMBINED, omit_agent_self=True,
|
obs_props = ObservationProperties(render_agents=ARO.COMBINED, omit_agent_self=True,
|
||||||
pomdp_r=2, additional_agent_placeholder=None)
|
pomdp_r=2, additional_agent_placeholder=None, cast_shadows=True)
|
||||||
|
|
||||||
move_props = {'allow_square_movement': True,
|
move_props = {'allow_square_movement': True,
|
||||||
'allow_diagonal_movement': False,
|
'allow_diagonal_movement': False,
|
||||||
'allow_no_op': False}
|
'allow_no_op': False}
|
||||||
|
global_timings = []
|
||||||
|
for i in range(20):
|
||||||
|
|
||||||
factory = DirtFactory(n_agents=1, done_at_collision=False,
|
factory = DirtFactory(n_agents=2, done_at_collision=False,
|
||||||
level_name='rooms', max_steps=400,
|
level_name='rooms', max_steps=1000,
|
||||||
doors_have_area=False,
|
doors_have_area=False,
|
||||||
obs_prop=obs_props, parse_doors=True,
|
obs_prop=obs_props, parse_doors=True,
|
||||||
record_episodes=True, verbose=True,
|
record_episodes=True, verbose=True,
|
||||||
mv_prop=move_props, dirt_prop=dirt_props,
|
mv_prop=move_props, dirt_prop=dirt_props,
|
||||||
inject_agents=[TSPDirtAgent]
|
# inject_agents=[TSPDirtAgent],
|
||||||
)
|
)
|
||||||
|
|
||||||
# noinspection DuplicatedCode
|
# noinspection DuplicatedCode
|
||||||
n_actions = factory.action_space.n - 1
|
n_actions = factory.action_space.n - 1
|
||||||
_ = factory.observation_space
|
_ = factory.observation_space
|
||||||
|
obs_space = factory.observation_space
|
||||||
for epoch in range(10):
|
obs_space_named = factory.named_observation_space
|
||||||
random_actions = [[random.randint(0, n_actions) for _
|
times = []
|
||||||
in range(factory.n_agents)] for _
|
import time
|
||||||
in range(factory.max_steps+1)]
|
for epoch in range(10):
|
||||||
env_state = factory.reset()
|
start_time = time.time()
|
||||||
if render:
|
random_actions = [[random.randint(0, n_actions) for _
|
||||||
factory.render()
|
in range(factory.n_agents)] for _
|
||||||
tsp_agent = factory.get_injected_agents()[0]
|
in range(factory.max_steps+1)]
|
||||||
|
env_state = factory.reset()
|
||||||
r = 0
|
|
||||||
for agent_i_action in random_actions:
|
|
||||||
env_state, step_r, done_bool, info_obj = factory.step(tsp_agent.predict())
|
|
||||||
r += step_r
|
|
||||||
if render:
|
if render:
|
||||||
factory.render()
|
factory.render()
|
||||||
if done_bool:
|
# tsp_agent = factory.get_injected_agents()[0]
|
||||||
break
|
|
||||||
print(f'Factory run {epoch} done, reward is:\n {r}')
|
r = 0
|
||||||
|
for agent_i_action in random_actions:
|
||||||
|
env_state, step_r, done_bool, info_obj = factory.step(agent_i_action)
|
||||||
|
r += step_r
|
||||||
|
if render:
|
||||||
|
factory.render()
|
||||||
|
if done_bool:
|
||||||
|
break
|
||||||
|
times.append(time.time() - start_time)
|
||||||
|
# print(f'Factory run {epoch} done, reward is:\n {r}')
|
||||||
|
print('Time Taken: ', sum(times) / 10)
|
||||||
|
global_timings.append(sum(times) / 10)
|
||||||
|
print('Time Taken: ', sum(global_timings[10:]) / 10)
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
@ -6,11 +6,11 @@ import numpy as np
|
|||||||
import random
|
import random
|
||||||
|
|
||||||
from environments.factory.base.base_factory import BaseFactory
|
from environments.factory.base.base_factory import BaseFactory
|
||||||
from environments.helpers import Constants as c
|
from environments.helpers import Constants as c, Constants
|
||||||
from environments import helpers as h
|
from environments import helpers as h
|
||||||
from environments.factory.base.objects import Agent, Entity, Action, Tile, MoveableEntity
|
from environments.factory.base.objects import Agent, Entity, Action, Tile, MoveableEntity
|
||||||
from environments.factory.base.registers import Entities, EntityObjectRegister, ObjectRegister, \
|
from environments.factory.base.registers import Entities, EntityRegister, EnvObjectRegister, MovingEntityObjectRegister, \
|
||||||
MovingEntityObjectRegister
|
BoundRegisterMixin
|
||||||
|
|
||||||
from environments.factory.base.renderer import RenderEntity
|
from environments.factory.base.renderer import RenderEntity
|
||||||
|
|
||||||
@ -19,7 +19,7 @@ NO_ITEM = 0
|
|||||||
ITEM_DROP_OFF = 1
|
ITEM_DROP_OFF = 1
|
||||||
|
|
||||||
|
|
||||||
class Item(MoveableEntity):
|
class Item(Entity):
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
@ -41,20 +41,21 @@ class Item(MoveableEntity):
|
|||||||
def set_auto_despawn(self, auto_despawn):
|
def set_auto_despawn(self, auto_despawn):
|
||||||
self._auto_despawn = auto_despawn
|
self._auto_despawn = auto_despawn
|
||||||
|
|
||||||
|
def despawn(self):
|
||||||
|
# Todo: Move this to base class?
|
||||||
|
curr_tile = self.tile
|
||||||
|
curr_tile.leave(self)
|
||||||
|
self._tile = None
|
||||||
|
self._register.notify_change_to_value(self)
|
||||||
|
return True
|
||||||
|
|
||||||
class ItemRegister(MovingEntityObjectRegister):
|
|
||||||
|
|
||||||
def as_array(self):
|
class ItemRegister(EntityRegister):
|
||||||
self._array[:] = c.FREE_CELL.value
|
|
||||||
for item in self:
|
|
||||||
if item.pos != c.NO_POS.value:
|
|
||||||
self._array[0, item.x, item.y] = item.encoding
|
|
||||||
return self._array
|
|
||||||
|
|
||||||
_accepted_objects = Item
|
_accepted_objects = Item
|
||||||
|
|
||||||
def spawn_items(self, tiles: List[Tile]):
|
def spawn_items(self, tiles: List[Tile]):
|
||||||
items = [Item(tile) for tile in tiles]
|
items = [Item(tile, self) for tile in tiles]
|
||||||
self.register_additional_items(items)
|
self.register_additional_items(items)
|
||||||
|
|
||||||
def despawn_items(self, items: List[Item]):
|
def despawn_items(self, items: List[Item]):
|
||||||
@ -63,7 +64,7 @@ class ItemRegister(MovingEntityObjectRegister):
|
|||||||
del self[item]
|
del self[item]
|
||||||
|
|
||||||
|
|
||||||
class Inventory(UserList):
|
class Inventory(EntityRegister, BoundRegisterMixin):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_blocking_light(self):
|
def is_blocking_light(self):
|
||||||
@ -73,19 +74,18 @@ class Inventory(UserList):
|
|||||||
def name(self):
|
def name(self):
|
||||||
return f'{self.__class__.__name__}({self.agent.name})'
|
return f'{self.__class__.__name__}({self.agent.name})'
|
||||||
|
|
||||||
def __init__(self, pomdp_r: int, level_shape: (int, int), agent: Agent, capacity: int):
|
def __init__(self, obs_shape: (int, int), agent: Agent, capacity: int):
|
||||||
super(Inventory, self).__init__()
|
super(Inventory, self).__init__()
|
||||||
self.agent = agent
|
self.agent = agent
|
||||||
self.pomdp_r = pomdp_r
|
self._obs_shape = obs_shape
|
||||||
self._level_shape = level_shape
|
|
||||||
if self.pomdp_r:
|
self._array = np.zeros((1, *self._obs_shape))
|
||||||
self._array = np.zeros((1, pomdp_r * 2 + 1, pomdp_r * 2 + 1))
|
|
||||||
else:
|
|
||||||
self._array = np.zeros((1, *self._level_shape))
|
|
||||||
self.capacity = min(capacity, self._array.size)
|
self.capacity = min(capacity, self._array.size)
|
||||||
|
|
||||||
def as_array(self):
|
def as_array(self):
|
||||||
self._array[:] = c.FREE_CELL.value
|
self._array[:] = c.FREE_CELL.value
|
||||||
|
# ToDo: Make this Lazy
|
||||||
for item_idx, item in enumerate(self):
|
for item_idx, item in enumerate(self):
|
||||||
x_diff, y_diff = divmod(item_idx, self._array.shape[1])
|
x_diff, y_diff = divmod(item_idx, self._array.shape[1])
|
||||||
self._array[0, int(x_diff), int(y_diff)] = item.encoding
|
self._array[0, int(x_diff), int(y_diff)] = item.encoding
|
||||||
@ -110,25 +110,22 @@ class Inventory(UserList):
|
|||||||
return attr_dict
|
return attr_dict
|
||||||
|
|
||||||
|
|
||||||
class Inventories(ObjectRegister):
|
class Inventories(EnvObjectRegister):
|
||||||
|
|
||||||
_accepted_objects = Inventory
|
_accepted_objects = Inventory
|
||||||
is_blocking_light = False
|
is_blocking_light = False
|
||||||
can_be_shadowed = False
|
can_be_shadowed = False
|
||||||
hide_from_obs_builder = True
|
hide_from_obs_builder = True
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, obs_shape, *args, **kwargs):
|
||||||
super(Inventories, self).__init__(*args, is_per_agent=True, individual_slices=True, **kwargs)
|
super(Inventories, self).__init__(*args, is_per_agent=True, individual_slices=True, **kwargs)
|
||||||
self.is_observable = True
|
self._obs_shape = obs_shape
|
||||||
|
|
||||||
def as_array(self):
|
def as_array(self):
|
||||||
# self._array[:] = c.FREE_CELL.value
|
return np.stack([inventory.as_array() for inv_idx, inventory in enumerate(self)])
|
||||||
for inv_idx, inventory in enumerate(self):
|
|
||||||
self._array[inv_idx] = inventory.as_array()
|
|
||||||
return self._array
|
|
||||||
|
|
||||||
def spawn_inventories(self, agents, pomdp_r, capacity):
|
def spawn_inventories(self, agents, capacity):
|
||||||
inventories = [self._accepted_objects(pomdp_r, self._level_shape, agent, capacity)
|
inventories = [self._accepted_objects(self._obs_shape, agent, capacity)
|
||||||
for _, agent in enumerate(agents)]
|
for _, agent in enumerate(agents)]
|
||||||
self.register_additional_items(inventories)
|
self.register_additional_items(inventories)
|
||||||
|
|
||||||
@ -183,20 +180,20 @@ class DropOffLocation(Entity):
|
|||||||
return super().summarize_state(n_steps=n_steps)
|
return super().summarize_state(n_steps=n_steps)
|
||||||
|
|
||||||
|
|
||||||
class DropOffLocations(EntityObjectRegister):
|
class DropOffLocations(EntityRegister):
|
||||||
|
|
||||||
_accepted_objects = DropOffLocation
|
_accepted_objects = DropOffLocation
|
||||||
|
|
||||||
def as_array(self):
|
@DeprecationWarning
|
||||||
|
def Xas_array(self):
|
||||||
|
# Todo: Which is faster?
|
||||||
|
# indices = list(zip(range(len(self)), *zip(*[x.pos for x in self])))
|
||||||
|
# np.put(self._array, [np.ravel_multi_index(x, self._array.shape) for x in indices], self.encodings)
|
||||||
self._array[:] = c.FREE_CELL.value
|
self._array[:] = c.FREE_CELL.value
|
||||||
for item in self:
|
indices = list(zip([0, ] * len(self), *zip(*[x.pos for x in self])))
|
||||||
if item.pos != c.NO_POS.value:
|
np.put(self._array, [np.ravel_multi_index(x, self._array.shape) for x in indices], self.encodings)
|
||||||
self._array[0, item.x, item.y] = item.encoding
|
|
||||||
return self._array
|
return self._array
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
super(DropOffLocations, self).__repr__()
|
|
||||||
|
|
||||||
|
|
||||||
class ItemProperties(NamedTuple):
|
class ItemProperties(NamedTuple):
|
||||||
n_items: int = 5 # How many items are there at the same time
|
n_items: int = 5 # How many items are there at the same time
|
||||||
@ -241,17 +238,23 @@ class ItemFactory(BaseFactory):
|
|||||||
empty_tiles = self[c.FLOOR].empty_tiles[:self.item_prop.n_items]
|
empty_tiles = self[c.FLOOR].empty_tiles[:self.item_prop.n_items]
|
||||||
item_register.spawn_items(empty_tiles)
|
item_register.spawn_items(empty_tiles)
|
||||||
|
|
||||||
inventories = Inventories(self._level_shape if not self._pomdp_r else ((self.pomdp_diameter,) * 2))
|
inventories = Inventories(self._level_shape if not self._pomdp_r else ((self.pomdp_diameter,) * 2),
|
||||||
inventories.spawn_inventories(self[c.AGENT], self._pomdp_r,
|
self._level_shape)
|
||||||
self.item_prop.max_agent_inventory_capacity)
|
inventories.spawn_inventories(self[c.AGENT], self.item_prop.max_agent_inventory_capacity)
|
||||||
|
|
||||||
super_entities.update({c.DROP_OFF: drop_offs, c.ITEM: item_register, c.INVENTORY: inventories})
|
super_entities.update({c.DROP_OFF: drop_offs, c.ITEM: item_register, c.INVENTORY: inventories})
|
||||||
return super_entities
|
return super_entities
|
||||||
|
|
||||||
def additional_per_agent_obs_build(self, agent) -> List[np.ndarray]:
|
def _additional_raw_observations(self, agent) -> Dict[Constants, np.typing.ArrayLike]:
|
||||||
additional_per_agent_obs_build = super().additional_per_agent_obs_build(agent)
|
additional_raw_observations = super()._additional_raw_observations(agent)
|
||||||
additional_per_agent_obs_build.append(self[c.INVENTORY].by_entity(agent).as_array())
|
additional_raw_observations.update({c.INVENTORY: self[c.INVENTORY].by_entity(agent).as_array()})
|
||||||
return additional_per_agent_obs_build
|
return additional_raw_observations
|
||||||
|
|
||||||
|
def _additional_observations(self) -> Dict[Constants, np.typing.ArrayLike]:
|
||||||
|
additional_observations = super()._additional_observations()
|
||||||
|
additional_observations.update({c.ITEM: self[c.ITEM].as_array()})
|
||||||
|
additional_observations.update({c.DROP_OFF: self[c.DROP_OFF].as_array()})
|
||||||
|
return additional_observations
|
||||||
|
|
||||||
def do_item_action(self, agent: Agent):
|
def do_item_action(self, agent: Agent):
|
||||||
inventory = self[c.INVENTORY].by_entity(agent)
|
inventory = self[c.INVENTORY].by_entity(agent)
|
||||||
@ -264,7 +267,7 @@ class ItemFactory(BaseFactory):
|
|||||||
elif item := self[c.ITEM].by_pos(agent.pos):
|
elif item := self[c.ITEM].by_pos(agent.pos):
|
||||||
try:
|
try:
|
||||||
inventory.append(item)
|
inventory.append(item)
|
||||||
item.move(self._NO_POS_TILE)
|
item.despawn()
|
||||||
return c.VALID
|
return c.VALID
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
return c.NOT_VALID
|
return c.NOT_VALID
|
||||||
@ -308,7 +311,7 @@ class ItemFactory(BaseFactory):
|
|||||||
if item.auto_despawn >= 1:
|
if item.auto_despawn >= 1:
|
||||||
item.set_auto_despawn(item.auto_despawn-1)
|
item.set_auto_despawn(item.auto_despawn-1)
|
||||||
elif not item.auto_despawn:
|
elif not item.auto_despawn:
|
||||||
self[c.ITEM].delete_entity(item)
|
self[c.ITEM].delete_env_object(item)
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -327,12 +330,12 @@ class ItemFactory(BaseFactory):
|
|||||||
info_dict.update({f'{agent.name}_item_drop_off': 1})
|
info_dict.update({f'{agent.name}_item_drop_off': 1})
|
||||||
info_dict.update(item_drop_off=1)
|
info_dict.update(item_drop_off=1)
|
||||||
self.print(f'{agent.name} just dropped of an item at {drop_off.pos}.')
|
self.print(f'{agent.name} just dropped of an item at {drop_off.pos}.')
|
||||||
reward += 0.5
|
reward += 1
|
||||||
else:
|
else:
|
||||||
info_dict.update({f'{agent.name}_item_pickup': 1})
|
info_dict.update({f'{agent.name}_item_pickup': 1})
|
||||||
info_dict.update(item_pickup=1)
|
info_dict.update(item_pickup=1)
|
||||||
self.print(f'{agent.name} just picked up an item at {agent.pos}')
|
self.print(f'{agent.name} just picked up an item at {agent.pos}')
|
||||||
reward += 0.1
|
reward += 0.2
|
||||||
else:
|
else:
|
||||||
if self[c.DROP_OFF].by_pos(agent.pos):
|
if self[c.DROP_OFF].by_pos(agent.pos):
|
||||||
info_dict.update({f'{agent.name}_failed_drop_off': 1})
|
info_dict.update({f'{agent.name}_failed_drop_off': 1})
|
||||||
@ -363,13 +366,13 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
item_probs = ItemProperties()
|
item_probs = ItemProperties()
|
||||||
|
|
||||||
obs_props = ObservationProperties(render_agents=ARO.LEVEL, omit_agent_self=True, pomdp_r=2)
|
obs_props = ObservationProperties(render_agents=ARO.SEPERATE, omit_agent_self=True, pomdp_r=2)
|
||||||
|
|
||||||
move_props = {'allow_square_movement': True,
|
move_props = {'allow_square_movement': True,
|
||||||
'allow_diagonal_movement': False,
|
'allow_diagonal_movement': True,
|
||||||
'allow_no_op': False}
|
'allow_no_op': False}
|
||||||
|
|
||||||
factory = ItemFactory(n_agents=3, done_at_collision=False,
|
factory = ItemFactory(n_agents=2, done_at_collision=False,
|
||||||
level_name='rooms', max_steps=400,
|
level_name='rooms', max_steps=400,
|
||||||
obs_prop=obs_props, parse_doors=True,
|
obs_prop=obs_props, parse_doors=True,
|
||||||
record_episodes=True, verbose=True,
|
record_episodes=True, verbose=True,
|
||||||
@ -378,7 +381,8 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
# noinspection DuplicatedCode
|
# noinspection DuplicatedCode
|
||||||
n_actions = factory.action_space.n - 1
|
n_actions = factory.action_space.n - 1
|
||||||
_ = factory.observation_space
|
obs_space = factory.observation_space
|
||||||
|
obs_space_named = factory.named_observation_space
|
||||||
|
|
||||||
for epoch in range(4):
|
for epoch in range(4):
|
||||||
random_actions = [[random.randint(0, n_actions) for _
|
random_actions = [[random.randint(0, n_actions) for _
|
||||||
|
@ -1,12 +1,12 @@
|
|||||||
import itertools
|
import itertools
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from enum import Enum, auto
|
from enum import Enum
|
||||||
from typing import Tuple, Union
|
from pathlib import Path
|
||||||
|
from typing import Tuple, Union, Dict, List
|
||||||
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pathlib import Path
|
from numpy.typing import ArrayLike
|
||||||
|
|
||||||
from stable_baselines3 import PPO, DQN, A2C
|
from stable_baselines3 import PPO, DQN, A2C
|
||||||
|
|
||||||
MODEL_MAP = dict(PPO=PPO, DQN=DQN, A2C=A2C)
|
MODEL_MAP = dict(PPO=PPO, DQN=DQN, A2C=A2C)
|
||||||
@ -29,6 +29,7 @@ class Constants(Enum):
|
|||||||
LEVEL = 'Level'
|
LEVEL = 'Level'
|
||||||
AGENT = 'Agent'
|
AGENT = 'Agent'
|
||||||
AGENT_PLACEHOLDER = 'AGENT_PLACEHOLDER'
|
AGENT_PLACEHOLDER = 'AGENT_PLACEHOLDER'
|
||||||
|
GLOBAL_POSITION = 'GLOBAL_POSITION'
|
||||||
FREE_CELL = 0
|
FREE_CELL = 0
|
||||||
OCCUPIED_CELL = 1
|
OCCUPIED_CELL = 1
|
||||||
SHADOWED_CELL = -1
|
SHADOWED_CELL = -1
|
||||||
@ -109,6 +110,58 @@ ACTIONMAP = defaultdict(lambda: (0, 0), {m.NORTH: (-1, 0), m.NORTHEAST: (-1, +1)
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ObservationTranslator:
|
||||||
|
|
||||||
|
def __init__(self, obs_shape_2d: (int, int), this_named_observation_space: Dict[str, dict],
|
||||||
|
*per_agent_named_obs_space: Dict[str, dict],
|
||||||
|
placeholder_fill_value: Union[int, str] = 'N'):
|
||||||
|
assert len(obs_shape_2d) == 2
|
||||||
|
self.obs_shape = obs_shape_2d
|
||||||
|
if isinstance(placeholder_fill_value, str):
|
||||||
|
if placeholder_fill_value.lower() in ['normal', 'n']:
|
||||||
|
self.random_fill = lambda: np.random.normal(size=self.obs_shape)
|
||||||
|
elif placeholder_fill_value.lower() in ['uniform', 'u']:
|
||||||
|
self.random_fill = lambda: np.random.uniform(size=self.obs_shape)
|
||||||
|
else:
|
||||||
|
raise ValueError('Please chooe between "uniform" or "normal"')
|
||||||
|
else:
|
||||||
|
self.random_fill = None
|
||||||
|
|
||||||
|
self._this_named_obs_space = this_named_observation_space
|
||||||
|
self._per_agent_named_obs_space = list(per_agent_named_obs_space)
|
||||||
|
|
||||||
|
def translate_observation(self, agent_idx: int, obs: np.ndarray):
|
||||||
|
target_obs_space = self._per_agent_named_obs_space[agent_idx]
|
||||||
|
translation = [idx_space_dict['explained_idxs'] for name, idx_space_dict in target_obs_space.items()]
|
||||||
|
flat_translation = [x for y in translation for x in y]
|
||||||
|
return np.take(obs, flat_translation, axis=1 if obs.ndim == 4 else 0)
|
||||||
|
|
||||||
|
def translate_observations(self, observations: List[ArrayLike]):
|
||||||
|
return [self.translate_observation(idx, observation) for idx, observation in enumerate(observations)]
|
||||||
|
|
||||||
|
def __call__(self, observations):
|
||||||
|
return self.translate_observations(observations)
|
||||||
|
|
||||||
|
|
||||||
|
class ActionTranslator:
|
||||||
|
|
||||||
|
def __init__(self, target_named_action_space: Dict[str, int], *per_agent_named_action_space: Dict[str, int]):
|
||||||
|
self._target_named_action_space = target_named_action_space
|
||||||
|
self._per_agent_named_action_space = list(per_agent_named_action_space)
|
||||||
|
self._per_agent_idx_actions = [{idx: a for a, idx in x.items()} for x in self._per_agent_named_action_space]
|
||||||
|
|
||||||
|
def translate_action(self, agent_idx: int, action: int):
|
||||||
|
named_action = self._per_agent_idx_actions[agent_idx][action]
|
||||||
|
translated_action = self._target_named_action_space[named_action]
|
||||||
|
return translated_action
|
||||||
|
|
||||||
|
def translate_actions(self, actions: List[int]):
|
||||||
|
return [self.translate_action(idx, action) for idx, action in enumerate(actions)]
|
||||||
|
|
||||||
|
def __call__(self, actions):
|
||||||
|
return self.translate_actions(actions)
|
||||||
|
|
||||||
|
|
||||||
# Utility functions
|
# Utility functions
|
||||||
def parse_level(path):
|
def parse_level(path):
|
||||||
with path.open('r') as lvl:
|
with path.open('r') as lvl:
|
||||||
@ -128,7 +181,7 @@ def one_hot_level(level, wall_char: Union[c, str] = c.WALL):
|
|||||||
return binary_grid
|
return binary_grid
|
||||||
|
|
||||||
|
|
||||||
def check_position(slice_to_check_against: np.ndarray, position_to_check: Tuple[int, int]):
|
def check_position(slice_to_check_against: ArrayLike, position_to_check: Tuple[int, int]):
|
||||||
x_pos, y_pos = position_to_check
|
x_pos, y_pos = position_to_check
|
||||||
|
|
||||||
# Check if agent colides with grid boundrys
|
# Check if agent colides with grid boundrys
|
||||||
@ -177,6 +230,7 @@ def points_to_graph(coordiniates_or_tiles, allow_euclidean_connections=True, all
|
|||||||
graph.add_edge(a, b)
|
graph.add_edge(a, b)
|
||||||
return graph
|
return graph
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parsed_level = parse_level(Path(__file__).parent / 'factory' / 'levels' / 'simple.txt')
|
parsed_level = parse_level(Path(__file__).parent / 'factory' / 'levels' / 'simple.txt')
|
||||||
y = one_hot_level(parsed_level)
|
y = one_hot_level(parsed_level)
|
||||||
|
@ -20,10 +20,10 @@ class ObservationProperties(NamedTuple):
|
|||||||
render_agents: AgentRenderOptions = AgentRenderOptions.SEPERATE
|
render_agents: AgentRenderOptions = AgentRenderOptions.SEPERATE
|
||||||
omit_agent_self: bool = True
|
omit_agent_self: bool = True
|
||||||
additional_agent_placeholder: Union[None, str, int] = None
|
additional_agent_placeholder: Union[None, str, int] = None
|
||||||
cast_shadows = True
|
cast_shadows: bool = True
|
||||||
frames_to_stack: int = 0
|
frames_to_stack: int = 0
|
||||||
pomdp_r: int = 0
|
pomdp_r: int = 0
|
||||||
show_global_position_info: bool = True
|
show_global_position_info: bool = False
|
||||||
|
|
||||||
|
|
||||||
class MarlFrameStack(gym.ObservationWrapper):
|
class MarlFrameStack(gym.ObservationWrapper):
|
||||||
|
@ -3,6 +3,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import yaml
|
import yaml
|
||||||
|
from stable_baselines3 import A2C
|
||||||
|
|
||||||
from environments import helpers as h
|
from environments import helpers as h
|
||||||
from environments.helpers import Constants as c
|
from environments.helpers import Constants as c
|
||||||
@ -16,13 +17,12 @@ warnings.filterwarnings('ignore', category=UserWarning)
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
||||||
determin = False
|
determin = True
|
||||||
render = True
|
render = True
|
||||||
record = True
|
record = False
|
||||||
seed = 67
|
seed = 67
|
||||||
n_agents = 1
|
n_agents = 1
|
||||||
out_path = Path('study_out/e_1_new_reward/no_obs/dirt/A2C_new_reward/0_A2C_new_reward')
|
out_path = Path('study_out/single_run_with_export/dirt')
|
||||||
out_path_2 = Path('study_out/e_1_obs_stack_3_gae_0.25_n_steps_16/seperate_N/dirt/A2C_obs_stack_3_gae_0.25_n_steps_16/1_A2C_obs_stack_3_gae_0.25_n_steps_16')
|
|
||||||
model_path = out_path
|
model_path = out_path
|
||||||
|
|
||||||
with (out_path / f'env_params.json').open('r') as f:
|
with (out_path / f'env_params.json').open('r') as f:
|
||||||
@ -35,10 +35,9 @@ if __name__ == '__main__':
|
|||||||
env_kwargs.update(record_episodes=record, done_at_collision=True)
|
env_kwargs.update(record_episodes=record, done_at_collision=True)
|
||||||
|
|
||||||
this_model = out_path / 'model.zip'
|
this_model = out_path / 'model.zip'
|
||||||
other_model = out_path / 'model.zip'
|
|
||||||
|
|
||||||
model_cls = next(val for key, val in h.MODEL_MAP.items() if key in out_path.parent.name)
|
model_cls = A2C # next(val for key, val in h.MODEL_MAP.items() if key in out_path.parent.name)
|
||||||
models = [model_cls.load(this_model)] # , model_cls.load(other_model)]
|
models = [model_cls.load(this_model)]
|
||||||
|
|
||||||
# Init Env
|
# Init Env
|
||||||
with DirtFactory(**env_kwargs) as env:
|
with DirtFactory(**env_kwargs) as env:
|
||||||
@ -59,6 +58,8 @@ if __name__ == '__main__':
|
|||||||
rew += step_r
|
rew += step_r
|
||||||
if render:
|
if render:
|
||||||
env.render()
|
env.render()
|
||||||
|
if not env.unwrapped.unwrapped[c.AGENT][0].temp_valid:
|
||||||
|
print('Invalid ACtions')
|
||||||
if done_bool:
|
if done_bool:
|
||||||
break
|
break
|
||||||
print(f'Factory run {episode} done, reward is:\n {rew}')
|
print(f'Factory run {episode} done, reward is:\n {rew}')
|
||||||
|
@ -434,8 +434,8 @@ if __name__ == '__main__':
|
|||||||
# Iteration
|
# Iteration
|
||||||
start_mp_baseline_run(env_map, policy_path)
|
start_mp_baseline_run(env_map, policy_path)
|
||||||
|
|
||||||
# for seed_path in (y for y in policy_path.iterdir() if y.is_dir()):
|
# for policy_path in (y for y in policy_path.iterdir() if y.is_dir()):
|
||||||
# load_model_run_baseline(seed_path)
|
# load_model_run_baseline(policy_path)
|
||||||
print('Baseline Tracking done')
|
print('Baseline Tracking done')
|
||||||
|
|
||||||
# Then iterate over every model and monitor "ood behavior" - "is it ood?"
|
# Then iterate over every model and monitor "ood behavior" - "is it ood?"
|
||||||
@ -448,11 +448,11 @@ if __name__ == '__main__':
|
|||||||
for policy_path in [x for x in env_path.iterdir() if x. is_dir()]:
|
for policy_path in [x for x in env_path.iterdir() if x. is_dir()]:
|
||||||
# FIXME: Pick random seed or iterate over available seeds
|
# FIXME: Pick random seed or iterate over available seeds
|
||||||
# First seed path version
|
# First seed path version
|
||||||
# seed_path = next((y for y in policy_path.iterdir() if y.is_dir()))
|
# policy_path = next((y for y in policy_path.iterdir() if y.is_dir()))
|
||||||
# Iteration
|
# Iteration
|
||||||
start_mp_study_run(env_map, policy_path)
|
start_mp_study_run(env_map, policy_path)
|
||||||
#for seed_path in (y for y in policy_path.iterdir() if y.is_dir()):
|
#for policy_path in (y for y in policy_path.iterdir() if y.is_dir()):
|
||||||
# load_model_run_study(seed_path, env_map[env_path.name][0], observation_modes[obs_mode])
|
# load_model_run_study(policy_path, env_map[env_path.name][0], observation_modes[obs_mode])
|
||||||
print('OOD Tracking Done')
|
print('OOD Tracking Done')
|
||||||
|
|
||||||
# Plotting
|
# Plotting
|
||||||
|
226
studies/single_run_with_export.py
Normal file
226
studies/single_run_with_export.py
Normal file
@ -0,0 +1,226 @@
|
|||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
try:
|
||||||
|
# noinspection PyUnboundLocalVariable
|
||||||
|
if __package__ is None:
|
||||||
|
DIR = Path(__file__).resolve().parent
|
||||||
|
sys.path.insert(0, str(DIR.parent))
|
||||||
|
__package__ = DIR.name
|
||||||
|
else:
|
||||||
|
DIR = None
|
||||||
|
except NameError:
|
||||||
|
DIR = None
|
||||||
|
pass
|
||||||
|
|
||||||
|
import simplejson
|
||||||
|
from environments.helpers import ActionTranslator, ObservationTranslator
|
||||||
|
from environments.logging.recorder import EnvRecorder
|
||||||
|
from environments import helpers as h
|
||||||
|
from environments.factory.factory_dirt import DirtProperties, DirtFactory
|
||||||
|
from environments.factory.factory_item import ItemProperties, ItemFactory
|
||||||
|
from environments.factory.factory_dest import DestProperties, DestFactory, DestModeOptions
|
||||||
|
from environments.factory.combined_factories import DirtDestItemFactory
|
||||||
|
from environments.logging.envmonitor import EnvMonitor
|
||||||
|
from environments.utility_classes import MovementProperties, ObservationProperties, AgentRenderOptions
|
||||||
|
|
||||||
|
"""
|
||||||
|
In this studie, we want to export trained Agents for debugging purposes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def encapsule_env_factory(env_fctry, env_kwrgs):
|
||||||
|
|
||||||
|
def _init():
|
||||||
|
with env_fctry(**env_kwrgs) as init_env:
|
||||||
|
return init_env
|
||||||
|
|
||||||
|
return _init
|
||||||
|
|
||||||
|
|
||||||
|
def load_model_run_baseline(policy_path, env_to_run):
|
||||||
|
# retrieve model class
|
||||||
|
model_cls = h.MODEL_MAP['A2C']
|
||||||
|
# Load both agents
|
||||||
|
model = model_cls.load(policy_path / 'model.zip', device='cpu')
|
||||||
|
# Load old env kwargs
|
||||||
|
with next(policy_path.glob('*.json')).open('r') as f:
|
||||||
|
env_kwargs = simplejson.load(f)
|
||||||
|
env_kwargs.update(done_at_collision=True)
|
||||||
|
# Init Env
|
||||||
|
with env_to_run(**env_kwargs) as env_factory:
|
||||||
|
monitored_env_factory = EnvMonitor(env_factory)
|
||||||
|
recorded_env_factory = EnvRecorder(monitored_env_factory)
|
||||||
|
|
||||||
|
# Evaluation Loop for i in range(n Episodes)
|
||||||
|
for episode in range(5):
|
||||||
|
env_state = recorded_env_factory.reset()
|
||||||
|
rew, done_bool = 0, False
|
||||||
|
while not done_bool:
|
||||||
|
action = model.predict(env_state, deterministic=True)[0]
|
||||||
|
env_state, step_r, done_bool, info_obj = recorded_env_factory.step(action)
|
||||||
|
rew += step_r
|
||||||
|
if done_bool:
|
||||||
|
break
|
||||||
|
print(f'Factory run {episode} done, reward is:\n {rew}')
|
||||||
|
recorded_env_factory.save_run(filepath=policy_path / f'monitor.pick')
|
||||||
|
recorded_env_factory.save_records(filepath=policy_path / f'recorder.json')
|
||||||
|
|
||||||
|
|
||||||
|
def load_model_run_combined(root_path, env_to_run, env_kwargs):
|
||||||
|
# retrieve model class
|
||||||
|
model_cls = h.MODEL_MAP['A2C']
|
||||||
|
# Load both agents
|
||||||
|
models = [model_cls.load(model_zip, device='cpu') for model_zip in root_path.rglob('model.zip')]
|
||||||
|
# Load old env kwargs
|
||||||
|
env_kwargs = env_kwargs.copy()
|
||||||
|
env_kwargs.update(
|
||||||
|
n_agents=len(models),
|
||||||
|
done_at_collision=False)
|
||||||
|
|
||||||
|
# Init Env
|
||||||
|
with env_to_run(**env_kwargs) as env_factory:
|
||||||
|
|
||||||
|
action_translator = ActionTranslator(env_factory.named_action_space,
|
||||||
|
*[x.named_action_space for x in models])
|
||||||
|
observation_translator = ObservationTranslator(env_factory.observation_space.shape[-2:],
|
||||||
|
env_factory.named_observation_space,
|
||||||
|
*[x.named_observation_space for x in models])
|
||||||
|
|
||||||
|
monitored_env_factory = EnvMonitor(env_factory)
|
||||||
|
recorded_env_factory = EnvRecorder(monitored_env_factory)
|
||||||
|
# Evaluation Loop for i in range(n Episodes)
|
||||||
|
for episode in range(5):
|
||||||
|
env_state = recorded_env_factory.reset()
|
||||||
|
rew, done_bool = 0, False
|
||||||
|
while not done_bool:
|
||||||
|
translated_observations = observation_translator(env_state)
|
||||||
|
actions = [model.predict(translated_observations[model_idx], deterministic=True)[0]
|
||||||
|
for model_idx, model in enumerate(models)]
|
||||||
|
translated_actions = action_translator(actions)
|
||||||
|
env_state, step_r, done_bool, info_obj = recorded_env_factory.step(translated_actions)
|
||||||
|
rew += step_r
|
||||||
|
if done_bool:
|
||||||
|
break
|
||||||
|
print(f'Factory run {episode} done, reward is:\n {rew}')
|
||||||
|
recorded_env_factory.save_run(filepath=policy_path / f'monitor.pick')
|
||||||
|
recorded_env_factory.save_records(filepath=policy_path / f'recorder.json')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# What to do:
|
||||||
|
train = True
|
||||||
|
individual_run = True
|
||||||
|
combined_run = True
|
||||||
|
|
||||||
|
train_steps = 2e6
|
||||||
|
frames_to_stack = 3
|
||||||
|
|
||||||
|
# Define a global studi save path
|
||||||
|
study_root_path = Path(__file__).parent.parent / 'study_out' / f'{Path(__file__).stem}'
|
||||||
|
|
||||||
|
# Define Global Env Parameters
|
||||||
|
# Define properties object parameters
|
||||||
|
obs_props = ObservationProperties(render_agents=AgentRenderOptions.NOT,
|
||||||
|
additional_agent_placeholder=None,
|
||||||
|
omit_agent_self=True,
|
||||||
|
frames_to_stack=frames_to_stack,
|
||||||
|
pomdp_r=2, cast_shadows=True)
|
||||||
|
move_props = MovementProperties(allow_diagonal_movement=True,
|
||||||
|
allow_square_movement=True,
|
||||||
|
allow_no_op=False)
|
||||||
|
dirt_props = DirtProperties(initial_dirt_ratio=0.35, initial_dirt_spawn_r_var=0.1,
|
||||||
|
clean_amount=0.34,
|
||||||
|
max_spawn_amount=0.1, max_global_amount=20,
|
||||||
|
max_local_amount=1, spawn_frequency=0, max_spawn_ratio=0.05,
|
||||||
|
dirt_smear_amount=0.0, agent_can_interact=True)
|
||||||
|
item_props = ItemProperties(n_items=10, agent_can_interact=True,
|
||||||
|
spawn_frequency=30, n_drop_off_locations=2,
|
||||||
|
max_agent_inventory_capacity=15)
|
||||||
|
dest_props = DestProperties(n_dests=4, spawn_mode=DestModeOptions.GROUPED, spawn_frequency=1)
|
||||||
|
factory_kwargs = dict(n_agents=1, max_steps=400, parse_doors=True,
|
||||||
|
level_name='rooms', doors_have_area=True,
|
||||||
|
verbose=False,
|
||||||
|
mv_prop=move_props,
|
||||||
|
obs_prop=obs_props,
|
||||||
|
done_at_collision=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Bundle both environments with global kwargs and parameters
|
||||||
|
env_map = {}
|
||||||
|
env_map.update({'dirt': (DirtFactory, dict(dirt_prop=dirt_props,
|
||||||
|
**factory_kwargs.copy()))})
|
||||||
|
env_map.update({'item': (ItemFactory, dict(item_prop=item_props,
|
||||||
|
**factory_kwargs.copy()))})
|
||||||
|
env_map.update({'dest': (DestFactory, dict(dest_prop=dest_props,
|
||||||
|
**factory_kwargs.copy()))})
|
||||||
|
env_map.update({'combined': (DirtDestItemFactory, dict(dest_prop=dest_props,
|
||||||
|
item_prop=item_props,
|
||||||
|
dirt_prop=dirt_props,
|
||||||
|
**factory_kwargs.copy()))})
|
||||||
|
env_names = list(env_map.keys())
|
||||||
|
|
||||||
|
# Train starts here ############################################################
|
||||||
|
# Build Major Loop parameters, parameter versions, Env Classes and models
|
||||||
|
if train:
|
||||||
|
for env_key in (env_key for env_key in env_map if 'combined' != env_key):
|
||||||
|
model_cls = h.MODEL_MAP['A2C']
|
||||||
|
combination_path = study_root_path / env_key
|
||||||
|
env_class, env_kwargs = env_map[env_key]
|
||||||
|
|
||||||
|
# Output folder
|
||||||
|
if (combination_path / 'monitor.pick').exists():
|
||||||
|
continue
|
||||||
|
combination_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
with env_class(**env_kwargs) as env_factory:
|
||||||
|
param_path = combination_path / f'env_params.json'
|
||||||
|
env_factory.save_params(param_path)
|
||||||
|
|
||||||
|
# EnvMonitor Init
|
||||||
|
callbacks = [EnvMonitor(env_factory)]
|
||||||
|
|
||||||
|
# Model Init
|
||||||
|
model = model_cls("MlpPolicy", env_factory,
|
||||||
|
verbose=1, seed=69, device='cpu')
|
||||||
|
|
||||||
|
# Model train
|
||||||
|
model.learn(total_timesteps=int(train_steps), callback=callbacks)
|
||||||
|
|
||||||
|
# Model save
|
||||||
|
model.named_action_space = env_factory.unwrapped.named_action_space
|
||||||
|
model.named_observation_space = env_factory.unwrapped.named_observation_space
|
||||||
|
save_path = combination_path / f'model.zip'
|
||||||
|
model.save(save_path)
|
||||||
|
|
||||||
|
# Monitor Save
|
||||||
|
callbacks[0].save_run(combination_path / 'monitor.pick')
|
||||||
|
|
||||||
|
# Better be save then sorry: Clean up!
|
||||||
|
del env_factory, model
|
||||||
|
import gc
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
# Train ends here ############################################################
|
||||||
|
|
||||||
|
# Evaluation starts here #####################################################
|
||||||
|
# First Iterate over every model and monitor "as trained"
|
||||||
|
if individual_run:
|
||||||
|
print('Start Individual Recording')
|
||||||
|
for env_key in (env_key for env_key in env_map if 'combined' != env_key):
|
||||||
|
# For trained policy in study_root_path / identifier
|
||||||
|
policy_path = study_root_path / env_key
|
||||||
|
load_model_run_baseline(policy_path, env_map[policy_path.name][0])
|
||||||
|
|
||||||
|
# for policy_path in (y for y in policy_path.iterdir() if y.is_dir()):
|
||||||
|
# load_model_run_baseline(policy_path)
|
||||||
|
print('Start Individual Training')
|
||||||
|
|
||||||
|
# Then iterate over every model and monitor "ood behavior" - "is it ood?"
|
||||||
|
if combined_run:
|
||||||
|
print('Start combined run')
|
||||||
|
for env_key in (env_key for env_key in env_map if 'combined' == env_key):
|
||||||
|
# For trained policy in study_root_path / identifier
|
||||||
|
factory, kwargs = env_map[env_key]
|
||||||
|
load_model_run_combined(study_root_path, factory, kwargs)
|
||||||
|
print('OOD Tracking Done')
|
Loading…
x
Reference in New Issue
Block a user