recoder adaption

This commit is contained in:
Steffen Illium
2021-10-04 17:53:19 +02:00
parent 4c21a0af7c
commit 696e520862
21 changed files with 665 additions and 380 deletions

View File

@ -13,8 +13,8 @@ from environments.factory.base.shadow_casting import Map
from environments.factory.renderer import Renderer, RenderEntity
from environments.helpers import Constants as c, Constants
from environments import helpers as h
from environments.factory.base.objects import Agent, Tile, Action, Wall
from environments.factory.base.registers import Actions, Entities, Agents, Doors, FloorTiles, WallTiles
from environments.factory.base.objects import Agent, Tile, Action
from environments.factory.base.registers import Actions, Entities, Agents, Doors, FloorTiles, WallTiles, PlaceHolders
from environments.utility_classes import MovementProperties
import simplejson
@ -58,7 +58,7 @@ class BaseFactory(gym.Env):
def __init__(self, level_name='simple', n_agents=1, max_steps=int(5e2), pomdp_r: Union[None, int] = 0,
movement_properties: MovementProperties = MovementProperties(), parse_doors=False,
combin_agent_obs: bool = False, frames_to_stack=0, record_episodes=False,
omit_agent_in_obs=False, done_at_collision=False, cast_shadows=True,
omit_agent_in_obs=False, done_at_collision=False, cast_shadows=True, additional_agent_placeholder=None,
verbose=False, doors_have_area=True, env_seed=time.time_ns(), **kwargs):
assert frames_to_stack != 1 and frames_to_stack >= 0, "'frames_to_stack' cannot be negative or 1."
if kwargs:
@ -74,6 +74,7 @@ class BaseFactory(gym.Env):
self.level_name = level_name
self._level_shape = None
self.verbose = verbose
self.additional_agent_placeholder = additional_agent_placeholder
self._renderer = None # expensive - don't use it when not required !
self._entities = Entities()
@ -141,6 +142,14 @@ class BaseFactory(gym.Env):
individual_slices=not self.combin_agent_obs)
entities.update({c.AGENT: agents})
if self.additional_agent_placeholder is not None:
# Empty Observations with either [0, 1, N(0, 1)]
placeholder = PlaceHolders.from_tiles([self._NO_POS_TILE], self._level_shape,
fill_value=self.additional_agent_placeholder)
entities.update({c.AGENT_PLACEHOLDER: placeholder})
# All entities
self._entities = Entities()
self._entities.register_additional_items(entities)
@ -155,10 +164,12 @@ class BaseFactory(gym.Env):
def _init_obs_cube(self):
arrays = self._entities.observable_arrays
# FIXME: Move logic to Register
if self.omit_agent_in_obs and self.n_agents == 1:
del arrays[c.AGENT]
elif self.omit_agent_in_obs:
arrays[c.AGENT] = np.delete(arrays[c.AGENT], 0, axis=0)
# This does not seem to be necesarry, because this case is allready handled by the Agent Register Class
# elif self.omit_agent_in_obs:
# arrays[c.AGENT] = np.delete(arrays[c.AGENT], 0, axis=0)
obs_cube_z = sum([a.shape[0] if not self[key].is_per_agent else 1 for key, a in arrays.items()])
self._obs_cube = np.zeros((obs_cube_z, *self._level_shape), dtype=np.float32)
@ -273,6 +284,7 @@ class BaseFactory(gym.Env):
agent_pos_is_omitted = False
agent_omit_idx = None
if self.omit_agent_in_obs and self.n_agents == 1:
# There is only a single agent and we want to omit the agent obs, so just remove the array.
del state_array_dict[c.AGENT]
elif self.omit_agent_in_obs and self.combin_agent_obs and self.n_agents > 1:
state_array_dict[c.AGENT][0, agent.x, agent.y] -= agent.encoding
@ -295,6 +307,9 @@ class BaseFactory(gym.Env):
for array_idx in range(array.shape[0]):
self._obs_cube[running_idx: running_idx+z] = array[[x for x in range(array.shape[0])
if x != agent_omit_idx]]
elif key == c.AGENT and self.omit_agent_in_obs and self.combin_agent_obs:
z = 1
self._obs_cube[running_idx: running_idx + z] = array
else:
z = array.shape[0]
self._obs_cube[running_idx: running_idx+z] = array
@ -499,12 +514,8 @@ class BaseFactory(gym.Env):
def _summarize_state(self):
summary = {f'{REC_TAC}step': self._steps}
if self._steps == 0:
summary.update({f'{REC_TAC}{self[c.WALLS].name}': {self[c.WALLS].summarize_states()},
'FactoryName': self.__class__.__name__})
for entity_group in self._entities:
if not isinstance(entity_group, WallTiles):
summary.update({f'{REC_TAC}{entity_group.name}': entity_group.summarize_states()})
summary.update({f'{REC_TAC}{entity_group.name}': entity_group.summarize_states(n_steps=self._steps)})
return summary
def print(self, string):

View File

@ -93,11 +93,11 @@ class Entity(Object):
return self._tile
def __init__(self, tile, **kwargs):
super(Entity, self).__init__(**kwargs)
super().__init__(**kwargs)
self._tile = tile
tile.enter(self)
def summarize_state(self) -> dict:
def summarize_state(self, **_) -> dict:
return dict(name=str(self.name), x=int(self.x), y=int(self.y),
tile=str(self.tile.name), can_collide=bool(self.can_collide))
@ -125,7 +125,7 @@ class MoveableEntity(Entity):
return last_x-curr_x, last_y-curr_y
def __init__(self, *args, **kwargs):
super(MoveableEntity, self).__init__(*args, **kwargs)
super().__init__(*args, **kwargs)
self._last_tile = None
def move(self, next_tile):
@ -143,11 +143,34 @@ class MoveableEntity(Entity):
class Action(Object):
def __init__(self, *args, **kwargs):
super(Action, self).__init__(*args, **kwargs)
super().__init__(*args, **kwargs)
class PlaceHolder(MoveableEntity):
pass
def __init__(self, *args, fill_value=0, **kwargs):
super().__init__(*args, **kwargs)
self._fill_value = fill_value
@property
def last_tile(self):
return self.tile
@property
def direction_of_view(self):
return self.pos
@property
def can_collide(self):
return False
@property
def encoding(self):
return c.NO_POS.value[0]
@property
def name(self):
return "PlaceHolder"
class Tile(Object):
@ -203,8 +226,8 @@ class Tile(Object):
def __repr__(self):
return f'{self.name}(@{self.pos})'
def summarize_state(self):
return dict(name=self.name, x=self.x, y=self.y)
def summarize_state(self, **_):
return dict(name=self.name, x=int(self.x), y=int(self.y))
class Wall(Tile):
@ -254,8 +277,8 @@ class Door(Entity):
if not closed_on_init:
self._open()
def summarize_state(self):
state_dict = super().summarize_state()
def summarize_state(self, **kwargs):
state_dict = super().summarize_state(**kwargs)
state_dict.update(state=str(self.str_state), time_to_close=int(self.time_to_close))
return state_dict
@ -315,7 +338,7 @@ class Agent(MoveableEntity):
self.temp_action = None
self.temp_light_map = None
def summarize_state(self):
state_dict = super().summarize_state()
def summarize_state(self, **kwargs):
state_dict = super().summarize_state(**kwargs)
state_dict.update(valid=bool(self.temp_valid), action=str(self.temp_action))
return state_dict

View File

@ -81,8 +81,8 @@ class ObjectRegister(Register):
if self.individual_slices:
self._array = np.concatenate((self._array, np.zeros((1, *self._array.shape[1:]))))
def summarize_states(self):
return [val.summarize_state() for val in self.values()]
def summarize_states(self, n_steps=None):
return [val.summarize_state(n_steps=n_steps) for val in self.values()]
class EntityObjectRegister(ObjectRegister, ABC):
@ -156,23 +156,25 @@ class MovingEntityObjectRegister(EntityObjectRegister, ABC):
del self[name]
class PlaceHolderRegister(MovingEntityObjectRegister):
class PlaceHolders(MovingEntityObjectRegister):
_accepted_objects = PlaceHolder
def __init__(self, *args, fill_value: Union[str, int] = 0, **kwargs):
super().__init__(*args, **kwargs)
self.fill_value = fill_value
# noinspection DuplicatedCode
def as_array(self):
self._array[:] = c.FREE_CELL.value
# noinspection PyTupleAssignmentBalance
for z, x, y, v in zip(range(len(self)), *zip(*[x.pos for x in self]), [x.encoding for x in self]):
if self.individual_slices:
self._array[z, x, y] += v
else:
self._array[0, x, y] += v
if isinstance(self.fill_value, int):
self._array[:] = self.fill_value
elif self.fill_value == "normal":
self._array = np.random.normal(size=self._array.shape)
if self.individual_slices:
return self._array
else:
return self._array.sum(axis=0, keepdims=True)
return self._array[None, 0]
class Entities(Register):
@ -243,6 +245,12 @@ class WallTiles(EntityObjectRegister):
def from_tiles(cls, tiles, *args, **kwargs):
raise RuntimeError()
def summarize_states(self, n_steps=None):
if n_steps == h.STEPS_START:
return super(WallTiles, self).summarize_states(n_steps=n_steps)
else:
return {}
class FloorTiles(WallTiles):
@ -272,6 +280,10 @@ class FloorTiles(WallTiles):
def from_tiles(cls, tiles, *args, **kwargs):
raise RuntimeError()
def summarize_states(self, n_steps=None):
# Do not summarize
return {}
class Agents(MovingEntityObjectRegister):