diff --git a/environments/factory/base_factory.py b/environments/factory/base_factory.py index 2918ed7..0256f81 100644 --- a/environments/factory/base_factory.py +++ b/environments/factory/base_factory.py @@ -12,7 +12,7 @@ from environments import helpers as h class MovementProperties(NamedTuple): - allow_square_movement: bool = False + allow_square_movement: bool = True allow_diagonal_movement: bool = False allow_no_op: bool = False @@ -111,6 +111,10 @@ class StateSlice(Register): class BaseFactory(gym.Env): + # def __setattr__(self, key, value): + # if isinstance(value, dict): + + @property def action_space(self): return spaces.Discrete(self._actions.n) diff --git a/environments/factory/levels/rooms.txt b/environments/factory/levels/rooms.txt new file mode 100644 index 0000000..83d2e9c --- /dev/null +++ b/environments/factory/levels/rooms.txt @@ -0,0 +1,13 @@ +############### +#------#------# +#---#--#------# +#--------#----# +#------#------# +#------#------# +###-#######-### +#----##-------# +#-----#----#--# +#-------------# +#-----#-------# +#-----#-------# +############### \ No newline at end of file diff --git a/environments/factory/renderer.py b/environments/factory/renderer.py index b598509..fdb56d2 100644 --- a/environments/factory/renderer.py +++ b/environments/factory/renderer.py @@ -26,7 +26,7 @@ class Renderer: self.grid_lines = grid_lines self.view_radius = view_radius pygame.init() - self.screen_size = (grid_h*cell_size, grid_w*cell_size) + self.screen_size = (grid_w*cell_size, grid_h*cell_size) self.screen = pygame.display.set_mode(self.screen_size) self.clock = pygame.time.Clock() assets = list((Path(__file__).parent / 'assets').rglob('*.png')) @@ -36,7 +36,7 @@ class Renderer: def fill_bg(self): self.screen.fill(Renderer.BG_COLOR) if self.grid_lines: - h, w = self.screen_size + w, h = self.screen_size for x in range(0, w, self.cell_size): for y in range(0, h, self.cell_size): rect = pygame.Rect(x, y, self.cell_size, self.cell_size) @@ -81,7 +81,8 @@ class Renderer: shape_surf.set_alpha(64) blits.appendleft(dict(source=shape_surf, dest=visibility_rect)) blits.append(bp) - for blit in blits: self.screen.blit(**blit) + for blit in blits: + self.screen.blit(**blit) pygame.display.flip() self.clock.tick(self.fps) diff --git a/environments/factory/simple_factory.py b/environments/factory/simple_factory.py index 2dde6df..20969f5 100644 --- a/environments/factory/simple_factory.py +++ b/environments/factory/simple_factory.py @@ -1,11 +1,12 @@ from collections import OrderedDict from dataclasses import dataclass +from pathlib import Path from typing import List, Union, NamedTuple import random import numpy as np -from environments.factory.base_factory import BaseFactory, AgentState +from environments.factory.base_factory import BaseFactory, AgentState, MovementProperties from environments import helpers as h from environments.logging.monitor import MonitorCallback @@ -186,16 +187,19 @@ if __name__ == '__main__': render = True dirt_props = DirtProperties() - factory = SimpleFactory(n_agents=2, dirt_properties=dirt_props) + move_props = MovementProperties(allow_diagonal_movement=False, allow_no_op=False) + factory = SimpleFactory(n_agents=2, dirt_properties=dirt_props, movement_properties=move_props, level='rooms', + pomdp_radius=2) + n_actions = factory.action_space.n - 1 - with MonitorCallback(factory): - for epoch in range(100): - random_actions = [(random.randint(0, n_actions), random.randint(0, n_actions)) for _ in range(200)] - env_state, this_reward, done_bool, _ = factory.reset() - for agent_i_action in random_actions: - env_state, reward, done_bool, info_obj = factory.step(agent_i_action) - if render: - factory.render() - if done_bool: - break - print(f'Factory run {epoch} done, reward is:\n {reward}') + + for epoch in range(100): + random_actions = [(random.randint(0, n_actions), random.randint(0, n_actions)) for _ in range(200)] + env_state, this_reward, done_bool, _ = factory.reset() + for agent_i_action in random_actions: + env_state, reward, done_bool, info_obj = factory.step(agent_i_action) + if render: + factory.render() + if done_bool: + break + print(f'Factory run {epoch} done, reward is:\n {reward}') diff --git a/main.py b/main.py index d06d030..a9f8093 100644 --- a/main.py +++ b/main.py @@ -102,24 +102,23 @@ if __name__ == '__main__': out_path = None - # for modeL_type in [PPO, A2C, RegDQN, DQN]: - modeL_type = PPO - for coef in [0.01, 0.1, 0.25]: + for modeL_type in [PPO, A2C, RegDQN, DQN]: for seed in range(3): env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=None, max_steps=400, - movement_properties=move_props, + movement_properties=move_props, level='rooms', omit_agent_slice_in_obs=True) env.save_params(Path('debug_out', 'yaml.txt')) # env = FrameStack(env, 4) - model = modeL_type("MlpPolicy", env, verbose=1, seed=seed, device='cpu') + kwargs = dict(ent_coef=0.01) if isinstance(modeL_type, (PPO, A2C)) else {} + model = modeL_type("MlpPolicy", env, verbose=1, seed=seed, device='cpu', **kwargs) out_path = Path('debug_out') / f'{model.__class__.__name__}_{time_stamp}' # identifier = f'{seed}_{model.__class__.__name__}_{time_stamp}' - identifier = f'{seed}_{str(coef).replace(".", "")}_{time_stamp}' + identifier = f'{seed}_{modeL_type.__class__.__name__}_{time_stamp}' out_path /= identifier callbacks = CallbackList(