Item and Dirt Factory Working again
This commit is contained in:
@ -1,8 +1,6 @@
|
||||
import abc
|
||||
import enum
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from enum import Enum
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
from typing import List, Union, Iterable, Dict
|
||||
@ -13,8 +11,8 @@ from gym import spaces
|
||||
from gym.wrappers import FrameStack
|
||||
|
||||
from environments.factory.base.shadow_casting import Map
|
||||
from environments.helpers import Constants as c, Constants
|
||||
from environments import helpers as h
|
||||
from environments.helpers import Constants as c
|
||||
from environments.factory.base.objects import Agent, Tile, Action
|
||||
from environments.factory.base.registers import Actions, Entities, Agents, Doors, FloorTiles, WallTiles, PlaceHolders, \
|
||||
GlobalPositions
|
||||
@ -53,10 +51,9 @@ class BaseFactory(gym.Env):
|
||||
_, named_obs = self._build_observations()
|
||||
if self.n_agents > 1:
|
||||
# Only return the first named obs space, as their structure at the moment is same.
|
||||
return [{key.name: val for key, val in named_ob.items()} for named_ob in named_obs.values()][0]
|
||||
return named_obs[list(named_obs.keys())[0]]
|
||||
else:
|
||||
return {key.name: val for key, val in named_obs.items()}
|
||||
|
||||
return named_obs
|
||||
|
||||
@property
|
||||
def pomdp_diameter(self):
|
||||
@ -143,27 +140,27 @@ class BaseFactory(gym.Env):
|
||||
|
||||
# Walls
|
||||
walls = WallTiles.from_argwhere_coordinates(
|
||||
np.argwhere(level_array == c.OCCUPIED_CELL.value),
|
||||
np.argwhere(level_array == c.OCCUPIED_CELL),
|
||||
self._level_shape
|
||||
)
|
||||
self._entities.register_additional_items({c.WALLS: walls})
|
||||
|
||||
# Floor
|
||||
floor = FloorTiles.from_argwhere_coordinates(
|
||||
np.argwhere(level_array == c.FREE_CELL.value),
|
||||
np.argwhere(level_array == c.FREE_CELL),
|
||||
self._level_shape
|
||||
)
|
||||
self._entities.register_additional_items({c.FLOOR: floor})
|
||||
|
||||
# NOPOS
|
||||
self._NO_POS_TILE = Tile(c.NO_POS.value, None)
|
||||
self._NO_POS_TILE = Tile(c.NO_POS, None)
|
||||
|
||||
# Doors
|
||||
if self.parse_doors:
|
||||
parsed_doors = h.one_hot_level(self._parsed_level, c.DOOR)
|
||||
parsed_doors = np.pad(parsed_doors, self.obs_prop.pomdp_r, 'constant', constant_values=0)
|
||||
if np.any(parsed_doors):
|
||||
door_tiles = [floor.by_pos(tuple(pos)) for pos in np.argwhere(parsed_doors == c.OCCUPIED_CELL.value)]
|
||||
door_tiles = [floor.by_pos(tuple(pos)) for pos in np.argwhere(parsed_doors == c.OCCUPIED_CELL)]
|
||||
doors = Doors.from_tiles(door_tiles, self._level_shape,
|
||||
entity_kwargs=dict(context=floor)
|
||||
)
|
||||
@ -209,7 +206,7 @@ class BaseFactory(gym.Env):
|
||||
if self.obs_prop.show_global_position_info:
|
||||
global_positions = GlobalPositions(self._level_shape)
|
||||
obs_shape_2d = self._level_shape if not self._pomdp_r else ((self.pomdp_diameter,) * 2)
|
||||
global_positions.spawn_GlobalPositionObjects(obs_shape_2d, self[c.AGENT])
|
||||
global_positions.spawn_global_position_objects(obs_shape_2d, self[c.AGENT])
|
||||
self._entities.register_additional_items({c.GLOBAL_POSITION: global_positions})
|
||||
|
||||
# Return
|
||||
@ -239,8 +236,8 @@ class BaseFactory(gym.Env):
|
||||
for action, agent in zip(actions, self[c.AGENT]):
|
||||
agent.clear_temp_state()
|
||||
action_obj = self._actions[int(action)]
|
||||
# self.print(f'Action #{action} has been resolved to: {action_obj}')
|
||||
if h.MovingAction.is_member(action_obj):
|
||||
# cls.print(f'Action #{action} has been resolved to: {action_obj}')
|
||||
if h.EnvActions.is_move(action_obj):
|
||||
valid = self._move_or_colide(agent, action_obj)
|
||||
elif h.EnvActions.NOOP == agent.temp_action:
|
||||
valid = c.VALID
|
||||
@ -338,12 +335,12 @@ class BaseFactory(gym.Env):
|
||||
obs_dict[c.AGENT_PLACEHOLDER] = placeholder_obs
|
||||
obs_dict[c.DOORS] = door_obs
|
||||
obs_dict.update(add_obs_dict)
|
||||
observations = np.vstack(list(obs_dict.values()))
|
||||
obsn = np.vstack(list(obs_dict.values()))
|
||||
if self.obs_prop.pomdp_r:
|
||||
observations = self._do_pomdp_cutout(agent, observations)
|
||||
obsn = self._do_pomdp_cutout(agent, obsn)
|
||||
|
||||
raw_obs = self._additional_raw_observations(agent)
|
||||
observations = np.vstack((observations, *list(raw_obs.values())))
|
||||
raw_obs = self._additional_per_agent_raw_observations(agent)
|
||||
obsn = np.vstack((obsn, *list(raw_obs.values())))
|
||||
|
||||
keys = list(chain(obs_dict.keys(), raw_obs.keys()))
|
||||
idxs = np.cumsum([x.shape[0] for x in chain(obs_dict.values(), raw_obs.values())]) - 1
|
||||
@ -365,7 +362,7 @@ class BaseFactory(gym.Env):
|
||||
print(e)
|
||||
raise e
|
||||
if self.obs_prop.cast_shadows:
|
||||
obs_block_light = observations[light_block_obs] != c.OCCUPIED_CELL.value
|
||||
obs_block_light = obsn[light_block_obs] != c.OCCUPIED_CELL
|
||||
door_shadowing = False
|
||||
if self.parse_doors:
|
||||
if doors := self[c.DOORS]:
|
||||
@ -395,11 +392,11 @@ class BaseFactory(gym.Env):
|
||||
light_block_map[xs, ys] = 0
|
||||
agent.temp_light_map = light_block_map.copy()
|
||||
|
||||
observations[shadowed_obs] = ((observations[shadowed_obs] * light_block_map) + 0.) - (1 - light_block_map)
|
||||
obsn[shadowed_obs] = ((obsn[shadowed_obs] * light_block_map) + 0.) - (1 - light_block_map)
|
||||
else:
|
||||
pass
|
||||
|
||||
per_agent_obsn[agent.name] = observations
|
||||
per_agent_obsn[agent.name] = obsn
|
||||
|
||||
if self.n_agents == 1:
|
||||
agent_name = self[c.AGENT][0].name
|
||||
@ -450,7 +447,7 @@ class BaseFactory(gym.Env):
|
||||
tiles_with_collisions.append(tile)
|
||||
return tiles_with_collisions
|
||||
|
||||
def _move_or_colide(self, agent: Agent, action: Action) -> Constants:
|
||||
def _move_or_colide(self, agent: Agent, action: Action) -> bool:
|
||||
new_tile, valid = self._check_agent_move(agent, action)
|
||||
if valid:
|
||||
# Does not collide width level boundaries
|
||||
@ -624,7 +621,7 @@ class BaseFactory(gym.Env):
|
||||
return []
|
||||
|
||||
@property
|
||||
def additional_entities(self) -> Dict[(Enum, Entities)]:
|
||||
def additional_entities(self) -> Dict[(str, Entities)]:
|
||||
"""
|
||||
When heriting from this Base Class, you musst implement this methode!!!
|
||||
|
||||
@ -652,11 +649,11 @@ class BaseFactory(gym.Env):
|
||||
return False
|
||||
|
||||
@abc.abstractmethod
|
||||
def _additional_observations(self) -> Dict[Constants, np.typing.ArrayLike]:
|
||||
def _additional_observations(self) -> Dict[str, np.typing.ArrayLike]:
|
||||
return {}
|
||||
|
||||
@abc.abstractmethod
|
||||
def _additional_raw_observations(self, agent) -> Dict[Constants, np.typing.ArrayLike]:
|
||||
def _additional_per_agent_raw_observations(self, agent) -> Dict[str, np.typing.ArrayLike]:
|
||||
additional_raw_observations = {}
|
||||
if self.obs_prop.show_global_position_info:
|
||||
additional_raw_observations.update({c.GLOBAL_POSITION: self[c.GLOBAL_POSITION].by_entity(agent).as_array()})
|
||||
|
Reference in New Issue
Block a user