Debugged Item Factory
This commit is contained in:
@ -128,7 +128,7 @@ class BaseFactory(gym.Env):
|
||||
parsed_doors = h.one_hot_level(parsed_level, c.DOOR)
|
||||
if np.any(parsed_doors):
|
||||
door_tiles = [floor.by_pos(pos) for pos in np.argwhere(parsed_doors == c.OCCUPIED_CELL.value)]
|
||||
doors = Doors.from_tiles(door_tiles, self._level_shape, context=floor, is_blocking_light=True)
|
||||
doors = Doors.from_tiles(door_tiles, self._level_shape, context=floor)
|
||||
entities.update({c.DOORS: doors})
|
||||
|
||||
# Actions
|
||||
@ -137,7 +137,8 @@ class BaseFactory(gym.Env):
|
||||
self._actions.register_additional_items(additional_actions)
|
||||
|
||||
# Agents
|
||||
agents = Agents.from_tiles(floor.empty_tiles[:self.n_agents], self._level_shape)
|
||||
agents = Agents.from_tiles(floor.empty_tiles[:self.n_agents], self._level_shape,
|
||||
individual_slices=not self.combin_agent_obs)
|
||||
entities.update({c.AGENT: agents})
|
||||
|
||||
# All entities
|
||||
@ -152,10 +153,12 @@ class BaseFactory(gym.Env):
|
||||
return self._entities
|
||||
|
||||
def _init_obs_cube(self):
|
||||
arrays = self._entities.arrays
|
||||
arrays = self._entities.observable_arrays
|
||||
|
||||
if self.omit_agent_in_obs and self.n_agents == 1:
|
||||
del arrays[c.AGENT]
|
||||
elif self.omit_agent_in_obs:
|
||||
arrays[c.AGENT] = np.delete(arrays[c.AGENT], 0, axis=0)
|
||||
obs_cube_z = sum([a.shape[0] if not self[key].is_per_agent else 1 for key, a in arrays.items()])
|
||||
self._obs_cube = np.zeros((obs_cube_z, *self._level_shape), dtype=np.float32)
|
||||
|
||||
@ -257,7 +260,7 @@ class BaseFactory(gym.Env):
|
||||
return c.NOT_VALID
|
||||
|
||||
def _get_observations(self) -> np.ndarray:
|
||||
state_array_dict = self._entities.arrays
|
||||
state_array_dict = self._entities.obs_arrays
|
||||
if self.n_agents == 1:
|
||||
obs = self._build_per_agent_obs(self[c.AGENT][0], state_array_dict)
|
||||
elif self.n_agents >= 2:
|
||||
@ -268,11 +271,14 @@ class BaseFactory(gym.Env):
|
||||
|
||||
def _build_per_agent_obs(self, agent: Agent, state_array_dict) -> np.ndarray:
|
||||
agent_pos_is_omitted = False
|
||||
agent_omit_idx = None
|
||||
if self.omit_agent_in_obs and self.n_agents == 1:
|
||||
del state_array_dict[c.AGENT]
|
||||
elif self.omit_agent_in_obs and self.combin_agent_obs and self.n_agents > 1:
|
||||
state_array_dict[c.AGENT][0, agent.x, agent.y] -= agent.encoding
|
||||
agent_pos_is_omitted = True
|
||||
elif self.omit_agent_in_obs and not self.combin_agent_obs and self.n_agents > 1:
|
||||
agent_omit_idx = next((i for i, a in enumerate(self[c.AGENT]) if a == agent))
|
||||
|
||||
running_idx, shadowing_idxs, can_be_shadowed_idxs = 0, [], []
|
||||
|
||||
@ -284,8 +290,14 @@ class BaseFactory(gym.Env):
|
||||
z = 1
|
||||
self._obs_cube[running_idx: running_idx+z] = array[per_agent_idx]
|
||||
else:
|
||||
z = array.shape[0]
|
||||
self._obs_cube[running_idx: running_idx+z] = array
|
||||
if key == c.AGENT and agent_omit_idx is not None:
|
||||
z = array.shape[0] - 1
|
||||
for array_idx in range(array.shape[0]):
|
||||
self._obs_cube[running_idx: running_idx+z] = array[[x for x in range(array.shape[0])
|
||||
if x != agent_omit_idx]]
|
||||
else:
|
||||
z = array.shape[0]
|
||||
self._obs_cube[running_idx: running_idx+z] = array
|
||||
# Define which OBS SLices cast a Shadow
|
||||
if self[key].is_blocking_light:
|
||||
for i in range(z):
|
||||
@ -345,9 +357,13 @@ class BaseFactory(gym.Env):
|
||||
else:
|
||||
pass
|
||||
|
||||
# Additional Observation:
|
||||
for additional_obs in self.additional_obs_build():
|
||||
obs[running_idx:running_idx+additional_obs.shape[0]] = additional_obs
|
||||
running_idx += additional_obs.shape[0]
|
||||
for additional_per_agent_obs in self.additional_per_agent_obs_build(agent):
|
||||
obs[running_idx:running_idx + additional_per_agent_obs.shape[0]] = additional_per_agent_obs
|
||||
running_idx += additional_per_agent_obs.shape[0]
|
||||
|
||||
return obs
|
||||
|
||||
@ -522,6 +538,10 @@ class BaseFactory(gym.Env):
|
||||
def additional_obs_build(self) -> List[np.ndarray]:
|
||||
return []
|
||||
|
||||
@abc.abstractmethod
|
||||
def additional_per_agent_obs_build(self, agent) -> List[np.ndarray]:
|
||||
return []
|
||||
|
||||
@abc.abstractmethod
|
||||
def do_additional_reset(self) -> None:
|
||||
pass
|
||||
|
@ -1,3 +1,4 @@
|
||||
from collections import defaultdict
|
||||
from enum import Enum
|
||||
from typing import Union
|
||||
|
||||
@ -9,7 +10,7 @@ import itertools
|
||||
|
||||
class Object:
|
||||
|
||||
_u_idx = 0
|
||||
_u_idx = defaultdict(lambda: 0)
|
||||
|
||||
def __bool__(self):
|
||||
return True
|
||||
@ -40,8 +41,8 @@ class Object:
|
||||
elif self._str_ident is not None and self._enum_ident is None:
|
||||
self._name = f'{self.__class__.__name__}[{self._str_ident}]'
|
||||
elif self._str_ident is None and self._enum_ident is None:
|
||||
self._name = f'{self.__class__.__name__}#{self._u_idx}'
|
||||
Object._u_idx += 1
|
||||
self._name = f'{self.__class__.__name__}#{self._u_idx[self.__class__.__name__]}'
|
||||
Object._u_idx[self.__class__.__name__] += 1
|
||||
else:
|
||||
raise ValueError('Please use either of the idents.')
|
||||
|
||||
|
@ -4,7 +4,7 @@ from typing import List, Union, Dict
|
||||
|
||||
import numpy as np
|
||||
|
||||
from environments.factory.base.objects import Entity, Tile, Agent, Door, Action, Wall
|
||||
from environments.factory.base.objects import Entity, Tile, Agent, Door, Action, Wall, Object
|
||||
from environments.utility_classes import MovementProperties
|
||||
from environments import helpers as h
|
||||
from environments.helpers import Constants as c
|
||||
@ -93,9 +93,13 @@ class EntityObjectRegister(ObjectRegister, ABC):
|
||||
@classmethod
|
||||
def from_tiles(cls, tiles, *args, **kwargs):
|
||||
# objects_name = cls._accepted_objects.__name__
|
||||
register_obj = cls(*args, **kwargs)
|
||||
try:
|
||||
del kwargs['individual_slices']
|
||||
except KeyError:
|
||||
pass
|
||||
entities = [cls._accepted_objects(tile, str_ident=i, **kwargs)
|
||||
for i, tile in enumerate(tiles)]
|
||||
register_obj = cls(*args)
|
||||
register_obj.register_additional_items(entities)
|
||||
return register_obj
|
||||
|
||||
@ -139,10 +143,17 @@ class MovingEntityObjectRegister(EntityObjectRegister, ABC):
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def __delitem__(self, name):
|
||||
idx = next(i for i, entity in enumerate(self) if entity.name == name)
|
||||
del self._register[name]
|
||||
if self.individual_slices:
|
||||
self._array = np.delete(self._array, idx, axis=0)
|
||||
|
||||
def delete_item(self, item):
|
||||
if not isinstance(item, str):
|
||||
item = item.name
|
||||
del self._register[item]
|
||||
self.delete_item_by_name(item.name)
|
||||
|
||||
def delete_item_by_name(self, name):
|
||||
del self[name]
|
||||
|
||||
|
||||
class Entities(Register):
|
||||
@ -150,9 +161,13 @@ class Entities(Register):
|
||||
_accepted_objects = EntityObjectRegister
|
||||
|
||||
@property
|
||||
def arrays(self):
|
||||
def observable_arrays(self):
|
||||
return {key: val.as_array() for key, val in self.items() if val.is_observable}
|
||||
|
||||
@property
|
||||
def obs_arrays(self):
|
||||
return {key: val.as_array() for key, val in self.items() if val.is_observable and not val.hide_from_obs_builder}
|
||||
|
||||
@property
|
||||
def names(self):
|
||||
return list(self._register.keys())
|
||||
|
Reference in New Issue
Block a user