mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-23 07:16:44 +02:00
In Debugging
This commit is contained in:
parent
0fc4db193f
commit
4731f63ba6
@ -27,7 +27,7 @@ class BaseFactory(gym.Env):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def action_space(self):
|
def action_space(self):
|
||||||
return spaces.Discrete(self._actions.n)
|
return spaces.Discrete(len(self._actions))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def observation_space(self):
|
def observation_space(self):
|
||||||
@ -69,6 +69,7 @@ class BaseFactory(gym.Env):
|
|||||||
self._level_shape = None
|
self._level_shape = None
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
self._renderer = None # expensive - don't use it when not required !
|
self._renderer = None # expensive - don't use it when not required !
|
||||||
|
self._entities = Entities()
|
||||||
|
|
||||||
self.n_agents = n_agents
|
self.n_agents = n_agents
|
||||||
|
|
||||||
@ -92,6 +93,9 @@ class BaseFactory(gym.Env):
|
|||||||
# Reset
|
# Reset
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
return self._entities[item]
|
||||||
|
|
||||||
def _base_init_env(self):
|
def _base_init_env(self):
|
||||||
# Objects
|
# Objects
|
||||||
entities = {}
|
entities = {}
|
||||||
@ -116,7 +120,7 @@ class BaseFactory(gym.Env):
|
|||||||
entities.update({c.FLOOR: floor})
|
entities.update({c.FLOOR: floor})
|
||||||
|
|
||||||
# NOPOS
|
# NOPOS
|
||||||
self.NO_POS_TILE = Tile(c.NO_POS, c.NO_POS.value)
|
self.NO_POS_TILE = Tile(c.NO_POS.value)
|
||||||
|
|
||||||
# Doors
|
# Doors
|
||||||
parsed_doors = h.one_hot_level(parsed_level, c.DOOR)
|
parsed_doors = h.one_hot_level(parsed_level, c.DOOR)
|
||||||
@ -145,7 +149,7 @@ class BaseFactory(gym.Env):
|
|||||||
|
|
||||||
if self.omit_agent_in_obs and self.n_agents == 1:
|
if self.omit_agent_in_obs and self.n_agents == 1:
|
||||||
del arrays[c.AGENT]
|
del arrays[c.AGENT]
|
||||||
obs_cube_z = sum([a.shape[0] if not self._entities[key].is_per_agent else 1 for key, a in arrays.items()])
|
obs_cube_z = sum([a.shape[0] if not self[key].is_per_agent else 1 for key, a in arrays.items()])
|
||||||
self._obs_cube = np.zeros((obs_cube_z, *self._level_shape), dtype=np.float32)
|
self._obs_cube = np.zeros((obs_cube_z, *self._level_shape), dtype=np.float32)
|
||||||
|
|
||||||
# Optionally Pad this obs cube for pomdp cases
|
# Optionally Pad this obs cube for pomdp cases
|
||||||
@ -175,14 +179,14 @@ class BaseFactory(gym.Env):
|
|||||||
self.hook_pre_step()
|
self.hook_pre_step()
|
||||||
|
|
||||||
# Move this in a seperate function?
|
# Move this in a seperate function?
|
||||||
for action, agent in zip(actions, self._entities[c.AGENT]):
|
for action, agent in zip(actions, self[c.AGENT]):
|
||||||
agent.clear_temp_sate()
|
agent.clear_temp_sate()
|
||||||
action_obj = self._actions[action]
|
action_obj = self._actions[action]
|
||||||
if self._actions.is_moving_action(action_obj):
|
if self._actions.is_moving_action(action_obj):
|
||||||
valid = self._move_or_colide(agent, action_obj)
|
valid = self._move_or_colide(agent, action_obj)
|
||||||
elif self._actions.is_no_op(action_obj):
|
elif h.EnvActions.NOOP == agent.temp_action:
|
||||||
valid = c.VALID.value
|
valid = c.VALID
|
||||||
elif self._actions.is_door_usage(action_obj):
|
elif h.EnvActions.USE_DOOR == action_obj:
|
||||||
valid = self._handle_door_interaction(agent)
|
valid = self._handle_door_interaction(agent)
|
||||||
else:
|
else:
|
||||||
valid = self.do_additional_actions(agent, action_obj)
|
valid = self.do_additional_actions(agent, action_obj)
|
||||||
@ -206,7 +210,7 @@ class BaseFactory(gym.Env):
|
|||||||
|
|
||||||
# Step the door close intervall
|
# Step the door close intervall
|
||||||
if self.parse_doors:
|
if self.parse_doors:
|
||||||
self._entities[c.DOORS].tick_doors()
|
self[c.DOORS].tick_doors()
|
||||||
|
|
||||||
# Finalize
|
# Finalize
|
||||||
reward, reward_info = self.calculate_reward()
|
reward, reward_info = self.calculate_reward()
|
||||||
@ -224,53 +228,61 @@ class BaseFactory(gym.Env):
|
|||||||
|
|
||||||
return obs, reward, done, info
|
return obs, reward, done, info
|
||||||
|
|
||||||
def _handle_door_interaction(self, agent):
|
def _handle_door_interaction(self, agent) -> c:
|
||||||
# Check if agent really is standing on a door:
|
# Check if agent really is standing on a door:
|
||||||
if self.doors_have_area:
|
if self.doors_have_area:
|
||||||
door = self._entities[c.DOORS].get_near_position(agent.pos)
|
door = self[c.DOORS].get_near_position(agent.pos)
|
||||||
else:
|
else:
|
||||||
door = self._entities[c.DOORS].by_pos(agent.pos)
|
door = self[c.DOORS].by_pos(agent.pos)
|
||||||
if door is not None:
|
if door is not None:
|
||||||
door.use()
|
door.use()
|
||||||
return c.VALID.value
|
return c.VALID
|
||||||
# When he doesn't...
|
# When he doesn't...
|
||||||
else:
|
else:
|
||||||
return c.NOT_VALID.value
|
return c.NOT_VALID
|
||||||
|
|
||||||
def _get_observations(self) -> np.ndarray:
|
def _get_observations(self) -> np.ndarray:
|
||||||
|
state_array_dict = self._entities.arrays
|
||||||
if self.n_agents == 1:
|
if self.n_agents == 1:
|
||||||
obs = self._build_per_agent_obs(self._entities[c.AGENT][0])
|
obs = self._build_per_agent_obs(self[c.AGENT][0], state_array_dict)
|
||||||
elif self.n_agents >= 2:
|
elif self.n_agents >= 2:
|
||||||
obs = np.stack([self._build_per_agent_obs(agent) for agent in self._entities[c.AGENT]])
|
obs = np.stack([self._build_per_agent_obs(agent, state_array_dict) for agent in self[c.AGENT]])
|
||||||
else:
|
else:
|
||||||
raise ValueError('n_agents cannot be smaller than 1!!')
|
raise ValueError('n_agents cannot be smaller than 1!!')
|
||||||
return obs
|
return obs
|
||||||
|
|
||||||
def _build_per_agent_obs(self, agent: Agent) -> np.ndarray:
|
def _build_per_agent_obs(self, agent: Agent, state_array_dict) -> np.ndarray:
|
||||||
plain_arrays = self._entities.arrays
|
agent_pos_is_omitted = False
|
||||||
if self.omit_agent_in_obs and self.n_agents == 1:
|
if self.omit_agent_in_obs and self.n_agents == 1:
|
||||||
del plain_arrays[c.AGENT]
|
del state_array_dict[c.AGENT]
|
||||||
|
elif self.omit_agent_in_obs and self.combin_agent_obs and self.n_agents > 1:
|
||||||
|
state_array_dict[c.AGENT][0, agent.x, agent.y] -= agent.encoding
|
||||||
|
agent_pos_is_omitted = True
|
||||||
|
|
||||||
running_idx, shadowing_idxs, can_be_shadowed_idxs = 0, [], []
|
running_idx, shadowing_idxs, can_be_shadowed_idxs = 0, [], []
|
||||||
|
|
||||||
for key, array in plain_arrays.items():
|
for key, array in state_array_dict.items():
|
||||||
if self._entities[key].is_per_agent:
|
# Flush state array object representation to obs cube
|
||||||
per_agent_idx = self._entities[key].get_idx_by_name(agent.name)
|
if self[key].is_per_agent:
|
||||||
|
per_agent_idx = self[key].get_idx_by_name(agent.name)
|
||||||
z = 1
|
z = 1
|
||||||
self._obs_cube[running_idx: z] = array[per_agent_idx]
|
self._obs_cube[running_idx: running_idx+z] = array[per_agent_idx]
|
||||||
else:
|
else:
|
||||||
z = array.shape[0]
|
z = array.shape[0]
|
||||||
self._obs_cube[running_idx: z] = array
|
self._obs_cube[running_idx: running_idx+z] = array
|
||||||
# Define which OBS SLices cast a Shadow
|
# Define which OBS SLices cast a Shadow
|
||||||
if self._entities[key].is_blocking_light:
|
if self[key].is_blocking_light:
|
||||||
for i in range(z):
|
for i in range(z):
|
||||||
shadowing_idxs.append(running_idx + i)
|
shadowing_idxs.append(running_idx + i)
|
||||||
# Define which OBS SLices are effected by shadows
|
# Define which OBS SLices are effected by shadows
|
||||||
if self._entities[key].can_be_shadowed:
|
if self[key].can_be_shadowed:
|
||||||
for i in range(z):
|
for i in range(z):
|
||||||
can_be_shadowed_idxs.append(running_idx + i)
|
can_be_shadowed_idxs.append(running_idx + i)
|
||||||
running_idx += z
|
running_idx += z
|
||||||
|
|
||||||
|
if agent_pos_is_omitted:
|
||||||
|
state_array_dict[c.AGENT][0, agent.x, agent.y] += agent.encoding
|
||||||
|
|
||||||
if r := self.pomdp_r:
|
if r := self.pomdp_r:
|
||||||
x, y = self._level_shape
|
x, y = self._level_shape
|
||||||
self._padded_obs_cube[:, r:r + x, r:r + y] = self._obs_cube
|
self._padded_obs_cube[:, r:r + x, r:r + y] = self._obs_cube
|
||||||
@ -284,7 +296,7 @@ class BaseFactory(gym.Env):
|
|||||||
if self.cast_shadows:
|
if self.cast_shadows:
|
||||||
obs_block_light = [obs[idx] != c.OCCUPIED_CELL.value for idx in shadowing_idxs]
|
obs_block_light = [obs[idx] != c.OCCUPIED_CELL.value for idx in shadowing_idxs]
|
||||||
door_shadowing = False
|
door_shadowing = False
|
||||||
if door := self._entities[c.DOORS].by_pos(agent.pos):
|
if door := self[c.DOORS].by_pos(agent.pos):
|
||||||
if door.is_closed:
|
if door.is_closed:
|
||||||
for group in door.connectivity_subgroups:
|
for group in door.connectivity_subgroups:
|
||||||
if agent.last_pos not in group:
|
if agent.last_pos not in group:
|
||||||
@ -319,7 +331,7 @@ class BaseFactory(gym.Env):
|
|||||||
|
|
||||||
def get_all_tiles_with_collisions(self) -> List[Tile]:
|
def get_all_tiles_with_collisions(self) -> List[Tile]:
|
||||||
tiles_with_collisions = list()
|
tiles_with_collisions = list()
|
||||||
for tile in self._entities[c.FLOOR]:
|
for tile in self[c.FLOOR]:
|
||||||
if tile.is_occupied():
|
if tile.is_occupied():
|
||||||
guests = [guest for guest in tile.guests if guest.can_collide]
|
guests = [guest for guest in tile.guests if guest.can_collide]
|
||||||
if len(guests) >= 2:
|
if len(guests) >= 2:
|
||||||
@ -337,11 +349,11 @@ class BaseFactory(gym.Env):
|
|||||||
|
|
||||||
def _check_agent_move(self, agent, action: Action) -> (Tile, bool):
|
def _check_agent_move(self, agent, action: Action) -> (Tile, bool):
|
||||||
# Actions
|
# Actions
|
||||||
x_diff, y_diff = h.ACTIONMAP[action.name]
|
x_diff, y_diff = h.ACTIONMAP[action.identifier]
|
||||||
x_new = agent.x + x_diff
|
x_new = agent.x + x_diff
|
||||||
y_new = agent.y + y_diff
|
y_new = agent.y + y_diff
|
||||||
|
|
||||||
new_tile = self._entities[c.FLOOR].by_pos((x_new, y_new))
|
new_tile = self[c.FLOOR].by_pos((x_new, y_new))
|
||||||
if new_tile:
|
if new_tile:
|
||||||
valid = c.VALID
|
valid = c.VALID
|
||||||
else:
|
else:
|
||||||
@ -350,13 +362,13 @@ class BaseFactory(gym.Env):
|
|||||||
return tile, valid
|
return tile, valid
|
||||||
|
|
||||||
if self.parse_doors and agent.last_pos != c.NO_POS:
|
if self.parse_doors and agent.last_pos != c.NO_POS:
|
||||||
if door := self._entities[c.DOORS].by_pos(new_tile.pos):
|
if door := self[c.DOORS].by_pos(new_tile.pos):
|
||||||
if door.can_collide:
|
if door.can_collide:
|
||||||
return agent.tile, c.NOT_VALID
|
return agent.tile, c.NOT_VALID
|
||||||
else: # door.is_closed:
|
else: # door.is_closed:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if door := self._entities[c.DOORS].by_pos(agent.pos):
|
if door := self[c.DOORS].by_pos(agent.pos):
|
||||||
if door.is_open:
|
if door.is_open:
|
||||||
pass
|
pass
|
||||||
else: # door.is_closed:
|
else: # door.is_closed:
|
||||||
@ -376,7 +388,7 @@ class BaseFactory(gym.Env):
|
|||||||
info_dict = dict()
|
info_dict = dict()
|
||||||
reward = 0
|
reward = 0
|
||||||
|
|
||||||
for agent in self._entities[c.AGENT]:
|
for agent in self[c.AGENT]:
|
||||||
if self._actions.is_moving_action(agent.temp_action):
|
if self._actions.is_moving_action(agent.temp_action):
|
||||||
if agent.temp_valid:
|
if agent.temp_valid:
|
||||||
# info_dict.update(movement=1)
|
# info_dict.update(movement=1)
|
||||||
@ -387,7 +399,7 @@ class BaseFactory(gym.Env):
|
|||||||
self.print(f'{agent.name} just hit the wall at {agent.pos}.')
|
self.print(f'{agent.name} just hit the wall at {agent.pos}.')
|
||||||
info_dict.update({f'{agent.name}_vs_LEVEL': 1})
|
info_dict.update({f'{agent.name}_vs_LEVEL': 1})
|
||||||
|
|
||||||
elif self._actions.is_door_usage(agent.temp_action):
|
elif h.EnvActions.USE_DOOR == agent.temp_action:
|
||||||
if agent.temp_valid:
|
if agent.temp_valid:
|
||||||
self.print(f'{agent.name} did just use the door at {agent.pos}.')
|
self.print(f'{agent.name} did just use the door at {agent.pos}.')
|
||||||
info_dict.update(door_used=1)
|
info_dict.update(door_used=1)
|
||||||
@ -396,7 +408,7 @@ class BaseFactory(gym.Env):
|
|||||||
self.print(f'{agent.name} just tried to use a door at {agent.pos}, but failed.')
|
self.print(f'{agent.name} just tried to use a door at {agent.pos}, but failed.')
|
||||||
info_dict.update({f'{agent.name}_failed_action': 1})
|
info_dict.update({f'{agent.name}_failed_action': 1})
|
||||||
info_dict.update({f'{agent.name}_failed_door_open': 1})
|
info_dict.update({f'{agent.name}_failed_door_open': 1})
|
||||||
elif self._actions.is_no_op(agent.temp_action):
|
elif h.EnvActions.NOOP == agent.temp_action:
|
||||||
info_dict.update(no_op=1)
|
info_dict.update(no_op=1)
|
||||||
reward -= 0.00
|
reward -= 0.00
|
||||||
|
|
||||||
@ -415,15 +427,15 @@ class BaseFactory(gym.Env):
|
|||||||
height, width = self._obs_cube.shape[1:]
|
height, width = self._obs_cube.shape[1:]
|
||||||
self._renderer = Renderer(width, height, view_radius=self.pomdp_r, fps=5)
|
self._renderer = Renderer(width, height, view_radius=self.pomdp_r, fps=5)
|
||||||
|
|
||||||
walls = [RenderEntity('wall', wall.pos) for wall in self._entities[c.WALLS]]
|
walls = [RenderEntity('wall', wall.pos) for wall in self[c.WALLS]]
|
||||||
|
|
||||||
agents = []
|
agents = []
|
||||||
for i, agent in enumerate(self._entities[c.AGENT]):
|
for i, agent in enumerate(self[c.AGENT]):
|
||||||
name, state = h.asset_str(agent)
|
name, state = h.asset_str(agent)
|
||||||
agents.append(RenderEntity(name, agent.pos, 1, 'none', state, i + 1, agent.temp_light_map))
|
agents.append(RenderEntity(name, agent.pos, 1, 'none', state, i + 1, agent.temp_light_map))
|
||||||
doors = []
|
doors = []
|
||||||
if self.parse_doors:
|
if self.parse_doors:
|
||||||
for i, door in enumerate(self._entities[c.DOORS]):
|
for i, door in enumerate(self[c.DOORS]):
|
||||||
name, state = 'door_open' if door.is_open else 'door_closed', 'blank'
|
name, state = 'door_open' if door.is_open else 'door_closed', 'blank'
|
||||||
doors.append(RenderEntity(name, door.pos, 1, 'none', state, i + 1))
|
doors.append(RenderEntity(name, door.pos, 1, 'none', state, i + 1))
|
||||||
additional_assets = self.render_additional_assets()
|
additional_assets = self.render_additional_assets()
|
||||||
@ -442,7 +454,7 @@ class BaseFactory(gym.Env):
|
|||||||
def _summarize_state(self):
|
def _summarize_state(self):
|
||||||
summary = {f'{REC_TAC}_step': self._steps}
|
summary = {f'{REC_TAC}_step': self._steps}
|
||||||
|
|
||||||
self._entities[c.WALLS].summarize_state()
|
self[c.WALLS].summarize_state()
|
||||||
for entity in self._entities:
|
for entity in self._entities:
|
||||||
if hasattr(entity, 'summarize_state'):
|
if hasattr(entity, 'summarize_state'):
|
||||||
summary.update({f'{REC_TAC}_{entity.name}': entity.summarize_state()})
|
summary.update({f'{REC_TAC}_{entity.name}': entity.summarize_state()})
|
||||||
@ -484,7 +496,7 @@ class BaseFactory(gym.Env):
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def do_additional_actions(self, agent: Agent, action: int) -> Union[None, bool]:
|
def do_additional_actions(self, agent: Agent, action: Action) -> Union[None, c]:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
|
@ -1,3 +1,6 @@
|
|||||||
|
from enum import Enum
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from environments.helpers import Constants as c
|
from environments.helpers import Constants as c
|
||||||
@ -6,6 +9,8 @@ import itertools
|
|||||||
|
|
||||||
class Object:
|
class Object:
|
||||||
|
|
||||||
|
_u_idx = 0
|
||||||
|
|
||||||
def __bool__(self):
|
def __bool__(self):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -17,21 +22,41 @@ class Object:
|
|||||||
def name(self):
|
def name(self):
|
||||||
return self._name
|
return self._name
|
||||||
|
|
||||||
def __init__(self, name, name_is_identifier=False, is_blocking_light=False, **kwargs):
|
@property
|
||||||
name = name.name if hasattr(name, 'name') else name
|
def identifier(self):
|
||||||
self._name = f'{self.__class__.__name__}#{name}' if name_is_identifier else name
|
return self._enum_ident
|
||||||
|
|
||||||
|
def __init__(self, enum_ident: Union[Enum, None] = None, is_blocking_light=False, **kwargs):
|
||||||
|
self._enum_ident = enum_ident
|
||||||
|
if self._enum_ident is not None:
|
||||||
|
self._name = f'{self.__class__.__name__}[{self._enum_ident.name}]'
|
||||||
|
else:
|
||||||
|
self._name = f'{self.__class__.__name__}#{self._u_idx}'
|
||||||
|
Object._u_idx += 1
|
||||||
self._is_blocking_light = is_blocking_light
|
self._is_blocking_light = is_blocking_light
|
||||||
if kwargs:
|
if kwargs:
|
||||||
print(f'Following kwargs were passed, but ignored: {kwargs}')
|
print(f'Following kwargs were passed, but ignored: {kwargs}')
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f'{self.__class__.__name__}({self.name})'
|
return f'{self.name}'
|
||||||
|
|
||||||
|
def __eq__(self, other) -> bool:
|
||||||
|
if self._enum_ident is not None:
|
||||||
|
if isinstance(other, Enum):
|
||||||
|
return other == self._enum_ident
|
||||||
|
elif isinstance(other, Object):
|
||||||
|
return other._enum_ident == self._enum_ident
|
||||||
|
else:
|
||||||
|
raise ValueError('Must be evaluated against an Enunm Identifier or Object with such.')
|
||||||
|
else:
|
||||||
|
assert isinstance(other, Object), ' This Object can only be compared to other Objects.'
|
||||||
|
return other.name == self.name
|
||||||
|
|
||||||
|
|
||||||
class Action(Object):
|
class Action(Object):
|
||||||
|
|
||||||
def __init__(self, *args):
|
def __init__(self, *args, **kwargs):
|
||||||
super(Action, self).__init__(*args)
|
super(Action, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class Tile(Object):
|
class Tile(Object):
|
||||||
@ -56,8 +81,8 @@ class Tile(Object):
|
|||||||
def pos(self):
|
def pos(self):
|
||||||
return self._pos
|
return self._pos
|
||||||
|
|
||||||
def __init__(self, i, pos, **kwargs):
|
def __init__(self, pos, **kwargs):
|
||||||
super(Tile, self).__init__(i, **kwargs)
|
super(Tile, self).__init__(**kwargs)
|
||||||
self._guests = dict()
|
self._guests = dict()
|
||||||
self._pos = tuple(pos)
|
self._pos = tuple(pos)
|
||||||
|
|
||||||
@ -84,6 +109,9 @@ class Tile(Object):
|
|||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f'{self.name}(@{self.pos})'
|
||||||
|
|
||||||
|
|
||||||
class Wall(Tile):
|
class Wall(Tile):
|
||||||
pass
|
pass
|
||||||
@ -97,7 +125,7 @@ class Entity(Object):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def encoding(self):
|
def encoding(self):
|
||||||
return 1
|
return c.OCCUPIED_CELL.value
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def x(self):
|
def x(self):
|
||||||
@ -115,14 +143,17 @@ class Entity(Object):
|
|||||||
def tile(self):
|
def tile(self):
|
||||||
return self._tile
|
return self._tile
|
||||||
|
|
||||||
def __init__(self, identifier, tile: Tile, **kwargs):
|
def __init__(self, tile: Tile, **kwargs):
|
||||||
super(Entity, self).__init__(identifier, **kwargs)
|
super(Entity, self).__init__(**kwargs)
|
||||||
self._tile = tile
|
self._tile = tile
|
||||||
tile.enter(self)
|
tile.enter(self)
|
||||||
|
|
||||||
def summarize_state(self):
|
def summarize_state(self):
|
||||||
return self.__dict__.copy()
|
return self.__dict__.copy()
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f'{self.name}(@{self.pos})'
|
||||||
|
|
||||||
|
|
||||||
class Door(Entity):
|
class Door(Entity):
|
||||||
|
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
import random
|
import random
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from enum import Enum
|
|
||||||
from typing import List, Union, Dict
|
from typing import List, Union, Dict
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -18,13 +17,8 @@ class Register:
|
|||||||
def name(self):
|
def name(self):
|
||||||
return self.__class__.__name__
|
return self.__class__.__name__
|
||||||
|
|
||||||
@property
|
|
||||||
def n(self):
|
|
||||||
return len(self)
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
self._register = dict()
|
self._register = dict()
|
||||||
self._names = dict()
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self._register)
|
return len(self._register)
|
||||||
@ -36,9 +30,7 @@ class Register:
|
|||||||
assert isinstance(other, self._accepted_objects), f'All item names have to be of type ' \
|
assert isinstance(other, self._accepted_objects), f'All item names have to be of type ' \
|
||||||
f'{self._accepted_objects}, ' \
|
f'{self._accepted_objects}, ' \
|
||||||
f'but were {other.__class__}.,'
|
f'but were {other.__class__}.,'
|
||||||
new_idx = len(self._register)
|
self._register.update({other.name: other})
|
||||||
self._names.update({other.name: new_idx})
|
|
||||||
self._register.update({new_idx: other})
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def register_additional_items(self, others: List[_accepted_objects]):
|
def register_additional_items(self, others: List[_accepted_objects]):
|
||||||
@ -56,31 +48,16 @@ class Register:
|
|||||||
return self._register.items()
|
return self._register.items()
|
||||||
|
|
||||||
def __getitem__(self, item):
|
def __getitem__(self, item):
|
||||||
try:
|
if isinstance(item, int):
|
||||||
return self._register[item]
|
try:
|
||||||
except KeyError as e:
|
return next(v for i, v in enumerate(self._register.values()) if i == item)
|
||||||
print('NO')
|
except StopIteration:
|
||||||
print(e)
|
return None
|
||||||
raise
|
return self._register[item]
|
||||||
|
|
||||||
def by_name(self, item):
|
|
||||||
return self[self._names[item]]
|
|
||||||
|
|
||||||
def by_enum(self, enum_obj: Enum):
|
|
||||||
return self[self._names[enum_obj.name]]
|
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f'{self.__class__.__name__}({self._register})'
|
return f'{self.__class__.__name__}({self._register})'
|
||||||
|
|
||||||
def get_name(self, item):
|
|
||||||
return self._register[item].name
|
|
||||||
|
|
||||||
def get_idx_by_name(self, item):
|
|
||||||
return self._names[item]
|
|
||||||
|
|
||||||
def get_idx(self, enum_obj: Enum):
|
|
||||||
return self._names[enum_obj.name]
|
|
||||||
|
|
||||||
|
|
||||||
class ObjectRegister(Register):
|
class ObjectRegister(Register):
|
||||||
def __init__(self, level_shape: (int, int), *args, individual_slices=False, is_per_agent=False, **kwargs):
|
def __init__(self, level_shape: (int, int), *args, individual_slices=False, is_per_agent=False, **kwargs):
|
||||||
@ -96,7 +73,7 @@ class ObjectRegister(Register):
|
|||||||
self._array = np.zeros((1, *self._level_shape))
|
self._array = np.zeros((1, *self._level_shape))
|
||||||
else:
|
else:
|
||||||
if self.individual_slices:
|
if self.individual_slices:
|
||||||
self._array = np.concatenate((self._array, np.zeros(1, *self._level_shape)))
|
self._array = np.concatenate((self._array, np.zeros((1, *self._level_shape))))
|
||||||
|
|
||||||
|
|
||||||
class EntityObjectRegister(ObjectRegister, ABC):
|
class EntityObjectRegister(ObjectRegister, ABC):
|
||||||
@ -107,8 +84,8 @@ class EntityObjectRegister(ObjectRegister, ABC):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_tiles(cls, tiles, *args, **kwargs):
|
def from_tiles(cls, tiles, *args, **kwargs):
|
||||||
# objects_name = cls._accepted_objects.__name__
|
# objects_name = cls._accepted_objects.__name__
|
||||||
entities = [cls._accepted_objects(i, tile, name_is_identifier=True, **kwargs)
|
entities = [cls._accepted_objects(tile, **kwargs)
|
||||||
for i, tile in enumerate(tiles)]
|
for tile in tiles]
|
||||||
register_obj = cls(*args)
|
register_obj = cls(*args)
|
||||||
register_obj.register_additional_items(entities)
|
register_obj.register_additional_items(entities)
|
||||||
return register_obj
|
return register_obj
|
||||||
@ -119,7 +96,7 @@ class EntityObjectRegister(ObjectRegister, ABC):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def positions(self):
|
def positions(self):
|
||||||
return list(self._tiles.keys())
|
return [x.pos for x in self]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def tiles(self):
|
def tiles(self):
|
||||||
@ -128,25 +105,15 @@ class EntityObjectRegister(ObjectRegister, ABC):
|
|||||||
def __init__(self, *args, is_blocking_light=False, is_observable=True, can_be_shadowed=True, **kwargs):
|
def __init__(self, *args, is_blocking_light=False, is_observable=True, can_be_shadowed=True, **kwargs):
|
||||||
super(EntityObjectRegister, self).__init__(*args, **kwargs)
|
super(EntityObjectRegister, self).__init__(*args, **kwargs)
|
||||||
self.can_be_shadowed = can_be_shadowed
|
self.can_be_shadowed = can_be_shadowed
|
||||||
self._tiles = dict()
|
|
||||||
self.is_blocking_light = is_blocking_light
|
self.is_blocking_light = is_blocking_light
|
||||||
self.is_observable = is_observable
|
self.is_observable = is_observable
|
||||||
|
|
||||||
def register_item(self, other):
|
|
||||||
super(EntityObjectRegister, self).register_item(other)
|
|
||||||
self._tiles[other.pos] = other
|
|
||||||
|
|
||||||
def register_additional_items(self, others):
|
|
||||||
for other in others:
|
|
||||||
self.register_item(other)
|
|
||||||
return self
|
|
||||||
|
|
||||||
def by_pos(self, pos):
|
def by_pos(self, pos):
|
||||||
if isinstance(pos, np.ndarray):
|
if isinstance(pos, np.ndarray):
|
||||||
pos = tuple(pos)
|
pos = tuple(pos)
|
||||||
try:
|
try:
|
||||||
return self._tiles[pos]
|
return next(item for item in self.values() if item.pos == pos)
|
||||||
except KeyError:
|
except StopIteration:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@ -159,12 +126,14 @@ class MovingEntityObjectRegister(EntityObjectRegister, ABC):
|
|||||||
if isinstance(pos, np.ndarray):
|
if isinstance(pos, np.ndarray):
|
||||||
pos = tuple(pos)
|
pos = tuple(pos)
|
||||||
try:
|
try:
|
||||||
return [x for x in self if x == pos][0]
|
return next(x for x in self if x.pos == pos)
|
||||||
except IndexError:
|
except StopIteration:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def delete_item(self, item):
|
def delete_item(self, item):
|
||||||
self
|
if not isinstance(item, str):
|
||||||
|
item = item.name
|
||||||
|
del self._register[item]
|
||||||
|
|
||||||
|
|
||||||
class Entities(Register):
|
class Entities(Register):
|
||||||
@ -186,7 +155,7 @@ class Entities(Register):
|
|||||||
return iter([x for sublist in self.values() for x in sublist])
|
return iter([x for sublist in self.values() for x in sublist])
|
||||||
|
|
||||||
def register_item(self, other: dict):
|
def register_item(self, other: dict):
|
||||||
assert not any([key for key in other.keys() if key in self._names]), \
|
assert not any([key for key in other.keys() if key in self.keys()]), \
|
||||||
"This group of entities has already been registered!"
|
"This group of entities has already been registered!"
|
||||||
self._register.update(other)
|
self._register.update(other)
|
||||||
return self
|
return self
|
||||||
@ -206,7 +175,8 @@ class WallTiles(EntityObjectRegister):
|
|||||||
return self._array
|
return self._array
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(WallTiles, self).__init__(*args, individual_slices=False, is_blocking_light=self._light_blocking, **kwargs)
|
super(WallTiles, self).__init__(*args, individual_slices=False,
|
||||||
|
is_blocking_light=self._light_blocking, **kwargs)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def encoding(self):
|
def encoding(self):
|
||||||
@ -221,8 +191,8 @@ class WallTiles(EntityObjectRegister):
|
|||||||
tiles = cls(*args, **kwargs)
|
tiles = cls(*args, **kwargs)
|
||||||
# noinspection PyTypeChecker
|
# noinspection PyTypeChecker
|
||||||
tiles.register_additional_items(
|
tiles.register_additional_items(
|
||||||
[cls._accepted_objects(i, pos, name_is_identifier=True, is_blocking_light=cls._light_blocking)
|
[cls._accepted_objects(pos, is_blocking_light=cls._light_blocking)
|
||||||
for i, pos in enumerate(argwhere_coordinates)]
|
for pos in argwhere_coordinates]
|
||||||
)
|
)
|
||||||
return tiles
|
return tiles
|
||||||
|
|
||||||
@ -237,7 +207,7 @@ class FloorTiles(WallTiles):
|
|||||||
_light_blocking = False
|
_light_blocking = False
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(self.__class__, self).__init__(*args, is_observable=False, **kwargs)
|
super(FloorTiles, self).__init__(*args, is_observable=False, **kwargs)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def encoding(self):
|
def encoding(self):
|
||||||
@ -265,8 +235,11 @@ class Agents(MovingEntityObjectRegister):
|
|||||||
def as_array(self):
|
def as_array(self):
|
||||||
self._array[:] = c.FREE_CELL.value
|
self._array[:] = c.FREE_CELL.value
|
||||||
# noinspection PyTupleAssignmentBalance
|
# noinspection PyTupleAssignmentBalance
|
||||||
z, x, y = range(len(self)), *zip(*[x.pos for x in self])
|
for z, x, y, v in zip(range(len(self)), *zip(*[x.pos for x in self]), [x.encoding for x in self]):
|
||||||
self._array[z, x, y] = c.OCCUPIED_CELL.value
|
if self.individual_slices:
|
||||||
|
self._array[z, x, y] += v
|
||||||
|
else:
|
||||||
|
self._array[0, x, y] += v
|
||||||
if self.individual_slices:
|
if self.individual_slices:
|
||||||
return self._array
|
return self._array
|
||||||
else:
|
else:
|
||||||
@ -293,9 +266,9 @@ class Doors(EntityObjectRegister):
|
|||||||
_accepted_objects = Door
|
_accepted_objects = Door
|
||||||
|
|
||||||
def get_near_position(self, position: (int, int)) -> Union[None, Door]:
|
def get_near_position(self, position: (int, int)) -> Union[None, Door]:
|
||||||
if found_doors := [door for door in self if position in door.access_area]:
|
try:
|
||||||
return found_doors[0]
|
return next(door for door in self if position in door.access_area)
|
||||||
else:
|
except StopIteration:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def tick_doors(self):
|
def tick_doors(self):
|
||||||
@ -320,39 +293,23 @@ class Actions(Register):
|
|||||||
super(Actions, self).__init__()
|
super(Actions, self).__init__()
|
||||||
|
|
||||||
if self.allow_square_movement:
|
if self.allow_square_movement:
|
||||||
self.register_additional_items([self._accepted_objects(direction) for direction in h.ManhattanMoves])
|
self.register_additional_items([self._accepted_objects(enum_ident=direction)
|
||||||
|
for direction in h.ManhattanMoves])
|
||||||
if self.allow_diagonal_movement:
|
if self.allow_diagonal_movement:
|
||||||
self.register_additional_items([self._accepted_objects(direction) for direction in h.DiagonalMoves])
|
self.register_additional_items([self._accepted_objects(enum_ident=direction)
|
||||||
|
for direction in h.DiagonalMoves])
|
||||||
self._movement_actions = self._register.copy()
|
self._movement_actions = self._register.copy()
|
||||||
if self.can_use_doors:
|
if self.can_use_doors:
|
||||||
self.register_additional_items([self._accepted_objects(h.EnvActions.USE_DOOR)])
|
self.register_additional_items([self._accepted_objects(enum_ident=h.EnvActions.USE_DOOR)])
|
||||||
if self.allow_no_op:
|
if self.allow_no_op:
|
||||||
self.register_additional_items([self._accepted_objects(h.EnvActions.NOOP)])
|
self.register_additional_items([self._accepted_objects(enum_ident=h.EnvActions.NOOP)])
|
||||||
|
|
||||||
def is_moving_action(self, action: Union[int]):
|
def is_moving_action(self, action: Union[int]):
|
||||||
return action in self.movement_actions.values()
|
return action in self.movement_actions.values()
|
||||||
|
|
||||||
def is_no_op(self, action: Union[str, Action, int]):
|
|
||||||
if isinstance(action, int):
|
|
||||||
action = self[action]
|
|
||||||
if isinstance(action, Action):
|
|
||||||
action = action.name
|
|
||||||
return action == h.EnvActions.NOOP.name
|
|
||||||
|
|
||||||
def is_door_usage(self, action: Union[str, int]):
|
|
||||||
if isinstance(action, int):
|
|
||||||
action = self[action]
|
|
||||||
if isinstance(action, Action):
|
|
||||||
action = action.name
|
|
||||||
return action == h.EnvActions.USE_DOOR.name
|
|
||||||
|
|
||||||
|
|
||||||
class Zones(Register):
|
class Zones(Register):
|
||||||
|
|
||||||
@property
|
|
||||||
def danger_zone(self):
|
|
||||||
return self._zone_slices[self.by_enum(c.DANGER_ZONE)]
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def accounting_zones(self):
|
def accounting_zones(self):
|
||||||
return [self[idx] for idx, name in self.items() if name != c.DANGER_ZONE.value]
|
return [self[idx] for idx, name in self.items() if name != c.DANGER_ZONE.value]
|
||||||
@ -380,11 +337,5 @@ class Zones(Register):
|
|||||||
def __getitem__(self, item):
|
def __getitem__(self, item):
|
||||||
return self._zone_slices[item]
|
return self._zone_slices[item]
|
||||||
|
|
||||||
def get_name(self, item):
|
|
||||||
return self._register[item]
|
|
||||||
|
|
||||||
def by_name(self, item):
|
|
||||||
return self[super(Zones, self).by_name(item)]
|
|
||||||
|
|
||||||
def register_additional_items(self, other: Union[str, List[str]]):
|
def register_additional_items(self, other: Union[str, List[str]]):
|
||||||
raise AttributeError('You are not allowed to add additional Zones in runtime.')
|
raise AttributeError('You are not allowed to add additional Zones in runtime.')
|
||||||
|
@ -100,7 +100,7 @@ class Inventories(ObjectRegister):
|
|||||||
can_be_shadowed = False
|
can_be_shadowed = False
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(Inventories, self).__init__(*args, is_per_agent=True, **kwargs)
|
super(Inventories, self).__init__(*args, is_per_agent=True, individual_slices=True, **kwargs)
|
||||||
self.is_observable = True
|
self.is_observable = True
|
||||||
|
|
||||||
def as_array(self):
|
def as_array(self):
|
||||||
@ -132,10 +132,10 @@ class DropOffLocation(Entity):
|
|||||||
def place_item(self, item):
|
def place_item(self, item):
|
||||||
if self.is_full:
|
if self.is_full:
|
||||||
raise RuntimeWarning("There is currently no way to clear the storage or make it unfull.")
|
raise RuntimeWarning("There is currently no way to clear the storage or make it unfull.")
|
||||||
return False
|
return c.NOT_VALID
|
||||||
else:
|
else:
|
||||||
self.storage.append(item)
|
self.storage.append(item)
|
||||||
return True
|
return c.VALID
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_full(self):
|
def is_full(self):
|
||||||
@ -166,56 +166,48 @@ class ItemProperties(NamedTuple):
|
|||||||
# noinspection PyAttributeOutsideInit, PyAbstractClass
|
# noinspection PyAttributeOutsideInit, PyAbstractClass
|
||||||
class DoubleTaskFactory(SimpleFactory):
|
class DoubleTaskFactory(SimpleFactory):
|
||||||
# noinspection PyMissingConstructor
|
# noinspection PyMissingConstructor
|
||||||
def __init__(self, item_properties: ItemProperties, *args, with_dirt=False, env_seed=time.time_ns(), **kwargs):
|
def __init__(self, item_properties: ItemProperties, *args, env_seed=time.time_ns(), **kwargs):
|
||||||
self.item_properties = item_properties
|
self.item_properties = item_properties
|
||||||
kwargs.update(env_seed=env_seed)
|
kwargs.update(env_seed=env_seed)
|
||||||
self._item_rng = np.random.default_rng(env_seed)
|
self._item_rng = np.random.default_rng(env_seed)
|
||||||
assert item_properties.n_items < kwargs.get('pomdp_r', 0) ** 2 or not kwargs.get('pomdp_r', 0)
|
assert item_properties.n_items < kwargs.get('pomdp_r', 0) ** 2 or not kwargs.get('pomdp_r', 0)
|
||||||
self._super = DoubleTaskFactory if with_dirt else SimpleFactory
|
super(DoubleTaskFactory, self).__init__(*args, **kwargs)
|
||||||
super(self._super, self).__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def additional_actions(self) -> Union[Action, List[Action]]:
|
def additional_actions(self) -> Union[Action, List[Action]]:
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
super_actions = super(self._super, self).additional_actions
|
super_actions = super(DoubleTaskFactory, self).additional_actions
|
||||||
super_actions.append(Action(h.EnvActions.ITEM_ACTION))
|
super_actions.append(Action(enum_ident=h.EnvActions.ITEM_ACTION))
|
||||||
return super_actions
|
return super_actions
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def additional_entities(self) -> Dict[(Enum, Entities)]:
|
def additional_entities(self) -> Dict[(Enum, Entities)]:
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
super_entities = super(self._super, self).additional_entities
|
super_entities = super(DoubleTaskFactory, self).additional_entities
|
||||||
|
|
||||||
empty_tiles = self._entities[c.FLOOR].empty_tiles[:self.item_properties.n_drop_off_locations]
|
empty_tiles = self[c.FLOOR].empty_tiles[:self.item_properties.n_drop_off_locations]
|
||||||
drop_offs = DropOffLocations.from_tiles(empty_tiles, self._level_shape,
|
drop_offs = DropOffLocations.from_tiles(empty_tiles, self._level_shape,
|
||||||
storage_size_until_full=self.item_properties.max_dropoff_storage_size)
|
storage_size_until_full=self.item_properties.max_dropoff_storage_size)
|
||||||
item_register = ItemRegister(self._level_shape)
|
item_register = ItemRegister(self._level_shape)
|
||||||
empty_tiles = self._entities[c.FLOOR].empty_tiles[:self.item_properties.n_items]
|
empty_tiles = self[c.FLOOR].empty_tiles[:self.item_properties.n_items]
|
||||||
item_register.spawn_items(empty_tiles)
|
item_register.spawn_items(empty_tiles)
|
||||||
|
|
||||||
inventories = Inventories(self._level_shape)
|
inventories = Inventories(self._level_shape)
|
||||||
inventories.spawn_inventories(self._entities[c.AGENT], self.pomdp_r,
|
inventories.spawn_inventories(self[c.AGENT], self.pomdp_r,
|
||||||
self.item_properties.max_agent_inventory_capacity)
|
self.item_properties.max_agent_inventory_capacity)
|
||||||
|
|
||||||
super_entities.update({c.DROP_OFF: drop_offs, c.ITEM: item_register, c.INVENTORY: inventories})
|
super_entities.update({c.DROP_OFF: drop_offs, c.ITEM: item_register, c.INVENTORY: inventories})
|
||||||
return super_entities
|
return super_entities
|
||||||
|
|
||||||
def _is_item_action(self, action):
|
|
||||||
if isinstance(action, int):
|
|
||||||
action = self._actions[action]
|
|
||||||
if isinstance(action, Action):
|
|
||||||
action = action.name
|
|
||||||
return action == h.EnvActions.ITEM_ACTION.name
|
|
||||||
|
|
||||||
def do_item_action(self, agent: Agent):
|
def do_item_action(self, agent: Agent):
|
||||||
inventory = self._entities[c.INVENTORY].by_name(agent.name)
|
inventory = self[c.INVENTORY].by_name(agent.name)
|
||||||
if drop_off := self._entities[c.DROP_OFF].by_pos(agent.pos):
|
if drop_off := self[c.DROP_OFF].by_pos(agent.pos):
|
||||||
if inventory:
|
if inventory:
|
||||||
valid = drop_off.place_item(inventory.pop(0))
|
valid = drop_off.place_item(inventory.pop(0))
|
||||||
return valid
|
return valid
|
||||||
else:
|
else:
|
||||||
return c.NOT_VALID
|
return c.NOT_VALID
|
||||||
elif item := self._entities[c.ITEM].by_pos(agent.pos):
|
elif item := self[c.ITEM].by_pos(agent.pos):
|
||||||
try:
|
try:
|
||||||
inventory.append(item)
|
inventory.append(item)
|
||||||
item.move(self.NO_POS_TILE)
|
item.move(self.NO_POS_TILE)
|
||||||
@ -225,16 +217,16 @@ class DoubleTaskFactory(SimpleFactory):
|
|||||||
else:
|
else:
|
||||||
return c.NOT_VALID
|
return c.NOT_VALID
|
||||||
|
|
||||||
def do_additional_actions(self, agent: Agent, action: int) -> Union[None, bool]:
|
def do_additional_actions(self, agent: Agent, action: Action) -> Union[None, c]:
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
valid = super(self._super, self).do_additional_actions(agent, action)
|
valid = super(DoubleTaskFactory, self).do_additional_actions(agent, action)
|
||||||
if valid is None:
|
if valid is None:
|
||||||
if self._is_item_action(action):
|
if action == h.EnvActions.ITEM_ACTION:
|
||||||
if self.item_properties.agent_can_interact:
|
if self.item_properties.agent_can_interact:
|
||||||
valid = self.do_item_action(agent)
|
valid = self.do_item_action(agent)
|
||||||
return bool(valid)
|
return valid
|
||||||
else:
|
else:
|
||||||
return False
|
return c.NOT_VALID
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
@ -242,14 +234,14 @@ class DoubleTaskFactory(SimpleFactory):
|
|||||||
|
|
||||||
def do_additional_reset(self) -> None:
|
def do_additional_reset(self) -> None:
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
super(self._super, self).do_additional_reset()
|
super(DoubleTaskFactory, self).do_additional_reset()
|
||||||
self._next_item_spawn = self.item_properties.spawn_frequency
|
self._next_item_spawn = self.item_properties.spawn_frequency
|
||||||
self.trigger_item_spawn()
|
self.trigger_item_spawn()
|
||||||
|
|
||||||
def trigger_item_spawn(self):
|
def trigger_item_spawn(self):
|
||||||
if item_to_spawns := max(0, (self.item_properties.n_items - len(self._entities[c.ITEM]))):
|
if item_to_spawns := max(0, (self.item_properties.n_items - len(self[c.ITEM]))):
|
||||||
empty_tiles = self._entities[c.FLOOR].empty_tiles[:item_to_spawns]
|
empty_tiles = self[c.FLOOR].empty_tiles[:item_to_spawns]
|
||||||
self._entities[c.ITEM].spawn_items(empty_tiles)
|
self[c.ITEM].spawn_items(empty_tiles)
|
||||||
self._next_item_spawn = self.item_properties.spawn_frequency
|
self._next_item_spawn = self.item_properties.spawn_frequency
|
||||||
self.print(f'{item_to_spawns} new items have been spawned; next spawn in {self._next_item_spawn}')
|
self.print(f'{item_to_spawns} new items have been spawned; next spawn in {self._next_item_spawn}')
|
||||||
else:
|
else:
|
||||||
@ -257,7 +249,7 @@ class DoubleTaskFactory(SimpleFactory):
|
|||||||
|
|
||||||
def do_additional_step(self) -> dict:
|
def do_additional_step(self) -> dict:
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
info_dict = super(self._super, self).do_additional_step()
|
info_dict = super(DoubleTaskFactory, self).do_additional_step()
|
||||||
if not self._next_item_spawn:
|
if not self._next_item_spawn:
|
||||||
self.trigger_item_spawn()
|
self.trigger_item_spawn()
|
||||||
else:
|
else:
|
||||||
@ -266,10 +258,10 @@ class DoubleTaskFactory(SimpleFactory):
|
|||||||
|
|
||||||
def calculate_additional_reward(self, agent: Agent) -> (int, dict):
|
def calculate_additional_reward(self, agent: Agent) -> (int, dict):
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
reward, info_dict = super(self._super, self).calculate_additional_reward(agent)
|
reward, info_dict = super(DoubleTaskFactory, self).calculate_additional_reward(agent)
|
||||||
if self._is_item_action(agent.temp_action):
|
if h.EnvActions.ITEM_ACTION == agent.temp_action:
|
||||||
if agent.temp_valid:
|
if agent.temp_valid:
|
||||||
if self._entities[c.DROP_OFF].by_pos(agent.pos):
|
if self[c.DROP_OFF].by_pos(agent.pos):
|
||||||
info_dict.update({f'{agent.name}_item_dropoff': 1})
|
info_dict.update({f'{agent.name}_item_dropoff': 1})
|
||||||
|
|
||||||
reward += 1
|
reward += 1
|
||||||
@ -283,10 +275,10 @@ class DoubleTaskFactory(SimpleFactory):
|
|||||||
|
|
||||||
def render_additional_assets(self, mode='human'):
|
def render_additional_assets(self, mode='human'):
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
additional_assets = super(self._super, self).render_additional_assets()
|
additional_assets = super(DoubleTaskFactory, self).render_additional_assets()
|
||||||
items = [RenderEntity(c.ITEM.value, item.tile.pos) for item in self._entities[c.ITEM]]
|
items = [RenderEntity(c.ITEM.value, item.tile.pos) for item in self[c.ITEM]]
|
||||||
additional_assets.extend(items)
|
additional_assets.extend(items)
|
||||||
drop_offs = [RenderEntity(c.DROP_OFF.value, drop_off.tile.pos) for drop_off in self._entities[c.DROP_OFF]]
|
drop_offs = [RenderEntity(c.DROP_OFF.value, drop_off.tile.pos) for drop_off in self[c.DROP_OFF]]
|
||||||
additional_assets.extend(drop_offs)
|
additional_assets.extend(drop_offs)
|
||||||
return additional_assets
|
return additional_assets
|
||||||
|
|
||||||
@ -297,8 +289,8 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
item_props = ItemProperties()
|
item_props = ItemProperties()
|
||||||
|
|
||||||
factory = DoubleTaskFactory(item_props, n_agents=1, done_at_collision=False, frames_to_stack=0,
|
factory = DoubleTaskFactory(item_props, n_agents=3, done_at_collision=False, frames_to_stack=0,
|
||||||
level_name='rooms', max_steps=400,
|
level_name='rooms', max_steps=4000,
|
||||||
omit_agent_slice_in_obs=True, parse_doors=True, pomdp_r=3,
|
omit_agent_slice_in_obs=True, parse_doors=True, pomdp_r=3,
|
||||||
record_episodes=False, verbose=False
|
record_episodes=False, verbose=False
|
||||||
)
|
)
|
||||||
|
@ -8,7 +8,7 @@ import numpy as np
|
|||||||
from environments.helpers import Constants as c
|
from environments.helpers import Constants as c
|
||||||
from environments import helpers as h
|
from environments import helpers as h
|
||||||
from environments.factory.base.base_factory import BaseFactory
|
from environments.factory.base.base_factory import BaseFactory
|
||||||
from environments.factory.base.objects import Agent, Action, Entity
|
from environments.factory.base.objects import Agent, Action, Entity, Tile
|
||||||
from environments.factory.base.registers import Entities, MovingEntityObjectRegister
|
from environments.factory.base.registers import Entities, MovingEntityObjectRegister
|
||||||
|
|
||||||
from environments.factory.renderer import RenderEntity
|
from environments.factory.renderer import RenderEntity
|
||||||
@ -85,19 +85,21 @@ class DirtRegister(MovingEntityObjectRegister):
|
|||||||
super(DirtRegister, self).__init__(*args)
|
super(DirtRegister, self).__init__(*args)
|
||||||
self._dirt_properties: DirtProperties = dirt_properties
|
self._dirt_properties: DirtProperties = dirt_properties
|
||||||
|
|
||||||
def spawn_dirt(self, then_dirty_tiles) -> None:
|
def spawn_dirt(self, then_dirty_tiles) -> c:
|
||||||
if not self.amount > self.dirt_properties.max_global_amount:
|
if isinstance(then_dirty_tiles, Tile):
|
||||||
# randomly distribute dirt across the grid
|
then_dirty_tiles = [then_dirty_tiles]
|
||||||
for tile in then_dirty_tiles:
|
for tile in then_dirty_tiles:
|
||||||
|
if not self.amount > self.dirt_properties.max_global_amount:
|
||||||
dirt = self.by_pos(tile.pos)
|
dirt = self.by_pos(tile.pos)
|
||||||
if dirt is None:
|
if dirt is None:
|
||||||
dirt = Dirt(0, tile, amount=self.dirt_properties.gain_amount)
|
dirt = Dirt(tile, amount=self.dirt_properties.gain_amount)
|
||||||
self.register_item(dirt)
|
self.register_item(dirt)
|
||||||
else:
|
else:
|
||||||
new_value = dirt.amount + self.dirt_properties.gain_amount
|
new_value = dirt.amount + self.dirt_properties.gain_amount
|
||||||
dirt.set_new_amount(min(new_value, self.dirt_properties.max_local_amount))
|
dirt.set_new_amount(min(new_value, self.dirt_properties.max_local_amount))
|
||||||
else:
|
else:
|
||||||
pass
|
return c.NOT_VALID
|
||||||
|
return c.VALID
|
||||||
|
|
||||||
|
|
||||||
def softmax(x):
|
def softmax(x):
|
||||||
@ -117,7 +119,7 @@ class SimpleFactory(BaseFactory):
|
|||||||
def additional_actions(self) -> Union[Action, List[Action]]:
|
def additional_actions(self) -> Union[Action, List[Action]]:
|
||||||
super_actions = super(SimpleFactory, self).additional_actions
|
super_actions = super(SimpleFactory, self).additional_actions
|
||||||
if self.dirt_properties.agent_can_interact:
|
if self.dirt_properties.agent_can_interact:
|
||||||
super_actions.append(Action(CLEAN_UP_ACTION))
|
super_actions.append(Action(enum_ident=CLEAN_UP_ACTION))
|
||||||
return super_actions
|
return super_actions
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -127,13 +129,6 @@ class SimpleFactory(BaseFactory):
|
|||||||
super_entities.update(({c.DIRT: dirt_register}))
|
super_entities.update(({c.DIRT: dirt_register}))
|
||||||
return super_entities
|
return super_entities
|
||||||
|
|
||||||
def _is_clean_up_action(self, action: Union[str, Action, int]):
|
|
||||||
if isinstance(action, int):
|
|
||||||
action = self._actions[action]
|
|
||||||
if isinstance(action, Action):
|
|
||||||
action = action.name
|
|
||||||
return action == CLEAN_UP_ACTION.name
|
|
||||||
|
|
||||||
def __init__(self, *args, dirt_properties: DirtProperties = DirtProperties(), env_seed=time.time_ns(), **kwargs):
|
def __init__(self, *args, dirt_properties: DirtProperties = DirtProperties(), env_seed=time.time_ns(), **kwargs):
|
||||||
self.dirt_properties = dirt_properties
|
self.dirt_properties = dirt_properties
|
||||||
self._dirt_rng = np.random.default_rng(env_seed)
|
self._dirt_rng = np.random.default_rng(env_seed)
|
||||||
@ -144,38 +139,43 @@ class SimpleFactory(BaseFactory):
|
|||||||
def render_additional_assets(self, mode='human'):
|
def render_additional_assets(self, mode='human'):
|
||||||
additional_assets = super(SimpleFactory, self).render_additional_assets()
|
additional_assets = super(SimpleFactory, self).render_additional_assets()
|
||||||
dirt = [RenderEntity('dirt', dirt.tile.pos, min(0.15 + dirt.amount, 1.5), 'scale')
|
dirt = [RenderEntity('dirt', dirt.tile.pos, min(0.15 + dirt.amount, 1.5), 'scale')
|
||||||
for dirt in self._entities[c.DIRT]]
|
for dirt in self[c.DIRT]]
|
||||||
additional_assets.extend(dirt)
|
additional_assets.extend(dirt)
|
||||||
return additional_assets
|
return additional_assets
|
||||||
|
|
||||||
def clean_up(self, agent: Agent) -> bool:
|
def clean_up(self, agent: Agent) -> c:
|
||||||
if dirt := self._entities[c.DIRT].by_pos(agent.pos):
|
if dirt := self[c.DIRT].by_pos(agent.pos):
|
||||||
new_dirt_amount = dirt.amount - self.dirt_properties.clean_amount
|
new_dirt_amount = dirt.amount - self.dirt_properties.clean_amount
|
||||||
dirt.set_new_amount(max(new_dirt_amount, c.FREE_CELL.value))
|
|
||||||
return True
|
if new_dirt_amount <= 0:
|
||||||
|
self[c.DIRT].delete_item(dirt)
|
||||||
|
else:
|
||||||
|
dirt.set_new_amount(max(new_dirt_amount, c.FREE_CELL.value))
|
||||||
|
return c.VALID
|
||||||
else:
|
else:
|
||||||
return False
|
return c.NOT_VALID
|
||||||
|
|
||||||
def trigger_dirt_spawn(self):
|
def trigger_dirt_spawn(self):
|
||||||
free_for_dirt = self._entities[c.FLOOR].empty_tiles
|
free_for_dirt = self[c.FLOOR].empty_tiles
|
||||||
new_spawn = self._dirt_rng.uniform(0, self.dirt_properties.max_spawn_ratio)
|
new_spawn = self._dirt_rng.uniform(0, self.dirt_properties.max_spawn_ratio)
|
||||||
n_dirt_tiles = max(0, int(new_spawn * len(free_for_dirt)))
|
n_dirt_tiles = max(0, int(new_spawn * len(free_for_dirt)))
|
||||||
self._entities[c.DIRT].spawn_dirt(free_for_dirt[:n_dirt_tiles])
|
self[c.DIRT].spawn_dirt(free_for_dirt[:n_dirt_tiles])
|
||||||
|
|
||||||
def do_additional_step(self) -> dict:
|
def do_additional_step(self) -> dict:
|
||||||
info_dict = super(SimpleFactory, self).do_additional_step()
|
info_dict = super(SimpleFactory, self).do_additional_step()
|
||||||
if smear_amount := self.dirt_properties.dirt_smear_amount:
|
if smear_amount := self.dirt_properties.dirt_smear_amount:
|
||||||
for agent in self._entities[c.AGENT]:
|
for agent in self[c.AGENT]:
|
||||||
if agent.temp_valid and agent.last_pos != c.NO_POS:
|
if agent.temp_valid and agent.last_pos != c.NO_POS:
|
||||||
if old_pos_dirt := self._entities[c.DIRT].by_pos(agent.last_pos):
|
if self._actions.is_moving_action(agent.temp_action):
|
||||||
if smeared_dirt := round(old_pos_dirt.amount * smear_amount, 2):
|
if old_pos_dirt := self[c.DIRT].by_pos(agent.last_pos):
|
||||||
old_pos_dirt.set_new_amount(max(0, old_pos_dirt.amount-smeared_dirt))
|
if smeared_dirt := round(old_pos_dirt.amount * smear_amount, 2):
|
||||||
if new_pos_dirt := self._entities[c.DIRT].by_pos(agent.pos):
|
old_pos_dirt.set_new_amount(max(0, old_pos_dirt.amount-smeared_dirt))
|
||||||
new_pos_dirt.set_new_amount(max(0, new_pos_dirt.amount + smeared_dirt))
|
if new_pos_dirt := self[c.DIRT].by_pos(agent.pos):
|
||||||
else:
|
new_pos_dirt.set_new_amount(max(0, new_pos_dirt.amount + smeared_dirt))
|
||||||
self._entities[c.Dirt].spawn_dirt(agent.tile)
|
else:
|
||||||
new_pos_dirt = self._entities[c.DIRT].by_pos(agent.pos)
|
if self[c.DIRT].spawn_dirt(agent.tile):
|
||||||
new_pos_dirt.set_new_amount(max(0, new_pos_dirt.amount + smeared_dirt))
|
new_pos_dirt = self[c.DIRT].by_pos(agent.pos)
|
||||||
|
new_pos_dirt.set_new_amount(max(0, new_pos_dirt.amount + smeared_dirt))
|
||||||
|
|
||||||
if not self._next_dirt_spawn:
|
if not self._next_dirt_spawn:
|
||||||
self.trigger_dirt_spawn()
|
self.trigger_dirt_spawn()
|
||||||
@ -184,15 +184,15 @@ class SimpleFactory(BaseFactory):
|
|||||||
self._next_dirt_spawn -= 1
|
self._next_dirt_spawn -= 1
|
||||||
return info_dict
|
return info_dict
|
||||||
|
|
||||||
def do_additional_actions(self, agent: Agent, action: int) -> Union[None, bool]:
|
def do_additional_actions(self, agent: Agent, action: Action) -> Union[None, c]:
|
||||||
valid = super(SimpleFactory, self).do_additional_actions(agent, action)
|
valid = super(SimpleFactory, self).do_additional_actions(agent, action)
|
||||||
if valid is None:
|
if valid is None:
|
||||||
if self._is_clean_up_action(action):
|
if action == CLEAN_UP_ACTION:
|
||||||
if self.dirt_properties.agent_can_interact:
|
if self.dirt_properties.agent_can_interact:
|
||||||
valid = self.clean_up(agent)
|
valid = self.clean_up(agent)
|
||||||
return valid
|
return valid
|
||||||
else:
|
else:
|
||||||
return False
|
return c.NOT_VALID
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
@ -205,7 +205,7 @@ class SimpleFactory(BaseFactory):
|
|||||||
|
|
||||||
def calculate_additional_reward(self, agent: Agent) -> (int, dict):
|
def calculate_additional_reward(self, agent: Agent) -> (int, dict):
|
||||||
reward, info_dict = super(SimpleFactory, self).calculate_additional_reward(agent)
|
reward, info_dict = super(SimpleFactory, self).calculate_additional_reward(agent)
|
||||||
dirt = [dirt.amount for dirt in self._entities[c.DIRT]]
|
dirt = [dirt.amount for dirt in self[c.DIRT]]
|
||||||
current_dirt_amount = sum(dirt)
|
current_dirt_amount = sum(dirt)
|
||||||
dirty_tile_count = len(dirt)
|
dirty_tile_count = len(dirt)
|
||||||
if dirty_tile_count:
|
if dirty_tile_count:
|
||||||
@ -220,7 +220,7 @@ class SimpleFactory(BaseFactory):
|
|||||||
if agent.temp_collisions:
|
if agent.temp_collisions:
|
||||||
self.print(f't = {self._steps}\t{agent.name} has collisions with {agent.temp_collisions}')
|
self.print(f't = {self._steps}\t{agent.name} has collisions with {agent.temp_collisions}')
|
||||||
|
|
||||||
if self._is_clean_up_action(agent.temp_action):
|
if agent.temp_action == CLEAN_UP_ACTION:
|
||||||
if agent.temp_valid:
|
if agent.temp_valid:
|
||||||
reward += 0.5
|
reward += 0.5
|
||||||
self.print(f'{agent.name} did just clean up some dirt at {agent.pos}.')
|
self.print(f'{agent.name} did just clean up some dirt at {agent.pos}.')
|
||||||
@ -245,7 +245,7 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
factory = SimpleFactory(n_agents=1, done_at_collision=False, frames_to_stack=0,
|
factory = SimpleFactory(n_agents=1, done_at_collision=False, frames_to_stack=0,
|
||||||
level_name='rooms', max_steps=400,
|
level_name='rooms', max_steps=400,
|
||||||
omit_agent_slice_in_obs=True, parse_doors=True, pomdp_r=3,
|
omit_agent_slice_in_obs=True, parse_doors=True, pomdp_r=2,
|
||||||
record_episodes=False, verbose=False
|
record_episodes=False, verbose=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -27,13 +27,13 @@ class Constants(Enum):
|
|||||||
NO_POS = (-9999, -9999)
|
NO_POS = (-9999, -9999)
|
||||||
|
|
||||||
DOORS = 'Doors'
|
DOORS = 'Doors'
|
||||||
CLOSED_DOOR = 1
|
CLOSED_DOOR = 'closed'
|
||||||
OPEN_DOOR = -1
|
OPEN_DOOR = 'open'
|
||||||
|
|
||||||
ACTION = auto()
|
ACTION = 'action'
|
||||||
COLLISIONS = auto()
|
COLLISIONS = 'collision'
|
||||||
VALID = True
|
VALID = 'valid'
|
||||||
NOT_VALID = False
|
NOT_VALID = 'not_valid'
|
||||||
|
|
||||||
# Dirt Env
|
# Dirt Env
|
||||||
DIRT = 'Dirt'
|
DIRT = 'Dirt'
|
||||||
@ -44,7 +44,10 @@ class Constants(Enum):
|
|||||||
DROP_OFF = 'Drop_Off'
|
DROP_OFF = 'Drop_Off'
|
||||||
|
|
||||||
def __bool__(self):
|
def __bool__(self):
|
||||||
return bool(self.value)
|
if 'not_' in self.value:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return bool(self.value)
|
||||||
|
|
||||||
|
|
||||||
class ManhattanMoves(Enum):
|
class ManhattanMoves(Enum):
|
||||||
@ -72,10 +75,10 @@ d = DiagonalMoves
|
|||||||
m = ManhattanMoves
|
m = ManhattanMoves
|
||||||
c = Constants
|
c = Constants
|
||||||
|
|
||||||
ACTIONMAP = defaultdict(lambda: (0, 0), {m.NORTH.name: (-1, 0), d.NORTHEAST.name: (-1, +1),
|
ACTIONMAP = defaultdict(lambda: (0, 0), {m.NORTH: (-1, 0), d.NORTHEAST: (-1, +1),
|
||||||
m.EAST.name: (0, 1), d.SOUTHEAST.name: (1, 1),
|
m.EAST: (0, 1), d.SOUTHEAST: (1, 1),
|
||||||
m.SOUTH.name: (1, 0), d.SOUTHWEST.name: (+1, -1),
|
m.SOUTH: (1, 0), d.SOUTHWEST: (+1, -1),
|
||||||
m.WEST.name: (0, -1), d.NORTHWEST.name: (-1, -1)
|
m.WEST: (0, -1), d.NORTHWEST: (-1, -1)
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
2
main.py
2
main.py
@ -116,7 +116,7 @@ if __name__ == '__main__':
|
|||||||
pomdp_radius=2, max_steps=500, parse_doors=True,
|
pomdp_radius=2, max_steps=500, parse_doors=True,
|
||||||
level_name='rooms', frames_to_stack=3,
|
level_name='rooms', frames_to_stack=3,
|
||||||
omit_agent_slice_in_obs=True, combin_agent_slices_in_obs=True, record_episodes=False,
|
omit_agent_slice_in_obs=True, combin_agent_slices_in_obs=True, record_episodes=False,
|
||||||
cast_shadows=True, doors_have_area=False, seed=seed
|
cast_shadows=True, doors_have_area=False, seed=seed, verbose=True,
|
||||||
) as env:
|
) as env:
|
||||||
|
|
||||||
if modeL_type.__name__ in ["PPO", "A2C"]:
|
if modeL_type.__name__ in ["PPO", "A2C"]:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user