added MarlFrameStack and salina stuff

This commit is contained in:
Robert Müller
2021-11-23 14:03:52 +01:00
parent 59484f49c9
commit 5c15bb2ddf
6 changed files with 109 additions and 23 deletions

View File

@ -1,4 +1,4 @@
def make(env_name, n_agents=1, pomdp_r=2, max_steps=400, stack_n_frames=3):
def make(env_name, pomdp_r=2, max_steps=400, stack_n_frames=3, n_agents=1, individual_rewards=False):
import yaml
from pathlib import Path
from environments.factory.combined_factories import DirtItemFactory
@ -12,7 +12,8 @@ def make(env_name, n_agents=1, pomdp_r=2, max_steps=400, stack_n_frames=3):
obs_props = ObservationProperties(render_agents=AgentRenderOptions.COMBINED,
frames_to_stack=stack_n_frames, pomdp_r=pomdp_r)
factory_kwargs = dict(n_agents=n_agents, max_steps=max_steps, obs_prop=obs_props,
factory_kwargs = dict(n_agents=n_agents, individual_rewards=individual_rewards,
max_steps=max_steps, obs_prop=obs_props,
mv_prop=MovementProperties(**dictionary['movement_props']),
dirt_prop=DirtProperties(**dictionary['dirt_props']),
record_episodes=False, verbose=False, **dictionary['factory_props']

View File

@ -15,12 +15,11 @@ from environments.helpers import Constants as c, Constants
from environments import helpers as h
from environments.factory.base.objects import Agent, Tile, Action
from environments.factory.base.registers import Actions, Entities, Agents, Doors, FloorTiles, WallTiles, PlaceHolders
from environments.utility_classes import MovementProperties, ObservationProperties
from environments.utility_classes import MovementProperties, ObservationProperties, MarlFrameStack
from environments.utility_classes import AgentRenderOptions as a_obs
import simplejson
REC_TAC = 'rec_'
@ -57,7 +56,7 @@ class BaseFactory(gym.Env):
def __enter__(self):
return self if self.obs_prop.frames_to_stack == 0 else \
FrameStack(self, self.obs_prop.frames_to_stack)
MarlFrameStack(FrameStack(self, self.obs_prop.frames_to_stack))
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()

View File

@ -1,5 +1,6 @@
from enum import Enum
from typing import NamedTuple, Union
import gym
from gym.wrappers.frame_stack import FrameStack
class AgentRenderOptions(object):
@ -22,3 +23,14 @@ class ObservationProperties(NamedTuple):
cast_shadows = True
frames_to_stack: int = 0
pomdp_r: int = 0
class MarlFrameStack(gym.ObservationWrapper):
def __init__(self, env):
super().__init__(env)
def observation(self, observation):
if isinstance(self.env, FrameStack) and self.env.unwrapped.n_agents > 1:
return observation[0:].swapaxes(0, 1)
return observation