added MarlFrameStack and salina stuff
This commit is contained in:
@ -1,5 +1,6 @@
|
||||
from enum import Enum
|
||||
from typing import NamedTuple, Union
|
||||
import gym
|
||||
from gym.wrappers.frame_stack import FrameStack
|
||||
|
||||
|
||||
class AgentRenderOptions(object):
|
||||
@ -22,3 +23,14 @@ class ObservationProperties(NamedTuple):
|
||||
cast_shadows = True
|
||||
frames_to_stack: int = 0
|
||||
pomdp_r: int = 0
|
||||
|
||||
|
||||
class MarlFrameStack(gym.ObservationWrapper):
|
||||
def __init__(self, env):
|
||||
super().__init__(env)
|
||||
|
||||
def observation(self, observation):
|
||||
if isinstance(self.env, FrameStack) and self.env.unwrapped.n_agents > 1:
|
||||
return observation[0:].swapaxes(0, 1)
|
||||
return observation
|
||||
|
||||
|
Reference in New Issue
Block a user