From 7b4e60b0aa897140c5a1c197bef53a5e33f3eb4b Mon Sep 17 00:00:00 2001 From: steffen-illium Date: Sat, 29 May 2021 10:49:39 +0200 Subject: [PATCH] new register objects state slices are now registers def __get__(int) def by_name(str) :v: --- environments/factory/base_factory.py | 57 +++++++++++++++++--------- environments/factory/simple_factory.py | 18 +++++--- environments/logging/monitor.py | 3 ++ environments/logging/plotting.py | 4 +- main.py | 6 +-- 5 files changed, 57 insertions(+), 31 deletions(-) diff --git a/environments/factory/base_factory.py b/environments/factory/base_factory.py index 5afd889..4134735 100644 --- a/environments/factory/base_factory.py +++ b/environments/factory/base_factory.py @@ -32,12 +32,40 @@ class AgentState: raise AttributeError(f'"{key}" cannot be updated, this attr is not a part of {self.__class__.__name__}') -class Actions: +class Register: @property def n(self): return len(self) + def __init__(self): + self._register = dict() + + def __len__(self): + return len(self._register) + + def __add__(self, other: Union[str, List[str]]): + other = other if isinstance(other, list) else [other] + assert all([isinstance(x, str) for x in other]), f'All item names have to be of type {str}.' + self._register.update({key+len(self._register): value for key, value in enumerate(other)}) + return self + + def register_additional_items(self, other: Union[str, List[str]]): + self_with_additional_items = self + other + return self_with_additional_items + + def __getitem__(self, item): + return self._register[item] + + def by_name(self, item): + return list(self._register.keys())[list(self._register.values()).index(item)] + + def __repr__(self): + return f'{self.__class__.__name__}({self._register})' + + +class Actions(Register): + @property def movement_actions(self): return self._movement_actions @@ -45,33 +73,25 @@ class Actions: def __init__(self, allow_square_movement=False, allow_diagonal_movement=False, allow_no_op=False): # FIXME: There is a bug in helpers because there actions are ints. and the order matters. assert not(allow_square_movement is False and allow_diagonal_movement is True), "There is a bug in helpers!!!" + super(Actions, self).__init__() self.allow_no_op = allow_no_op self.allow_diagonal_movement = allow_diagonal_movement self.allow_square_movement = allow_square_movement - self._registerd_actions = dict() if allow_square_movement: self + ['north', 'east', 'south', 'west'] if allow_diagonal_movement: self + ['north-east', 'south-east', 'south-west', 'north-west'] - self._movement_actions = self._registerd_actions.copy() + self._movement_actions = self._register.copy() if self.allow_no_op: self + 'no-op' - def __len__(self): - return len(self._registerd_actions) - def __add__(self, other: Union[str, List[str]]): - other = other if isinstance(other, list) else [other] - assert all([isinstance(x, str) for x in other]), f'All action names have to be of type {str}.' - self._registerd_actions.update({key+len(self._registerd_actions): value for key, value in enumerate(other)}) - return self +class StateSlice(Register): - def register_additional_actions(self, other: Union[str, List[str]]): - self_with_additional_actions = self + other - return self_with_additional_actions - - def __getitem__(self, item): - return self._registerd_actions[item] + def __init__(self, n_agents: int): + super(StateSlice, self).__init__() + offset = 1 + self.register_additional_items(['level', *[f'agent#{i}' for i in range(offset, n_agents+offset)]]) class BaseFactory(gym.Env): @@ -88,9 +108,6 @@ class BaseFactory(gym.Env): def movement_actions(self): return self._actions.movement_actions - @property - def string_slices(self): - return {value: key for key, value in self.slice_strings.items()} def __init__(self, level='simple', n_agents=1, max_steps=int(2e2), **kwargs): self.n_agents = n_agents @@ -104,7 +121,7 @@ class BaseFactory(gym.Env): self.level = h.one_hot_level( h.parse_level(Path(__file__).parent / h.LEVELS_DIR / f'{level}.txt') ) - self.slice_strings = {0: 'level', **{i: f'agent#{i}' for i in range(1, self.n_agents+1)}} + self.state_slices = StateSlice(n_agents) self.reset() @property diff --git a/environments/factory/simple_factory.py b/environments/factory/simple_factory.py index e285550..1595a62 100644 --- a/environments/factory/simple_factory.py +++ b/environments/factory/simple_factory.py @@ -38,8 +38,8 @@ class SimpleFactory(BaseFactory): self.verbose = verbose self.max_dirt = 20 super(SimpleFactory, self).__init__(*args, **kwargs) - self.slice_strings.update({self.state.shape[0]-1: 'dirt'}) - self.renderer = None # expensive - dont use it when not required ! + self.state_slices.register_additional_items('dirt') + self.renderer = None # expensive - don't use it when not required ! def render(self): @@ -52,7 +52,9 @@ class SimpleFactory(BaseFactory): walls = [Entity('wall', pos) for pos in np.argwhere(self.state[h.LEVEL_IDX] > h.IS_FREE_CELL)] def asset_str(agent): - cols = ' '.join([self.slice_strings[j] for j in agent.collisions]) + if any([x is None for x in [self.state_slices[j] for j in agent.collisions]]): + print('error') + cols = ' '.join([self.state_slices[j] for j in agent.collisions]) if 'agent' in cols: return 'agent_collision' elif not agent.action_valid or 'level' in cols or 'agent' in cols: @@ -131,8 +133,12 @@ class SimpleFactory(BaseFactory): for agent_state in agent_states: cols = agent_state.collisions + + list_of_collisions = [self.state_slices[entity] for entity in cols + if entity != self.state_slices.by_name("dirt")] + self.print(f't = {self.steps}\tAgent {agent_state.i} has collisions with ' - f'{[self.slice_strings[entity] for entity in cols if entity != self.string_slices["dirt"]]}') + f'{list_of_collisions}') if self._is_clean_up_action(agent_state.action): if agent_state.action_valid: reward += 1 @@ -155,8 +161,8 @@ class SimpleFactory(BaseFactory): reward -= 0.25 for entity in cols: - if entity != self.string_slices["dirt"]: - self.monitor.set(f'agent_{agent_state.i}_vs_{self.slice_strings[entity]}', 1) + if entity != self.state_slices.by_name("dirt"): + self.monitor.set(f'agent_{agent_state.i}_vs_{self.state_slices[entity]}', 1) self.monitor.set('dirt_amount', current_dirt_amount) self.monitor.set('dirty_tiles', dirty_tiles) diff --git a/environments/logging/monitor.py b/environments/logging/monitor.py index 6b4acae..ccfe1f7 100644 --- a/environments/logging/monitor.py +++ b/environments/logging/monitor.py @@ -104,6 +104,9 @@ class MonitorCallback(BaseCallback): df = pd.DataFrame(columns=monitor.columns) for _, row in monitor.iterrows(): df.loc[df.shape[0]] = row + if df is None: # The env exited premature, we catch it. + self.closed = True + return for column in list(df.columns): if column != 'episode': df[f'{column}_roll'] = df[column].rolling(window=50).mean() diff --git a/environments/logging/plotting.py b/environments/logging/plotting.py index 0592091..4613698 100644 --- a/environments/logging/plotting.py +++ b/environments/logging/plotting.py @@ -20,13 +20,13 @@ PALETTE = 10 * ( ) -def plot(filepath, ext='png', tag='monitor', **kwargs): +def plot(filepath, ext='png', **kwargs): plt.rcParams.update(kwargs) plt.tight_layout() figure = plt.gcf() plt.show() - figure.savefig(str(filepath.parent / f'{filepath.stem}_{tag}_measures.{ext}'), format=ext) + figure.savefig(str(filepath), format=ext) def prepare_plot(filepath, results_df, ext='png', tag=''): diff --git a/main.py b/main.py index 3199496..2b74805 100644 --- a/main.py +++ b/main.py @@ -61,8 +61,8 @@ def combine_runs(run_path: Union[str, PathLike]): if __name__ == '__main__': - combine_runs('debug_out/PPO_1622120377') - exit() + # combine_runs('debug_out/PPO_1622120377') + # exit() from stable_baselines3 import PPO # DQN dirt_props = DirtProperties() @@ -72,7 +72,7 @@ if __name__ == '__main__': for seed in range(5): - env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, allow_diagonal_movement=False, allow_no_op=False) + env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, allow_diagonal_movement=True, allow_no_op=False) model = PPO("MlpPolicy", env, verbose=1, ent_coef=0.0, seed=seed, device='cpu')