Rework of Observations and Entity Differentiation, lazy obs build by notification
This commit is contained in:
@ -1,7 +1,9 @@
|
||||
import abc
|
||||
import enum
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from enum import Enum
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
from typing import List, Union, Iterable, Dict
|
||||
import numpy as np
|
||||
@ -14,7 +16,8 @@ from environments.factory.base.shadow_casting import Map
|
||||
from environments.helpers import Constants as c, Constants
|
||||
from environments import helpers as h
|
||||
from environments.factory.base.objects import Agent, Tile, Action
|
||||
from environments.factory.base.registers import Actions, Entities, Agents, Doors, FloorTiles, WallTiles, PlaceHolders
|
||||
from environments.factory.base.registers import Actions, Entities, Agents, Doors, FloorTiles, WallTiles, PlaceHolders, \
|
||||
GlobalPositions
|
||||
from environments.utility_classes import MovementProperties, ObservationProperties, MarlFrameStack
|
||||
from environments.utility_classes import AgentRenderOptions as a_obs
|
||||
|
||||
@ -30,17 +33,31 @@ class BaseFactory(gym.Env):
|
||||
def action_space(self):
|
||||
return spaces.Discrete(len(self._actions))
|
||||
|
||||
@property
|
||||
def named_action_space(self):
|
||||
return {x.identifier.value: idx for idx, x in enumerate(self._actions.values())}
|
||||
|
||||
@property
|
||||
def observation_space(self):
|
||||
if r := self._pomdp_r:
|
||||
z = self._obs_cube.shape[0]
|
||||
xy = r*2 + 1
|
||||
level_shape = (z, xy, xy)
|
||||
obs, _ = self._build_observations()
|
||||
if self.n_agents > 1:
|
||||
shape = obs[0].shape
|
||||
else:
|
||||
level_shape = self._obs_cube.shape
|
||||
space = spaces.Box(low=0, high=1, shape=level_shape, dtype=np.float32)
|
||||
shape = obs.shape
|
||||
space = spaces.Box(low=0, high=1, shape=shape, dtype=np.float32)
|
||||
return space
|
||||
|
||||
@property
|
||||
def named_observation_space(self):
|
||||
# Build it
|
||||
_, named_obs = self._build_observations()
|
||||
if self.n_agents > 1:
|
||||
# Only return the first named obs space, as their structure at the moment is same.
|
||||
return [{key.name: val for key, val in named_ob.items()} for named_ob in named_obs.values()][0]
|
||||
else:
|
||||
return {key.name: val for key, val in named_obs.items()}
|
||||
|
||||
|
||||
@property
|
||||
def pomdp_diameter(self):
|
||||
return self._pomdp_r * 2 + 1
|
||||
@ -86,11 +103,14 @@ class BaseFactory(gym.Env):
|
||||
self.obs_prop = obs_prop
|
||||
self.level_name = level_name
|
||||
self._level_shape = None
|
||||
self._obs_shape = None
|
||||
self.verbose = verbose
|
||||
self._renderer = None # expensive - don't use it when not required !
|
||||
self._entities = Entities()
|
||||
|
||||
self.n_agents = n_agents
|
||||
level_filepath = Path(__file__).parent.parent / h.LEVELS_DIR / f'{self.level_name}.txt'
|
||||
self._parsed_level = h.parse_level(level_filepath)
|
||||
|
||||
self.max_steps = max_steps
|
||||
self._pomdp_r = self.obs_prop.pomdp_r
|
||||
@ -114,10 +134,12 @@ class BaseFactory(gym.Env):
|
||||
# Objects
|
||||
self._entities = Entities()
|
||||
# Level
|
||||
level_filepath = Path(__file__).parent.parent / h.LEVELS_DIR / f'{self.level_name}.txt'
|
||||
parsed_level = h.parse_level(level_filepath)
|
||||
level_array = h.one_hot_level(parsed_level)
|
||||
|
||||
level_array = h.one_hot_level(self._parsed_level)
|
||||
level_array = np.pad(level_array, self.obs_prop.pomdp_r, 'constant', constant_values=1)
|
||||
|
||||
self._level_shape = level_array.shape
|
||||
self._obs_shape = self._level_shape if not self.obs_prop.pomdp_r else (self.pomdp_diameter, ) * 2
|
||||
|
||||
# Walls
|
||||
walls = WallTiles.from_argwhere_coordinates(
|
||||
@ -134,13 +156,14 @@ class BaseFactory(gym.Env):
|
||||
self._entities.register_additional_items({c.FLOOR: floor})
|
||||
|
||||
# NOPOS
|
||||
self._NO_POS_TILE = Tile(c.NO_POS.value)
|
||||
self._NO_POS_TILE = Tile(c.NO_POS.value, None)
|
||||
|
||||
# Doors
|
||||
if self.parse_doors:
|
||||
parsed_doors = h.one_hot_level(parsed_level, c.DOOR)
|
||||
parsed_doors = h.one_hot_level(self._parsed_level, c.DOOR)
|
||||
parsed_doors = np.pad(parsed_doors, self.obs_prop.pomdp_r, 'constant', constant_values=0)
|
||||
if np.any(parsed_doors):
|
||||
door_tiles = [floor.by_pos(pos) for pos in np.argwhere(parsed_doors == c.OCCUPIED_CELL.value)]
|
||||
door_tiles = [floor.by_pos(tuple(pos)) for pos in np.argwhere(parsed_doors == c.OCCUPIED_CELL.value)]
|
||||
doors = Doors.from_tiles(door_tiles, self._level_shape,
|
||||
entity_kwargs=dict(context=floor)
|
||||
)
|
||||
@ -153,12 +176,11 @@ class BaseFactory(gym.Env):
|
||||
|
||||
# Agents
|
||||
agents_to_spawn = self.n_agents-len(self._injected_agents)
|
||||
agents_kwargs = dict(level_shape=self._level_shape,
|
||||
individual_slices=self.obs_prop.render_agents == a_obs.SEPERATE,
|
||||
hide_from_obs_builder=self.obs_prop.render_agents == a_obs.LEVEL,
|
||||
is_observable=self.obs_prop.render_agents != a_obs.NOT)
|
||||
agents_kwargs = dict(individual_slices=self.obs_prop.render_agents == a_obs.SEPERATE,
|
||||
hide_from_obs_builder=self.obs_prop.render_agents in [a_obs.NOT, a_obs.LEVEL],
|
||||
)
|
||||
if agents_to_spawn:
|
||||
agents = Agents.from_tiles(floor.empty_tiles[:agents_to_spawn], **agents_kwargs)
|
||||
agents = Agents.from_tiles(floor.empty_tiles[:agents_to_spawn], self._level_shape, **agents_kwargs)
|
||||
else:
|
||||
agents = Agents(**agents_kwargs)
|
||||
if self._injected_agents:
|
||||
@ -173,10 +195,10 @@ class BaseFactory(gym.Env):
|
||||
# TODO: Make this accept Lists for multiple placeholders
|
||||
|
||||
# Empty Observations with either [0, 1, N(0, 1)]
|
||||
placeholder = PlaceHolders.from_tiles([self._NO_POS_TILE], self._level_shape,
|
||||
entity_kwargs=dict(
|
||||
fill_value=self.obs_prop.additional_agent_placeholder)
|
||||
)
|
||||
placeholder = PlaceHolders.from_values(self.obs_prop.additional_agent_placeholder, self._level_shape,
|
||||
entity_kwargs=dict(
|
||||
fill_value=self.obs_prop.additional_agent_placeholder)
|
||||
)
|
||||
|
||||
self._entities.register_additional_items({c.AGENT_PLACEHOLDER: placeholder})
|
||||
|
||||
@ -184,24 +206,22 @@ class BaseFactory(gym.Env):
|
||||
if additional_entities := self.additional_entities:
|
||||
self._entities.register_additional_items(additional_entities)
|
||||
|
||||
if self.obs_prop.show_global_position_info:
|
||||
global_positions = GlobalPositions(self._level_shape)
|
||||
obs_shape_2d = self._level_shape if not self._pomdp_r else ((self.pomdp_diameter,) * 2)
|
||||
global_positions.spawn_GlobalPositionObjects(obs_shape_2d, self[c.AGENT])
|
||||
self._entities.register_additional_items({c.GLOBAL_POSITION: global_positions})
|
||||
|
||||
# Return
|
||||
return self._entities
|
||||
|
||||
def _init_obs_cube(self):
|
||||
arrays = self._entities.obs_arrays
|
||||
|
||||
obs_cube_z = sum([a.shape[0] if not self[key].is_per_agent else 1 for key, a in arrays.items()])
|
||||
obs_cube_z += 1 if self.obs_prop.show_global_position_info else 0
|
||||
self._obs_cube = np.zeros((obs_cube_z, *self._level_shape), dtype=np.float32)
|
||||
|
||||
def reset(self) -> (np.ndarray, int, bool, dict):
|
||||
def reset(self) -> (np.typing.ArrayLike, int, bool, dict):
|
||||
_ = self._base_init_env()
|
||||
self._init_obs_cube()
|
||||
self.do_additional_reset()
|
||||
|
||||
self._steps = 0
|
||||
|
||||
obs = self._get_observations()
|
||||
obs, _ = self._build_observations()
|
||||
return obs
|
||||
|
||||
def step(self, actions):
|
||||
@ -264,7 +284,7 @@ class BaseFactory(gym.Env):
|
||||
# Post step Hook for later use
|
||||
info.update(self.hook_post_step())
|
||||
|
||||
obs = self._get_observations()
|
||||
obs, _ = self._build_observations()
|
||||
|
||||
return obs, reward, done, info
|
||||
|
||||
@ -284,141 +304,120 @@ class BaseFactory(gym.Env):
|
||||
else:
|
||||
return c.NOT_VALID
|
||||
|
||||
def _get_observations(self) -> np.ndarray:
|
||||
state_array_dict = self._entities.obs_arrays
|
||||
if self.n_agents == 1:
|
||||
obs = self._build_per_agent_obs(self[c.AGENT][0], state_array_dict)
|
||||
elif self.n_agents >= 2:
|
||||
obs = np.stack([self._build_per_agent_obs(agent, state_array_dict) for agent in self[c.AGENT]])
|
||||
else:
|
||||
raise ValueError('n_agents cannot be smaller than 1!!')
|
||||
return obs
|
||||
def _build_observations(self) -> np.typing.ArrayLike:
|
||||
# Observation dict:
|
||||
per_agent_expl_idx = dict()
|
||||
per_agent_obsn = dict()
|
||||
# Generel Observations
|
||||
lvl_obs = self[c.WALLS].as_array()
|
||||
door_obs = self[c.DOORS].as_array()
|
||||
agent_obs = self[c.AGENT].as_array() if self.obs_prop.render_agents != a_obs.NOT else None
|
||||
placeholder_obs = self[c.AGENT_PLACEHOLDER].as_array() if self[c.AGENT_PLACEHOLDER] else None
|
||||
add_obs_dict = self._additional_observations()
|
||||
|
||||
def _build_per_agent_obs(self, agent: Agent, state_array_dict) -> np.ndarray:
|
||||
agent_pos_is_omitted = False
|
||||
agent_omit_idx = None
|
||||
|
||||
if self.obs_prop.omit_agent_self and self.n_agents == 1:
|
||||
pass
|
||||
elif self.obs_prop.omit_agent_self and self.obs_prop.render_agents in [a_obs.COMBINED, ] and self.n_agents > 1:
|
||||
state_array_dict[c.AGENT][0, agent.x, agent.y] -= agent.encoding
|
||||
agent_pos_is_omitted = True
|
||||
elif self.obs_prop.omit_agent_self and self.obs_prop.render_agents == a_obs.SEPERATE and self.n_agents > 1:
|
||||
agent_omit_idx = next((i for i, a in enumerate(self[c.AGENT]) if a == agent))
|
||||
|
||||
running_idx, shadowing_idxs, can_be_shadowed_idxs = 0, [], []
|
||||
self._obs_cube[:] = 0
|
||||
|
||||
# FIXME: Refactor this! Make a globally build observation, then add individual per-agent-obs
|
||||
for key, array in state_array_dict.items():
|
||||
# Flush state array object representation to obs cube
|
||||
if not self[key].hide_from_obs_builder:
|
||||
if self[key].is_per_agent:
|
||||
per_agent_idx = self[key].idx_by_entity(agent)
|
||||
z = 1
|
||||
self._obs_cube[running_idx: running_idx+z] = array[per_agent_idx]
|
||||
else:
|
||||
if key == c.AGENT and agent_omit_idx is not None:
|
||||
z = array.shape[0] - 1
|
||||
for array_idx in range(array.shape[0]):
|
||||
self._obs_cube[running_idx: running_idx+z] = array[[x for x in range(array.shape[0])
|
||||
if x != agent_omit_idx]]
|
||||
# Agent OBS are combined
|
||||
elif key == c.AGENT and self.obs_prop.omit_agent_self \
|
||||
and self.obs_prop.render_agents == a_obs.COMBINED:
|
||||
z = 1
|
||||
self._obs_cube[running_idx: running_idx + z] = array
|
||||
# Each Agent is rendered on a seperate array slice
|
||||
for agent_idx, agent in enumerate(self[c.AGENT]):
|
||||
obs_dict = dict()
|
||||
# Build Agent Observations
|
||||
if self.obs_prop.render_agents != a_obs.NOT:
|
||||
if self.obs_prop.omit_agent_self:
|
||||
if self.obs_prop.render_agents == a_obs.SEPERATE:
|
||||
agent_obs = np.take(agent_obs, [x for x in range(self.n_agents) if x != agent_idx], axis=0)
|
||||
else:
|
||||
z = array.shape[0]
|
||||
self._obs_cube[running_idx: running_idx + z] = array
|
||||
# Define which OBS SLices cast a Shadow
|
||||
if self[key].is_blocking_light:
|
||||
for i in range(z):
|
||||
shadowing_idxs.append(running_idx + i)
|
||||
# Define which OBS SLices are effected by shadows
|
||||
if self[key].can_be_shadowed:
|
||||
for i in range(z):
|
||||
can_be_shadowed_idxs.append(running_idx + i)
|
||||
running_idx += z
|
||||
agent_obs = agent_obs.copy()
|
||||
agent_obs[(0, *agent.pos)] -= agent.encoding
|
||||
|
||||
if agent_pos_is_omitted:
|
||||
state_array_dict[c.AGENT][0, agent.x, agent.y] += agent.encoding
|
||||
|
||||
if self._pomdp_r:
|
||||
obs = self._do_pomdp_obs_cutout(agent, self._obs_cube)
|
||||
else:
|
||||
obs = self._obs_cube
|
||||
|
||||
obs = obs.copy()
|
||||
|
||||
if self.obs_prop.cast_shadows:
|
||||
obs_block_light = [obs[idx] != c.OCCUPIED_CELL.value for idx in shadowing_idxs]
|
||||
door_shadowing = False
|
||||
if self.parse_doors:
|
||||
if doors := self[c.DOORS]:
|
||||
if door := doors.by_pos(agent.pos):
|
||||
if door.is_closed:
|
||||
for group in door.connectivity_subgroups:
|
||||
if agent.last_pos not in group:
|
||||
door_shadowing = True
|
||||
if self._pomdp_r:
|
||||
blocking = [tuple(np.subtract(x, agent.pos) + (self._pomdp_r, self._pomdp_r))
|
||||
for x in group]
|
||||
xs, ys = zip(*blocking)
|
||||
else:
|
||||
xs, ys = zip(*group)
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
obs_block_light[0][xs, ys] = False
|
||||
|
||||
light_block_map = Map((np.prod(obs_block_light, axis=0) != True).astype(int))
|
||||
if self._pomdp_r:
|
||||
light_block_map = light_block_map.do_fov(self._pomdp_r, self._pomdp_r, max(self._level_shape))
|
||||
else:
|
||||
light_block_map = light_block_map.do_fov(*agent.pos, max(self._level_shape))
|
||||
if door_shadowing:
|
||||
# noinspection PyUnboundLocalVariable
|
||||
light_block_map[xs, ys] = 0
|
||||
agent.temp_light_map = light_block_map
|
||||
for obs_idx in can_be_shadowed_idxs:
|
||||
obs[obs_idx] = ((obs[obs_idx] * light_block_map) + 0.) - (1 - light_block_map) # * obs[0])
|
||||
else:
|
||||
pass
|
||||
|
||||
# Agents observe other agents as wall
|
||||
if self.obs_prop.render_agents == a_obs.LEVEL and self.n_agents > 1:
|
||||
other_agent_obs = self[c.AGENT].as_array()
|
||||
if self.obs_prop.omit_agent_self:
|
||||
other_agent_obs[:, agent.x, agent.y] -= agent.encoding
|
||||
# Build Level Observations
|
||||
if self.obs_prop.render_agents == a_obs.LEVEL:
|
||||
lvl_obs = lvl_obs.copy()
|
||||
lvl_obs += agent_obs
|
||||
|
||||
obs_dict[c.WALLS] = lvl_obs
|
||||
if self.obs_prop.render_agents in [a_obs.SEPERATE, a_obs.COMBINED]:
|
||||
obs_dict[c.AGENT] = agent_obs
|
||||
if self[c.AGENT_PLACEHOLDER]:
|
||||
obs_dict[c.AGENT_PLACEHOLDER] = placeholder_obs
|
||||
obs_dict[c.DOORS] = door_obs
|
||||
obs_dict.update(add_obs_dict)
|
||||
observations = np.vstack(list(obs_dict.values()))
|
||||
if self.obs_prop.pomdp_r:
|
||||
oobs = self._do_pomdp_obs_cutout(agent, other_agent_obs)[0]
|
||||
# noinspection PyUnresolvedReferences
|
||||
mask = (oobs != c.SHADOWED_CELL.value).astype(int)
|
||||
obs[0] += oobs * mask
|
||||
observations = self._do_pomdp_cutout(agent, observations)
|
||||
|
||||
raw_obs = self._additional_raw_observations(agent)
|
||||
observations = np.vstack((observations, *list(raw_obs.values())))
|
||||
|
||||
keys = list(chain(obs_dict.keys(), raw_obs.keys()))
|
||||
idxs = np.cumsum([x.shape[0] for x in chain(obs_dict.values(), raw_obs.values())]) - 1
|
||||
per_agent_expl_idx[agent.name] = {key: list(range(a, b)) for key, a, b in
|
||||
zip(keys, idxs, list(idxs[1:]) + [idxs[-1]+1, ])}
|
||||
|
||||
# Shadow Casting
|
||||
try:
|
||||
light_block_obs = [obs_idx for key, obs_idx in per_agent_expl_idx[agent.name].items()
|
||||
if self[key].is_blocking_light]
|
||||
# Flatten
|
||||
light_block_obs = [x for y in light_block_obs for x in y]
|
||||
shadowed_obs = [obs_idx for key, obs_idx in per_agent_expl_idx[agent.name].items()
|
||||
if self[key].can_be_shadowed]
|
||||
# Flatten
|
||||
shadowed_obs = [x for y in shadowed_obs for x in y]
|
||||
except AttributeError as e:
|
||||
print('Check your Keys! Only use Constants as Keys!')
|
||||
print(e)
|
||||
raise e
|
||||
if self.obs_prop.cast_shadows:
|
||||
obs_block_light = observations[light_block_obs] != c.OCCUPIED_CELL.value
|
||||
door_shadowing = False
|
||||
if self.parse_doors:
|
||||
if doors := self[c.DOORS]:
|
||||
if door := doors.by_pos(agent.pos):
|
||||
if door.is_closed:
|
||||
for group in door.connectivity_subgroups:
|
||||
if agent.last_pos not in group:
|
||||
door_shadowing = True
|
||||
if self._pomdp_r:
|
||||
blocking = [
|
||||
tuple(np.subtract(x, agent.pos) + (self._pomdp_r, self._pomdp_r))
|
||||
for x in group]
|
||||
xs, ys = zip(*blocking)
|
||||
else:
|
||||
xs, ys = zip(*group)
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
obs_block_light[:, xs, ys] = False
|
||||
|
||||
light_block_map = Map((np.prod(obs_block_light, axis=0) != True).astype(int).squeeze())
|
||||
if self._pomdp_r:
|
||||
light_block_map = light_block_map.do_fov(self._pomdp_r, self._pomdp_r, max(self._level_shape))
|
||||
else:
|
||||
light_block_map = light_block_map.do_fov(*agent.pos, max(self._level_shape))
|
||||
if door_shadowing:
|
||||
# noinspection PyUnboundLocalVariable
|
||||
light_block_map[xs, ys] = 0
|
||||
agent.temp_light_map = light_block_map.copy()
|
||||
|
||||
observations[shadowed_obs] = ((observations[shadowed_obs] * light_block_map) + 0.) - (1 - light_block_map)
|
||||
else:
|
||||
obs[0] += other_agent_obs
|
||||
pass
|
||||
|
||||
# Additional Observation:
|
||||
for additional_obs in self.additional_obs_build():
|
||||
obs[running_idx:running_idx+additional_obs.shape[0]] = additional_obs
|
||||
running_idx += additional_obs.shape[0]
|
||||
for additional_per_agent_obs in self.additional_per_agent_obs_build(agent):
|
||||
obs[running_idx:running_idx + additional_per_agent_obs.shape[0]] = additional_per_agent_obs
|
||||
running_idx += additional_per_agent_obs.shape[0]
|
||||
per_agent_obsn[agent.name] = observations
|
||||
|
||||
return obs
|
||||
if self.n_agents == 1:
|
||||
agent_name = self[c.AGENT][0].name
|
||||
obs, explained_idx = per_agent_obsn[agent_name], per_agent_expl_idx[agent_name]
|
||||
elif self.n_agents >= 2:
|
||||
obs, explained_idx = np.stack(list(per_agent_obsn.values())), per_agent_expl_idx
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
def _do_pomdp_obs_cutout(self, agent, obs_to_be_padded):
|
||||
return obs, explained_idx
|
||||
|
||||
def _do_pomdp_cutout(self, agent, obs_to_be_padded):
|
||||
assert obs_to_be_padded.ndim == 3
|
||||
r, d = self._pomdp_r, self.pomdp_diameter
|
||||
x0, x1 = max(0, agent.x - r), min(agent.x + r + 1, self._level_shape[0])
|
||||
y0, y1 = max(0, agent.y - r), min(agent.y + r + 1, self._level_shape[1])
|
||||
# Other Agent Obs = oobs
|
||||
oobs = obs_to_be_padded[:, x0:x1, y0:y1]
|
||||
if oobs.shape[0:] != (d, d):
|
||||
if oobs.shape[1:] != (d, d):
|
||||
if xd := oobs.shape[1] % d:
|
||||
if agent.x > r:
|
||||
x0_pad = 0
|
||||
@ -478,7 +477,7 @@ class BaseFactory(gym.Env):
|
||||
if doors := self[c.DOORS]:
|
||||
if self.doors_have_area:
|
||||
if door := doors.by_pos(new_tile.pos):
|
||||
if door.can_collide:
|
||||
if door.is_open:
|
||||
return agent.tile, c.NOT_VALID
|
||||
else: # door.is_closed:
|
||||
pass
|
||||
@ -569,7 +568,7 @@ class BaseFactory(gym.Env):
|
||||
if not self._renderer: # lazy init
|
||||
from environments.factory.base.renderer import Renderer, RenderEntity
|
||||
global Renderer, RenderEntity
|
||||
height, width = self._obs_cube.shape[1:]
|
||||
height, width = self._level_shape
|
||||
self._renderer = Renderer(width, height, view_radius=self._pomdp_r, fps=5)
|
||||
|
||||
# noinspection PyUnboundLocalVariable
|
||||
@ -636,20 +635,6 @@ class BaseFactory(gym.Env):
|
||||
|
||||
# Functions which provide additions to functions of the base class
|
||||
# Always call super!!!!!!
|
||||
@abc.abstractmethod
|
||||
def additional_obs_build(self) -> List[np.ndarray]:
|
||||
return []
|
||||
|
||||
def additional_per_agent_obs_build(self, agent) -> List[np.ndarray]:
|
||||
additional_per_agent_obs = []
|
||||
if self.obs_prop.show_global_position_info:
|
||||
pos_array = np.zeros(self.observation_space.shape[1:])
|
||||
for xy in range(1):
|
||||
pos_array[0, xy] = agent.pos[xy] / self._level_shape[xy]
|
||||
additional_per_agent_obs.append(pos_array)
|
||||
|
||||
return additional_per_agent_obs
|
||||
|
||||
@abc.abstractmethod
|
||||
def do_additional_reset(self) -> None:
|
||||
pass
|
||||
@ -666,6 +651,17 @@ class BaseFactory(gym.Env):
|
||||
def check_additional_done(self) -> bool:
|
||||
return False
|
||||
|
||||
@abc.abstractmethod
|
||||
def _additional_observations(self) -> Dict[Constants, np.typing.ArrayLike]:
|
||||
return {}
|
||||
|
||||
@abc.abstractmethod
|
||||
def _additional_raw_observations(self, agent) -> Dict[Constants, np.typing.ArrayLike]:
|
||||
additional_raw_observations = {}
|
||||
if self.obs_prop.show_global_position_info:
|
||||
additional_raw_observations.update({c.GLOBAL_POSITION: self[c.GLOBAL_POSITION].by_entity(agent).as_array()})
|
||||
return additional_raw_observations
|
||||
|
||||
@abc.abstractmethod
|
||||
def calculate_additional_reward(self, agent: Agent) -> (int, dict):
|
||||
return 0, {}
|
||||
|
Reference in New Issue
Block a user