Smaller fixes, now running.
This commit is contained in:
@ -3,23 +3,24 @@ import time
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import List, Union, Iterable, Dict
|
||||
import numpy as np
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
from gym import spaces
|
||||
|
||||
import yaml
|
||||
from gym.wrappers import FrameStack
|
||||
|
||||
from environments.factory.base.shadow_casting import Map
|
||||
from environments.factory.renderer import Renderer, RenderEntity
|
||||
from environments.helpers import Constants as c, Constants
|
||||
from environments import helpers as h
|
||||
from environments.factory.base.objects import Agent, Tile, Action
|
||||
from environments.factory.base.objects import Agent, Tile, Action, Wall
|
||||
from environments.factory.base.registers import Actions, Entities, Agents, Doors, FloorTiles, WallTiles
|
||||
from environments.utility_classes import MovementProperties
|
||||
|
||||
REC_TAC = 'rec'
|
||||
import simplejson
|
||||
|
||||
|
||||
REC_TAC = 'rec_'
|
||||
|
||||
|
||||
# noinspection PyAttributeOutsideInit
|
||||
@ -67,6 +68,8 @@ class BaseFactory(gym.Env):
|
||||
self.env_seed = env_seed
|
||||
self.seed(env_seed)
|
||||
self._base_rng = np.random.default_rng(self.env_seed)
|
||||
if isinstance(movement_properties, dict):
|
||||
movement_properties = MovementProperties(**movement_properties)
|
||||
self.movement_properties = movement_properties
|
||||
self.level_name = level_name
|
||||
self._level_shape = None
|
||||
@ -118,7 +121,7 @@ class BaseFactory(gym.Env):
|
||||
entities.update({c.FLOOR: floor})
|
||||
|
||||
# NOPOS
|
||||
self.NO_POS_TILE = Tile(c.NO_POS.value)
|
||||
self._NO_POS_TILE = Tile(c.NO_POS.value)
|
||||
|
||||
# Doors
|
||||
if self.parse_doors:
|
||||
@ -175,7 +178,7 @@ class BaseFactory(gym.Env):
|
||||
|
||||
def step(self, actions):
|
||||
|
||||
if self.n_agents == 1:
|
||||
if self.n_agents == 1 and not isinstance(actions, list):
|
||||
actions = [int(actions)]
|
||||
|
||||
assert isinstance(actions, Iterable), f'"actions" has to be in [{int, list}]'
|
||||
@ -470,16 +473,16 @@ class BaseFactory(gym.Env):
|
||||
d = {key: val for key, val in self.__dict__.items() if not key.startswith('_') and not key.startswith('__')}
|
||||
filepath.parent.mkdir(parents=True, exist_ok=True)
|
||||
with filepath.open('w') as f:
|
||||
yaml.dump(d, f)
|
||||
# pickle.dump(d, f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
simplejson.dump(d, f, indent=4, namedtuple_as_object=True)
|
||||
|
||||
def _summarize_state(self):
|
||||
summary = {f'{REC_TAC}_step': self._steps}
|
||||
summary = {f'{REC_TAC}step': self._steps}
|
||||
|
||||
self[c.WALLS].summarize_state()
|
||||
for entity in self._entities:
|
||||
if hasattr(entity, 'summarize_state'):
|
||||
summary.update({f'{REC_TAC}_{entity.name}': entity.summarize_state()})
|
||||
if self._steps == 0:
|
||||
summary.update({f'{REC_TAC}{self[c.WALLS].name}': {self[c.WALLS].summarize_states()}})
|
||||
for entity_group in self._entities:
|
||||
if not isinstance(entity_group, WallTiles):
|
||||
summary.update({f'{REC_TAC}{entity_group.name}': entity_group.summarize_states()})
|
||||
return summary
|
||||
|
||||
def print(self, string):
|
||||
|
@ -124,6 +124,9 @@ class Tile(Object):
|
||||
def __repr__(self):
|
||||
return f'{self.name}(@{self.pos})'
|
||||
|
||||
def summarize_state(self):
|
||||
return dict(name=self.name, x=self.x, y=self.y)
|
||||
|
||||
|
||||
class Wall(Tile):
|
||||
pass
|
||||
@ -160,8 +163,9 @@ class Entity(Object):
|
||||
self._tile = tile
|
||||
tile.enter(self)
|
||||
|
||||
def summarize_state(self):
|
||||
return self.__dict__.copy()
|
||||
def summarize_state(self) -> dict:
|
||||
return dict(name=str(self.name), x=int(self.x), y=int(self.y),
|
||||
tile=str(self.tile.name), can_collide=bool(self.can_collide))
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.name}(@{self.pos})'
|
||||
@ -180,6 +184,10 @@ class Door(Entity):
|
||||
def encoding(self):
|
||||
return 1 if self.is_closed else 0.5
|
||||
|
||||
@property
|
||||
def str_state(self):
|
||||
return 'open' if self.is_open else 'closed'
|
||||
|
||||
@property
|
||||
def access_area(self):
|
||||
return [node for node in self.connectivity.nodes
|
||||
@ -206,6 +214,11 @@ class Door(Entity):
|
||||
if not closed_on_init:
|
||||
self._open()
|
||||
|
||||
def summarize_state(self):
|
||||
state_dict = super().summarize_state()
|
||||
state_dict.update(state=str(self.str_state), time_to_close=int(self.time_to_close))
|
||||
return state_dict
|
||||
|
||||
@property
|
||||
def is_closed(self):
|
||||
return self._state == c.CLOSED_DOOR
|
||||
@ -296,3 +309,8 @@ class Agent(MoveableEntity):
|
||||
self.temp_valid = None
|
||||
self.temp_action = None
|
||||
self.temp_light_map = None
|
||||
|
||||
def summarize_state(self):
|
||||
state_dict = super().summarize_state()
|
||||
state_dict.update(valid=bool(self.temp_valid), action=str(self.temp_action))
|
||||
return state_dict
|
||||
|
@ -15,7 +15,7 @@ class Register:
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self.__class__.__name__
|
||||
return f'{self.__class__.__name__}'
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._register = dict()
|
||||
@ -78,6 +78,9 @@ class ObjectRegister(Register):
|
||||
if self.individual_slices:
|
||||
self._array = np.concatenate((self._array, np.zeros((1, *self._level_shape))))
|
||||
|
||||
def summarize_states(self):
|
||||
return [val.summarize_state() for val in self.values()]
|
||||
|
||||
|
||||
class EntityObjectRegister(ObjectRegister, ABC):
|
||||
|
||||
@ -154,8 +157,8 @@ class Entities(Register):
|
||||
def __init__(self):
|
||||
super(Entities, self).__init__()
|
||||
|
||||
def __iter__(self):
|
||||
return iter([x for sublist in self.values() for x in sublist])
|
||||
def iter_individual_entitites(self):
|
||||
return iter((x for sublist in self.values() for x in sublist))
|
||||
|
||||
def register_item(self, other: dict):
|
||||
assert not any([key for key in other.keys() if key in self.keys()]), \
|
||||
|
Reference in New Issue
Block a user