frame stack
This commit is contained in:
@ -6,6 +6,8 @@ import gym
|
||||
import numpy as np
|
||||
from gym import spaces
|
||||
|
||||
import yaml
|
||||
|
||||
from environments import helpers as h
|
||||
|
||||
|
||||
@ -191,6 +193,7 @@ class BaseFactory(gym.Env):
|
||||
abs(a_pos[1]-self.pomdp_radius):abs(a_pos[1]-self.pomdp_radius)+obs.shape[2]] = obs
|
||||
obs = obs_padded
|
||||
else:
|
||||
assert not self.omit_agent_slice_in_obs
|
||||
obs = self._state
|
||||
if self.omit_agent_slice_in_obs:
|
||||
if obs.shape != (3, 5, 5):
|
||||
@ -315,7 +318,9 @@ class BaseFactory(gym.Env):
|
||||
raise NotImplementedError
|
||||
|
||||
def save_params(self, filepath: Path):
|
||||
d = {key: val for key, val in self.__dict__.items() if not key.startswith('_') or not key.startswith('__')}
|
||||
d = {key: val for key, val in self.__dict__.items() if not key.startswith('_') and not key.startswith('__')}
|
||||
filepath.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with filepath.open('wb') as f:
|
||||
# yaml.dump(d, f)
|
||||
pickle.dump(d, f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
Reference in New Issue
Block a user