mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-07-08 02:21:36 +02:00
Added 'shared' dirt piles option for eval + Fixed usage of renderer + Added recorder option
This commit is contained in:
@ -2,6 +2,7 @@ import copy
|
||||
import os
|
||||
import random
|
||||
|
||||
import imageio # requires ffmpeg install on operating system and imageio-ffmpeg package for python
|
||||
from scipy import signal
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
@ -79,6 +80,8 @@ class A2C:
|
||||
os.mkdir(self.results_path)
|
||||
# Save settings in results folder
|
||||
self.save_configs()
|
||||
if self.cfg[nms.ENV]["record"]:
|
||||
self.recorder = imageio.get_writer(f'{self.results_path}/pygame_recording.mp4', fps=5)
|
||||
|
||||
def set_cfg(self, eval=False):
|
||||
if eval:
|
||||
@ -422,6 +425,15 @@ class A2C:
|
||||
if self.cfg[nms.ALGORITHM]["pile_all_done"] in ["all", "distributed"]:
|
||||
if all([all(cleaned_dirt_piles[i].values()) for i in range(self.n_agents)]):
|
||||
done = True
|
||||
elif self.cfg[nms.ALGORITHM]["pile_all_done"] == "shared":
|
||||
# End episode if both agents together have cleaned all dirt piles
|
||||
meta_cleaned_dirt_piles = {pos: False for pos in dirt_piles_positions}
|
||||
for agent_idx in range(self.n_agents):
|
||||
for (pos, cleaned) in cleaned_dirt_piles[agent_idx].items():
|
||||
if cleaned:
|
||||
meta_cleaned_dirt_piles[pos] = True
|
||||
if all(meta_cleaned_dirt_piles.values()):
|
||||
done = True
|
||||
|
||||
return reward, done
|
||||
|
||||
@ -484,8 +496,6 @@ class A2C:
|
||||
@torch.no_grad()
|
||||
def train_loop(self):
|
||||
env = self.factory
|
||||
if self.cfg[nms.ENV][nms.TRAIN_RENDER]:
|
||||
env.render()
|
||||
n_steps, max_steps = [self.cfg[nms.ALGORITHM][k] for k in [nms.N_STEPS, nms.MAX_STEPS]]
|
||||
global_steps, episode = 0, 0
|
||||
indices = self.distribute_indices(env)
|
||||
@ -497,6 +507,8 @@ class A2C:
|
||||
while global_steps < max_steps:
|
||||
print(global_steps)
|
||||
obs = env.reset() # !!!!!!!!Commented seems to work better? Only if a fixed spawnpoint is given
|
||||
if self.cfg[nms.ENV][nms.TRAIN_RENDER]:
|
||||
env.render()
|
||||
self.set_agent_spawnpoint(env)
|
||||
ordered_dirt_piles = self.get_ordered_dirt_piles(env, cleaned_dirt_piles, target_pile)
|
||||
# Reset current target pile at episode begin if all piles have to be cleaned in one episode
|
||||
@ -578,8 +590,6 @@ class A2C:
|
||||
def eval_loop(self, n_episodes, render=False):
|
||||
env = self.eval_factory
|
||||
self.set_cfg(eval=True)
|
||||
if self.cfg[nms.ENV][nms.EVAL_RENDER]:
|
||||
env.render()
|
||||
episode, results = 0, []
|
||||
dirt_piles_positions = self.get_dirt_piles_positions(env)
|
||||
indices = self.distribute_indices(env)
|
||||
@ -591,10 +601,15 @@ class A2C:
|
||||
|
||||
while episode < n_episodes:
|
||||
obs = env.reset()
|
||||
if self.cfg[nms.ENV][nms.EVAL_RENDER]:
|
||||
if self.cfg[nms.ENV]["save_and_log"] and self.cfg[nms.ENV]["record"]:
|
||||
env.set_recorder(self.recorder)
|
||||
env.render()
|
||||
env._renderer.fps = 5
|
||||
self.set_agent_spawnpoint(env)
|
||||
"""obs = list(obs.values())"""
|
||||
# Reset current target pile at episode begin if all piles have to be cleaned in one episode
|
||||
if self.cfg[nms.ALGORITHM]["pile_all_done"] in ["all", "distributed"]:
|
||||
if self.cfg[nms.ALGORITHM]["pile_all_done"] in ["all", "distributed", "shared"]:
|
||||
target_pile = [partition[0] for partition in indices]
|
||||
if self.cfg[nms.ALGORITHM]["pile_all_done"] == "distributed":
|
||||
cleaned_dirt_piles = [{dirt_piles_positions[idx]: False for idx in indices[i]} for i in range(self.n_agents)]
|
||||
@ -637,6 +652,10 @@ class A2C:
|
||||
|
||||
episode += 1
|
||||
|
||||
# Properly finalize the video file
|
||||
if self.cfg[nms.ENV]["save_and_log"] and self.cfg[nms.ENV]["record"]:
|
||||
self.recorder.close()
|
||||
|
||||
def plot_reward_development(self):
|
||||
smoothed_data = np.convolve(self.reward_development, np.ones(10) / 10, mode='valid')
|
||||
plt.plot(smoothed_data)
|
||||
|
@ -17,6 +17,7 @@ env:
|
||||
train_render: False
|
||||
eval_render: True
|
||||
save_and_log: True
|
||||
record: False
|
||||
method: marl_factory_grid.algorithms.marl.LoopSEAC
|
||||
algorithm:
|
||||
gamma: 0.99
|
||||
@ -27,7 +28,7 @@ algorithm:
|
||||
advantage: "Advantage-AC" # Options: "Advantage-AC", "TD-Advantage-AC", "Reinforce"
|
||||
pile-order: "dynamic" # Use "dynamic" to see emergent phenomenon and "smart" to prevent it
|
||||
pile-observability: "single" # Options: "single", "all"
|
||||
pile_all_done: "all" # Options: "single", "all" ("single" for training, "all" for eval)
|
||||
pile_all_done: "shared" # Options: "single", "all" ("single" for training, "all" for eval), "shared"
|
||||
auxiliary_piles: False # Option that is only considered when pile-order = "agents"
|
||||
chunk-episode: 20000 # Chunk size. (0 = update networks with full episode at once)
|
||||
|
||||
|
@ -16,7 +16,8 @@ env:
|
||||
individual_rewards: True
|
||||
train_render: False
|
||||
eval_render: True
|
||||
save_and_log: False
|
||||
save_and_log: True
|
||||
record: False
|
||||
method: marl_factory_grid.algorithms.marl.LoopSEAC
|
||||
algorithm:
|
||||
gamma: 0.99
|
||||
|
@ -17,6 +17,7 @@ env:
|
||||
train_render: False
|
||||
eval_render: True
|
||||
save_and_log: True
|
||||
record: False
|
||||
method: marl_factory_grid.algorithms.marl.LoopSEAC
|
||||
algorithm:
|
||||
gamma: 0.99
|
||||
|
@ -2,4 +2,7 @@ marl_factory_grid>environment>rules.py#SpawnEntity.on_reset()
|
||||
marl_factory_grid>environment>rewards.py
|
||||
marl_factory_grid>modules>clean_up>groups.py#DirtPiles.trigger_spawn()
|
||||
marl_factory_grid>environment>rules.py#AgentSpawnRule
|
||||
marl_factory_grid>utils>states.py#GameState.__init__()
|
||||
marl_factory_grid>utils>states.py#GameState.__init__()
|
||||
marl_factory_grid>environment>factory.py>Factory#render
|
||||
marl_factory_grid>environment>factory.py>Factory#set_recorder
|
||||
marl_factory_grid>utils>renderer.py>Renderer#render
|
@ -17,6 +17,7 @@ env:
|
||||
train_render: False
|
||||
eval_render: True
|
||||
save_and_log: False
|
||||
record: False
|
||||
method: marl_factory_grid.algorithms.marl.LoopSEAC
|
||||
algorithm:
|
||||
gamma: 0.99
|
||||
|
@ -109,6 +109,7 @@ class Factory(gym.Env):
|
||||
|
||||
# expensive - don't use; unless required !
|
||||
self._renderer = None
|
||||
self._recorder = None
|
||||
|
||||
# Init entities
|
||||
entities = self.map.do_init()
|
||||
@ -277,7 +278,10 @@ class Factory(gym.Env):
|
||||
for render_entity in render_entities:
|
||||
if render_entity.name == c.AGENT:
|
||||
render_entity.aux = self.obs_builder.curr_lightmaps[render_entity.real_name]
|
||||
return self._renderer.render(render_entities)
|
||||
return self._renderer.render(render_entities, self._recorder)
|
||||
|
||||
def set_recorder(self, recorder):
|
||||
self._recorder = recorder
|
||||
|
||||
def summarize_header(self):
|
||||
header = {'rec_step': self.state.curr_step}
|
||||
|
@ -156,7 +156,7 @@ class Renderer:
|
||||
rects.append(dict(source=shape_surf, dest=visibility_rect))
|
||||
return rects
|
||||
|
||||
def render(self, entities):
|
||||
def render(self, entities, recorder):
|
||||
"""
|
||||
Renders the entities on the screen.
|
||||
|
||||
@ -190,6 +190,11 @@ class Renderer:
|
||||
for blit in blits:
|
||||
self.screen.blit(**blit)
|
||||
|
||||
if recorder:
|
||||
frame = pygame.surfarray.array3d(self.screen)
|
||||
frame = np.transpose(frame, (1, 0, 2)) # Transpose to (height, width, channels)
|
||||
recorder.append_data(frame)
|
||||
|
||||
pygame.display.flip()
|
||||
self.clock.tick(self.fps)
|
||||
rgb_obs = pygame.surfarray.array3d(self.screen)
|
||||
|
Reference in New Issue
Block a user