Item and Dirt Factory Working again

This commit is contained in:
Steffen Illium
2021-12-23 13:19:31 +01:00
parent b43f595207
commit 78bf19f7f4
11 changed files with 257 additions and 321 deletions

View File

@ -1,8 +1,6 @@
import abc
import enum
import time
from collections import defaultdict
from enum import Enum
from itertools import chain
from pathlib import Path
from typing import List, Union, Iterable, Dict
@ -13,8 +11,8 @@ from gym import spaces
from gym.wrappers import FrameStack
from environments.factory.base.shadow_casting import Map
from environments.helpers import Constants as c, Constants
from environments import helpers as h
from environments.helpers import Constants as c
from environments.factory.base.objects import Agent, Tile, Action
from environments.factory.base.registers import Actions, Entities, Agents, Doors, FloorTiles, WallTiles, PlaceHolders, \
GlobalPositions
@ -53,10 +51,9 @@ class BaseFactory(gym.Env):
_, named_obs = self._build_observations()
if self.n_agents > 1:
# Only return the first named obs space, as their structure at the moment is same.
return [{key.name: val for key, val in named_ob.items()} for named_ob in named_obs.values()][0]
return named_obs[list(named_obs.keys())[0]]
else:
return {key.name: val for key, val in named_obs.items()}
return named_obs
@property
def pomdp_diameter(self):
@ -143,27 +140,27 @@ class BaseFactory(gym.Env):
# Walls
walls = WallTiles.from_argwhere_coordinates(
np.argwhere(level_array == c.OCCUPIED_CELL.value),
np.argwhere(level_array == c.OCCUPIED_CELL),
self._level_shape
)
self._entities.register_additional_items({c.WALLS: walls})
# Floor
floor = FloorTiles.from_argwhere_coordinates(
np.argwhere(level_array == c.FREE_CELL.value),
np.argwhere(level_array == c.FREE_CELL),
self._level_shape
)
self._entities.register_additional_items({c.FLOOR: floor})
# NOPOS
self._NO_POS_TILE = Tile(c.NO_POS.value, None)
self._NO_POS_TILE = Tile(c.NO_POS, None)
# Doors
if self.parse_doors:
parsed_doors = h.one_hot_level(self._parsed_level, c.DOOR)
parsed_doors = np.pad(parsed_doors, self.obs_prop.pomdp_r, 'constant', constant_values=0)
if np.any(parsed_doors):
door_tiles = [floor.by_pos(tuple(pos)) for pos in np.argwhere(parsed_doors == c.OCCUPIED_CELL.value)]
door_tiles = [floor.by_pos(tuple(pos)) for pos in np.argwhere(parsed_doors == c.OCCUPIED_CELL)]
doors = Doors.from_tiles(door_tiles, self._level_shape,
entity_kwargs=dict(context=floor)
)
@ -209,7 +206,7 @@ class BaseFactory(gym.Env):
if self.obs_prop.show_global_position_info:
global_positions = GlobalPositions(self._level_shape)
obs_shape_2d = self._level_shape if not self._pomdp_r else ((self.pomdp_diameter,) * 2)
global_positions.spawn_GlobalPositionObjects(obs_shape_2d, self[c.AGENT])
global_positions.spawn_global_position_objects(obs_shape_2d, self[c.AGENT])
self._entities.register_additional_items({c.GLOBAL_POSITION: global_positions})
# Return
@ -239,8 +236,8 @@ class BaseFactory(gym.Env):
for action, agent in zip(actions, self[c.AGENT]):
agent.clear_temp_state()
action_obj = self._actions[int(action)]
# self.print(f'Action #{action} has been resolved to: {action_obj}')
if h.MovingAction.is_member(action_obj):
# cls.print(f'Action #{action} has been resolved to: {action_obj}')
if h.EnvActions.is_move(action_obj):
valid = self._move_or_colide(agent, action_obj)
elif h.EnvActions.NOOP == agent.temp_action:
valid = c.VALID
@ -338,12 +335,12 @@ class BaseFactory(gym.Env):
obs_dict[c.AGENT_PLACEHOLDER] = placeholder_obs
obs_dict[c.DOORS] = door_obs
obs_dict.update(add_obs_dict)
observations = np.vstack(list(obs_dict.values()))
obsn = np.vstack(list(obs_dict.values()))
if self.obs_prop.pomdp_r:
observations = self._do_pomdp_cutout(agent, observations)
obsn = self._do_pomdp_cutout(agent, obsn)
raw_obs = self._additional_raw_observations(agent)
observations = np.vstack((observations, *list(raw_obs.values())))
raw_obs = self._additional_per_agent_raw_observations(agent)
obsn = np.vstack((obsn, *list(raw_obs.values())))
keys = list(chain(obs_dict.keys(), raw_obs.keys()))
idxs = np.cumsum([x.shape[0] for x in chain(obs_dict.values(), raw_obs.values())]) - 1
@ -365,7 +362,7 @@ class BaseFactory(gym.Env):
print(e)
raise e
if self.obs_prop.cast_shadows:
obs_block_light = observations[light_block_obs] != c.OCCUPIED_CELL.value
obs_block_light = obsn[light_block_obs] != c.OCCUPIED_CELL
door_shadowing = False
if self.parse_doors:
if doors := self[c.DOORS]:
@ -395,11 +392,11 @@ class BaseFactory(gym.Env):
light_block_map[xs, ys] = 0
agent.temp_light_map = light_block_map.copy()
observations[shadowed_obs] = ((observations[shadowed_obs] * light_block_map) + 0.) - (1 - light_block_map)
obsn[shadowed_obs] = ((obsn[shadowed_obs] * light_block_map) + 0.) - (1 - light_block_map)
else:
pass
per_agent_obsn[agent.name] = observations
per_agent_obsn[agent.name] = obsn
if self.n_agents == 1:
agent_name = self[c.AGENT][0].name
@ -450,7 +447,7 @@ class BaseFactory(gym.Env):
tiles_with_collisions.append(tile)
return tiles_with_collisions
def _move_or_colide(self, agent: Agent, action: Action) -> Constants:
def _move_or_colide(self, agent: Agent, action: Action) -> bool:
new_tile, valid = self._check_agent_move(agent, action)
if valid:
# Does not collide width level boundaries
@ -624,7 +621,7 @@ class BaseFactory(gym.Env):
return []
@property
def additional_entities(self) -> Dict[(Enum, Entities)]:
def additional_entities(self) -> Dict[(str, Entities)]:
"""
When heriting from this Base Class, you musst implement this methode!!!
@ -652,11 +649,11 @@ class BaseFactory(gym.Env):
return False
@abc.abstractmethod
def _additional_observations(self) -> Dict[Constants, np.typing.ArrayLike]:
def _additional_observations(self) -> Dict[str, np.typing.ArrayLike]:
return {}
@abc.abstractmethod
def _additional_raw_observations(self, agent) -> Dict[Constants, np.typing.ArrayLike]:
def _additional_per_agent_raw_observations(self, agent) -> Dict[str, np.typing.ArrayLike]:
additional_raw_observations = {}
if self.obs_prop.show_global_position_info:
additional_raw_observations.update({c.GLOBAL_POSITION: self[c.GLOBAL_POSITION].by_entity(agent).as_array()})