mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-06-21 11:21:35 +02:00
added MarlFrameStack and salina stuff
This commit is contained in:
@ -1,4 +1,4 @@
|
||||
def make(env_name, n_agents=1, pomdp_r=2, max_steps=400, stack_n_frames=3):
|
||||
def make(env_name, pomdp_r=2, max_steps=400, stack_n_frames=3, n_agents=1, individual_rewards=False):
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
from environments.factory.combined_factories import DirtItemFactory
|
||||
@ -12,7 +12,8 @@ def make(env_name, n_agents=1, pomdp_r=2, max_steps=400, stack_n_frames=3):
|
||||
obs_props = ObservationProperties(render_agents=AgentRenderOptions.COMBINED,
|
||||
frames_to_stack=stack_n_frames, pomdp_r=pomdp_r)
|
||||
|
||||
factory_kwargs = dict(n_agents=n_agents, max_steps=max_steps, obs_prop=obs_props,
|
||||
factory_kwargs = dict(n_agents=n_agents, individual_rewards=individual_rewards,
|
||||
max_steps=max_steps, obs_prop=obs_props,
|
||||
mv_prop=MovementProperties(**dictionary['movement_props']),
|
||||
dirt_prop=DirtProperties(**dictionary['dirt_props']),
|
||||
record_episodes=False, verbose=False, **dictionary['factory_props']
|
||||
|
@ -15,12 +15,11 @@ from environments.helpers import Constants as c, Constants
|
||||
from environments import helpers as h
|
||||
from environments.factory.base.objects import Agent, Tile, Action
|
||||
from environments.factory.base.registers import Actions, Entities, Agents, Doors, FloorTiles, WallTiles, PlaceHolders
|
||||
from environments.utility_classes import MovementProperties, ObservationProperties
|
||||
from environments.utility_classes import MovementProperties, ObservationProperties, MarlFrameStack
|
||||
from environments.utility_classes import AgentRenderOptions as a_obs
|
||||
|
||||
import simplejson
|
||||
|
||||
|
||||
REC_TAC = 'rec_'
|
||||
|
||||
|
||||
@ -57,7 +56,7 @@ class BaseFactory(gym.Env):
|
||||
|
||||
def __enter__(self):
|
||||
return self if self.obs_prop.frames_to_stack == 0 else \
|
||||
FrameStack(self, self.obs_prop.frames_to_stack)
|
||||
MarlFrameStack(FrameStack(self, self.obs_prop.frames_to_stack))
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
|
@ -1,5 +1,6 @@
|
||||
from enum import Enum
|
||||
from typing import NamedTuple, Union
|
||||
import gym
|
||||
from gym.wrappers.frame_stack import FrameStack
|
||||
|
||||
|
||||
class AgentRenderOptions(object):
|
||||
@ -22,3 +23,14 @@ class ObservationProperties(NamedTuple):
|
||||
cast_shadows = True
|
||||
frames_to_stack: int = 0
|
||||
pomdp_r: int = 0
|
||||
|
||||
|
||||
class MarlFrameStack(gym.ObservationWrapper):
|
||||
def __init__(self, env):
|
||||
super().__init__(env)
|
||||
|
||||
def observation(self, observation):
|
||||
if isinstance(self.env, FrameStack) and self.env.unwrapped.n_agents > 1:
|
||||
return observation[0:].swapaxes(0, 1)
|
||||
return observation
|
||||
|
||||
|
Reference in New Issue
Block a user