Smaller fixes, now running.
This commit is contained in:
@ -3,23 +3,24 @@ import time
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import List, Union, Iterable, Dict
|
||||
import numpy as np
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
from gym import spaces
|
||||
|
||||
import yaml
|
||||
from gym.wrappers import FrameStack
|
||||
|
||||
from environments.factory.base.shadow_casting import Map
|
||||
from environments.factory.renderer import Renderer, RenderEntity
|
||||
from environments.helpers import Constants as c, Constants
|
||||
from environments import helpers as h
|
||||
from environments.factory.base.objects import Agent, Tile, Action
|
||||
from environments.factory.base.objects import Agent, Tile, Action, Wall
|
||||
from environments.factory.base.registers import Actions, Entities, Agents, Doors, FloorTiles, WallTiles
|
||||
from environments.utility_classes import MovementProperties
|
||||
|
||||
REC_TAC = 'rec'
|
||||
import simplejson
|
||||
|
||||
|
||||
REC_TAC = 'rec_'
|
||||
|
||||
|
||||
# noinspection PyAttributeOutsideInit
|
||||
@ -67,6 +68,8 @@ class BaseFactory(gym.Env):
|
||||
self.env_seed = env_seed
|
||||
self.seed(env_seed)
|
||||
self._base_rng = np.random.default_rng(self.env_seed)
|
||||
if isinstance(movement_properties, dict):
|
||||
movement_properties = MovementProperties(**movement_properties)
|
||||
self.movement_properties = movement_properties
|
||||
self.level_name = level_name
|
||||
self._level_shape = None
|
||||
@ -118,7 +121,7 @@ class BaseFactory(gym.Env):
|
||||
entities.update({c.FLOOR: floor})
|
||||
|
||||
# NOPOS
|
||||
self.NO_POS_TILE = Tile(c.NO_POS.value)
|
||||
self._NO_POS_TILE = Tile(c.NO_POS.value)
|
||||
|
||||
# Doors
|
||||
if self.parse_doors:
|
||||
@ -175,7 +178,7 @@ class BaseFactory(gym.Env):
|
||||
|
||||
def step(self, actions):
|
||||
|
||||
if self.n_agents == 1:
|
||||
if self.n_agents == 1 and not isinstance(actions, list):
|
||||
actions = [int(actions)]
|
||||
|
||||
assert isinstance(actions, Iterable), f'"actions" has to be in [{int, list}]'
|
||||
@ -470,16 +473,16 @@ class BaseFactory(gym.Env):
|
||||
d = {key: val for key, val in self.__dict__.items() if not key.startswith('_') and not key.startswith('__')}
|
||||
filepath.parent.mkdir(parents=True, exist_ok=True)
|
||||
with filepath.open('w') as f:
|
||||
yaml.dump(d, f)
|
||||
# pickle.dump(d, f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
simplejson.dump(d, f, indent=4, namedtuple_as_object=True)
|
||||
|
||||
def _summarize_state(self):
|
||||
summary = {f'{REC_TAC}_step': self._steps}
|
||||
summary = {f'{REC_TAC}step': self._steps}
|
||||
|
||||
self[c.WALLS].summarize_state()
|
||||
for entity in self._entities:
|
||||
if hasattr(entity, 'summarize_state'):
|
||||
summary.update({f'{REC_TAC}_{entity.name}': entity.summarize_state()})
|
||||
if self._steps == 0:
|
||||
summary.update({f'{REC_TAC}{self[c.WALLS].name}': {self[c.WALLS].summarize_states()}})
|
||||
for entity_group in self._entities:
|
||||
if not isinstance(entity_group, WallTiles):
|
||||
summary.update({f'{REC_TAC}{entity_group.name}': entity_group.summarize_states()})
|
||||
return summary
|
||||
|
||||
def print(self, string):
|
||||
|
Reference in New Issue
Block a user