actions are now objects :P

This commit is contained in:
steffen-illium 2021-05-28 17:47:43 +02:00
parent 36fe59c95c
commit 604ffc3b60
3 changed files with 61 additions and 46 deletions

View File

@ -1,4 +1,4 @@
from typing import List, Union, Iterable, TypedDict from typing import List, Union, Iterable
import gym import gym
from gym import spaces from gym import spaces
@ -34,40 +34,51 @@ class AgentState:
class Actions: class Actions:
def __init__(self, allow_square_movement=True, allow_diagonal_movement=True, allow_no_OP=True): @property
self.allow_no_OP = allow_no_OP def n(self):
return len(self)
@property
def movement_actions(self):
return self._movement_actions
def __init__(self, allow_square_movement=False, allow_diagonal_movement=False, allow_no_op=False):
# FIXME: There is a bug in helpers because there actions are ints. and the order matters.
assert not(allow_square_movement is False and allow_diagonal_movement is True), "There is a bug in helpers!!!"
self.allow_no_op = allow_no_op
self.allow_diagonal_movement = allow_diagonal_movement self.allow_diagonal_movement = allow_diagonal_movement
self.allow_square_movement = allow_square_movement self.allow_square_movement = allow_square_movement
self._registerd_actions = dict() self._registerd_actions = dict()
if allow_square_movement: if allow_square_movement:
self + {key: val for key, val in enumerate(['north', 'east', 'south', 'west'])} self + ['north', 'east', 'south', 'west']
if allow_diagonal_movement: if allow_diagonal_movement:
self + {key: val for key, val in enumerate(['north-east', 'south-east', 'south-west', 'north-west'])} self + ['north-east', 'south-east', 'south-west', 'north-west']
self._movement_actions = self._registerd_actions.copy() self._movement_actions = self._registerd_actions.copy()
if self.allow_no_OP: if self.allow_no_op:
self + {0:'no-op'} self + 'no-op'
def __len__(self): def __len__(self):
return len(self._registerd_actions) return len(self._registerd_actions)
def __add__(self, other: dict): def __add__(self, other: Union[str, List[str]]):
assert all([isinstance(x, int) for x in other.keys()]), f'All action keys have to be of type {int}.' other = other if isinstance(other, list) else [other]
assert all([isinstance(x, str) for x in other.values()]), f'All action values have to be of type {str}.' assert all([isinstance(x, str) for x in other]), f'All action names have to be of type {str}.'
self._registerd_actions.update({key+len(self._registerd_actions): value for key,value in other.items()}) self._registerd_actions.update({key+len(self._registerd_actions): value for key, value in enumerate(other)})
return self return self
def register_additional_actions(self, other:dict): def register_additional_actions(self, other: Union[str, List[str]]):
self_with_additional_actions = self + other self_with_additional_actions = self + other
return self_with_additional_actions return self_with_additional_actions
def __getitem__(self, item):
return self._registerd_actions[item]
class BaseFactory(gym.Env): class BaseFactory(gym.Env):
@property @property
def action_space(self): def action_space(self):
return spaces.Discrete(self._registered_actions) return spaces.Discrete(self._actions.n)
@property @property
def observation_space(self): def observation_space(self):
@ -75,27 +86,20 @@ class BaseFactory(gym.Env):
@property @property
def movement_actions(self): def movement_actions(self):
if self._movement_actions is None: return self._actions.movement_actions
self._movement_actions = dict()
if self.allow_square_movement:
self._movement_actions.update(
)
if self.allow_diagonal_movement:
self.{key: val for key, val in zip(range(4), ['ne', 'ne', 'nw', 'nw'])}
return self._movement_actions
@property @property
def string_slices(self): def string_slices(self):
return {value: key for key, value in self.slice_strings.items()} return {value: key for key, value in self.slice_strings.items()}
def __init__(self, level='simple', n_agents=1, max_steps=int(2e2)): def __init__(self, level='simple', n_agents=1, max_steps=int(2e2), **kwargs):
self.n_agents = n_agents self.n_agents = n_agents
self.max_steps = max_steps self.max_steps = max_steps
self.done_at_collision = False self.done_at_collision = False
self._actions = Actions(allow_square_movement=True, allow_diagonal_movement=True, allow_no_OP=False) _actions = Actions(allow_square_movement=kwargs.get('allow_square_movement', True),
allow_diagonal_movement=kwargs.get('allow_diagonal_movement', True),
allow_no_op=kwargs.get('allow_no_op', True))
self._actions = _actions + self.additional_actions
self.level = h.one_hot_level( self.level = h.one_hot_level(
h.parse_level(Path(__file__).parent / h.LEVELS_DIR / f'{level}.txt') h.parse_level(Path(__file__).parent / h.LEVELS_DIR / f'{level}.txt')
@ -103,7 +107,16 @@ class BaseFactory(gym.Env):
self.slice_strings = {0: 'level', **{i: f'agent#{i}' for i in range(1, self.n_agents+1)}} self.slice_strings = {0: 'level', **{i: f'agent#{i}' for i in range(1, self.n_agents+1)}}
self.reset() self.reset()
def register_additional_actions(self) -> dict: @property
def additional_actions(self) -> Union[str, List[str]]:
"""
When heriting from this Base Class, you musst implement this methode!!!
Please return a dict with the given types -> {int: str}.
The int should start at 0.
:return: An Actions-object holding all actions with keys in range 0-n.
:rtype: Actions
"""
raise NotImplementedError('Please register additional actions ') raise NotImplementedError('Please register additional actions ')
def reset(self) -> (np.ndarray, int, bool, dict): def reset(self) -> (np.ndarray, int, bool, dict):
@ -125,7 +138,7 @@ class BaseFactory(gym.Env):
# Returns State # Returns State
return self.state return self.state
def additional_actions(self, agent_i: int, action: int) -> ((int, int), bool): def do_additional_actions(self, agent_i: int, action: int) -> ((int, int), bool):
raise NotImplementedError raise NotImplementedError
def step(self, actions): def step(self, actions):
@ -143,7 +156,7 @@ class BaseFactory(gym.Env):
elif self._is_no_op(action): elif self._is_no_op(action):
pos, valid = self.agent_i_position(agent_i), True pos, valid = self.agent_i_position(agent_i), True
else: else:
pos, valid = self.additional_actions(agent_i, action) pos, valid = self.do_additional_actions(agent_i, action)
# Update state accordingly # Update state accordingly
agent_i_state.update(pos=pos, action_valid=valid) agent_i_state.update(pos=pos, action_valid=valid)
agent_states.append(agent_i_state) agent_states.append(agent_i_state)
@ -162,10 +175,10 @@ class BaseFactory(gym.Env):
return self.state, reward, done, info return self.state, reward, done, info
def _is_moving_action(self, action): def _is_moving_action(self, action):
return self._registered_actions[action] in self.movement_actions return action in self._actions.movement_actions
def _is_no_op(self, action): def _is_no_op(self, action):
return self._registered_actions[action] == 'no-op' return self._actions[action] == 'no-op'
def check_all_collisions(self, agent_states: List[AgentState], collisions: int) -> np.ndarray: def check_all_collisions(self, agent_states: List[AgentState], collisions: int) -> np.ndarray:
collision_vecs = np.zeros((len(agent_states), collisions)) # n_agents x n_slices collision_vecs = np.zeros((len(agent_states), collisions)) # n_agents x n_slices

View File

@ -1,6 +1,6 @@
from collections import OrderedDict from collections import OrderedDict
from dataclasses import dataclass from dataclasses import dataclass
from typing import List from typing import List, Union
import random import random
import numpy as np import numpy as np
@ -12,7 +12,7 @@ from environments.logging.monitor import MonitorCallback
from environments.factory.renderer import Renderer, Entity from environments.factory.renderer import Renderer, Entity
DIRT_INDEX = -1 DIRT_INDEX = -1
CLEAN_UP_ACTION = 'clean_up'
@dataclass @dataclass
class DirtProperties: class DirtProperties:
@ -26,13 +26,14 @@ class DirtProperties:
class SimpleFactory(BaseFactory): class SimpleFactory(BaseFactory):
def register_additional_actions(self): @property
return 1 def additional_actions(self) -> Union[str, List[str]]:
return CLEAN_UP_ACTION
def _is_clean_up_action(self, action): def _is_clean_up_action(self, action):
return self.action_space.n - 1 == action return self._actions[action] == CLEAN_UP_ACTION
def __init__(self, *args, dirt_properties: DirtProperties, verbose=False, force_skip_render=False, **kwargs): def __init__(self, *args, dirt_properties: DirtProperties, verbose=False, **kwargs):
self._dirt_properties = dirt_properties self._dirt_properties = dirt_properties
self.verbose = verbose self.verbose = verbose
self.max_dirt = 20 self.max_dirt = 20
@ -98,7 +99,7 @@ class SimpleFactory(BaseFactory):
self.next_dirt_spawn -= 1 self.next_dirt_spawn -= 1
return self.state, r, done, info return self.state, r, done, info
def additional_actions(self, agent_i: int, action: int) -> ((int, int), bool): def do_additional_actions(self, agent_i: int, action: int) -> ((int, int), bool):
if action != self._is_moving_action(action): if action != self._is_moving_action(action):
if self._is_clean_up_action(action): if self._is_clean_up_action(action):
agent_i_pos = self.agent_i_position(agent_i) agent_i_pos = self.agent_i_position(agent_i)
@ -175,9 +176,10 @@ if __name__ == '__main__':
dirt_props = DirtProperties() dirt_props = DirtProperties()
factory = SimpleFactory(n_agents=2, dirt_properties=dirt_props) factory = SimpleFactory(n_agents=2, dirt_properties=dirt_props)
n_actions = factory.action_space.n - 1
with MonitorCallback(factory): with MonitorCallback(factory):
for epoch in range(100): for epoch in range(100):
random_actions = [(random.randint(0, 8), random.randint(0, 8)) for _ in range(200)] random_actions = [(random.randint(0, n_actions), random.randint(0, n_actions)) for _ in range(200)]
env_state, this_reward, done_bool, _ = factory.reset() env_state, this_reward, done_bool, _ = factory.reset()
for agent_i_action in random_actions: for agent_i_action in random_actions:
env_state, reward, done_bool, info_obj = factory.step(agent_i_action) env_state, reward, done_bool, info_obj = factory.step(agent_i_action)

View File

@ -45,7 +45,7 @@ def combine_runs(run_path: Union[str, PathLike]):
df = pd.concat(df_list, ignore_index=True) df = pd.concat(df_list, ignore_index=True)
df = df.fillna(0).rename(columns={'episode': 'Episode', 'run': 'Run'}) df = df.fillna(0).rename(columns={'episode': 'Episode', 'run': 'Run'})
df_group = df.groupby(['Episode', 'Run']).aggregate({col: 'mean' if col in ['dirt_amount', df_group = df.groupby(['Episode', 'Run']).aggregate({col: 'sum' if col in ['dirt_amount',
'dirty_tiles'] else 'sum' 'dirty_tiles'] else 'sum'
for col in df.columns if for col in df.columns if
col not in ['Episode', 'Run', 'train_step'] col not in ['Episode', 'Run', 'train_step']
@ -66,8 +66,8 @@ def combine_runs(run_path: Union[str, PathLike]):
if __name__ == '__main__': if __name__ == '__main__':
combine_runs('debug_out/PPO_1622128912') # combine_runs('debug_out/PPO_1622128912')
exit() # exit()
from stable_baselines3 import DQN, PPO from stable_baselines3 import DQN, PPO
@ -78,7 +78,7 @@ if __name__ == '__main__':
for seed in range(5): for seed in range(5):
env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, force_skip_render=True) env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, allow_diagonal_movement=False, allow_no_op=False)
model = PPO("MlpPolicy", env, verbose=1, ent_coef=0.0, seed=seed, device='cpu') model = PPO("MlpPolicy", env, verbose=1, ent_coef=0.0, seed=seed, device='cpu')