diff --git a/environments/factory/base_factory.py b/environments/factory/base_factory.py index 4661552..5afd889 100644 --- a/environments/factory/base_factory.py +++ b/environments/factory/base_factory.py @@ -1,4 +1,4 @@ -from typing import List, Union, Iterable, TypedDict +from typing import List, Union, Iterable import gym from gym import spaces @@ -34,40 +34,51 @@ class AgentState: class Actions: - def __init__(self, allow_square_movement=True, allow_diagonal_movement=True, allow_no_OP=True): - self.allow_no_OP = allow_no_OP + @property + def n(self): + return len(self) + + @property + def movement_actions(self): + return self._movement_actions + + def __init__(self, allow_square_movement=False, allow_diagonal_movement=False, allow_no_op=False): + # FIXME: There is a bug in helpers because there actions are ints. and the order matters. + assert not(allow_square_movement is False and allow_diagonal_movement is True), "There is a bug in helpers!!!" + self.allow_no_op = allow_no_op self.allow_diagonal_movement = allow_diagonal_movement self.allow_square_movement = allow_square_movement self._registerd_actions = dict() if allow_square_movement: - self + {key: val for key, val in enumerate(['north', 'east', 'south', 'west'])} + self + ['north', 'east', 'south', 'west'] if allow_diagonal_movement: - self + {key: val for key, val in enumerate(['north-east', 'south-east', 'south-west', 'north-west'])} - + self + ['north-east', 'south-east', 'south-west', 'north-west'] self._movement_actions = self._registerd_actions.copy() - if self.allow_no_OP: - self + {0:'no-op'} - + if self.allow_no_op: + self + 'no-op' def __len__(self): return len(self._registerd_actions) - def __add__(self, other: dict): - assert all([isinstance(x, int) for x in other.keys()]), f'All action keys have to be of type {int}.' - assert all([isinstance(x, str) for x in other.values()]), f'All action values have to be of type {str}.' - self._registerd_actions.update({key+len(self._registerd_actions): value for key,value in other.items()}) + def __add__(self, other: Union[str, List[str]]): + other = other if isinstance(other, list) else [other] + assert all([isinstance(x, str) for x in other]), f'All action names have to be of type {str}.' + self._registerd_actions.update({key+len(self._registerd_actions): value for key, value in enumerate(other)}) return self - def register_additional_actions(self, other:dict): + def register_additional_actions(self, other: Union[str, List[str]]): self_with_additional_actions = self + other - return self_with_additional_actions + return self_with_additional_actions + + def __getitem__(self, item): + return self._registerd_actions[item] class BaseFactory(gym.Env): @property def action_space(self): - return spaces.Discrete(self._registered_actions) + return spaces.Discrete(self._actions.n) @property def observation_space(self): @@ -75,27 +86,20 @@ class BaseFactory(gym.Env): @property def movement_actions(self): - if self._movement_actions is None: - self._movement_actions = dict() - if self.allow_square_movement: - self._movement_actions.update( - ) - if self.allow_diagonal_movement: - self.{key: val for key, val in zip(range(4), ['ne', 'ne', 'nw', 'nw'])} - - return self._movement_actions - + return self._actions.movement_actions @property def string_slices(self): return {value: key for key, value in self.slice_strings.items()} - def __init__(self, level='simple', n_agents=1, max_steps=int(2e2)): + def __init__(self, level='simple', n_agents=1, max_steps=int(2e2), **kwargs): self.n_agents = n_agents self.max_steps = max_steps self.done_at_collision = False - self._actions = Actions(allow_square_movement=True, allow_diagonal_movement=True, allow_no_OP=False) - + _actions = Actions(allow_square_movement=kwargs.get('allow_square_movement', True), + allow_diagonal_movement=kwargs.get('allow_diagonal_movement', True), + allow_no_op=kwargs.get('allow_no_op', True)) + self._actions = _actions + self.additional_actions self.level = h.one_hot_level( h.parse_level(Path(__file__).parent / h.LEVELS_DIR / f'{level}.txt') @@ -103,7 +107,16 @@ class BaseFactory(gym.Env): self.slice_strings = {0: 'level', **{i: f'agent#{i}' for i in range(1, self.n_agents+1)}} self.reset() - def register_additional_actions(self) -> dict: + @property + def additional_actions(self) -> Union[str, List[str]]: + """ + When heriting from this Base Class, you musst implement this methode!!! + Please return a dict with the given types -> {int: str}. + The int should start at 0. + + :return: An Actions-object holding all actions with keys in range 0-n. + :rtype: Actions + """ raise NotImplementedError('Please register additional actions ') def reset(self) -> (np.ndarray, int, bool, dict): @@ -125,7 +138,7 @@ class BaseFactory(gym.Env): # Returns State return self.state - def additional_actions(self, agent_i: int, action: int) -> ((int, int), bool): + def do_additional_actions(self, agent_i: int, action: int) -> ((int, int), bool): raise NotImplementedError def step(self, actions): @@ -143,7 +156,7 @@ class BaseFactory(gym.Env): elif self._is_no_op(action): pos, valid = self.agent_i_position(agent_i), True else: - pos, valid = self.additional_actions(agent_i, action) + pos, valid = self.do_additional_actions(agent_i, action) # Update state accordingly agent_i_state.update(pos=pos, action_valid=valid) agent_states.append(agent_i_state) @@ -162,10 +175,10 @@ class BaseFactory(gym.Env): return self.state, reward, done, info def _is_moving_action(self, action): - return self._registered_actions[action] in self.movement_actions + return action in self._actions.movement_actions def _is_no_op(self, action): - return self._registered_actions[action] == 'no-op' + return self._actions[action] == 'no-op' def check_all_collisions(self, agent_states: List[AgentState], collisions: int) -> np.ndarray: collision_vecs = np.zeros((len(agent_states), collisions)) # n_agents x n_slices diff --git a/environments/factory/simple_factory.py b/environments/factory/simple_factory.py index a1e495a..e285550 100644 --- a/environments/factory/simple_factory.py +++ b/environments/factory/simple_factory.py @@ -1,6 +1,6 @@ from collections import OrderedDict from dataclasses import dataclass -from typing import List +from typing import List, Union import random import numpy as np @@ -12,7 +12,7 @@ from environments.logging.monitor import MonitorCallback from environments.factory.renderer import Renderer, Entity DIRT_INDEX = -1 - +CLEAN_UP_ACTION = 'clean_up' @dataclass class DirtProperties: @@ -26,13 +26,14 @@ class DirtProperties: class SimpleFactory(BaseFactory): - def register_additional_actions(self): - return 1 + @property + def additional_actions(self) -> Union[str, List[str]]: + return CLEAN_UP_ACTION def _is_clean_up_action(self, action): - return self.action_space.n - 1 == action + return self._actions[action] == CLEAN_UP_ACTION - def __init__(self, *args, dirt_properties: DirtProperties, verbose=False, force_skip_render=False, **kwargs): + def __init__(self, *args, dirt_properties: DirtProperties, verbose=False, **kwargs): self._dirt_properties = dirt_properties self.verbose = verbose self.max_dirt = 20 @@ -98,7 +99,7 @@ class SimpleFactory(BaseFactory): self.next_dirt_spawn -= 1 return self.state, r, done, info - def additional_actions(self, agent_i: int, action: int) -> ((int, int), bool): + def do_additional_actions(self, agent_i: int, action: int) -> ((int, int), bool): if action != self._is_moving_action(action): if self._is_clean_up_action(action): agent_i_pos = self.agent_i_position(agent_i) @@ -175,9 +176,10 @@ if __name__ == '__main__': dirt_props = DirtProperties() factory = SimpleFactory(n_agents=2, dirt_properties=dirt_props) + n_actions = factory.action_space.n - 1 with MonitorCallback(factory): for epoch in range(100): - random_actions = [(random.randint(0, 8), random.randint(0, 8)) for _ in range(200)] + random_actions = [(random.randint(0, n_actions), random.randint(0, n_actions)) for _ in range(200)] env_state, this_reward, done_bool, _ = factory.reset() for agent_i_action in random_actions: env_state, reward, done_bool, info_obj = factory.step(agent_i_action) diff --git a/main.py b/main.py index 1aedd46..d71d26c 100644 --- a/main.py +++ b/main.py @@ -45,7 +45,7 @@ def combine_runs(run_path: Union[str, PathLike]): df = pd.concat(df_list, ignore_index=True) df = df.fillna(0).rename(columns={'episode': 'Episode', 'run': 'Run'}) - df_group = df.groupby(['Episode', 'Run']).aggregate({col: 'mean' if col in ['dirt_amount', + df_group = df.groupby(['Episode', 'Run']).aggregate({col: 'sum' if col in ['dirt_amount', 'dirty_tiles'] else 'sum' for col in df.columns if col not in ['Episode', 'Run', 'train_step'] @@ -66,8 +66,8 @@ def combine_runs(run_path: Union[str, PathLike]): if __name__ == '__main__': - combine_runs('debug_out/PPO_1622128912') - exit() + # combine_runs('debug_out/PPO_1622128912') + # exit() from stable_baselines3 import DQN, PPO @@ -78,7 +78,7 @@ if __name__ == '__main__': for seed in range(5): - env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, force_skip_render=True) + env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, allow_diagonal_movement=False, allow_no_op=False) model = PPO("MlpPolicy", env, verbose=1, ent_coef=0.0, seed=seed, device='cpu')