named tuples working
This commit is contained in:
@ -1,6 +1,6 @@
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
from typing import List, Union, Iterable
|
||||
from typing import List, Union, Iterable, NamedTuple
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
@ -11,6 +11,12 @@ import yaml
|
||||
from environments import helpers as h
|
||||
|
||||
|
||||
class MovementProperties(NamedTuple):
|
||||
allow_square_movement: bool = False
|
||||
allow_diagonal_movement: bool = False
|
||||
allow_no_op: bool = False
|
||||
|
||||
|
||||
class AgentState:
|
||||
|
||||
def __init__(self, i: int, action: int):
|
||||
@ -78,16 +84,17 @@ class Actions(Register):
|
||||
def movement_actions(self):
|
||||
return self._movement_actions
|
||||
|
||||
def __init__(self, allow_square_movement=False, allow_diagonal_movement=False, allow_no_op=False):
|
||||
def __init__(self, movement_properties: MovementProperties):
|
||||
self.allow_no_op = movement_properties.allow_no_op
|
||||
self.allow_diagonal_movement = movement_properties.allow_diagonal_movement
|
||||
self.allow_square_movement = movement_properties.allow_square_movement
|
||||
# FIXME: There is a bug in helpers because there actions are ints. and the order matters.
|
||||
assert not(allow_square_movement is False and allow_diagonal_movement is True), "There is a bug in helpers!!!"
|
||||
assert not(self.allow_square_movement is False and self.allow_diagonal_movement is True), "There is a bug in helpers!!!"
|
||||
super(Actions, self).__init__()
|
||||
self.allow_no_op = allow_no_op
|
||||
self.allow_diagonal_movement = allow_diagonal_movement
|
||||
self.allow_square_movement = allow_square_movement
|
||||
if allow_square_movement:
|
||||
|
||||
if self.allow_square_movement:
|
||||
self + ['north', 'east', 'south', 'west']
|
||||
if allow_diagonal_movement:
|
||||
if self.allow_diagonal_movement:
|
||||
self + ['north-east', 'south-east', 'south-west', 'north-west']
|
||||
self._movement_actions = self._register.copy()
|
||||
if self.allow_no_op:
|
||||
@ -124,20 +131,18 @@ class BaseFactory(gym.Env):
|
||||
return self._actions.movement_actions
|
||||
|
||||
def __init__(self, level='simple', n_agents=1, max_steps=int(5e2), pomdp_radius: Union[None, int] = None,
|
||||
allow_square_movement=True, allow_diagonal_movement=True, allow_no_op=True,
|
||||
movement_properties: MovementProperties = MovementProperties(),
|
||||
omit_agent_slice_in_obs=False, **kwargs):
|
||||
self.allow_no_op = allow_no_op
|
||||
self.allow_diagonal_movement = allow_diagonal_movement
|
||||
self.allow_square_movement = allow_square_movement
|
||||
|
||||
self.movement_properties = movement_properties
|
||||
|
||||
self.n_agents = n_agents
|
||||
self.max_steps = max_steps
|
||||
self.pomdp_radius = pomdp_radius
|
||||
self.omit_agent_slice_in_obs = omit_agent_slice_in_obs
|
||||
|
||||
self.done_at_collision = False
|
||||
_actions = Actions(allow_square_movement=self.allow_square_movement,
|
||||
allow_diagonal_movement=self.allow_diagonal_movement,
|
||||
allow_no_op=allow_no_op)
|
||||
_actions = Actions(self.movement_properties)
|
||||
self._actions = _actions + self.additional_actions
|
||||
|
||||
self._level = h.one_hot_level(
|
||||
|
Reference in New Issue
Block a user