new register objects

state slices are now registers
def __get__(int)
def by_name(str)

✌️
This commit is contained in:
steffen-illium
2021-05-29 10:49:39 +02:00
parent efedce579e
commit 7b4e60b0aa
5 changed files with 57 additions and 31 deletions

View File

@ -32,12 +32,40 @@ class AgentState:
raise AttributeError(f'"{key}" cannot be updated, this attr is not a part of {self.__class__.__name__}')
class Actions:
class Register:
@property
def n(self):
return len(self)
def __init__(self):
self._register = dict()
def __len__(self):
return len(self._register)
def __add__(self, other: Union[str, List[str]]):
other = other if isinstance(other, list) else [other]
assert all([isinstance(x, str) for x in other]), f'All item names have to be of type {str}.'
self._register.update({key+len(self._register): value for key, value in enumerate(other)})
return self
def register_additional_items(self, other: Union[str, List[str]]):
self_with_additional_items = self + other
return self_with_additional_items
def __getitem__(self, item):
return self._register[item]
def by_name(self, item):
return list(self._register.keys())[list(self._register.values()).index(item)]
def __repr__(self):
return f'{self.__class__.__name__}({self._register})'
class Actions(Register):
@property
def movement_actions(self):
return self._movement_actions
@ -45,33 +73,25 @@ class Actions:
def __init__(self, allow_square_movement=False, allow_diagonal_movement=False, allow_no_op=False):
# FIXME: There is a bug in helpers because there actions are ints. and the order matters.
assert not(allow_square_movement is False and allow_diagonal_movement is True), "There is a bug in helpers!!!"
super(Actions, self).__init__()
self.allow_no_op = allow_no_op
self.allow_diagonal_movement = allow_diagonal_movement
self.allow_square_movement = allow_square_movement
self._registerd_actions = dict()
if allow_square_movement:
self + ['north', 'east', 'south', 'west']
if allow_diagonal_movement:
self + ['north-east', 'south-east', 'south-west', 'north-west']
self._movement_actions = self._registerd_actions.copy()
self._movement_actions = self._register.copy()
if self.allow_no_op:
self + 'no-op'
def __len__(self):
return len(self._registerd_actions)
def __add__(self, other: Union[str, List[str]]):
other = other if isinstance(other, list) else [other]
assert all([isinstance(x, str) for x in other]), f'All action names have to be of type {str}.'
self._registerd_actions.update({key+len(self._registerd_actions): value for key, value in enumerate(other)})
return self
class StateSlice(Register):
def register_additional_actions(self, other: Union[str, List[str]]):
self_with_additional_actions = self + other
return self_with_additional_actions
def __getitem__(self, item):
return self._registerd_actions[item]
def __init__(self, n_agents: int):
super(StateSlice, self).__init__()
offset = 1
self.register_additional_items(['level', *[f'agent#{i}' for i in range(offset, n_agents+offset)]])
class BaseFactory(gym.Env):
@ -88,9 +108,6 @@ class BaseFactory(gym.Env):
def movement_actions(self):
return self._actions.movement_actions
@property
def string_slices(self):
return {value: key for key, value in self.slice_strings.items()}
def __init__(self, level='simple', n_agents=1, max_steps=int(2e2), **kwargs):
self.n_agents = n_agents
@ -104,7 +121,7 @@ class BaseFactory(gym.Env):
self.level = h.one_hot_level(
h.parse_level(Path(__file__).parent / h.LEVELS_DIR / f'{level}.txt')
)
self.slice_strings = {0: 'level', **{i: f'agent#{i}' for i in range(1, self.n_agents+1)}}
self.state_slices = StateSlice(n_agents)
self.reset()
@property