actions are now objects :P

This commit is contained in:
steffen-illium
2021-05-28 17:47:43 +02:00
parent 36fe59c95c
commit 604ffc3b60
3 changed files with 61 additions and 46 deletions

View File

@@ -1,4 +1,4 @@
from typing import List, Union, Iterable, TypedDict
from typing import List, Union, Iterable
import gym
from gym import spaces
@@ -34,40 +34,51 @@ class AgentState:
class Actions:
def __init__(self, allow_square_movement=True, allow_diagonal_movement=True, allow_no_OP=True):
self.allow_no_OP = allow_no_OP
@property
def n(self):
return len(self)
@property
def movement_actions(self):
return self._movement_actions
def __init__(self, allow_square_movement=False, allow_diagonal_movement=False, allow_no_op=False):
# FIXME: There is a bug in helpers because there actions are ints. and the order matters.
assert not(allow_square_movement is False and allow_diagonal_movement is True), "There is a bug in helpers!!!"
self.allow_no_op = allow_no_op
self.allow_diagonal_movement = allow_diagonal_movement
self.allow_square_movement = allow_square_movement
self._registerd_actions = dict()
if allow_square_movement:
self + {key: val for key, val in enumerate(['north', 'east', 'south', 'west'])}
self + ['north', 'east', 'south', 'west']
if allow_diagonal_movement:
self + {key: val for key, val in enumerate(['north-east', 'south-east', 'south-west', 'north-west'])}
self + ['north-east', 'south-east', 'south-west', 'north-west']
self._movement_actions = self._registerd_actions.copy()
if self.allow_no_OP:
self + {0:'no-op'}
if self.allow_no_op:
self + 'no-op'
def __len__(self):
return len(self._registerd_actions)
def __add__(self, other: dict):
assert all([isinstance(x, int) for x in other.keys()]), f'All action keys have to be of type {int}.'
assert all([isinstance(x, str) for x in other.values()]), f'All action values have to be of type {str}.'
self._registerd_actions.update({key+len(self._registerd_actions): value for key,value in other.items()})
def __add__(self, other: Union[str, List[str]]):
other = other if isinstance(other, list) else [other]
assert all([isinstance(x, str) for x in other]), f'All action names have to be of type {str}.'
self._registerd_actions.update({key+len(self._registerd_actions): value for key, value in enumerate(other)})
return self
def register_additional_actions(self, other:dict):
def register_additional_actions(self, other: Union[str, List[str]]):
self_with_additional_actions = self + other
return self_with_additional_actions
return self_with_additional_actions
def __getitem__(self, item):
return self._registerd_actions[item]
class BaseFactory(gym.Env):
@property
def action_space(self):
return spaces.Discrete(self._registered_actions)
return spaces.Discrete(self._actions.n)
@property
def observation_space(self):
@@ -75,27 +86,20 @@ class BaseFactory(gym.Env):
@property
def movement_actions(self):
if self._movement_actions is None:
self._movement_actions = dict()
if self.allow_square_movement:
self._movement_actions.update(
)
if self.allow_diagonal_movement:
self.{key: val for key, val in zip(range(4), ['ne', 'ne', 'nw', 'nw'])}
return self._movement_actions
return self._actions.movement_actions
@property
def string_slices(self):
return {value: key for key, value in self.slice_strings.items()}
def __init__(self, level='simple', n_agents=1, max_steps=int(2e2)):
def __init__(self, level='simple', n_agents=1, max_steps=int(2e2), **kwargs):
self.n_agents = n_agents
self.max_steps = max_steps
self.done_at_collision = False
self._actions = Actions(allow_square_movement=True, allow_diagonal_movement=True, allow_no_OP=False)
_actions = Actions(allow_square_movement=kwargs.get('allow_square_movement', True),
allow_diagonal_movement=kwargs.get('allow_diagonal_movement', True),
allow_no_op=kwargs.get('allow_no_op', True))
self._actions = _actions + self.additional_actions
self.level = h.one_hot_level(
h.parse_level(Path(__file__).parent / h.LEVELS_DIR / f'{level}.txt')
@@ -103,7 +107,16 @@ class BaseFactory(gym.Env):
self.slice_strings = {0: 'level', **{i: f'agent#{i}' for i in range(1, self.n_agents+1)}}
self.reset()
def register_additional_actions(self) -> dict:
@property
def additional_actions(self) -> Union[str, List[str]]:
"""
When heriting from this Base Class, you musst implement this methode!!!
Please return a dict with the given types -> {int: str}.
The int should start at 0.
:return: An Actions-object holding all actions with keys in range 0-n.
:rtype: Actions
"""
raise NotImplementedError('Please register additional actions ')
def reset(self) -> (np.ndarray, int, bool, dict):
@@ -125,7 +138,7 @@ class BaseFactory(gym.Env):
# Returns State
return self.state
def additional_actions(self, agent_i: int, action: int) -> ((int, int), bool):
def do_additional_actions(self, agent_i: int, action: int) -> ((int, int), bool):
raise NotImplementedError
def step(self, actions):
@@ -143,7 +156,7 @@ class BaseFactory(gym.Env):
elif self._is_no_op(action):
pos, valid = self.agent_i_position(agent_i), True
else:
pos, valid = self.additional_actions(agent_i, action)
pos, valid = self.do_additional_actions(agent_i, action)
# Update state accordingly
agent_i_state.update(pos=pos, action_valid=valid)
agent_states.append(agent_i_state)
@@ -162,10 +175,10 @@ class BaseFactory(gym.Env):
return self.state, reward, done, info
def _is_moving_action(self, action):
return self._registered_actions[action] in self.movement_actions
return action in self._actions.movement_actions
def _is_no_op(self, action):
return self._registered_actions[action] == 'no-op'
return self._actions[action] == 'no-op'
def check_all_collisions(self, agent_states: List[AgentState], collisions: int) -> np.ndarray:
collision_vecs = np.zeros((len(agent_states), collisions)) # n_agents x n_slices