Files
marl-factory-grid/environment/utils/utility_classes.py
2023-06-09 14:04:17 +02:00

28 lines
811 B
Python

import gymnasium as gym
class EnvCombiner(object):
def __init__(self, *envs_cls):
self._env_dict = {env_cls.__name__: env_cls for env_cls in envs_cls}
@staticmethod
def combine_cls(name, *envs_cls):
return type(name, envs_cls, {})
def build(self):
name = f'{"".join([x.lower().replace("factory").capitalize() for x in self._env_dict.keys()])}Factory'
return self.combine_cls(name, tuple(self._env_dict.values()))
class MarlFrameStack(gym.ObservationWrapper):
"""todo @romue404"""
def __init__(self, env):
super().__init__(env)
def observation(self, observation):
if isinstance(self.env, gym.wrappers.FrameStack) and self.env.unwrapped.n_agents > 1:
return observation[0:].swapaxes(0, 1)
return observation