From dfca68cbeb0bef53ca9e48446bde77743da72c18 Mon Sep 17 00:00:00 2001 From: steffen-illium Date: Wed, 2 Jun 2021 09:17:51 +0200 Subject: [PATCH] alles was ich hab --- environments/factory/base_factory.py | 31 ++++++++++++++++---------- environments/factory/simple_factory.py | 6 ++--- environments/logging/plotting.py | 2 -- main.py | 6 ++--- reload_agent.py | 6 ++--- 5 files changed, 27 insertions(+), 24 deletions(-) diff --git a/environments/factory/base_factory.py b/environments/factory/base_factory.py index cdbef58..2b58085 100644 --- a/environments/factory/base_factory.py +++ b/environments/factory/base_factory.py @@ -113,15 +113,19 @@ class BaseFactory(gym.Env): def movement_actions(self): return self._actions.movement_actions - def __init__(self, level='simple', n_agents=1, max_steps=int(5e2), pomdp_radius: Union[None, int] = None, **kwargs): + def __init__(self, level='simple', n_agents=1, max_steps=int(5e2), pomdp_radius: Union[None, int] = None, + allow_square_movement=True, allow_diagonal_movement=True, allow_no_op=True, **kwargs): + self.allow_no_op = allow_no_op + self.allow_diagonal_movement = allow_diagonal_movement + self.allow_square_movement = allow_square_movement self.n_agents = n_agents self.max_steps = max_steps self.pomdp_radius = pomdp_radius self.done_at_collision = False - _actions = Actions(allow_square_movement=kwargs.get('allow_square_movement', True), - allow_diagonal_movement=kwargs.get('allow_diagonal_movement', True), - allow_no_op=kwargs.get('allow_no_op', True)) + _actions = Actions(allow_square_movement=self.allow_square_movement, + allow_diagonal_movement=self.allow_diagonal_movement, + allow_no_op=allow_no_op) self._actions = _actions + self.additional_actions self._level = h.one_hot_level( @@ -165,12 +169,16 @@ class BaseFactory(gym.Env): pos = self._agent_states[0].pos # pos = [agent_state.pos for agent_state in self.agent_states] # obs = [] ... list comprehension... pos per agent - npad = [(0, 0)] + [(self.pomdp_radius, self.pomdp_radius)] * (self._state.ndim - 1) - x_roll = self.pomdp_radius-pos[0] - y_roll = self.pomdp_radius-pos[1] - padded_state = np.pad(self._state, pad_width=npad, mode='constant', constant_values=0) - padded_state = np.roll(np.roll(padded_state, x_roll, axis=1), y_roll, axis=2) - obs = padded_state[:, :self.pomdp_radius * 2 + 1, :self.pomdp_radius * 2 + 1] + x0, x1 = max(0, pos[0] - self.pomdp_radius), pos[0] + self.pomdp_radius + 1 + y0, y1 = max(0, pos[1] - self.pomdp_radius), pos[1] + self.pomdp_radius + 1 + obs = self._state[:, x0:x1, y0:y1] + if obs.shape[1] != self.pomdp_radius ** 2 + 1 or obs.shape[2] != self.pomdp_radius ** 2 + 1: + obs_padded = np.full((obs.shape[0], self.pomdp_radius ** 2 + 1, self.pomdp_radius ** 2 + 1), 1) + a_pos = np.argwhere(obs[h.AGENT_START_IDX] == h.IS_OCCUPIED_CELL)[0] + obs_padded[:, + abs(a_pos[0]-self.pomdp_radius):abs(a_pos[0]-self.pomdp_radius)+obs.shape[1], + abs(a_pos[1]-self.pomdp_radius):abs(a_pos[1]-self.pomdp_radius)+obs.shape[2]] = obs + obs = obs_padded else: obs = self._state return obs @@ -211,8 +219,7 @@ class BaseFactory(gym.Env): info.update(step_reward=reward, step=self.steps) - obs = self._return_state() - return obs, reward, done, info + return None, reward, done, info def _is_moving_action(self, action): return action in self._actions.movement_actions diff --git a/environments/factory/simple_factory.py b/environments/factory/simple_factory.py index bc30484..6bc460a 100644 --- a/environments/factory/simple_factory.py +++ b/environments/factory/simple_factory.py @@ -150,7 +150,7 @@ class SimpleFactory(BaseFactory): self.print(f'Agent {agent_state.i} did just clean up some dirt at {agent_state.pos}.') info_dict.update(dirt_cleaned=1) else: - reward -= 0.0 + reward -= 0.01 self.print(f'Agent {agent_state.i} just tried to clean up some dirt ' f'at {agent_state.pos}, but was unsucsessfull.') info_dict.update(failed_cleanup_attempt=1) @@ -162,14 +162,14 @@ class SimpleFactory(BaseFactory): else: # info_dict.update(collision=1) # self.print('collision') - reward -= 0.00 + reward -= 0.01 else: info_dict.update(no_op=1) reward -= 0.00 for entity in list_of_collisions: - info_dict.update({f'agent_{agent_state.i}_vs_{self._state_slices.by_name(entity)}': 1}) + info_dict.update({f'agent_{agent_state.i}_vs_{entity}': 1}) info_dict.update(dirt_amount=current_dirt_amount) info_dict.update(dirty_tile_count=dirty_tiles) diff --git a/environments/logging/plotting.py b/environments/logging/plotting.py index 1eb9035..654f150 100644 --- a/environments/logging/plotting.py +++ b/environments/logging/plotting.py @@ -1,8 +1,6 @@ import seaborn as sns - from matplotlib import pyplot as plt - PALETTE = 10 * ( "#377eb8", "#4daf4a", diff --git a/main.py b/main.py index 1a2aeea..e364e80 100644 --- a/main.py +++ b/main.py @@ -12,7 +12,6 @@ from environments.factory.simple_factory import DirtProperties, SimpleFactory from environments.helpers import IGNORED_DF_COLUMNS from environments.logging.monitor import MonitorCallback from environments.logging.plotting import prepare_plot -from environments.logging.training import TraningMonitor warnings.filterwarnings('ignore', category=FutureWarning) warnings.filterwarnings('ignore', category=UserWarning) @@ -55,7 +54,7 @@ if __name__ == '__main__': for seed in range(5): env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=2, max_steps=400, - allow_diagonal_movement=False, allow_no_op=False, verbose=True) + allow_diagonal_movement=False, allow_no_op=False, verbose=False) model = modeL_type("MlpPolicy", env, verbose=1, seed=seed, device='cpu') @@ -65,8 +64,7 @@ if __name__ == '__main__': out_path /= identifier callbacks = CallbackList( - [TraningMonitor(out_path / f'train_logging_{identifier}.csv'), - MonitorCallback(env, filepath=out_path / f'monitor_{identifier}.pick', plotting=False)] + [MonitorCallback(env, filepath=out_path / f'monitor_{identifier}.pick', plotting=False)] ) model.learn(total_timesteps=int(2e5), callback=callbacks) diff --git a/reload_agent.py b/reload_agent.py index 93806cc..9973452 100644 --- a/reload_agent.py +++ b/reload_agent.py @@ -14,13 +14,13 @@ warnings.filterwarnings('ignore', category=UserWarning) if __name__ == '__main__': - out_path = Path(r'C:\Users\steff\projects\f_iks\debug_out\A2C_1622557712') + out_path = Path(r'C:\Users\steff\projects\f_iks\debug_out\A2C_1622558379') with (out_path / f'env_{out_path.name}.pick').open('rb') as f: env_kwargs = pickle.load(f) - env = SimpleFactory(**env_kwargs) + env = SimpleFactory(allow_no_op=False, allow_diagonal_movement=False, allow_square_movement=True, **env_kwargs) # Edit THIS: - model_path = out_path + model_path = out_path / '1_A2C_1622558379' model_files = list(natsorted(out_path.rglob('*.zip'))) this_model = model_files[0]