mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-06-18 18:52:52 +02:00
renaming
This commit is contained in:
27
marl_factory_grid/utils/utility_classes.py
Normal file
27
marl_factory_grid/utils/utility_classes.py
Normal file
@ -0,0 +1,27 @@
|
||||
import gymnasium as gym
|
||||
|
||||
|
||||
class EnvCombiner(object):
|
||||
|
||||
def __init__(self, *envs_cls):
|
||||
self._env_dict = {env_cls.__name__: env_cls for env_cls in envs_cls}
|
||||
|
||||
@staticmethod
|
||||
def combine_cls(name, *envs_cls):
|
||||
return type(name, envs_cls, {})
|
||||
|
||||
def build(self):
|
||||
name = f'{"".join([x.lower().replace("factory").capitalize() for x in self._env_dict.keys()])}Factory'
|
||||
|
||||
return self.combine_cls(name, tuple(self._env_dict.values()))
|
||||
|
||||
|
||||
class MarlFrameStack(gym.ObservationWrapper):
|
||||
"""todo @romue404"""
|
||||
def __init__(self, env):
|
||||
super().__init__(env)
|
||||
|
||||
def observation(self, observation):
|
||||
if isinstance(self.env, gym.wrappers.FrameStack) and self.env.unwrapped.n_agents > 1:
|
||||
return observation[0:].swapaxes(0, 1)
|
||||
return observation
|
Reference in New Issue
Block a user