correct plotting an reloading
This commit is contained in:
@ -138,11 +138,12 @@ class BaseFactory(gym.Env):
|
||||
def movement_actions(self):
|
||||
return self._actions.movement_actions
|
||||
|
||||
def __init__(self, level='simple', n_agents=1, max_steps=int(5e2), pomdp_radius: Union[None, int] = None,
|
||||
def __init__(self, level_name='simple', n_agents=1, max_steps=int(5e2), pomdp_radius: Union[None, int] = None,
|
||||
movement_properties: MovementProperties = MovementProperties(),
|
||||
omit_agent_slice_in_obs=False, **kwargs):
|
||||
|
||||
self.movement_properties = movement_properties
|
||||
self.level_name = level_name
|
||||
|
||||
self.n_agents = n_agents
|
||||
self.max_steps = max_steps
|
||||
@ -154,7 +155,7 @@ class BaseFactory(gym.Env):
|
||||
self._actions = _actions + self.additional_actions
|
||||
|
||||
self._level = h.one_hot_level(
|
||||
h.parse_level(Path(__file__).parent / h.LEVELS_DIR / f'{level}.txt')
|
||||
h.parse_level(Path(__file__).parent / h.LEVELS_DIR / f'{self.level_name}.txt')
|
||||
)
|
||||
self._state_slices = StateSlice(n_agents)
|
||||
if 'additional_slices' in kwargs:
|
||||
@ -328,8 +329,8 @@ class BaseFactory(gym.Env):
|
||||
|
||||
def save_params(self, filepath: Path):
|
||||
# noinspection PyProtectedMember
|
||||
d = {key: val._asdict() if hasattr(val, '_asdict') else val for key, val in self.__dict__.items()
|
||||
if not key.startswith('_') and not key.startswith('__')}
|
||||
# d = {key: val._asdict() if hasattr(val, '_asdict') else val for key, val in self.__dict__.items()
|
||||
d = {key: val for key, val in self.__dict__.items() if not key.startswith('_') and not key.startswith('__')}
|
||||
filepath.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with filepath.open('w') as f:
|
||||
|
Reference in New Issue
Block a user