zwischenstand, no checkout pls!

This commit is contained in:
steffen-illium
2021-05-28 14:43:03 +02:00
parent 1b98171f3a
commit 36fe59c95c
4 changed files with 100 additions and 30 deletions

View File

@ -1,4 +1,4 @@
from typing import List, Union, Iterable
from typing import List, Union, Iterable, TypedDict
import gym
from gym import spaces
@ -32,6 +32,37 @@ class AgentState:
raise AttributeError(f'"{key}" cannot be updated, this attr is not a part of {self.__class__.__name__}')
class Actions:
def __init__(self, allow_square_movement=True, allow_diagonal_movement=True, allow_no_OP=True):
self.allow_no_OP = allow_no_OP
self.allow_diagonal_movement = allow_diagonal_movement
self.allow_square_movement = allow_square_movement
self._registerd_actions = dict()
if allow_square_movement:
self + {key: val for key, val in enumerate(['north', 'east', 'south', 'west'])}
if allow_diagonal_movement:
self + {key: val for key, val in enumerate(['north-east', 'south-east', 'south-west', 'north-west'])}
self._movement_actions = self._registerd_actions.copy()
if self.allow_no_OP:
self + {0:'no-op'}
def __len__(self):
return len(self._registerd_actions)
def __add__(self, other: dict):
assert all([isinstance(x, int) for x in other.keys()]), f'All action keys have to be of type {int}.'
assert all([isinstance(x, str) for x in other.values()]), f'All action values have to be of type {str}.'
self._registerd_actions.update({key+len(self._registerd_actions): value for key,value in other.items()})
return self
def register_additional_actions(self, other:dict):
self_with_additional_actions = self + other
return self_with_additional_actions
class BaseFactory(gym.Env):
@property
@ -44,7 +75,16 @@ class BaseFactory(gym.Env):
@property
def movement_actions(self):
return (int(self.allow_square_movement) + int(self.allow_diagonal_movement)) * 4
if self._movement_actions is None:
self._movement_actions = dict()
if self.allow_square_movement:
self._movement_actions.update(
)
if self.allow_diagonal_movement:
self.{key: val for key, val in zip(range(4), ['ne', 'ne', 'nw', 'nw'])}
return self._movement_actions
@property
def string_slices(self):
@ -53,18 +93,17 @@ class BaseFactory(gym.Env):
def __init__(self, level='simple', n_agents=1, max_steps=int(2e2)):
self.n_agents = n_agents
self.max_steps = max_steps
self.allow_square_movement = True
self.allow_diagonal_movement = True
self.allow_no_OP = True
self.done_at_collision = False
self._registered_actions = self.movement_actions + int(self.allow_no_OP) + self.register_additional_actions()
self._actions = Actions(allow_square_movement=True, allow_diagonal_movement=True, allow_no_OP=False)
self.level = h.one_hot_level(
h.parse_level(Path(__file__).parent / h.LEVELS_DIR / f'{level}.txt')
)
self.slice_strings = {0: 'level', **{i: f'agent#{i}' for i in range(1, self.n_agents+1)}}
self.reset()
def register_additional_actions(self) -> int:
def register_additional_actions(self) -> dict:
raise NotImplementedError('Please register additional actions ')
def reset(self) -> (np.ndarray, int, bool, dict):
@ -123,10 +162,10 @@ class BaseFactory(gym.Env):
return self.state, reward, done, info
def _is_moving_action(self, action):
return action < self.movement_actions
return self._registered_actions[action] in self.movement_actions
def _is_no_op(self, action):
return self.allow_no_OP and (action - self.movement_actions) == 0
return self._registered_actions[action] == 'no-op'
def check_all_collisions(self, agent_states: List[AgentState], collisions: int) -> np.ndarray:
collision_vecs = np.zeros((len(agent_states), collisions)) # n_agents x n_slices