From 2589a06d02e27cd59b9bacd90b6949cd5f26b9fe Mon Sep 17 00:00:00 2001 From: steffen-illium Date: Mon, 7 Jun 2021 16:14:29 +0200 Subject: [PATCH 1/3] in progress --- environments/factory/base_factory.py | 26 ++++++++++++++ environments/helpers.py | 53 +++++++++------------------- main.py | 15 ++++---- reload_agent.py | 2 +- 4 files changed, 52 insertions(+), 44 deletions(-) diff --git a/environments/factory/base_factory.py b/environments/factory/base_factory.py index 565f5a6..c708a9b 100644 --- a/environments/factory/base_factory.py +++ b/environments/factory/base_factory.py @@ -17,6 +17,32 @@ class MovementProperties(NamedTuple): allow_no_op: bool = False +class Entity(): + + @property + def pos(self): + return self._pos + + def __init__(self, pos): + self._pos = pos + + def check_agent_move(state: np.ndarray, dim: int, action: str): + agent_slice = state[dim] # horizontal slice from state tensor + agent_pos = np.argwhere(agent_slice == 1) + if len(agent_pos) > 1: + raise AssertionError('Only one agent per slice is allowed.') + x, y = agent_pos[0] + + # Actions + x_diff, y_diff = ACTIONMAP[action] + x_new = x + x_diff + y_new = y + y_diff + + + + return (x, y), (x_new, y_new), valid + + class AgentState: def __init__(self, i: int, action: int): diff --git a/environments/helpers.py b/environments/helpers.py index 1df6cd2..5b7224e 100644 --- a/environments/helpers.py +++ b/environments/helpers.py @@ -1,3 +1,6 @@ +from collections import defaultdict +from typing import Tuple + import numpy as np from pathlib import Path @@ -11,6 +14,13 @@ IS_OCCUPIED_CELL = 1 TO_BE_AVERAGED = ['dirt_amount', 'dirty_tiles'] IGNORED_DF_COLUMNS = ['Episode', 'Run', 'train_step', 'step', 'index', 'dirt_amount', 'dirty_tile_count'] +ACTIONMAP = defaultdict(lambda: (0, 0), dict(north=(-1, 0), east=(0, 1), + south=(1, 0), west=(0, -1), + north_east=(-1, +1), south_east=(1, 1), + south_west=(+1, -1), north_west=(-1, -1) + ) + ) + # Utility functions def parse_level(path): @@ -28,48 +38,19 @@ def one_hot_level(level, wall_char=WALL): return binary_grid -def check_agent_move(state, dim, action): - agent_slice = state[dim] # horizontal slice from state tensor - agent_pos = np.argwhere(agent_slice == 1) - if len(agent_pos) > 1: - raise AssertionError('Only one agent per slice is allowed.') - x, y = agent_pos[0] - x_new, y_new = x, y - # Actions - if action == 0: # North - x_new -= 1 - elif action == 1: # East - y_new += 1 - elif action == 2: # South - x_new += 1 - elif action == 3: # West - y_new -= 1 - elif action == 4: # NE - x_new -= 1 - y_new += 1 - elif action == 5: # SE - x_new += 1 - y_new += 1 - elif action == 6: # SW - x_new += 1 - y_new -= 1 - elif action == 7: # NW - x_new -= 1 - y_new -= 1 - else: - pass +def check_position(state: np.ndarray, position_to_check: Tuple[int, int], dim: int = 0): + x, y = position_to_check + agent_slice = state[dim] # Check if agent colides with grid boundrys valid = not ( - x_new < 0 or y_new < 0 - or x_new >= agent_slice.shape[0] - or y_new >= agent_slice.shape[0] + x < 0 or y < 0 + or x >= agent_slice.shape[0] + or y >= agent_slice.shape[0] ) # Check for collision with level walls - valid = valid and not state[LEVEL_IDX][x_new, y_new] - - return (x, y), (x_new, y_new), valid + valid = valid and not state[LEVEL_IDX][x, y] if __name__ == '__main__': diff --git a/main.py b/main.py index 4a983fd..e60f998 100644 --- a/main.py +++ b/main.py @@ -35,8 +35,8 @@ def combine_runs(run_path: Union[str, PathLike]): df = df.fillna(0).rename(columns={'episode': 'Episode', 'run': 'Run'}) columns = [col for col in df.columns if col not in IGNORED_DF_COLUMNS] - roll_n = 30 - skip_n = 20 + roll_n = 50 + skip_n = 40 non_overlapp_window = df.groupby(['Run', 'Episode']).rolling(roll_n, min_periods=1).mean() @@ -68,8 +68,8 @@ def compare_runs(run_path: Path, run_identifier: int, parameter: Union[str, List df = df.fillna(0).rename(columns={'episode': 'Episode', 'run': 'Run', 'model': 'Model'}) columns = [col for col in df.columns if col in parameter] - roll_n = 30 - skip_n = 10 + roll_n = 40 + skip_n = 20 non_overlapp_window = df.groupby(['Model', 'Run', 'Episode']).rolling(roll_n, min_periods=1).mean() @@ -85,14 +85,15 @@ def compare_runs(run_path: Path, run_identifier: int, parameter: Union[str, List if __name__ == '__main__': - compare_runs(Path('debug_out'), 1623052687, ['agent_0_vs_level']) + compare_runs(Path('debug_out'), 1623052687, ['step_reward']) exit() from stable_baselines3 import PPO, DQN, A2C from algorithms.reg_dqn import RegDQN # from sb3_contrib import QRDQN - dirt_props = DirtProperties() + dirt_props = DirtProperties(clean_amount=3, gain_amount=0.2, max_global_amount=30, + max_local_amount=5, spawn_frequency=3) move_props = MovementProperties(allow_diagonal_movement=False, allow_square_movement=True, allow_no_op=False) @@ -100,7 +101,7 @@ if __name__ == '__main__': out_path = None - for modeL_type in [PPO, A2C, RegDQN, DQN]: + for modeL_type in [PPO, A2C]: # , RegDQN, DQN]: for seed in range(3): env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=3, max_steps=400, diff --git a/reload_agent.py b/reload_agent.py index cfa3383..b5a8607 100644 --- a/reload_agent.py +++ b/reload_agent.py @@ -28,7 +28,7 @@ if __name__ == '__main__': this_model = model_files[0] model = PPO.load(this_model) - evaluation_result = evaluate_policy(model, env, n_eval_episodes=100, deterministic=True, render=True) + evaluation_result = evaluate_policy(model, env, n_eval_episodes=100, deterministic=False, render=True) print(evaluation_result) env.close() From 62c141aa1c8bdb2c4f7d4d86987d8d7c1237b8cd Mon Sep 17 00:00:00 2001 From: steffen-illium Date: Mon, 7 Jun 2021 16:52:06 +0200 Subject: [PATCH 2/3] actions are now checked by string --- environments/factory/base_factory.py | 36 +++++++++++++--------------- environments/helpers.py | 14 +++++------ main.py | 4 ++-- 3 files changed, 26 insertions(+), 28 deletions(-) diff --git a/environments/factory/base_factory.py b/environments/factory/base_factory.py index c708a9b..1c8b921 100644 --- a/environments/factory/base_factory.py +++ b/environments/factory/base_factory.py @@ -26,22 +26,6 @@ class Entity(): def __init__(self, pos): self._pos = pos - def check_agent_move(state: np.ndarray, dim: int, action: str): - agent_slice = state[dim] # horizontal slice from state tensor - agent_pos = np.argwhere(agent_slice == 1) - if len(agent_pos) > 1: - raise AssertionError('Only one agent per slice is allowed.') - x, y = agent_pos[0] - - # Actions - x_diff, y_diff = ACTIONMAP[action] - x_new = x + x_diff - y_new = y + y_diff - - - - return (x, y), (x_new, y_new), valid - class AgentState: @@ -312,9 +296,7 @@ class BaseFactory(gym.Env): self._state[agent_i + h.AGENT_START_IDX, x_new, y_new] = h.IS_OCCUPIED_CELL def move_or_colide(self, agent_i: int, action: int) -> ((int, int), bool): - old_pos, new_pos, valid = h.check_agent_move(state=self._state, - dim=agent_i + h.AGENT_START_IDX, - action=action) + old_pos, new_pos, valid = self._check_agent_move(agent_i=agent_i, action=self._actions[action]) if valid: # Does not collide width level boundaries self.do_move(agent_i, old_pos, new_pos) @@ -323,6 +305,22 @@ class BaseFactory(gym.Env): # Agent seems to be trying to collide in this step return old_pos, valid + def _check_agent_move(self, agent_i, action: str): + agent_slice = self._state[h.AGENT_START_IDX + agent_i] # horizontal slice from state tensor + agent_pos = np.argwhere(agent_slice == 1) + if len(agent_pos) > 1: + raise AssertionError('Only one agent per slice is allowed.') + x, y = agent_pos[0] + + # Actions + x_diff, y_diff = h.ACTIONMAP[action] + x_new = x + x_diff + y_new = y + y_diff + + valid = h.check_position(self._state[h.LEVEL_IDX], (x_new, y_new)) + + return (x, y), (x_new, y_new), valid + def agent_i_position(self, agent_i: int) -> (int, int): positions = np.argwhere(self._state[h.AGENT_START_IDX + agent_i] == h.IS_OCCUPIED_CELL) assert positions.shape[0] == 1 diff --git a/environments/helpers.py b/environments/helpers.py index 5b7224e..64670a3 100644 --- a/environments/helpers.py +++ b/environments/helpers.py @@ -38,19 +38,19 @@ def one_hot_level(level, wall_char=WALL): return binary_grid -def check_position(state: np.ndarray, position_to_check: Tuple[int, int], dim: int = 0): - x, y = position_to_check - agent_slice = state[dim] +def check_position(slice_to_check_against: np.ndarray, position_to_check: Tuple[int, int]): + x_pos, y_pos = position_to_check # Check if agent colides with grid boundrys valid = not ( - x < 0 or y < 0 - or x >= agent_slice.shape[0] - or y >= agent_slice.shape[0] + x_pos < 0 or y_pos < 0 + or x_pos >= slice_to_check_against.shape[0] + or y_pos >= slice_to_check_against.shape[0] ) # Check for collision with level walls - valid = valid and not state[LEVEL_IDX][x, y] + valid = valid and not slice_to_check_against[x_pos, y_pos] + return valid if __name__ == '__main__': diff --git a/main.py b/main.py index e60f998..f77bf8a 100644 --- a/main.py +++ b/main.py @@ -85,8 +85,8 @@ def compare_runs(run_path: Path, run_identifier: int, parameter: Union[str, List if __name__ == '__main__': - compare_runs(Path('debug_out'), 1623052687, ['step_reward']) - exit() + # compare_runs(Path('debug_out'), 1623052687, ['step_reward']) + # exit() from stable_baselines3 import PPO, DQN, A2C from algorithms.reg_dqn import RegDQN From cf2378a734f765c55aa3dc8918d811cbb5b36ffe Mon Sep 17 00:00:00 2001 From: steffen-illium Date: Wed, 9 Jun 2021 13:12:49 +0200 Subject: [PATCH 3/3] multi_agent observation when n_agent more then 1 --- environments/factory/base_factory.py | 167 ++++++------------------- environments/factory/simple_factory.py | 47 +++---- environments/logging/plotting.py | 10 +- environments/utility_classes.py | 127 +++++++++++++++++++ main.py | 4 +- main_test.py | 75 +++++++++++ 6 files changed, 271 insertions(+), 159 deletions(-) create mode 100644 environments/utility_classes.py create mode 100644 main_test.py diff --git a/environments/factory/base_factory.py b/environments/factory/base_factory.py index 1c8b921..0fa06ae 100644 --- a/environments/factory/base_factory.py +++ b/environments/factory/base_factory.py @@ -1,6 +1,6 @@ from argparse import Namespace from pathlib import Path -from typing import List, Union, Iterable, NamedTuple +from typing import List, Union, Iterable import gym import numpy as np @@ -9,115 +9,7 @@ from gym import spaces import yaml from environments import helpers as h - - -class MovementProperties(NamedTuple): - allow_square_movement: bool = True - allow_diagonal_movement: bool = False - allow_no_op: bool = False - - -class Entity(): - - @property - def pos(self): - return self._pos - - def __init__(self, pos): - self._pos = pos - - -class AgentState: - - def __init__(self, i: int, action: int): - self.i = i - self.action = action - - self.collision_vector = None - self.action_valid = None - self.pos = None - self.info = {} - - @property - def collisions(self): - return np.argwhere(self.collision_vector != 0).flatten() - - def update(self, **kwargs): # is this hacky?? o.0 - for key, value in kwargs.items(): - if hasattr(self, key): - self.__setattr__(key, value) - else: - raise AttributeError(f'"{key}" cannot be updated, this attr is not a part of {self.__class__.__name__}') - - -class Register: - - @property - def n(self): - return len(self) - - def __init__(self): - self._register = dict() - - def __len__(self): - return len(self._register) - - def __add__(self, other: Union[str, List[str]]): - other = other if isinstance(other, list) else [other] - assert all([isinstance(x, str) for x in other]), f'All item names have to be of type {str}.' - self._register.update({key+len(self._register): value for key, value in enumerate(other)}) - return self - - def register_additional_items(self, other: Union[str, List[str]]): - self_with_additional_items = self + other - return self_with_additional_items - - def keys(self): - return self._register.keys() - - def items(self): - return self._register.items() - - def __getitem__(self, item): - return self._register[item] - - def by_name(self, item): - return list(self._register.keys())[list(self._register.values()).index(item)] - - def __repr__(self): - return f'{self.__class__.__name__}({self._register})' - - -class Actions(Register): - - @property - def movement_actions(self): - return self._movement_actions - - def __init__(self, movement_properties: MovementProperties): - self.allow_no_op = movement_properties.allow_no_op - self.allow_diagonal_movement = movement_properties.allow_diagonal_movement - self.allow_square_movement = movement_properties.allow_square_movement - # FIXME: There is a bug in helpers because there actions are ints. and the order matters. - assert not(self.allow_square_movement is False and self.allow_diagonal_movement is True), \ - "There is a bug in helpers!!!" - super(Actions, self).__init__() - - if self.allow_square_movement: - self + ['north', 'east', 'south', 'west'] - if self.allow_diagonal_movement: - self + ['north-east', 'south-east', 'south-west', 'north-west'] - self._movement_actions = self._register.copy() - if self.allow_no_op: - self + 'no-op' - - -class StateSlice(Register): - - def __init__(self, n_agents: int): - super(StateSlice, self).__init__() - offset = 1 - self.register_additional_items(['level', *[f'agent#{i}' for i in range(offset, n_agents+offset)]]) +from environments.utility_classes import Actions, StateSlice, AgentState, MovementProperties # noinspection PyAttributeOutsideInit @@ -148,9 +40,11 @@ class BaseFactory(gym.Env): def movement_actions(self): return self._actions.movement_actions - def __init__(self, level_name='simple', n_agents=1, max_steps=int(5e2), pomdp_radius: Union[None, int] = None, + def __init__(self, level_name='simple', n_agents=1, max_steps=int(5e2), pomdp_radius: Union[None, int] = 0, movement_properties: MovementProperties = MovementProperties(), + combin_agent_slices_in_obs: bool = False, omit_agent_slice_in_obs=False, **kwargs): + assert combin_agent_slices_in_obs != omit_agent_slice_in_obs, 'Both options are exclusive' self.movement_properties = movement_properties self.level_name = level_name @@ -158,6 +52,7 @@ class BaseFactory(gym.Env): self.n_agents = n_agents self.max_steps = max_steps self.pomdp_radius = pomdp_radius + self.combin_agent_slices_in_obs = combin_agent_slices_in_obs self.omit_agent_slice_in_obs = omit_agent_slice_in_obs self.done_at_collision = False @@ -185,7 +80,7 @@ class BaseFactory(gym.Env): raise NotImplementedError('Please register additional actions ') def reset(self) -> (np.ndarray, int, bool, dict): - self.steps = 0 + self._steps = 0 self._agent_states = [] # Agent placement ... agents = np.zeros((self.n_agents, *self._level.shape), dtype=np.int8) @@ -202,17 +97,25 @@ class BaseFactory(gym.Env): # Returns State return None - def _return_state(self): + def _get_observations(self) -> np.ndarray: + if self.n_agents == 1: + obs = self._build_per_agent_obs(0) + elif self.n_agents >= 2: + obs = np.stack([self._build_per_agent_obs(agent_i) for agent_i in range(self.n_agents)]) + return obs + + def _build_per_agent_obs(self, agent_i: int) -> np.ndarray: if self.pomdp_radius: - pos = self._agent_states[0].pos - # pos = [agent_state.pos for agent_state in self.agent_states] - # obs = [] ... list comprehension... pos per agent - x0, x1 = max(0, pos[0] - self.pomdp_radius), pos[0] + self.pomdp_radius + 1 - y0, y1 = max(0, pos[1] - self.pomdp_radius), pos[1] + self.pomdp_radius + 1 + global_pos = self._agent_states[agent_i].pos + x0, x1 = max(0, global_pos[0] - self.pomdp_radius), global_pos[0] + self.pomdp_radius + 1 + y0, y1 = max(0, global_pos[1] - self.pomdp_radius), global_pos[1] + self.pomdp_radius + 1 obs = self._state[:, x0:x1, y0:y1] if obs.shape[1] != self.pomdp_radius * 2 + 1 or obs.shape[2] != self.pomdp_radius * 2 + 1: obs_padded = np.full((obs.shape[0], self.pomdp_radius * 2 + 1, self.pomdp_radius * 2 + 1), 1) - a_pos = np.argwhere(obs[h.AGENT_START_IDX] == h.IS_OCCUPIED_CELL)[0] + try: + a_pos = np.argwhere(obs[h.AGENT_START_IDX + agent_i] == h.IS_OCCUPIED_CELL)[0] + except IndexError: + print('NO') obs_padded[:, abs(a_pos[0]-self.pomdp_radius):abs(a_pos[0]-self.pomdp_radius)+obs.shape[1], abs(a_pos[1]-self.pomdp_radius):abs(a_pos[1]-self.pomdp_radius)+obs.shape[2]] = obs @@ -223,7 +126,13 @@ class BaseFactory(gym.Env): obs_new = obs[[key for key, val in self._state_slices.items() if 'agent' not in val]] return obs_new else: - return obs + if self.combin_agent_slices_in_obs: + agent_obs = np.sum(obs[[key for key, val in self._state_slices.items() if 'agent' in val]], + axis=0, keepdims=True) + obs = np.concatenate((obs[:h.AGENT_START_IDX], agent_obs, obs[h.AGENT_START_IDX+self.n_agents:])) + return obs + else: + return obs def do_additional_actions(self, agent_i: int, action: int) -> ((int, int), bool): raise NotImplementedError @@ -231,16 +140,16 @@ class BaseFactory(gym.Env): def step(self, actions): actions = [actions] if isinstance(actions, int) or np.isscalar(actions) else actions assert isinstance(actions, Iterable), f'"actions" has to be in [{int, list}]' - self.steps += 1 + self._steps += 1 done = False # Move this in a seperate function? agent_states = list() for agent_i, action in enumerate(actions): agent_i_state = AgentState(agent_i, action) - if self._is_moving_action(action): + if self._actions.is_moving_action(action): pos, valid = self.move_or_colide(agent_i, action) - elif self._is_no_op(action): + elif self._actions.is_no_op(action): pos, valid = self.agent_i_position(agent_i), True else: pos, valid = self.do_additional_actions(agent_i, action) @@ -256,24 +165,18 @@ class BaseFactory(gym.Env): self._agent_states = agent_states reward, info = self.calculate_reward(agent_states) - if self.steps >= self.max_steps: + if self._steps >= self.max_steps: done = True - info.update(step_reward=reward, step=self.steps) + info.update(step_reward=reward, step=self._steps) return None, reward, done, info - def _is_moving_action(self, action): - return action in self._actions.movement_actions - - def _is_no_op(self, action): - return self._actions[action] == 'no-op' - def check_all_collisions(self, agent_states: List[AgentState], collisions: int) -> np.ndarray: collision_vecs = np.zeros((len(agent_states), collisions)) # n_agents x n_slices for agent_state in agent_states: # Register only collisions of moving agents - if self._is_moving_action(agent_state.action): + if self._actions.is_moving_action(agent_state.action): collision_vecs[agent_state.i] = self.check_collisions(agent_state) return collision_vecs diff --git a/environments/factory/simple_factory.py b/environments/factory/simple_factory.py index c1f052d..85168c7 100644 --- a/environments/factory/simple_factory.py +++ b/environments/factory/simple_factory.py @@ -1,16 +1,14 @@ from collections import OrderedDict -from dataclasses import dataclass -from pathlib import Path from typing import List, Union, NamedTuple import random import numpy as np -from environments.factory.base_factory import BaseFactory, AgentState, MovementProperties +from environments.factory.base_factory import BaseFactory from environments import helpers as h -from environments.logging.monitor import MonitorCallback from environments.factory.renderer import Renderer, Entity +from environments.utility_classes import AgentState, MovementProperties DIRT_INDEX = -1 CLEAN_UP_ACTION = 'clean_up' @@ -25,13 +23,16 @@ class DirtProperties(NamedTuple): max_global_amount: int = 20 # Max dirt amount in the whole environment. +# noinspection PyAttributeOutsideInit class SimpleFactory(BaseFactory): @property def additional_actions(self) -> Union[str, List[str]]: return CLEAN_UP_ACTION - def _is_clean_up_action(self, action): + def _is_clean_up_action(self, action: Union[str, int]): + if isinstance(action, str): + action = self._actions.by_name(action) return self._actions[action] == CLEAN_UP_ACTION def __init__(self, *args, dirt_properties: DirtProperties, verbose=False, **kwargs): @@ -47,9 +48,9 @@ class SimpleFactory(BaseFactory): height, width = self._state.shape[1:] self._renderer = Renderer(width, height, view_radius=self.pomdp_radius) - dirt = [Entity('dirt', [x, y], min(0.15 + self._state[DIRT_INDEX, x, y], 1.5), 'scale') - for x, y in np.argwhere(self._state[DIRT_INDEX] > h.IS_FREE_CELL)] - walls = [Entity('wall', pos) for pos in np.argwhere(self._state[h.LEVEL_IDX] > h.IS_FREE_CELL)] + dirt = [Entity('dirt', [x, y], min(0.15 + self._state[DIRT_INDEX, x, y], 1.5), 'scale') + for x, y in np.argwhere(self._state[DIRT_INDEX] > h.IS_FREE_CELL)] + walls = [Entity('wall', pos) for pos in np.argwhere(self._state[h.LEVEL_IDX] > h.IS_FREE_CELL)] def asset_str(agent): if any([x is None for x in [self._state_slices[j] for j in agent.collisions]]): @@ -93,17 +94,18 @@ class SimpleFactory(BaseFactory): return pos, cleanup_was_sucessfull def step(self, actions): - _, r, done, info = super(SimpleFactory, self).step(actions) + _, reward, done, info = super(SimpleFactory, self).step(actions) if not self._next_dirt_spawn: self.spawn_dirt() self._next_dirt_spawn = self.dirt_properties.spawn_frequency else: self._next_dirt_spawn -= 1 - obs = self._return_state() - return obs, r, done, info + + obs = self._get_observations() + return obs, reward, done, info def do_additional_actions(self, agent_i: int, action: int) -> ((int, int), bool): - if action != self._is_moving_action(action): + if action != self._actions.is_moving_action(action): if self._is_clean_up_action(action): agent_i_pos = self.agent_i_position(agent_i) _, valid = self.clean_up(agent_i_pos) @@ -119,7 +121,7 @@ class SimpleFactory(BaseFactory): self._state = np.concatenate((self._state, dirt_slice)) # dirt is now the last slice self.spawn_dirt() self._next_dirt_spawn = self.dirt_properties.spawn_frequency - obs = self._return_state() + obs = self._get_observations() return obs def calculate_reward(self, agent_states: List[AgentState]) -> (int, dict): @@ -141,7 +143,7 @@ class SimpleFactory(BaseFactory): if entity != self._state_slices.by_name("dirt")] if list_of_collisions: - self.print(f't = {self.steps}\tAgent {agent_state.i} has collisions with ' + self.print(f't = {self._steps}\tAgent {agent_state.i} has collisions with ' f'{list_of_collisions}') if self._is_clean_up_action(agent_state.action): @@ -155,7 +157,7 @@ class SimpleFactory(BaseFactory): f'at {agent_state.pos}, but was unsucsessfull.') info_dict.update(failed_cleanup_attempt=1) - elif self._is_moving_action(agent_state.action): + elif self._actions.is_moving_action(agent_state.action): if agent_state.action_valid: # info_dict.update(movement=1) reward -= 0.00 @@ -185,10 +187,11 @@ class SimpleFactory(BaseFactory): if __name__ == '__main__': render = True - import yaml - with Path(r'C:\Users\steff\projects\f_iks\debug_out\yaml.txt').open('r') as f: - env_kwargs = yaml.load(f) - factory = SimpleFactory(**env_kwargs) + + move_props = MovementProperties(allow_diagonal_movement=True, allow_square_movement=True) + dirt_props = DirtProperties() + factory = SimpleFactory(movement_properties=move_props, dirt_properties=dirt_props, n_agents=2, + combin_agent_slices_in_obs=True, omit_agent_slice_in_obs=False) # dirt_props = DirtProperties() # move_props = MovementProperties(allow_diagonal_movement=False, allow_no_op=False) @@ -200,10 +203,12 @@ if __name__ == '__main__': for epoch in range(100): random_actions = [[random.randint(0, n_actions) for _ in range(factory.n_agents)] for _ in range(200)] env_state = factory.reset() + r = 0 for agent_i_action in random_actions: - env_state, reward, done_bool, info_obj = factory.step(agent_i_action) + env_state, step_r, done_bool, info_obj = factory.step(agent_i_action) + r += step_r if render: factory.render() if done_bool: break - print(f'Factory run {epoch} done, reward is:\n {reward}') + print(f'Factory run {epoch} done, reward is:\n {r}') diff --git a/environments/logging/plotting.py b/environments/logging/plotting.py index 91585a2..323b99f 100644 --- a/environments/logging/plotting.py +++ b/environments/logging/plotting.py @@ -32,13 +32,15 @@ def prepare_plot(filepath, results_df, ext='png', hue='Measurement', style=None) hue_order = sorted(list(df[hue].unique())) try: sns.set(rc={'text.usetex': True}, style='whitegrid') - _ = sns.lineplot(data=df, x='Episode', y='Score', ci=95, palette=PALETTE, - hue_order=hue_order, hue=hue, style=style) + lineplot = sns.lineplot(data=df, x='Episode', y='Score', ci=95, palette=PALETTE, + hue_order=hue_order, hue=hue, style=style) + lineplot.set_title(f'{sorted(list(df["Measurement"].unique()))}') plot(filepath, ext=ext) # plot raises errors not lineplot! except (FileNotFoundError, RuntimeError): print('Struggling to plot Figure using LaTeX - going back to normal.') plt.close('all') sns.set(rc={'text.usetex': False}, style='whitegrid') - sns.lineplot(data=df, x='Episode', y='Score', hue=hue, style=style, - ci=95, palette=PALETTE, hue_order=hue_order) + lineplot = sns.lineplot(data=df, x='Episode', y='Score', hue=hue, style=style, + ci=95, palette=PALETTE, hue_order=hue_order) + lineplot.set_title(f'{sorted(list(df["Measurement"].unique()))}') plot(filepath, ext=ext) diff --git a/environments/utility_classes.py b/environments/utility_classes.py new file mode 100644 index 0000000..b49609c --- /dev/null +++ b/environments/utility_classes.py @@ -0,0 +1,127 @@ +from typing import Union, List, NamedTuple +import numpy as np + + +class MovementProperties(NamedTuple): + allow_square_movement: bool = True + allow_diagonal_movement: bool = False + allow_no_op: bool = False + +# Preperations for Entities (not used yet) +class Entity: + + @property + def pos(self): + return self._pos + + @property + def identifier(self): + return self._identifier + + def __init__(self, identifier, pos): + self._pos = pos + self._identifier = identifier + + +class AgentState: + + def __init__(self, i: int, action: int): + self.i = i + self.action = action + + self.collision_vector = None + self.action_valid = None + self.pos = None + self.info = {} + + @property + def collisions(self): + return np.argwhere(self.collision_vector != 0).flatten() + + def update(self, **kwargs): # is this hacky?? o.0 + for key, value in kwargs.items(): + if hasattr(self, key): + self.__setattr__(key, value) + else: + raise AttributeError(f'"{key}" cannot be updated, this attr is not a part of {self.__class__.__name__}') + + +class Register: + + @property + def n(self): + return len(self) + + def __init__(self): + self._register = dict() + + def __len__(self): + return len(self._register) + + def __add__(self, other: Union[str, List[str]]): + other = other if isinstance(other, list) else [other] + assert all([isinstance(x, str) for x in other]), f'All item names have to be of type {str}.' + self._register.update({key+len(self._register): value for key, value in enumerate(other)}) + return self + + def register_additional_items(self, other: Union[str, List[str]]): + self_with_additional_items = self + other + return self_with_additional_items + + def keys(self): + return self._register.keys() + + def items(self): + return self._register.items() + + def __getitem__(self, item): + return self._register[item] + + def by_name(self, item): + return list(self._register.keys())[list(self._register.values()).index(item)] + + def __repr__(self): + return f'{self.__class__.__name__}({self._register})' + + +class Actions(Register): + + @property + def movement_actions(self): + return self._movement_actions + + def __init__(self, movement_properties: MovementProperties): + self.allow_no_op = movement_properties.allow_no_op + self.allow_diagonal_movement = movement_properties.allow_diagonal_movement + self.allow_square_movement = movement_properties.allow_square_movement + # FIXME: There is a bug in helpers because there actions are ints. and the order matters. + # assert not(self.allow_square_movement is False and self.allow_diagonal_movement is True), \ + # "There is a bug in helpers!!!" + super(Actions, self).__init__() + + if self.allow_square_movement: + self + ['north', 'east', 'south', 'west'] + if self.allow_diagonal_movement: + self + ['north_east', 'south_east', 'south_west', 'north_west'] + self._movement_actions = self._register.copy() + if self.allow_no_op: + self + 'no-op' + + def is_moving_action(self, action: Union[str, int]): + if isinstance(action, str): + return action in self.movement_actions.values() + else: + return self[action] in self.movement_actions.values() + + def is_no_op(self, action: Union[str, int]): + if isinstance(action, str): + action = self.by_name(action) + return self[action] == 'no-op' + + +class StateSlice(Register): + + def __init__(self, n_agents: int): + super(StateSlice, self).__init__() + offset = 1 + self.register_additional_items(['level', *[f'agent#{i}' for i in range(offset, n_agents+offset)]]) diff --git a/main.py b/main.py index f77bf8a..63b7588 100644 --- a/main.py +++ b/main.py @@ -94,7 +94,7 @@ if __name__ == '__main__': dirt_props = DirtProperties(clean_amount=3, gain_amount=0.2, max_global_amount=30, max_local_amount=5, spawn_frequency=3) - move_props = MovementProperties(allow_diagonal_movement=False, + move_props = MovementProperties(allow_diagonal_movement=True, allow_square_movement=True, allow_no_op=False) time_stamp = int(time.time()) @@ -123,7 +123,7 @@ if __name__ == '__main__': [MonitorCallback(filepath=out_path / f'monitor_{identifier}.pick', plotting=False)] ) - model.learn(total_timesteps=int(5e5), callback=callbacks) + model.learn(total_timesteps=int(1e5), callback=callbacks) save_path = out_path / f'model_{identifier}.zip' save_path.parent.mkdir(parents=True, exist_ok=True) diff --git a/main_test.py b/main_test.py new file mode 100644 index 0000000..5113b68 --- /dev/null +++ b/main_test.py @@ -0,0 +1,75 @@ +# foreign imports +import warnings + +from pathlib import Path +import yaml +from natsort import natsorted + +from stable_baselines3.common.callbacks import CallbackList +from stable_baselines3 import PPO, DQN, A2C + +# our imports +from environments.factory.simple_factory import SimpleFactory +from environments.logging.monitor import MonitorCallback +from algorithms.reg_dqn import RegDQN +from main import compare_runs, combine_runs + +warnings.filterwarnings('ignore', category=FutureWarning) +warnings.filterwarnings('ignore', category=UserWarning) +model_mapping = dict(A2C=A2C, PPO=PPO, DQN=DQN, RegDQN=RegDQN) + + +if __name__ == '__main__': + + # get n policies pi_1, ..., pi_n trained in single agent setting + # rewards = [] + # repeat for x eval runs + # total reward = rollout game for y steps with n policies in multi-agent setting + # rewards += [total reward] + # boxplot total rewards + + run_id = '1623078961' + model_name = 'PPO' + + # ----------------------- + out_path = Path(__file__).parent / 'debug_out' + + # from sb3_contrib import QRDQN + model_path = out_path / f'{model_name}_{run_id}' + model_files = list(natsorted(model_path.rglob('model_*.zip'))) + this_model = model_files[0] + render = True + + model = model_mapping[model_name].load(this_model) + + for seed in range(3): + with (model_path / f'env_{model_path.name}.yaml').open('r') as f: + env_kwargs = yaml.load(f, Loader=yaml.FullLoader) + env_kwargs.update(n_agents=2) + env = SimpleFactory(**env_kwargs) + + exp_out_path = model_path / 'exp' + callbacks = CallbackList( + [MonitorCallback(filepath=exp_out_path / f'future_exp_name', plotting=True)] + ) + + n_actions = env.action_space.n + + for epoch in range(100): + observations = env.reset() + if render: + env.render() + done_bool = False + r = 0 + while not done_bool: + actions = [model.predict(obs, deterministic=False)[0] for obs in observations] + + obs, r, done_bool, info_obj = env.step(actions) + if render: + env.render() + if done_bool: + break + print(f'Factory run {epoch} done, reward is:\n {r}') + + if out_path: + combine_runs(out_path.parent)