Merge remote-tracking branch 'origin/main' into main
This commit is contained in:
commit
b70b12edda
@ -12,7 +12,7 @@ from environments import helpers as h
|
||||
|
||||
|
||||
class MovementProperties(NamedTuple):
|
||||
allow_square_movement: bool = False
|
||||
allow_square_movement: bool = True
|
||||
allow_diagonal_movement: bool = False
|
||||
allow_no_op: bool = False
|
||||
|
||||
@ -111,6 +111,10 @@ class StateSlice(Register):
|
||||
|
||||
class BaseFactory(gym.Env):
|
||||
|
||||
# def __setattr__(self, key, value):
|
||||
# if isinstance(value, dict):
|
||||
|
||||
|
||||
@property
|
||||
def action_space(self):
|
||||
return spaces.Discrete(self._actions.n)
|
||||
|
13
environments/factory/levels/rooms.txt
Normal file
13
environments/factory/levels/rooms.txt
Normal file
@ -0,0 +1,13 @@
|
||||
###############
|
||||
#------#------#
|
||||
#---#--#------#
|
||||
#--------#----#
|
||||
#------#------#
|
||||
#------#------#
|
||||
###-#######-###
|
||||
#----##-------#
|
||||
#-----#----#--#
|
||||
#-------------#
|
||||
#-----#-------#
|
||||
#-----#-------#
|
||||
###############
|
@ -26,7 +26,7 @@ class Renderer:
|
||||
self.grid_lines = grid_lines
|
||||
self.view_radius = view_radius
|
||||
pygame.init()
|
||||
self.screen_size = (grid_h*cell_size, grid_w*cell_size)
|
||||
self.screen_size = (grid_w*cell_size, grid_h*cell_size)
|
||||
self.screen = pygame.display.set_mode(self.screen_size)
|
||||
self.clock = pygame.time.Clock()
|
||||
assets = list((Path(__file__).parent / 'assets').rglob('*.png'))
|
||||
@ -36,7 +36,7 @@ class Renderer:
|
||||
def fill_bg(self):
|
||||
self.screen.fill(Renderer.BG_COLOR)
|
||||
if self.grid_lines:
|
||||
h, w = self.screen_size
|
||||
w, h = self.screen_size
|
||||
for x in range(0, w, self.cell_size):
|
||||
for y in range(0, h, self.cell_size):
|
||||
rect = pygame.Rect(x, y, self.cell_size, self.cell_size)
|
||||
@ -81,7 +81,8 @@ class Renderer:
|
||||
shape_surf.set_alpha(64)
|
||||
blits.appendleft(dict(source=shape_surf, dest=visibility_rect))
|
||||
blits.append(bp)
|
||||
for blit in blits: self.screen.blit(**blit)
|
||||
for blit in blits:
|
||||
self.screen.blit(**blit)
|
||||
pygame.display.flip()
|
||||
self.clock.tick(self.fps)
|
||||
|
||||
|
@ -1,11 +1,12 @@
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import List, Union, NamedTuple
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
|
||||
from environments.factory.base_factory import BaseFactory, AgentState
|
||||
from environments.factory.base_factory import BaseFactory, AgentState, MovementProperties
|
||||
from environments import helpers as h
|
||||
|
||||
from environments.logging.monitor import MonitorCallback
|
||||
@ -186,16 +187,19 @@ if __name__ == '__main__':
|
||||
render = True
|
||||
|
||||
dirt_props = DirtProperties()
|
||||
factory = SimpleFactory(n_agents=2, dirt_properties=dirt_props)
|
||||
move_props = MovementProperties(allow_diagonal_movement=False, allow_no_op=False)
|
||||
factory = SimpleFactory(n_agents=2, dirt_properties=dirt_props, movement_properties=move_props, level='rooms',
|
||||
pomdp_radius=2)
|
||||
|
||||
n_actions = factory.action_space.n - 1
|
||||
with MonitorCallback(factory):
|
||||
for epoch in range(100):
|
||||
random_actions = [(random.randint(0, n_actions), random.randint(0, n_actions)) for _ in range(200)]
|
||||
env_state, this_reward, done_bool, _ = factory.reset()
|
||||
for agent_i_action in random_actions:
|
||||
env_state, reward, done_bool, info_obj = factory.step(agent_i_action)
|
||||
if render:
|
||||
factory.render()
|
||||
if done_bool:
|
||||
break
|
||||
print(f'Factory run {epoch} done, reward is:\n {reward}')
|
||||
|
||||
for epoch in range(100):
|
||||
random_actions = [(random.randint(0, n_actions), random.randint(0, n_actions)) for _ in range(200)]
|
||||
env_state, this_reward, done_bool, _ = factory.reset()
|
||||
for agent_i_action in random_actions:
|
||||
env_state, reward, done_bool, info_obj = factory.step(agent_i_action)
|
||||
if render:
|
||||
factory.render()
|
||||
if done_bool:
|
||||
break
|
||||
print(f'Factory run {epoch} done, reward is:\n {reward}')
|
||||
|
11
main.py
11
main.py
@ -102,24 +102,23 @@ if __name__ == '__main__':
|
||||
|
||||
out_path = None
|
||||
|
||||
# for modeL_type in [PPO, A2C, RegDQN, DQN]:
|
||||
modeL_type = PPO
|
||||
for coef in [0.01, 0.1, 0.25]:
|
||||
for modeL_type in [PPO, A2C, RegDQN, DQN]:
|
||||
for seed in range(3):
|
||||
|
||||
env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=None, max_steps=400,
|
||||
movement_properties=move_props,
|
||||
movement_properties=move_props, level='rooms',
|
||||
omit_agent_slice_in_obs=True)
|
||||
env.save_params(Path('debug_out', 'yaml.txt'))
|
||||
|
||||
# env = FrameStack(env, 4)
|
||||
|
||||
model = modeL_type("MlpPolicy", env, verbose=1, seed=seed, device='cpu')
|
||||
kwargs = dict(ent_coef=0.01) if isinstance(modeL_type, (PPO, A2C)) else {}
|
||||
model = modeL_type("MlpPolicy", env, verbose=1, seed=seed, device='cpu', **kwargs)
|
||||
|
||||
out_path = Path('debug_out') / f'{model.__class__.__name__}_{time_stamp}'
|
||||
|
||||
# identifier = f'{seed}_{model.__class__.__name__}_{time_stamp}'
|
||||
identifier = f'{seed}_{str(coef).replace(".", "")}_{time_stamp}'
|
||||
identifier = f'{seed}_{modeL_type.__class__.__name__}_{time_stamp}'
|
||||
out_path /= identifier
|
||||
|
||||
callbacks = CallbackList(
|
||||
|
Loading…
x
Reference in New Issue
Block a user