diff --git a/environments/factory/assets/agents/move.png b/environments/factory/assets/agents/move.png new file mode 100644 index 0000000..2a56ae4 Binary files /dev/null and b/environments/factory/assets/agents/move.png differ diff --git a/environments/factory/assets/agents/valid.png b/environments/factory/assets/agents/valid.png index ae7c768..8341ab6 100644 Binary files a/environments/factory/assets/agents/valid.png and b/environments/factory/assets/agents/valid.png differ diff --git a/environments/factory/base/base_factory.py b/environments/factory/base/base_factory.py index 3e96e90..7749898 100644 --- a/environments/factory/base/base_factory.py +++ b/environments/factory/base/base_factory.py @@ -60,9 +60,12 @@ class BaseFactory(gym.Env): omit_agent_in_obs=False, done_at_collision=False, cast_shadows=True, verbose=False, doors_have_area=True, env_seed=time.time_ns(), **kwargs): assert frames_to_stack != 1 and frames_to_stack >= 0, "'frames_to_stack' cannot be negative or 1." + if kwargs: + print(f'Following kwargs were passed, but ignored: {kwargs}') # Attribute Assignment self.env_seed = env_seed + self.seed(env_seed) self._base_rng = np.random.default_rng(self.env_seed) self.movement_properties = movement_properties self.level_name = level_name @@ -85,11 +88,6 @@ class BaseFactory(gym.Env): self.parse_doors = parse_doors self.doors_have_area = doors_have_area - # Actions - self._actions = Actions(self.movement_properties, can_use_doors=self.parse_doors) - if additional_actions := self.additional_actions: - self._actions.register_additional_items(additional_actions) - # Reset self.reset() @@ -123,11 +121,17 @@ class BaseFactory(gym.Env): self.NO_POS_TILE = Tile(c.NO_POS.value) # Doors - parsed_doors = h.one_hot_level(parsed_level, c.DOOR) - if np.any(parsed_doors): - door_tiles = [floor.by_pos(pos) for pos in np.argwhere(parsed_doors == c.OCCUPIED_CELL.value)] - doors = Doors.from_tiles(door_tiles, self._level_shape, context=floor, is_blocking_light=True) - entities.update({c.DOORS: doors}) + if self.parse_doors: + parsed_doors = h.one_hot_level(parsed_level, c.DOOR) + if np.any(parsed_doors): + door_tiles = [floor.by_pos(pos) for pos in np.argwhere(parsed_doors == c.OCCUPIED_CELL.value)] + doors = Doors.from_tiles(door_tiles, self._level_shape, context=floor, is_blocking_light=True) + entities.update({c.DOORS: doors}) + + # Actions + self._actions = Actions(self.movement_properties, can_use_doors=self.parse_doors) + if additional_actions := self.additional_actions: + self._actions.register_additional_items(additional_actions) # Agents agents = Agents.from_tiles(floor.empty_tiles[:self.n_agents], self._level_shape) @@ -155,8 +159,8 @@ class BaseFactory(gym.Env): # Optionally Pad this obs cube for pomdp cases if r := self.pomdp_r: x, y = self._level_shape + # was c.SHADOW self._padded_obs_cube = np.full((obs_cube_z, x + r*2, y + r*2), c.SHADOWED_CELL.value, dtype=np.float32) - # self._padded_obs_cube[0] = c.OCCUPIED_CELL.value self._padded_obs_cube[:, r:r+x, r:r+y] = self._obs_cube def reset(self) -> (np.ndarray, int, bool, dict): @@ -170,7 +174,10 @@ class BaseFactory(gym.Env): return obs def step(self, actions): - actions = [actions] if isinstance(actions, int) or np.isscalar(actions) else actions + + if self.n_agents == 1: + actions = [int(actions)] + assert isinstance(actions, Iterable), f'"actions" has to be in [{int, list}]' self._steps += 1 done = False @@ -180,9 +187,10 @@ class BaseFactory(gym.Env): # Move this in a seperate function? for action, agent in zip(actions, self[c.AGENT]): - agent.clear_temp_sate() + agent.clear_temp_state() action_obj = self._actions[int(action)] - if self._actions.is_moving_action(action_obj): + self.print(f'Action #{action} has been resolved to: {action_obj}') + if h.MovingAction.is_member(action_obj): valid = self._move_or_colide(agent, action_obj) elif h.EnvActions.NOOP == agent.temp_action: valid = c.VALID @@ -210,7 +218,8 @@ class BaseFactory(gym.Env): # Step the door close intervall if self.parse_doors: - self[c.DOORS].tick_doors() + if doors := self[c.DOORS]: + doors.tick_doors() # Finalize reward, reward_info = self.calculate_reward() @@ -229,15 +238,18 @@ class BaseFactory(gym.Env): return obs, reward, done, info def _handle_door_interaction(self, agent) -> c: - # Check if agent really is standing on a door: - if self.doors_have_area: - door = self[c.DOORS].get_near_position(agent.pos) - else: - door = self[c.DOORS].by_pos(agent.pos) - if door is not None: - door.use() - return c.VALID - # When he doesn't... + if doors := self[c.DOORS]: + # Check if agent really is standing on a door: + if self.doors_have_area: + door = doors.get_near_position(agent.pos) + else: + door = doors.by_pos(agent.pos) + if door is not None: + door.use() + return c.VALID + # When he doesn't... + else: + return c.NOT_VALID else: return c.NOT_VALID @@ -284,8 +296,9 @@ class BaseFactory(gym.Env): state_array_dict[c.AGENT][0, agent.x, agent.y] += agent.encoding if r := self.pomdp_r: + self._padded_obs_cube[:] = c.SHADOWED_CELL.value # Was c.SHADOW + # self._padded_obs_cube[0] = c.OCCUPIED_CELL.value x, y = self._level_shape - self._padded_obs_cube[:] = c.SHADOWED_CELL.value self._padded_obs_cube[:, r:r + x, r:r + y] = self._obs_cube global_x, global_y = map(sum, zip(agent.pos, (r, r))) x0, x1 = max(0, global_x - self.pomdp_r), global_x + self.pomdp_r + 1 @@ -297,20 +310,22 @@ class BaseFactory(gym.Env): if self.cast_shadows: obs_block_light = [obs[idx] != c.OCCUPIED_CELL.value for idx in shadowing_idxs] door_shadowing = False - if door := self[c.DOORS].by_pos(agent.pos): - if door.is_closed: - for group in door.connectivity_subgroups: - if agent.last_pos not in group: - door_shadowing = True - if self.pomdp_r: - blocking = [tuple(np.subtract(x, agent.pos) + (self.pomdp_r, self.pomdp_r)) - for x in group] - xs, ys = zip(*blocking) - else: - xs, ys = zip(*group) + if self.parse_doors: + if doors := self[c.DOORS]: + if door := doors.by_pos(agent.pos): + if door.is_closed: + for group in door.connectivity_subgroups: + if agent.last_pos not in group: + door_shadowing = True + if self.pomdp_r: + blocking = [tuple(np.subtract(x, agent.pos) + (self.pomdp_r, self.pomdp_r)) + for x in group] + xs, ys = zip(*blocking) + else: + xs, ys = zip(*group) - # noinspection PyUnresolvedReferences - obs_block_light[0][xs, ys] = False + # noinspection PyUnresolvedReferences + obs_block_light[0][xs, ys] = False light_block_map = Map((np.prod(obs_block_light, axis=0) != True).astype(int)) if self.pomdp_r: @@ -361,22 +376,24 @@ class BaseFactory(gym.Env): return tile, valid if self.parse_doors and agent.last_pos != c.NO_POS: - if door := self[c.DOORS].by_pos(new_tile.pos): - if door.can_collide: - return agent.tile, c.NOT_VALID - else: # door.is_closed: - pass + if doors := self[c.DOORS]: + if self.doors_have_area: + if door := doors.by_pos(new_tile.pos): + if door.can_collide: + return agent.tile, c.NOT_VALID + else: # door.is_closed: + pass - if door := self[c.DOORS].by_pos(agent.pos): - if door.is_open: - pass - else: # door.is_closed: - if door.is_linked(agent.last_pos, new_tile.pos): + if door := doors.by_pos(agent.pos): + if door.is_open: pass - else: - return agent.tile, c.NOT_VALID - else: - pass + else: # door.is_closed: + if door.is_linked(agent.last_pos, new_tile.pos): + pass + else: + return agent.tile, c.NOT_VALID + else: + pass else: pass @@ -391,7 +408,9 @@ class BaseFactory(gym.Env): if self._actions.is_moving_action(agent.temp_action): if agent.temp_valid: # info_dict.update(movement=1) - reward -= 0.00 + # info_dict.update({f'{agent.name}_failed_action': 1}) + # reward += 0.00 + pass else: # self.print('collision') reward -= 0.01 @@ -400,16 +419,17 @@ class BaseFactory(gym.Env): elif h.EnvActions.USE_DOOR == agent.temp_action: if agent.temp_valid: + # reward += 0.00 self.print(f'{agent.name} did just use the door at {agent.pos}.') info_dict.update(door_used=1) else: - reward -= 0.00 + # reward -= 0.00 self.print(f'{agent.name} just tried to use a door at {agent.pos}, but failed.') info_dict.update({f'{agent.name}_failed_action': 1}) info_dict.update({f'{agent.name}_failed_door_open': 1}) elif h.EnvActions.NOOP == agent.temp_action: info_dict.update(no_op=1) - reward -= 0.00 + # reward -= 0.00 additional_reward, additional_info_dict = self.calculate_additional_reward(agent) reward += additional_reward diff --git a/environments/factory/base/objects.py b/environments/factory/base/objects.py index aa67b14..fda6607 100644 --- a/environments/factory/base/objects.py +++ b/environments/factory/base/objects.py @@ -24,15 +24,27 @@ class Object: @property def identifier(self): - return self._enum_ident - - def __init__(self, enum_ident: Union[Enum, None] = None, is_blocking_light=False, **kwargs): - self._enum_ident = enum_ident if self._enum_ident is not None: - self._name = f'{self.__class__.__name__}[{self._enum_ident.name}]' + return self._enum_ident + elif self._str_ident is not None: + return self._str_ident else: + return self._name + + def __init__(self, str_ident: Union[str, None] = None, enum_ident: Union[Enum, None] = None, is_blocking_light=False, **kwargs): + self._str_ident = str_ident + self._enum_ident = enum_ident + + if self._enum_ident is not None and self._str_ident is None: + self._name = f'{self.__class__.__name__}[{self._enum_ident.name}]' + elif self._str_ident is not None and self._enum_ident is None: + self._name = f'{self.__class__.__name__}[{self._str_ident}]' + elif self._str_ident is None and self._enum_ident is None: self._name = f'{self.__class__.__name__}#{self._u_idx}' - Object._u_idx += 1 + Object._u_idx += 1 + else: + raise ValueError('Please use either of the idents.') + self._is_blocking_light = is_blocking_light if kwargs: print(f'Following kwargs were passed, but ignored: {kwargs}') @@ -166,7 +178,7 @@ class Door(Entity): @property def encoding(self): - return 1 if self.is_closed else -1 + return 1 if self.is_closed else 0.5 @property def access_area(self): @@ -274,10 +286,10 @@ class Agent(MoveableEntity): def __init__(self, *args, **kwargs): super(Agent, self).__init__(*args, **kwargs) - self.clear_temp_sate() + self.clear_temp_state() # noinspection PyAttributeOutsideInit - def clear_temp_sate(self): + def clear_temp_state(self): # for attr in self.__dict__: # if attr.startswith('temp'): self.temp_collisions = [] diff --git a/environments/factory/base/registers.py b/environments/factory/base/registers.py index 6cb9b21..4f09641 100644 --- a/environments/factory/base/registers.py +++ b/environments/factory/base/registers.py @@ -53,7 +53,10 @@ class Register: return next(v for i, v in enumerate(self._register.values()) if i == item) except StopIteration: return None - return self._register[item] + try: + return self._register[item] + except KeyError: + return None def __repr__(self): return f'{self.__class__.__name__}({self._register})' @@ -84,8 +87,8 @@ class EntityObjectRegister(ObjectRegister, ABC): @classmethod def from_tiles(cls, tiles, *args, **kwargs): # objects_name = cls._accepted_objects.__name__ - entities = [cls._accepted_objects(tile, **kwargs) - for tile in tiles] + entities = [cls._accepted_objects(tile, str_ident=i, **kwargs) + for i, tile in enumerate(tiles)] register_obj = cls(*args) register_obj.register_additional_items(entities) return register_obj @@ -294,10 +297,10 @@ class Actions(Register): if self.allow_square_movement: self.register_additional_items([self._accepted_objects(enum_ident=direction) - for direction in h.ManhattanMoves]) + for direction in h.MovingAction.square()]) if self.allow_diagonal_movement: self.register_additional_items([self._accepted_objects(enum_ident=direction) - for direction in h.DiagonalMoves]) + for direction in h.MovingAction.diagonal()]) self._movement_actions = self._register.copy() if self.can_use_doors: self.register_additional_items([self._accepted_objects(enum_ident=h.EnvActions.USE_DOOR)]) diff --git a/environments/factory/renderer.py b/environments/factory/renderer.py index 42491db..a4ca734 100644 --- a/environments/factory/renderer.py +++ b/environments/factory/renderer.py @@ -79,14 +79,15 @@ class Renderer: rects = [] for i, j in product(range(-self.view_radius, self.view_radius+1), range(-self.view_radius, self.view_radius+1)): - if bool(view[self.view_radius+j, self.view_radius+i]): - visibility_rect = bp['dest'].copy() - visibility_rect.centerx += i*self.cell_size - visibility_rect.centery += j*self.cell_size - shape_surf = pygame.Surface(visibility_rect.size, pygame.SRCALPHA) - pygame.draw.rect(shape_surf, self.AGENT_VIEW_COLOR, shape_surf.get_rect()) - shape_surf.set_alpha(64) - rects.append(dict(source=shape_surf, dest=visibility_rect)) + if view is not None: + if bool(view[self.view_radius+j, self.view_radius+i]): + visibility_rect = bp['dest'].copy() + visibility_rect.centerx += i*self.cell_size + visibility_rect.centery += j*self.cell_size + shape_surf = pygame.Surface(visibility_rect.size, pygame.SRCALPHA) + pygame.draw.rect(shape_surf, self.AGENT_VIEW_COLOR, shape_surf.get_rect()) + shape_surf.set_alpha(64) + rects.append(dict(source=shape_surf, dest=visibility_rect)) return rects def render(self, entities): diff --git a/environments/factory/simple_factory.py b/environments/factory/simple_factory.py index 667e145..1df8bef 100644 --- a/environments/factory/simple_factory.py +++ b/environments/factory/simple_factory.py @@ -94,6 +94,10 @@ class DirtRegister(MovingEntityObjectRegister): return c.NOT_VALID return c.VALID + def __repr__(self): + s = super(DirtRegister, self).__repr__() + return f'{s[:-1]}, {self.amount})' + def softmax(x): """Compute softmax values for each sets of scores in x.""" @@ -149,7 +153,10 @@ class SimpleFactory(BaseFactory): return c.NOT_VALID def trigger_dirt_spawn(self): - free_for_dirt = self[c.FLOOR].empty_tiles + free_for_dirt = [x for x in self[c.FLOOR] + if len(x.guests) == 0 or (len(x.guests) == 1 and isinstance(next(y for y in x.guests), Dirt)) + ] + self._dirt_rng.shuffle(free_for_dirt) new_spawn = self._dirt_rng.uniform(0, self.dirt_properties.max_spawn_ratio) n_dirt_tiles = max(0, int(new_spawn * len(free_for_dirt))) self[c.DIRT].spawn_dirt(free_for_dirt[:n_dirt_tiles]) @@ -216,7 +223,7 @@ class SimpleFactory(BaseFactory): self.print(f'{agent.name} did just clean up some dirt at {agent.pos}.') info_dict.update(dirt_cleaned=1) else: - reward -= 0.00 + reward -= 0.01 self.print(f'{agent.name} just tried to clean up some dirt at {agent.pos}, but failed.') info_dict.update({f'{agent.name}_failed_action': 1}) info_dict.update({f'{agent.name}_failed_action': 1}) @@ -235,8 +242,8 @@ if __name__ == '__main__': factory = SimpleFactory(n_agents=1, done_at_collision=False, frames_to_stack=0, level_name='rooms', max_steps=400, combin_agent_obs=True, - omit_agent_in_obs=True, parse_doors=True, pomdp_r=2, - record_episodes=False, verbose=True + omit_agent_in_obs=True, parse_doors=False, pomdp_r=2, + record_episodes=False, verbose=True, cast_shadows=False ) # noinspection DuplicatedCode diff --git a/environments/helpers.py b/environments/helpers.py index b9de4b1..c9127b9 100644 --- a/environments/helpers.py +++ b/environments/helpers.py @@ -50,19 +50,28 @@ class Constants(Enum): return bool(self.value) -class ManhattanMoves(Enum): +class MovingAction(Enum): NORTH = 'north' EAST = 'east' SOUTH = 'south' WEST = 'west' - - -class DiagonalMoves(Enum): NORTHEAST = 'north_east' SOUTHEAST = 'south_east' SOUTHWEST = 'south_west' NORTHWEST = 'north_west' + @classmethod + def is_member(cls, other): + return any([other == direction for direction in cls]) + + @classmethod + def square(cls): + return [cls.NORTH, cls.EAST, cls.SOUTH, cls.WEST] + + @classmethod + def diagonal(cls): + return [cls.NORTHEAST, cls.SOUTHEAST, cls.SOUTHWEST, cls.NORTHWEST] + class EnvActions(Enum): NOOP = 'no_op' @@ -71,14 +80,13 @@ class EnvActions(Enum): ITEM_ACTION = 'item_action' -d = DiagonalMoves -m = ManhattanMoves +m = MovingAction c = Constants -ACTIONMAP = defaultdict(lambda: (0, 0), {m.NORTH: (-1, 0), d.NORTHEAST: (-1, +1), - m.EAST: (0, 1), d.SOUTHEAST: (1, 1), - m.SOUTH: (1, 0), d.SOUTHWEST: (+1, -1), - m.WEST: (0, -1), d.NORTHWEST: (-1, -1) +ACTIONMAP = defaultdict(lambda: (0, 0), {m.NORTH: (-1, 0), m.NORTHEAST: (-1, +1), + m.EAST: (0, 1), m.SOUTHEAST: (1, 1), + m.SOUTH: (1, 0), m.SOUTHWEST: (+1, -1), + m.WEST: (0, -1), m.NORTHWEST: (-1, -1) } ) @@ -126,8 +134,10 @@ def asset_str(agent): return 'agent_collision', 'blank' elif not agent.temp_valid or c.LEVEL.name in col_names or c.AGENT.name in col_names: return c.AGENT.value, 'invalid' - elif agent.temp_valid: + elif agent.temp_valid and not MovingAction.is_member(agent.temp_action): return c.AGENT.value, 'valid' + elif agent.temp_valid and MovingAction.is_member(agent.temp_action): + return c.AGENT.value, 'move' else: return c.AGENT.value, 'idle' diff --git a/environments/logging/monitor.py b/environments/logging/monitor.py index 93786d0..9ded10b 100644 --- a/environments/logging/monitor.py +++ b/environments/logging/monitor.py @@ -1,4 +1,5 @@ import pickle +from collections import defaultdict from pathlib import Path from typing import List, Dict @@ -17,7 +18,7 @@ class MonitorCallback(BaseCallback): super(MonitorCallback, self).__init__() self.filepath = Path(filepath) self._monitor_df = pd.DataFrame() - self._monitor_dict = dict() + self._monitor_dicts = defaultdict(dict) self.plotting = plotting self.started = False self.closed = False @@ -69,16 +70,22 @@ class MonitorCallback(BaseCallback): def _on_step(self, alt_infos: List[Dict] = None, alt_dones: List[bool] = None) -> bool: infos = alt_infos or self.locals.get('infos', []) - dones = alt_dones or self.locals.get('dones', None) or self.locals.get('done', [None]) - for _, info in enumerate(infos): - self._monitor_dict[self.num_timesteps] = {key: val for key, val in info.items() - if key not in ['terminal_observation', 'episode'] - and not key.startswith('rec_')} + if alt_dones is not None: + dones = alt_dones + elif self.locals.get('dones', None) is not None: + dones =self.locals.get('dones', None) + elif self.locals.get('dones', None) is not None: + dones = self.locals.get('done', [None]) + else: + dones = [] - for env_idx, done in enumerate(dones): + for env_idx, (info, done) in enumerate(zip(infos, dones)): + self._monitor_dicts[env_idx][self.num_timesteps - env_idx] = {key: val for key, val in info.items() + if key not in ['terminal_observation', 'episode'] + and not key.startswith('rec_')} if done: - env_monitor_df = pd.DataFrame.from_dict(self._monitor_dict, orient='index') - self._monitor_dict = dict() + env_monitor_df = pd.DataFrame.from_dict(self._monitor_dicts[env_idx], orient='index') + self._monitor_dicts[env_idx] = dict() columns = [col for col in env_monitor_df.columns if col not in IGNORED_DF_COLUMNS] env_monitor_df = env_monitor_df.aggregate( {col: 'mean' if col.endswith('ount') else 'sum' for col in columns} diff --git a/main.py b/main.py index 5967248..d64a472 100644 --- a/main.py +++ b/main.py @@ -8,6 +8,7 @@ import time import pandas as pd from stable_baselines3.common.callbacks import CallbackList +from stable_baselines3.common.vec_env import SubprocVecEnv from environments.factory.double_task_factory import DoubleTaskFactory, ItemProperties from environments.factory.simple_factory import DirtProperties, SimpleFactory @@ -84,8 +85,20 @@ def compare_runs(run_path: Path, run_identifier: int, parameter: Union[str, List print('Plotting done.') +def make_env(env_kwargs_dict): + + def _init(): + with SimpleFactory(**env_kwargs_dict) as init_env: + return init_env + + return _init + + if __name__ == '__main__': + # combine_runs(Path('debug_out') / 'A2C_1630314192') + # exit() + # compare_runs(Path('debug_out'), 1623052687, ['step_reward']) # exit() @@ -93,65 +106,67 @@ if __name__ == '__main__': from algorithms.reg_dqn import RegDQN # from sb3_contrib import QRDQN - dirt_props = DirtProperties(clean_amount=1, gain_amount=0.1, max_global_amount=20, - max_local_amount=1, spawn_frequency=5, max_spawn_ratio=0.05, - dirt_smear_amount=0.0, agent_can_interact=False) + dirt_props = DirtProperties(clean_amount=2, gain_amount=0.1, max_global_amount=20, + max_local_amount=1, spawn_frequency=3, max_spawn_ratio=0.05, + dirt_smear_amount=0.0, agent_can_interact=True) item_props = ItemProperties(n_items=5, agent_can_interact=True) - move_props = MovementProperties(allow_diagonal_movement=True, + move_props = MovementProperties(allow_diagonal_movement=False, allow_square_movement=True, allow_no_op=False) - train_steps = 6e5 + train_steps = 1e6 time_stamp = int(time.time()) out_path = None for modeL_type in [A2C, PPO, DQN]: # ,RegDQN, QRDQN]: for seed in range(3): + env_kwargs = dict(n_agents=1, + # with_dirt=True, + # item_properties=item_props, + dirt_properties=dirt_props, + movement_properties=move_props, + pomdp_r=2, max_steps=400, parse_doors=True, + level_name='simple', frames_to_stack=6, + omit_agent_in_obs=True, combin_agent_obs=True, record_episodes=False, + cast_shadows=True, doors_have_area=False, env_seed=seed, verbose=False, + ) - with SimpleFactory(n_agents=1, - # with_dirt=True, - # item_properties=item_props, - dirt_properties=dirt_props, - movement_properties=move_props, - pomdp_radius=2, max_steps=500, parse_doors=True, - level_name='rooms', frames_to_stack=3, - omit_agent_in_obs=True, combin_agent_slices_in_obs=True, record_episodes=False, - cast_shadows=True, doors_have_area=False, seed=seed, verbose=False, - ) as env: + # env = make_env(env_kwargs)() + env = SubprocVecEnv([make_env(env_kwargs) for _ in range(12)], start_method="spawn") - if modeL_type.__name__ in ["PPO", "A2C"]: - kwargs = dict(ent_coef=0.01) - elif modeL_type.__name__ in ["RegDQN", "DQN", "QRDQN"]: - kwargs = dict(buffer_size=50000, - learning_starts=64, - batch_size=64, - target_update_interval=5000, - exploration_fraction=0.25, - exploration_final_eps=0.025) - else: - raise NameError(f'The model "{model.__name__}" has the wrong name.') - model = modeL_type("MlpPolicy", env, verbose=1, seed=seed, device='cpu', **kwargs) + if modeL_type.__name__ in ["PPO", "A2C"]: + kwargs = dict(ent_coef=0.01) + elif modeL_type.__name__ in ["RegDQN", "DQN", "QRDQN"]: + kwargs = dict(buffer_size=50000, + learning_starts=64, + batch_size=64, + target_update_interval=5000, + exploration_fraction=0.25, + exploration_final_eps=0.025) + else: + raise NameError(f'The model "{modeL_type.__name__}" has the wrong name.') + model = modeL_type("MlpPolicy", env, verbose=1, seed=seed, device='cpu', **kwargs) - out_path = Path('debug_out') / f'{model.__class__.__name__}_{time_stamp}' + out_path = Path('debug_out') / f'{model.__class__.__name__}_{time_stamp}' - # identifier = f'{seed}_{model.__class__.__name__}_{time_stamp}' - identifier = f'{seed}_{model.__class__.__name__}_{time_stamp}' - out_path /= identifier + # identifier = f'{seed}_{model.__class__.__name__}_{time_stamp}' + identifier = f'{seed}_{model.__class__.__name__}_{time_stamp}' + out_path /= identifier - callbacks = CallbackList( - [MonitorCallback(filepath=out_path / f'monitor_{identifier}.pick', plotting=False), - RecorderCallback(filepath=out_path / f'recorder_{identifier}.json', occupation_map=False, - trajectory_map=False - )] - ) + callbacks = CallbackList( + [MonitorCallback(filepath=out_path / f'monitor_{identifier}.pick', plotting=False), + RecorderCallback(filepath=out_path / f'recorder_{identifier}.json', occupation_map=False, + trajectory_map=False + )] + ) - model.learn(total_timesteps=int(train_steps), callback=callbacks) + model.learn(total_timesteps=int(train_steps), callback=callbacks) - save_path = out_path / f'model_{identifier}.zip' - save_path.parent.mkdir(parents=True, exist_ok=True) - model.save(save_path) - env.save_params(out_path.parent / f'env_{model.__class__.__name__}_{time_stamp}.yaml') - print("Model Trained and saved") + save_path = out_path / f'model_{identifier}.zip' + save_path.parent.mkdir(parents=True, exist_ok=True) + model.save(save_path) + env.env_method('save_params', out_path.parent / f'env_{model.__class__.__name__}_{time_stamp}.yaml') + print("Model Trained and saved") print("Model Group Done.. Plotting...") if out_path: diff --git a/reload_agent.py b/reload_agent.py index f018df7..80b5e49 100644 --- a/reload_agent.py +++ b/reload_agent.py @@ -3,7 +3,7 @@ from pathlib import Path import yaml from natsort import natsorted -from stable_baselines3 import PPO +from stable_baselines3 import PPO, DQN, A2C from stable_baselines3.common.evaluation import evaluate_policy from environments.factory.simple_factory import DirtProperties, SimpleFactory @@ -12,16 +12,19 @@ from environments.factory.double_task_factory import ItemProperties, DoubleTaskF warnings.filterwarnings('ignore', category=FutureWarning) warnings.filterwarnings('ignore', category=UserWarning) +model_map = dict(PPO=PPO, DQN=DQN, A2C=A2C) if __name__ == '__main__': - model_name = 'A2C_1630073286' + model_name = 'A2C_1630414444' run_id = 0 + seed=69 out_path = Path(__file__).parent / 'debug_out' model_path = out_path / model_name with (model_path / f'env_{model_name}.yaml').open('r') as f: env_kwargs = yaml.load(f, Loader=yaml.FullLoader) + env_kwargs.update(verbose=True, env_seed=seed) if False: env_kwargs.update(dirt_properties=DirtProperties(clean_amount=1, gain_amount=0.1, max_global_amount=20, max_local_amount=1, spawn_frequency=5, max_spawn_ratio=0.05, @@ -30,9 +33,10 @@ if __name__ == '__main__': with SimpleFactory(**env_kwargs) as env: # Edit THIS: + env.seed(seed) model_files = list(natsorted((model_path / f'{run_id}_{model_name}').rglob('model_*.zip'))) this_model = model_files[0] - - model = PPO.load(this_model) - evaluation_result = evaluate_policy(model, env, n_eval_episodes=100, deterministic=False, render=True) + model_cls = next(val for key, val in model_map.items() if key in model_name) + model = model_cls.load(this_model) + evaluation_result = evaluate_policy(model, env, n_eval_episodes=100, deterministic=True, render=True) print(evaluation_result)