Debugged Item Factory

This commit is contained in:
Steffen Illium
2021-09-08 11:06:47 +02:00
parent 50c0d90c77
commit b09055d95d
6 changed files with 91 additions and 29 deletions

View File

@ -128,7 +128,7 @@ class BaseFactory(gym.Env):
parsed_doors = h.one_hot_level(parsed_level, c.DOOR)
if np.any(parsed_doors):
door_tiles = [floor.by_pos(pos) for pos in np.argwhere(parsed_doors == c.OCCUPIED_CELL.value)]
doors = Doors.from_tiles(door_tiles, self._level_shape, context=floor, is_blocking_light=True)
doors = Doors.from_tiles(door_tiles, self._level_shape, context=floor)
entities.update({c.DOORS: doors})
# Actions
@ -137,7 +137,8 @@ class BaseFactory(gym.Env):
self._actions.register_additional_items(additional_actions)
# Agents
agents = Agents.from_tiles(floor.empty_tiles[:self.n_agents], self._level_shape)
agents = Agents.from_tiles(floor.empty_tiles[:self.n_agents], self._level_shape,
individual_slices=not self.combin_agent_obs)
entities.update({c.AGENT: agents})
# All entities
@ -152,10 +153,12 @@ class BaseFactory(gym.Env):
return self._entities
def _init_obs_cube(self):
arrays = self._entities.arrays
arrays = self._entities.observable_arrays
if self.omit_agent_in_obs and self.n_agents == 1:
del arrays[c.AGENT]
elif self.omit_agent_in_obs:
arrays[c.AGENT] = np.delete(arrays[c.AGENT], 0, axis=0)
obs_cube_z = sum([a.shape[0] if not self[key].is_per_agent else 1 for key, a in arrays.items()])
self._obs_cube = np.zeros((obs_cube_z, *self._level_shape), dtype=np.float32)
@ -257,7 +260,7 @@ class BaseFactory(gym.Env):
return c.NOT_VALID
def _get_observations(self) -> np.ndarray:
state_array_dict = self._entities.arrays
state_array_dict = self._entities.obs_arrays
if self.n_agents == 1:
obs = self._build_per_agent_obs(self[c.AGENT][0], state_array_dict)
elif self.n_agents >= 2:
@ -268,11 +271,14 @@ class BaseFactory(gym.Env):
def _build_per_agent_obs(self, agent: Agent, state_array_dict) -> np.ndarray:
agent_pos_is_omitted = False
agent_omit_idx = None
if self.omit_agent_in_obs and self.n_agents == 1:
del state_array_dict[c.AGENT]
elif self.omit_agent_in_obs and self.combin_agent_obs and self.n_agents > 1:
state_array_dict[c.AGENT][0, agent.x, agent.y] -= agent.encoding
agent_pos_is_omitted = True
elif self.omit_agent_in_obs and not self.combin_agent_obs and self.n_agents > 1:
agent_omit_idx = next((i for i, a in enumerate(self[c.AGENT]) if a == agent))
running_idx, shadowing_idxs, can_be_shadowed_idxs = 0, [], []
@ -284,8 +290,14 @@ class BaseFactory(gym.Env):
z = 1
self._obs_cube[running_idx: running_idx+z] = array[per_agent_idx]
else:
z = array.shape[0]
self._obs_cube[running_idx: running_idx+z] = array
if key == c.AGENT and agent_omit_idx is not None:
z = array.shape[0] - 1
for array_idx in range(array.shape[0]):
self._obs_cube[running_idx: running_idx+z] = array[[x for x in range(array.shape[0])
if x != agent_omit_idx]]
else:
z = array.shape[0]
self._obs_cube[running_idx: running_idx+z] = array
# Define which OBS SLices cast a Shadow
if self[key].is_blocking_light:
for i in range(z):
@ -345,9 +357,13 @@ class BaseFactory(gym.Env):
else:
pass
# Additional Observation:
for additional_obs in self.additional_obs_build():
obs[running_idx:running_idx+additional_obs.shape[0]] = additional_obs
running_idx += additional_obs.shape[0]
for additional_per_agent_obs in self.additional_per_agent_obs_build(agent):
obs[running_idx:running_idx + additional_per_agent_obs.shape[0]] = additional_per_agent_obs
running_idx += additional_per_agent_obs.shape[0]
return obs
@ -522,6 +538,10 @@ class BaseFactory(gym.Env):
def additional_obs_build(self) -> List[np.ndarray]:
return []
@abc.abstractmethod
def additional_per_agent_obs_build(self, agent) -> List[np.ndarray]:
return []
@abc.abstractmethod
def do_additional_reset(self) -> None:
pass

View File

@ -1,3 +1,4 @@
from collections import defaultdict
from enum import Enum
from typing import Union
@ -9,7 +10,7 @@ import itertools
class Object:
_u_idx = 0
_u_idx = defaultdict(lambda: 0)
def __bool__(self):
return True
@ -40,8 +41,8 @@ class Object:
elif self._str_ident is not None and self._enum_ident is None:
self._name = f'{self.__class__.__name__}[{self._str_ident}]'
elif self._str_ident is None and self._enum_ident is None:
self._name = f'{self.__class__.__name__}#{self._u_idx}'
Object._u_idx += 1
self._name = f'{self.__class__.__name__}#{self._u_idx[self.__class__.__name__]}'
Object._u_idx[self.__class__.__name__] += 1
else:
raise ValueError('Please use either of the idents.')

View File

@ -4,7 +4,7 @@ from typing import List, Union, Dict
import numpy as np
from environments.factory.base.objects import Entity, Tile, Agent, Door, Action, Wall
from environments.factory.base.objects import Entity, Tile, Agent, Door, Action, Wall, Object
from environments.utility_classes import MovementProperties
from environments import helpers as h
from environments.helpers import Constants as c
@ -93,9 +93,13 @@ class EntityObjectRegister(ObjectRegister, ABC):
@classmethod
def from_tiles(cls, tiles, *args, **kwargs):
# objects_name = cls._accepted_objects.__name__
register_obj = cls(*args, **kwargs)
try:
del kwargs['individual_slices']
except KeyError:
pass
entities = [cls._accepted_objects(tile, str_ident=i, **kwargs)
for i, tile in enumerate(tiles)]
register_obj = cls(*args)
register_obj.register_additional_items(entities)
return register_obj
@ -139,10 +143,17 @@ class MovingEntityObjectRegister(EntityObjectRegister, ABC):
except StopIteration:
return None
def __delitem__(self, name):
idx = next(i for i, entity in enumerate(self) if entity.name == name)
del self._register[name]
if self.individual_slices:
self._array = np.delete(self._array, idx, axis=0)
def delete_item(self, item):
if not isinstance(item, str):
item = item.name
del self._register[item]
self.delete_item_by_name(item.name)
def delete_item_by_name(self, name):
del self[name]
class Entities(Register):
@ -150,9 +161,13 @@ class Entities(Register):
_accepted_objects = EntityObjectRegister
@property
def arrays(self):
def observable_arrays(self):
return {key: val.as_array() for key, val in self.items() if val.is_observable}
@property
def obs_arrays(self):
return {key: val.as_array() for key, val in self.items() if val.is_observable and not val.hide_from_obs_builder}
@property
def names(self):
return list(self._register.keys())