frame stack

This commit is contained in:
steffen-illium
2021-06-04 12:04:24 +02:00
parent b72013407e
commit 5668f5cb82
5 changed files with 36 additions and 20 deletions

View File

@ -6,6 +6,8 @@ import gym
import numpy as np
from gym import spaces
import yaml
from environments import helpers as h
@ -191,6 +193,7 @@ class BaseFactory(gym.Env):
abs(a_pos[1]-self.pomdp_radius):abs(a_pos[1]-self.pomdp_radius)+obs.shape[2]] = obs
obs = obs_padded
else:
assert not self.omit_agent_slice_in_obs
obs = self._state
if self.omit_agent_slice_in_obs:
if obs.shape != (3, 5, 5):
@ -315,7 +318,9 @@ class BaseFactory(gym.Env):
raise NotImplementedError
def save_params(self, filepath: Path):
d = {key: val for key, val in self.__dict__.items() if not key.startswith('_') or not key.startswith('__')}
d = {key: val for key, val in self.__dict__.items() if not key.startswith('_') and not key.startswith('__')}
filepath.parent.mkdir(parents=True, exist_ok=True)
with filepath.open('wb') as f:
# yaml.dump(d, f)
pickle.dump(d, f, protocol=pickle.HIGHEST_PROTOCOL)