new register objects
state slices are now registers
def __get__(int)
def by_name(str)
✌️
This commit is contained in:
@ -32,12 +32,40 @@ class AgentState:
|
||||
raise AttributeError(f'"{key}" cannot be updated, this attr is not a part of {self.__class__.__name__}')
|
||||
|
||||
|
||||
class Actions:
|
||||
class Register:
|
||||
|
||||
@property
|
||||
def n(self):
|
||||
return len(self)
|
||||
|
||||
def __init__(self):
|
||||
self._register = dict()
|
||||
|
||||
def __len__(self):
|
||||
return len(self._register)
|
||||
|
||||
def __add__(self, other: Union[str, List[str]]):
|
||||
other = other if isinstance(other, list) else [other]
|
||||
assert all([isinstance(x, str) for x in other]), f'All item names have to be of type {str}.'
|
||||
self._register.update({key+len(self._register): value for key, value in enumerate(other)})
|
||||
return self
|
||||
|
||||
def register_additional_items(self, other: Union[str, List[str]]):
|
||||
self_with_additional_items = self + other
|
||||
return self_with_additional_items
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self._register[item]
|
||||
|
||||
def by_name(self, item):
|
||||
return list(self._register.keys())[list(self._register.values()).index(item)]
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.__class__.__name__}({self._register})'
|
||||
|
||||
|
||||
class Actions(Register):
|
||||
|
||||
@property
|
||||
def movement_actions(self):
|
||||
return self._movement_actions
|
||||
@ -45,33 +73,25 @@ class Actions:
|
||||
def __init__(self, allow_square_movement=False, allow_diagonal_movement=False, allow_no_op=False):
|
||||
# FIXME: There is a bug in helpers because there actions are ints. and the order matters.
|
||||
assert not(allow_square_movement is False and allow_diagonal_movement is True), "There is a bug in helpers!!!"
|
||||
super(Actions, self).__init__()
|
||||
self.allow_no_op = allow_no_op
|
||||
self.allow_diagonal_movement = allow_diagonal_movement
|
||||
self.allow_square_movement = allow_square_movement
|
||||
self._registerd_actions = dict()
|
||||
if allow_square_movement:
|
||||
self + ['north', 'east', 'south', 'west']
|
||||
if allow_diagonal_movement:
|
||||
self + ['north-east', 'south-east', 'south-west', 'north-west']
|
||||
self._movement_actions = self._registerd_actions.copy()
|
||||
self._movement_actions = self._register.copy()
|
||||
if self.allow_no_op:
|
||||
self + 'no-op'
|
||||
|
||||
def __len__(self):
|
||||
return len(self._registerd_actions)
|
||||
|
||||
def __add__(self, other: Union[str, List[str]]):
|
||||
other = other if isinstance(other, list) else [other]
|
||||
assert all([isinstance(x, str) for x in other]), f'All action names have to be of type {str}.'
|
||||
self._registerd_actions.update({key+len(self._registerd_actions): value for key, value in enumerate(other)})
|
||||
return self
|
||||
class StateSlice(Register):
|
||||
|
||||
def register_additional_actions(self, other: Union[str, List[str]]):
|
||||
self_with_additional_actions = self + other
|
||||
return self_with_additional_actions
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self._registerd_actions[item]
|
||||
def __init__(self, n_agents: int):
|
||||
super(StateSlice, self).__init__()
|
||||
offset = 1
|
||||
self.register_additional_items(['level', *[f'agent#{i}' for i in range(offset, n_agents+offset)]])
|
||||
|
||||
|
||||
class BaseFactory(gym.Env):
|
||||
@ -88,9 +108,6 @@ class BaseFactory(gym.Env):
|
||||
def movement_actions(self):
|
||||
return self._actions.movement_actions
|
||||
|
||||
@property
|
||||
def string_slices(self):
|
||||
return {value: key for key, value in self.slice_strings.items()}
|
||||
|
||||
def __init__(self, level='simple', n_agents=1, max_steps=int(2e2), **kwargs):
|
||||
self.n_agents = n_agents
|
||||
@ -104,7 +121,7 @@ class BaseFactory(gym.Env):
|
||||
self.level = h.one_hot_level(
|
||||
h.parse_level(Path(__file__).parent / h.LEVELS_DIR / f'{level}.txt')
|
||||
)
|
||||
self.slice_strings = {0: 'level', **{i: f'agent#{i}' for i in range(1, self.n_agents+1)}}
|
||||
self.state_slices = StateSlice(n_agents)
|
||||
self.reset()
|
||||
|
||||
@property
|
||||
|
Reference in New Issue
Block a user