Merge remote-tracking branch 'origin/main' into main

This commit is contained in:
romue 2021-06-04 17:17:36 +02:00
commit b70b12edda
5 changed files with 44 additions and 23 deletions

View File

@ -12,7 +12,7 @@ from environments import helpers as h
class MovementProperties(NamedTuple): class MovementProperties(NamedTuple):
allow_square_movement: bool = False allow_square_movement: bool = True
allow_diagonal_movement: bool = False allow_diagonal_movement: bool = False
allow_no_op: bool = False allow_no_op: bool = False
@ -111,6 +111,10 @@ class StateSlice(Register):
class BaseFactory(gym.Env): class BaseFactory(gym.Env):
# def __setattr__(self, key, value):
# if isinstance(value, dict):
@property @property
def action_space(self): def action_space(self):
return spaces.Discrete(self._actions.n) return spaces.Discrete(self._actions.n)

View File

@ -0,0 +1,13 @@
###############
#------#------#
#---#--#------#
#--------#----#
#------#------#
#------#------#
###-#######-###
#----##-------#
#-----#----#--#
#-------------#
#-----#-------#
#-----#-------#
###############

View File

@ -26,7 +26,7 @@ class Renderer:
self.grid_lines = grid_lines self.grid_lines = grid_lines
self.view_radius = view_radius self.view_radius = view_radius
pygame.init() pygame.init()
self.screen_size = (grid_h*cell_size, grid_w*cell_size) self.screen_size = (grid_w*cell_size, grid_h*cell_size)
self.screen = pygame.display.set_mode(self.screen_size) self.screen = pygame.display.set_mode(self.screen_size)
self.clock = pygame.time.Clock() self.clock = pygame.time.Clock()
assets = list((Path(__file__).parent / 'assets').rglob('*.png')) assets = list((Path(__file__).parent / 'assets').rglob('*.png'))
@ -36,7 +36,7 @@ class Renderer:
def fill_bg(self): def fill_bg(self):
self.screen.fill(Renderer.BG_COLOR) self.screen.fill(Renderer.BG_COLOR)
if self.grid_lines: if self.grid_lines:
h, w = self.screen_size w, h = self.screen_size
for x in range(0, w, self.cell_size): for x in range(0, w, self.cell_size):
for y in range(0, h, self.cell_size): for y in range(0, h, self.cell_size):
rect = pygame.Rect(x, y, self.cell_size, self.cell_size) rect = pygame.Rect(x, y, self.cell_size, self.cell_size)
@ -81,7 +81,8 @@ class Renderer:
shape_surf.set_alpha(64) shape_surf.set_alpha(64)
blits.appendleft(dict(source=shape_surf, dest=visibility_rect)) blits.appendleft(dict(source=shape_surf, dest=visibility_rect))
blits.append(bp) blits.append(bp)
for blit in blits: self.screen.blit(**blit) for blit in blits:
self.screen.blit(**blit)
pygame.display.flip() pygame.display.flip()
self.clock.tick(self.fps) self.clock.tick(self.fps)

View File

@ -1,11 +1,12 @@
from collections import OrderedDict from collections import OrderedDict
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path
from typing import List, Union, NamedTuple from typing import List, Union, NamedTuple
import random import random
import numpy as np import numpy as np
from environments.factory.base_factory import BaseFactory, AgentState from environments.factory.base_factory import BaseFactory, AgentState, MovementProperties
from environments import helpers as h from environments import helpers as h
from environments.logging.monitor import MonitorCallback from environments.logging.monitor import MonitorCallback
@ -186,9 +187,12 @@ if __name__ == '__main__':
render = True render = True
dirt_props = DirtProperties() dirt_props = DirtProperties()
factory = SimpleFactory(n_agents=2, dirt_properties=dirt_props) move_props = MovementProperties(allow_diagonal_movement=False, allow_no_op=False)
factory = SimpleFactory(n_agents=2, dirt_properties=dirt_props, movement_properties=move_props, level='rooms',
pomdp_radius=2)
n_actions = factory.action_space.n - 1 n_actions = factory.action_space.n - 1
with MonitorCallback(factory):
for epoch in range(100): for epoch in range(100):
random_actions = [(random.randint(0, n_actions), random.randint(0, n_actions)) for _ in range(200)] random_actions = [(random.randint(0, n_actions), random.randint(0, n_actions)) for _ in range(200)]
env_state, this_reward, done_bool, _ = factory.reset() env_state, this_reward, done_bool, _ = factory.reset()

11
main.py
View File

@ -102,24 +102,23 @@ if __name__ == '__main__':
out_path = None out_path = None
# for modeL_type in [PPO, A2C, RegDQN, DQN]: for modeL_type in [PPO, A2C, RegDQN, DQN]:
modeL_type = PPO
for coef in [0.01, 0.1, 0.25]:
for seed in range(3): for seed in range(3):
env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=None, max_steps=400, env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=None, max_steps=400,
movement_properties=move_props, movement_properties=move_props, level='rooms',
omit_agent_slice_in_obs=True) omit_agent_slice_in_obs=True)
env.save_params(Path('debug_out', 'yaml.txt')) env.save_params(Path('debug_out', 'yaml.txt'))
# env = FrameStack(env, 4) # env = FrameStack(env, 4)
model = modeL_type("MlpPolicy", env, verbose=1, seed=seed, device='cpu') kwargs = dict(ent_coef=0.01) if isinstance(modeL_type, (PPO, A2C)) else {}
model = modeL_type("MlpPolicy", env, verbose=1, seed=seed, device='cpu', **kwargs)
out_path = Path('debug_out') / f'{model.__class__.__name__}_{time_stamp}' out_path = Path('debug_out') / f'{model.__class__.__name__}_{time_stamp}'
# identifier = f'{seed}_{model.__class__.__name__}_{time_stamp}' # identifier = f'{seed}_{model.__class__.__name__}_{time_stamp}'
identifier = f'{seed}_{str(coef).replace(".", "")}_{time_stamp}' identifier = f'{seed}_{modeL_type.__class__.__name__}_{time_stamp}'
out_path /= identifier out_path /= identifier
callbacks = CallbackList( callbacks = CallbackList(