mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2026-01-15 23:41:39 +01:00
Added 'shared' dirt piles option for eval + Fixed usage of renderer + Added recorder option
This commit is contained in:
@@ -2,6 +2,7 @@ import copy
|
|||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
|
|
||||||
|
import imageio # requires ffmpeg install on operating system and imageio-ffmpeg package for python
|
||||||
from scipy import signal
|
from scipy import signal
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import torch
|
import torch
|
||||||
@@ -79,6 +80,8 @@ class A2C:
|
|||||||
os.mkdir(self.results_path)
|
os.mkdir(self.results_path)
|
||||||
# Save settings in results folder
|
# Save settings in results folder
|
||||||
self.save_configs()
|
self.save_configs()
|
||||||
|
if self.cfg[nms.ENV]["record"]:
|
||||||
|
self.recorder = imageio.get_writer(f'{self.results_path}/pygame_recording.mp4', fps=5)
|
||||||
|
|
||||||
def set_cfg(self, eval=False):
|
def set_cfg(self, eval=False):
|
||||||
if eval:
|
if eval:
|
||||||
@@ -422,6 +425,15 @@ class A2C:
|
|||||||
if self.cfg[nms.ALGORITHM]["pile_all_done"] in ["all", "distributed"]:
|
if self.cfg[nms.ALGORITHM]["pile_all_done"] in ["all", "distributed"]:
|
||||||
if all([all(cleaned_dirt_piles[i].values()) for i in range(self.n_agents)]):
|
if all([all(cleaned_dirt_piles[i].values()) for i in range(self.n_agents)]):
|
||||||
done = True
|
done = True
|
||||||
|
elif self.cfg[nms.ALGORITHM]["pile_all_done"] == "shared":
|
||||||
|
# End episode if both agents together have cleaned all dirt piles
|
||||||
|
meta_cleaned_dirt_piles = {pos: False for pos in dirt_piles_positions}
|
||||||
|
for agent_idx in range(self.n_agents):
|
||||||
|
for (pos, cleaned) in cleaned_dirt_piles[agent_idx].items():
|
||||||
|
if cleaned:
|
||||||
|
meta_cleaned_dirt_piles[pos] = True
|
||||||
|
if all(meta_cleaned_dirt_piles.values()):
|
||||||
|
done = True
|
||||||
|
|
||||||
return reward, done
|
return reward, done
|
||||||
|
|
||||||
@@ -484,8 +496,6 @@ class A2C:
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def train_loop(self):
|
def train_loop(self):
|
||||||
env = self.factory
|
env = self.factory
|
||||||
if self.cfg[nms.ENV][nms.TRAIN_RENDER]:
|
|
||||||
env.render()
|
|
||||||
n_steps, max_steps = [self.cfg[nms.ALGORITHM][k] for k in [nms.N_STEPS, nms.MAX_STEPS]]
|
n_steps, max_steps = [self.cfg[nms.ALGORITHM][k] for k in [nms.N_STEPS, nms.MAX_STEPS]]
|
||||||
global_steps, episode = 0, 0
|
global_steps, episode = 0, 0
|
||||||
indices = self.distribute_indices(env)
|
indices = self.distribute_indices(env)
|
||||||
@@ -497,6 +507,8 @@ class A2C:
|
|||||||
while global_steps < max_steps:
|
while global_steps < max_steps:
|
||||||
print(global_steps)
|
print(global_steps)
|
||||||
obs = env.reset() # !!!!!!!!Commented seems to work better? Only if a fixed spawnpoint is given
|
obs = env.reset() # !!!!!!!!Commented seems to work better? Only if a fixed spawnpoint is given
|
||||||
|
if self.cfg[nms.ENV][nms.TRAIN_RENDER]:
|
||||||
|
env.render()
|
||||||
self.set_agent_spawnpoint(env)
|
self.set_agent_spawnpoint(env)
|
||||||
ordered_dirt_piles = self.get_ordered_dirt_piles(env, cleaned_dirt_piles, target_pile)
|
ordered_dirt_piles = self.get_ordered_dirt_piles(env, cleaned_dirt_piles, target_pile)
|
||||||
# Reset current target pile at episode begin if all piles have to be cleaned in one episode
|
# Reset current target pile at episode begin if all piles have to be cleaned in one episode
|
||||||
@@ -578,8 +590,6 @@ class A2C:
|
|||||||
def eval_loop(self, n_episodes, render=False):
|
def eval_loop(self, n_episodes, render=False):
|
||||||
env = self.eval_factory
|
env = self.eval_factory
|
||||||
self.set_cfg(eval=True)
|
self.set_cfg(eval=True)
|
||||||
if self.cfg[nms.ENV][nms.EVAL_RENDER]:
|
|
||||||
env.render()
|
|
||||||
episode, results = 0, []
|
episode, results = 0, []
|
||||||
dirt_piles_positions = self.get_dirt_piles_positions(env)
|
dirt_piles_positions = self.get_dirt_piles_positions(env)
|
||||||
indices = self.distribute_indices(env)
|
indices = self.distribute_indices(env)
|
||||||
@@ -591,10 +601,15 @@ class A2C:
|
|||||||
|
|
||||||
while episode < n_episodes:
|
while episode < n_episodes:
|
||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
|
if self.cfg[nms.ENV][nms.EVAL_RENDER]:
|
||||||
|
if self.cfg[nms.ENV]["save_and_log"] and self.cfg[nms.ENV]["record"]:
|
||||||
|
env.set_recorder(self.recorder)
|
||||||
|
env.render()
|
||||||
|
env._renderer.fps = 5
|
||||||
self.set_agent_spawnpoint(env)
|
self.set_agent_spawnpoint(env)
|
||||||
"""obs = list(obs.values())"""
|
"""obs = list(obs.values())"""
|
||||||
# Reset current target pile at episode begin if all piles have to be cleaned in one episode
|
# Reset current target pile at episode begin if all piles have to be cleaned in one episode
|
||||||
if self.cfg[nms.ALGORITHM]["pile_all_done"] in ["all", "distributed"]:
|
if self.cfg[nms.ALGORITHM]["pile_all_done"] in ["all", "distributed", "shared"]:
|
||||||
target_pile = [partition[0] for partition in indices]
|
target_pile = [partition[0] for partition in indices]
|
||||||
if self.cfg[nms.ALGORITHM]["pile_all_done"] == "distributed":
|
if self.cfg[nms.ALGORITHM]["pile_all_done"] == "distributed":
|
||||||
cleaned_dirt_piles = [{dirt_piles_positions[idx]: False for idx in indices[i]} for i in range(self.n_agents)]
|
cleaned_dirt_piles = [{dirt_piles_positions[idx]: False for idx in indices[i]} for i in range(self.n_agents)]
|
||||||
@@ -637,6 +652,10 @@ class A2C:
|
|||||||
|
|
||||||
episode += 1
|
episode += 1
|
||||||
|
|
||||||
|
# Properly finalize the video file
|
||||||
|
if self.cfg[nms.ENV]["save_and_log"] and self.cfg[nms.ENV]["record"]:
|
||||||
|
self.recorder.close()
|
||||||
|
|
||||||
def plot_reward_development(self):
|
def plot_reward_development(self):
|
||||||
smoothed_data = np.convolve(self.reward_development, np.ones(10) / 10, mode='valid')
|
smoothed_data = np.convolve(self.reward_development, np.ones(10) / 10, mode='valid')
|
||||||
plt.plot(smoothed_data)
|
plt.plot(smoothed_data)
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ env:
|
|||||||
train_render: False
|
train_render: False
|
||||||
eval_render: True
|
eval_render: True
|
||||||
save_and_log: True
|
save_and_log: True
|
||||||
|
record: False
|
||||||
method: marl_factory_grid.algorithms.marl.LoopSEAC
|
method: marl_factory_grid.algorithms.marl.LoopSEAC
|
||||||
algorithm:
|
algorithm:
|
||||||
gamma: 0.99
|
gamma: 0.99
|
||||||
@@ -27,7 +28,7 @@ algorithm:
|
|||||||
advantage: "Advantage-AC" # Options: "Advantage-AC", "TD-Advantage-AC", "Reinforce"
|
advantage: "Advantage-AC" # Options: "Advantage-AC", "TD-Advantage-AC", "Reinforce"
|
||||||
pile-order: "dynamic" # Use "dynamic" to see emergent phenomenon and "smart" to prevent it
|
pile-order: "dynamic" # Use "dynamic" to see emergent phenomenon and "smart" to prevent it
|
||||||
pile-observability: "single" # Options: "single", "all"
|
pile-observability: "single" # Options: "single", "all"
|
||||||
pile_all_done: "all" # Options: "single", "all" ("single" for training, "all" for eval)
|
pile_all_done: "shared" # Options: "single", "all" ("single" for training, "all" for eval), "shared"
|
||||||
auxiliary_piles: False # Option that is only considered when pile-order = "agents"
|
auxiliary_piles: False # Option that is only considered when pile-order = "agents"
|
||||||
chunk-episode: 20000 # Chunk size. (0 = update networks with full episode at once)
|
chunk-episode: 20000 # Chunk size. (0 = update networks with full episode at once)
|
||||||
|
|
||||||
|
|||||||
@@ -16,7 +16,8 @@ env:
|
|||||||
individual_rewards: True
|
individual_rewards: True
|
||||||
train_render: False
|
train_render: False
|
||||||
eval_render: True
|
eval_render: True
|
||||||
save_and_log: False
|
save_and_log: True
|
||||||
|
record: False
|
||||||
method: marl_factory_grid.algorithms.marl.LoopSEAC
|
method: marl_factory_grid.algorithms.marl.LoopSEAC
|
||||||
algorithm:
|
algorithm:
|
||||||
gamma: 0.99
|
gamma: 0.99
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ env:
|
|||||||
train_render: False
|
train_render: False
|
||||||
eval_render: True
|
eval_render: True
|
||||||
save_and_log: True
|
save_and_log: True
|
||||||
|
record: False
|
||||||
method: marl_factory_grid.algorithms.marl.LoopSEAC
|
method: marl_factory_grid.algorithms.marl.LoopSEAC
|
||||||
algorithm:
|
algorithm:
|
||||||
gamma: 0.99
|
gamma: 0.99
|
||||||
|
|||||||
@@ -2,4 +2,7 @@ marl_factory_grid>environment>rules.py#SpawnEntity.on_reset()
|
|||||||
marl_factory_grid>environment>rewards.py
|
marl_factory_grid>environment>rewards.py
|
||||||
marl_factory_grid>modules>clean_up>groups.py#DirtPiles.trigger_spawn()
|
marl_factory_grid>modules>clean_up>groups.py#DirtPiles.trigger_spawn()
|
||||||
marl_factory_grid>environment>rules.py#AgentSpawnRule
|
marl_factory_grid>environment>rules.py#AgentSpawnRule
|
||||||
marl_factory_grid>utils>states.py#GameState.__init__()
|
marl_factory_grid>utils>states.py#GameState.__init__()
|
||||||
|
marl_factory_grid>environment>factory.py>Factory#render
|
||||||
|
marl_factory_grid>environment>factory.py>Factory#set_recorder
|
||||||
|
marl_factory_grid>utils>renderer.py>Renderer#render
|
||||||
@@ -17,6 +17,7 @@ env:
|
|||||||
train_render: False
|
train_render: False
|
||||||
eval_render: True
|
eval_render: True
|
||||||
save_and_log: False
|
save_and_log: False
|
||||||
|
record: False
|
||||||
method: marl_factory_grid.algorithms.marl.LoopSEAC
|
method: marl_factory_grid.algorithms.marl.LoopSEAC
|
||||||
algorithm:
|
algorithm:
|
||||||
gamma: 0.99
|
gamma: 0.99
|
||||||
|
|||||||
@@ -109,6 +109,7 @@ class Factory(gym.Env):
|
|||||||
|
|
||||||
# expensive - don't use; unless required !
|
# expensive - don't use; unless required !
|
||||||
self._renderer = None
|
self._renderer = None
|
||||||
|
self._recorder = None
|
||||||
|
|
||||||
# Init entities
|
# Init entities
|
||||||
entities = self.map.do_init()
|
entities = self.map.do_init()
|
||||||
@@ -277,7 +278,10 @@ class Factory(gym.Env):
|
|||||||
for render_entity in render_entities:
|
for render_entity in render_entities:
|
||||||
if render_entity.name == c.AGENT:
|
if render_entity.name == c.AGENT:
|
||||||
render_entity.aux = self.obs_builder.curr_lightmaps[render_entity.real_name]
|
render_entity.aux = self.obs_builder.curr_lightmaps[render_entity.real_name]
|
||||||
return self._renderer.render(render_entities)
|
return self._renderer.render(render_entities, self._recorder)
|
||||||
|
|
||||||
|
def set_recorder(self, recorder):
|
||||||
|
self._recorder = recorder
|
||||||
|
|
||||||
def summarize_header(self):
|
def summarize_header(self):
|
||||||
header = {'rec_step': self.state.curr_step}
|
header = {'rec_step': self.state.curr_step}
|
||||||
|
|||||||
@@ -156,7 +156,7 @@ class Renderer:
|
|||||||
rects.append(dict(source=shape_surf, dest=visibility_rect))
|
rects.append(dict(source=shape_surf, dest=visibility_rect))
|
||||||
return rects
|
return rects
|
||||||
|
|
||||||
def render(self, entities):
|
def render(self, entities, recorder):
|
||||||
"""
|
"""
|
||||||
Renders the entities on the screen.
|
Renders the entities on the screen.
|
||||||
|
|
||||||
@@ -190,6 +190,11 @@ class Renderer:
|
|||||||
for blit in blits:
|
for blit in blits:
|
||||||
self.screen.blit(**blit)
|
self.screen.blit(**blit)
|
||||||
|
|
||||||
|
if recorder:
|
||||||
|
frame = pygame.surfarray.array3d(self.screen)
|
||||||
|
frame = np.transpose(frame, (1, 0, 2)) # Transpose to (height, width, channels)
|
||||||
|
recorder.append_data(frame)
|
||||||
|
|
||||||
pygame.display.flip()
|
pygame.display.flip()
|
||||||
self.clock.tick(self.fps)
|
self.clock.tick(self.fps)
|
||||||
rgb_obs = pygame.surfarray.array3d(self.screen)
|
rgb_obs = pygame.surfarray.array3d(self.screen)
|
||||||
|
|||||||
Reference in New Issue
Block a user