Merge remote-tracking branch 'origin/main' into main

This commit is contained in:
romue 2021-06-04 17:17:36 +02:00
commit b70b12edda
5 changed files with 44 additions and 23 deletions

View File

@ -12,7 +12,7 @@ from environments import helpers as h
class MovementProperties(NamedTuple):
allow_square_movement: bool = False
allow_square_movement: bool = True
allow_diagonal_movement: bool = False
allow_no_op: bool = False
@ -111,6 +111,10 @@ class StateSlice(Register):
class BaseFactory(gym.Env):
# def __setattr__(self, key, value):
# if isinstance(value, dict):
@property
def action_space(self):
return spaces.Discrete(self._actions.n)

View File

@ -0,0 +1,13 @@
###############
#------#------#
#---#--#------#
#--------#----#
#------#------#
#------#------#
###-#######-###
#----##-------#
#-----#----#--#
#-------------#
#-----#-------#
#-----#-------#
###############

View File

@ -26,7 +26,7 @@ class Renderer:
self.grid_lines = grid_lines
self.view_radius = view_radius
pygame.init()
self.screen_size = (grid_h*cell_size, grid_w*cell_size)
self.screen_size = (grid_w*cell_size, grid_h*cell_size)
self.screen = pygame.display.set_mode(self.screen_size)
self.clock = pygame.time.Clock()
assets = list((Path(__file__).parent / 'assets').rglob('*.png'))
@ -36,7 +36,7 @@ class Renderer:
def fill_bg(self):
self.screen.fill(Renderer.BG_COLOR)
if self.grid_lines:
h, w = self.screen_size
w, h = self.screen_size
for x in range(0, w, self.cell_size):
for y in range(0, h, self.cell_size):
rect = pygame.Rect(x, y, self.cell_size, self.cell_size)
@ -81,7 +81,8 @@ class Renderer:
shape_surf.set_alpha(64)
blits.appendleft(dict(source=shape_surf, dest=visibility_rect))
blits.append(bp)
for blit in blits: self.screen.blit(**blit)
for blit in blits:
self.screen.blit(**blit)
pygame.display.flip()
self.clock.tick(self.fps)

View File

@ -1,11 +1,12 @@
from collections import OrderedDict
from dataclasses import dataclass
from pathlib import Path
from typing import List, Union, NamedTuple
import random
import numpy as np
from environments.factory.base_factory import BaseFactory, AgentState
from environments.factory.base_factory import BaseFactory, AgentState, MovementProperties
from environments import helpers as h
from environments.logging.monitor import MonitorCallback
@ -186,16 +187,19 @@ if __name__ == '__main__':
render = True
dirt_props = DirtProperties()
factory = SimpleFactory(n_agents=2, dirt_properties=dirt_props)
move_props = MovementProperties(allow_diagonal_movement=False, allow_no_op=False)
factory = SimpleFactory(n_agents=2, dirt_properties=dirt_props, movement_properties=move_props, level='rooms',
pomdp_radius=2)
n_actions = factory.action_space.n - 1
with MonitorCallback(factory):
for epoch in range(100):
random_actions = [(random.randint(0, n_actions), random.randint(0, n_actions)) for _ in range(200)]
env_state, this_reward, done_bool, _ = factory.reset()
for agent_i_action in random_actions:
env_state, reward, done_bool, info_obj = factory.step(agent_i_action)
if render:
factory.render()
if done_bool:
break
print(f'Factory run {epoch} done, reward is:\n {reward}')
for epoch in range(100):
random_actions = [(random.randint(0, n_actions), random.randint(0, n_actions)) for _ in range(200)]
env_state, this_reward, done_bool, _ = factory.reset()
for agent_i_action in random_actions:
env_state, reward, done_bool, info_obj = factory.step(agent_i_action)
if render:
factory.render()
if done_bool:
break
print(f'Factory run {epoch} done, reward is:\n {reward}')

11
main.py
View File

@ -102,24 +102,23 @@ if __name__ == '__main__':
out_path = None
# for modeL_type in [PPO, A2C, RegDQN, DQN]:
modeL_type = PPO
for coef in [0.01, 0.1, 0.25]:
for modeL_type in [PPO, A2C, RegDQN, DQN]:
for seed in range(3):
env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=None, max_steps=400,
movement_properties=move_props,
movement_properties=move_props, level='rooms',
omit_agent_slice_in_obs=True)
env.save_params(Path('debug_out', 'yaml.txt'))
# env = FrameStack(env, 4)
model = modeL_type("MlpPolicy", env, verbose=1, seed=seed, device='cpu')
kwargs = dict(ent_coef=0.01) if isinstance(modeL_type, (PPO, A2C)) else {}
model = modeL_type("MlpPolicy", env, verbose=1, seed=seed, device='cpu', **kwargs)
out_path = Path('debug_out') / f'{model.__class__.__name__}_{time_stamp}'
# identifier = f'{seed}_{model.__class__.__name__}_{time_stamp}'
identifier = f'{seed}_{str(coef).replace(".", "")}_{time_stamp}'
identifier = f'{seed}_{modeL_type.__class__.__name__}_{time_stamp}'
out_path /= identifier
callbacks = CallbackList(