yaml working

This commit is contained in:
steffen-illium
2021-06-04 18:11:23 +02:00
parent 8ce92d5db4
commit f1306d4d6f
5 changed files with 41 additions and 28 deletions

View File

@ -1,4 +1,4 @@
import pickle
from argparse import Namespace
from pathlib import Path
from typing import List, Union, Iterable, NamedTuple
@ -89,7 +89,8 @@ class Actions(Register):
self.allow_diagonal_movement = movement_properties.allow_diagonal_movement
self.allow_square_movement = movement_properties.allow_square_movement
# FIXME: There is a bug in helpers because there actions are ints. and the order matters.
assert not(self.allow_square_movement is False and self.allow_diagonal_movement is True), "There is a bug in helpers!!!"
assert not(self.allow_square_movement is False and self.allow_diagonal_movement is True), \
"There is a bug in helpers!!!"
super(Actions, self).__init__()
if self.allow_square_movement:
@ -109,11 +110,14 @@ class StateSlice(Register):
self.register_additional_items(['level', *[f'agent#{i}' for i in range(offset, n_agents+offset)]])
# noinspection PyAttributeOutsideInit
class BaseFactory(gym.Env):
# def __setattr__(self, key, value):
# if isinstance(value, dict):
def __setattr__(self, key, value):
if isinstance(value, dict):
super(BaseFactory, self).__setattr__(key, Namespace(**value))
else:
super(BaseFactory, self).__setattr__(key, value)
@property
def action_space(self):
@ -199,8 +203,8 @@ class BaseFactory(gym.Env):
obs_padded = np.full((obs.shape[0], self.pomdp_radius ** 2 + 1, self.pomdp_radius ** 2 + 1), 1)
a_pos = np.argwhere(obs[h.AGENT_START_IDX] == h.IS_OCCUPIED_CELL)[0]
obs_padded[:,
abs(a_pos[0]-self.pomdp_radius):abs(a_pos[0]-self.pomdp_radius)+obs.shape[1],
abs(a_pos[1]-self.pomdp_radius):abs(a_pos[1]-self.pomdp_radius)+obs.shape[2]] = obs
abs(a_pos[0]-self.pomdp_radius):abs(a_pos[0]-self.pomdp_radius)+obs.shape[1],
abs(a_pos[1]-self.pomdp_radius):abs(a_pos[1]-self.pomdp_radius)+obs.shape[2]] = obs
obs = obs_padded
else:
obs = self._state
@ -323,9 +327,11 @@ class BaseFactory(gym.Env):
raise NotImplementedError
def save_params(self, filepath: Path):
d = {key: val for key, val in self.__dict__.items() if not key.startswith('_') and not key.startswith('__')}
# noinspection PyProtectedMember
d = {key: val._asdict() if hasattr(val, '_asdict') else val for key, val in self.__dict__.items()
if not key.startswith('_') and not key.startswith('__')}
filepath.parent.mkdir(parents=True, exist_ok=True)
with filepath.open('wb') as f:
# yaml.dump(d, f)
pickle.dump(d, f, protocol=pickle.HIGHEST_PROTOCOL)
with filepath.open('w') as f:
yaml.dump(d, f)
# pickle.dump(d, f, protocol=pickle.HIGHEST_PROTOCOL)