mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-23 15:26:43 +02:00
Merge remote-tracking branch 'origin/main' into main
This commit is contained in:
commit
b70b12edda
@ -12,7 +12,7 @@ from environments import helpers as h
|
|||||||
|
|
||||||
|
|
||||||
class MovementProperties(NamedTuple):
|
class MovementProperties(NamedTuple):
|
||||||
allow_square_movement: bool = False
|
allow_square_movement: bool = True
|
||||||
allow_diagonal_movement: bool = False
|
allow_diagonal_movement: bool = False
|
||||||
allow_no_op: bool = False
|
allow_no_op: bool = False
|
||||||
|
|
||||||
@ -111,6 +111,10 @@ class StateSlice(Register):
|
|||||||
|
|
||||||
class BaseFactory(gym.Env):
|
class BaseFactory(gym.Env):
|
||||||
|
|
||||||
|
# def __setattr__(self, key, value):
|
||||||
|
# if isinstance(value, dict):
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def action_space(self):
|
def action_space(self):
|
||||||
return spaces.Discrete(self._actions.n)
|
return spaces.Discrete(self._actions.n)
|
||||||
|
13
environments/factory/levels/rooms.txt
Normal file
13
environments/factory/levels/rooms.txt
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
###############
|
||||||
|
#------#------#
|
||||||
|
#---#--#------#
|
||||||
|
#--------#----#
|
||||||
|
#------#------#
|
||||||
|
#------#------#
|
||||||
|
###-#######-###
|
||||||
|
#----##-------#
|
||||||
|
#-----#----#--#
|
||||||
|
#-------------#
|
||||||
|
#-----#-------#
|
||||||
|
#-----#-------#
|
||||||
|
###############
|
@ -26,7 +26,7 @@ class Renderer:
|
|||||||
self.grid_lines = grid_lines
|
self.grid_lines = grid_lines
|
||||||
self.view_radius = view_radius
|
self.view_radius = view_radius
|
||||||
pygame.init()
|
pygame.init()
|
||||||
self.screen_size = (grid_h*cell_size, grid_w*cell_size)
|
self.screen_size = (grid_w*cell_size, grid_h*cell_size)
|
||||||
self.screen = pygame.display.set_mode(self.screen_size)
|
self.screen = pygame.display.set_mode(self.screen_size)
|
||||||
self.clock = pygame.time.Clock()
|
self.clock = pygame.time.Clock()
|
||||||
assets = list((Path(__file__).parent / 'assets').rglob('*.png'))
|
assets = list((Path(__file__).parent / 'assets').rglob('*.png'))
|
||||||
@ -36,7 +36,7 @@ class Renderer:
|
|||||||
def fill_bg(self):
|
def fill_bg(self):
|
||||||
self.screen.fill(Renderer.BG_COLOR)
|
self.screen.fill(Renderer.BG_COLOR)
|
||||||
if self.grid_lines:
|
if self.grid_lines:
|
||||||
h, w = self.screen_size
|
w, h = self.screen_size
|
||||||
for x in range(0, w, self.cell_size):
|
for x in range(0, w, self.cell_size):
|
||||||
for y in range(0, h, self.cell_size):
|
for y in range(0, h, self.cell_size):
|
||||||
rect = pygame.Rect(x, y, self.cell_size, self.cell_size)
|
rect = pygame.Rect(x, y, self.cell_size, self.cell_size)
|
||||||
@ -81,7 +81,8 @@ class Renderer:
|
|||||||
shape_surf.set_alpha(64)
|
shape_surf.set_alpha(64)
|
||||||
blits.appendleft(dict(source=shape_surf, dest=visibility_rect))
|
blits.appendleft(dict(source=shape_surf, dest=visibility_rect))
|
||||||
blits.append(bp)
|
blits.append(bp)
|
||||||
for blit in blits: self.screen.blit(**blit)
|
for blit in blits:
|
||||||
|
self.screen.blit(**blit)
|
||||||
pygame.display.flip()
|
pygame.display.flip()
|
||||||
self.clock.tick(self.fps)
|
self.clock.tick(self.fps)
|
||||||
|
|
||||||
|
@ -1,11 +1,12 @@
|
|||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
from typing import List, Union, NamedTuple
|
from typing import List, Union, NamedTuple
|
||||||
import random
|
import random
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from environments.factory.base_factory import BaseFactory, AgentState
|
from environments.factory.base_factory import BaseFactory, AgentState, MovementProperties
|
||||||
from environments import helpers as h
|
from environments import helpers as h
|
||||||
|
|
||||||
from environments.logging.monitor import MonitorCallback
|
from environments.logging.monitor import MonitorCallback
|
||||||
@ -186,9 +187,12 @@ if __name__ == '__main__':
|
|||||||
render = True
|
render = True
|
||||||
|
|
||||||
dirt_props = DirtProperties()
|
dirt_props = DirtProperties()
|
||||||
factory = SimpleFactory(n_agents=2, dirt_properties=dirt_props)
|
move_props = MovementProperties(allow_diagonal_movement=False, allow_no_op=False)
|
||||||
|
factory = SimpleFactory(n_agents=2, dirt_properties=dirt_props, movement_properties=move_props, level='rooms',
|
||||||
|
pomdp_radius=2)
|
||||||
|
|
||||||
n_actions = factory.action_space.n - 1
|
n_actions = factory.action_space.n - 1
|
||||||
with MonitorCallback(factory):
|
|
||||||
for epoch in range(100):
|
for epoch in range(100):
|
||||||
random_actions = [(random.randint(0, n_actions), random.randint(0, n_actions)) for _ in range(200)]
|
random_actions = [(random.randint(0, n_actions), random.randint(0, n_actions)) for _ in range(200)]
|
||||||
env_state, this_reward, done_bool, _ = factory.reset()
|
env_state, this_reward, done_bool, _ = factory.reset()
|
||||||
|
11
main.py
11
main.py
@ -102,24 +102,23 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
out_path = None
|
out_path = None
|
||||||
|
|
||||||
# for modeL_type in [PPO, A2C, RegDQN, DQN]:
|
for modeL_type in [PPO, A2C, RegDQN, DQN]:
|
||||||
modeL_type = PPO
|
|
||||||
for coef in [0.01, 0.1, 0.25]:
|
|
||||||
for seed in range(3):
|
for seed in range(3):
|
||||||
|
|
||||||
env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=None, max_steps=400,
|
env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=None, max_steps=400,
|
||||||
movement_properties=move_props,
|
movement_properties=move_props, level='rooms',
|
||||||
omit_agent_slice_in_obs=True)
|
omit_agent_slice_in_obs=True)
|
||||||
env.save_params(Path('debug_out', 'yaml.txt'))
|
env.save_params(Path('debug_out', 'yaml.txt'))
|
||||||
|
|
||||||
# env = FrameStack(env, 4)
|
# env = FrameStack(env, 4)
|
||||||
|
|
||||||
model = modeL_type("MlpPolicy", env, verbose=1, seed=seed, device='cpu')
|
kwargs = dict(ent_coef=0.01) if isinstance(modeL_type, (PPO, A2C)) else {}
|
||||||
|
model = modeL_type("MlpPolicy", env, verbose=1, seed=seed, device='cpu', **kwargs)
|
||||||
|
|
||||||
out_path = Path('debug_out') / f'{model.__class__.__name__}_{time_stamp}'
|
out_path = Path('debug_out') / f'{model.__class__.__name__}_{time_stamp}'
|
||||||
|
|
||||||
# identifier = f'{seed}_{model.__class__.__name__}_{time_stamp}'
|
# identifier = f'{seed}_{model.__class__.__name__}_{time_stamp}'
|
||||||
identifier = f'{seed}_{str(coef).replace(".", "")}_{time_stamp}'
|
identifier = f'{seed}_{modeL_type.__class__.__name__}_{time_stamp}'
|
||||||
out_path /= identifier
|
out_path /= identifier
|
||||||
|
|
||||||
callbacks = CallbackList(
|
callbacks = CallbackList(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user