named tuples working

This commit is contained in:
steffen-illium
2021-06-04 16:43:34 +02:00
parent 1e87c4807f
commit 8ba7c418f0
3 changed files with 27 additions and 19 deletions

View File

@ -1,6 +1,6 @@
import pickle
from pathlib import Path
from typing import List, Union, Iterable
from typing import List, Union, Iterable, NamedTuple
import gym
import numpy as np
@ -11,6 +11,12 @@ import yaml
from environments import helpers as h
class MovementProperties(NamedTuple):
allow_square_movement: bool = False
allow_diagonal_movement: bool = False
allow_no_op: bool = False
class AgentState:
def __init__(self, i: int, action: int):
@ -78,16 +84,17 @@ class Actions(Register):
def movement_actions(self):
return self._movement_actions
def __init__(self, allow_square_movement=False, allow_diagonal_movement=False, allow_no_op=False):
def __init__(self, movement_properties: MovementProperties):
self.allow_no_op = movement_properties.allow_no_op
self.allow_diagonal_movement = movement_properties.allow_diagonal_movement
self.allow_square_movement = movement_properties.allow_square_movement
# FIXME: There is a bug in helpers because there actions are ints. and the order matters.
assert not(allow_square_movement is False and allow_diagonal_movement is True), "There is a bug in helpers!!!"
assert not(self.allow_square_movement is False and self.allow_diagonal_movement is True), "There is a bug in helpers!!!"
super(Actions, self).__init__()
self.allow_no_op = allow_no_op
self.allow_diagonal_movement = allow_diagonal_movement
self.allow_square_movement = allow_square_movement
if allow_square_movement:
if self.allow_square_movement:
self + ['north', 'east', 'south', 'west']
if allow_diagonal_movement:
if self.allow_diagonal_movement:
self + ['north-east', 'south-east', 'south-west', 'north-west']
self._movement_actions = self._register.copy()
if self.allow_no_op:
@ -124,20 +131,18 @@ class BaseFactory(gym.Env):
return self._actions.movement_actions
def __init__(self, level='simple', n_agents=1, max_steps=int(5e2), pomdp_radius: Union[None, int] = None,
allow_square_movement=True, allow_diagonal_movement=True, allow_no_op=True,
movement_properties: MovementProperties = MovementProperties(),
omit_agent_slice_in_obs=False, **kwargs):
self.allow_no_op = allow_no_op
self.allow_diagonal_movement = allow_diagonal_movement
self.allow_square_movement = allow_square_movement
self.movement_properties = movement_properties
self.n_agents = n_agents
self.max_steps = max_steps
self.pomdp_radius = pomdp_radius
self.omit_agent_slice_in_obs = omit_agent_slice_in_obs
self.done_at_collision = False
_actions = Actions(allow_square_movement=self.allow_square_movement,
allow_diagonal_movement=self.allow_diagonal_movement,
allow_no_op=allow_no_op)
_actions = Actions(self.movement_properties)
self._actions = _actions + self.additional_actions
self._level = h.one_hot_level(