zwischenstand, no checkout pls!
This commit is contained in:
@ -1,4 +1,4 @@
|
||||
from typing import List, Union, Iterable
|
||||
from typing import List, Union, Iterable, TypedDict
|
||||
|
||||
import gym
|
||||
from gym import spaces
|
||||
@ -32,6 +32,37 @@ class AgentState:
|
||||
raise AttributeError(f'"{key}" cannot be updated, this attr is not a part of {self.__class__.__name__}')
|
||||
|
||||
|
||||
class Actions:
|
||||
|
||||
def __init__(self, allow_square_movement=True, allow_diagonal_movement=True, allow_no_OP=True):
|
||||
self.allow_no_OP = allow_no_OP
|
||||
self.allow_diagonal_movement = allow_diagonal_movement
|
||||
self.allow_square_movement = allow_square_movement
|
||||
self._registerd_actions = dict()
|
||||
if allow_square_movement:
|
||||
self + {key: val for key, val in enumerate(['north', 'east', 'south', 'west'])}
|
||||
if allow_diagonal_movement:
|
||||
self + {key: val for key, val in enumerate(['north-east', 'south-east', 'south-west', 'north-west'])}
|
||||
|
||||
self._movement_actions = self._registerd_actions.copy()
|
||||
if self.allow_no_OP:
|
||||
self + {0:'no-op'}
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return len(self._registerd_actions)
|
||||
|
||||
def __add__(self, other: dict):
|
||||
assert all([isinstance(x, int) for x in other.keys()]), f'All action keys have to be of type {int}.'
|
||||
assert all([isinstance(x, str) for x in other.values()]), f'All action values have to be of type {str}.'
|
||||
self._registerd_actions.update({key+len(self._registerd_actions): value for key,value in other.items()})
|
||||
return self
|
||||
|
||||
def register_additional_actions(self, other:dict):
|
||||
self_with_additional_actions = self + other
|
||||
return self_with_additional_actions
|
||||
|
||||
|
||||
class BaseFactory(gym.Env):
|
||||
|
||||
@property
|
||||
@ -44,7 +75,16 @@ class BaseFactory(gym.Env):
|
||||
|
||||
@property
|
||||
def movement_actions(self):
|
||||
return (int(self.allow_square_movement) + int(self.allow_diagonal_movement)) * 4
|
||||
if self._movement_actions is None:
|
||||
self._movement_actions = dict()
|
||||
if self.allow_square_movement:
|
||||
self._movement_actions.update(
|
||||
)
|
||||
if self.allow_diagonal_movement:
|
||||
self.{key: val for key, val in zip(range(4), ['ne', 'ne', 'nw', 'nw'])}
|
||||
|
||||
return self._movement_actions
|
||||
|
||||
|
||||
@property
|
||||
def string_slices(self):
|
||||
@ -53,18 +93,17 @@ class BaseFactory(gym.Env):
|
||||
def __init__(self, level='simple', n_agents=1, max_steps=int(2e2)):
|
||||
self.n_agents = n_agents
|
||||
self.max_steps = max_steps
|
||||
self.allow_square_movement = True
|
||||
self.allow_diagonal_movement = True
|
||||
self.allow_no_OP = True
|
||||
self.done_at_collision = False
|
||||
self._registered_actions = self.movement_actions + int(self.allow_no_OP) + self.register_additional_actions()
|
||||
self._actions = Actions(allow_square_movement=True, allow_diagonal_movement=True, allow_no_OP=False)
|
||||
|
||||
|
||||
self.level = h.one_hot_level(
|
||||
h.parse_level(Path(__file__).parent / h.LEVELS_DIR / f'{level}.txt')
|
||||
)
|
||||
self.slice_strings = {0: 'level', **{i: f'agent#{i}' for i in range(1, self.n_agents+1)}}
|
||||
self.reset()
|
||||
|
||||
def register_additional_actions(self) -> int:
|
||||
def register_additional_actions(self) -> dict:
|
||||
raise NotImplementedError('Please register additional actions ')
|
||||
|
||||
def reset(self) -> (np.ndarray, int, bool, dict):
|
||||
@ -123,10 +162,10 @@ class BaseFactory(gym.Env):
|
||||
return self.state, reward, done, info
|
||||
|
||||
def _is_moving_action(self, action):
|
||||
return action < self.movement_actions
|
||||
return self._registered_actions[action] in self.movement_actions
|
||||
|
||||
def _is_no_op(self, action):
|
||||
return self.allow_no_OP and (action - self.movement_actions) == 0
|
||||
return self._registered_actions[action] == 'no-op'
|
||||
|
||||
def check_all_collisions(self, agent_states: List[AgentState], collisions: int) -> np.ndarray:
|
||||
collision_vecs = np.zeros((len(agent_states), collisions)) # n_agents x n_slices
|
||||
|
Reference in New Issue
Block a user