mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-12-14 03:00:37 +01:00
Reword TSP runner
This commit is contained in:
@@ -1,4 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
|
import pickle
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
@@ -8,14 +9,14 @@ from marl_factory_grid.algorithms.tsp.contortions import get_dirt_quadrant_tsp_a
|
|||||||
|
|
||||||
|
|
||||||
def dirt_quadrant_multi_agent_tsp_eval(emergent_phenomenon):
|
def dirt_quadrant_multi_agent_tsp_eval(emergent_phenomenon):
|
||||||
run_tsp_setting("dirt_quadrant", emergent_phenomenon)
|
run_tsp_setting("dirt_quadrant", emergent_phenomenon, log=False)
|
||||||
|
|
||||||
|
|
||||||
def two_rooms_multi_agent_tsp_eval(emergent_phenomenon):
|
def two_rooms_multi_agent_tsp_eval(emergent_phenomenon):
|
||||||
run_tsp_setting("two_rooms", emergent_phenomenon)
|
run_tsp_setting("two_rooms", emergent_phenomenon, log=False)
|
||||||
|
|
||||||
|
|
||||||
def run_tsp_setting(config_name, emergent_phenomenon, n_episodes=1):
|
def run_tsp_setting(config_name, emergent_phenomenon, n_episodes=1, log=False):
|
||||||
# Render at each step?
|
# Render at each step?
|
||||||
render = True
|
render = True
|
||||||
|
|
||||||
@@ -35,8 +36,13 @@ def run_tsp_setting(config_name, emergent_phenomenon, n_episodes=1):
|
|||||||
with open(f"{results_path}/env_config.txt", "w") as txt_file:
|
with open(f"{results_path}/env_config.txt", "w") as txt_file:
|
||||||
txt_file.write(str(factory.conf))
|
txt_file.write(str(factory.conf))
|
||||||
|
|
||||||
|
still_existing_dirt_piles = []
|
||||||
|
reached_flags = []
|
||||||
|
|
||||||
for episode in trange(n_episodes):
|
for episode in trange(n_episodes):
|
||||||
_ = factory.reset()
|
_ = factory.reset()
|
||||||
|
still_existing_dirt_piles.append([])
|
||||||
|
reached_flags.append([])
|
||||||
done = False
|
done = False
|
||||||
if render:
|
if render:
|
||||||
factory.render()
|
factory.render()
|
||||||
@@ -48,6 +54,7 @@ def run_tsp_setting(config_name, emergent_phenomenon, n_episodes=1):
|
|||||||
else:
|
else:
|
||||||
print("Config name does not exist. Abort...")
|
print("Config name does not exist. Abort...")
|
||||||
break
|
break
|
||||||
|
ep_steps = 0
|
||||||
while not done:
|
while not done:
|
||||||
a = [x.predict() for x in agents]
|
a = [x.predict() for x in agents]
|
||||||
# Have this condition, to terminate as soon as all dirt piles are collected. This ensures that the implementation
|
# Have this condition, to terminate as soon as all dirt piles are collected. This ensures that the implementation
|
||||||
@@ -55,7 +62,33 @@ def run_tsp_setting(config_name, emergent_phenomenon, n_episodes=1):
|
|||||||
if 'DirtPiles' in list(factory.state.entities.keys()) and factory.state.entities['DirtPiles'].global_amount == 0.0:
|
if 'DirtPiles' in list(factory.state.entities.keys()) and factory.state.entities['DirtPiles'].global_amount == 0.0:
|
||||||
break
|
break
|
||||||
obs_type, _, _, done, info = factory.step(a)
|
obs_type, _, _, done, info = factory.step(a)
|
||||||
|
if 'DirtPiles' in list(factory.state.entities.keys()):
|
||||||
|
still_existing_dirt_piles[-1].append(len(factory.state.entities['DirtPiles']))
|
||||||
|
if 'Destinations' in list(factory.state.entities.keys()):
|
||||||
|
reached_flags[-1].append(sum([1 for ele in [x.was_reached() for x in factory.state['Destinations']] if ele]))
|
||||||
|
ep_steps += 1
|
||||||
if render:
|
if render:
|
||||||
factory.render()
|
factory.render()
|
||||||
if done:
|
if done:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
cleaned_dirt_piles_per_step = []
|
||||||
|
if 'DirtPiles' in list(factory.state.entities.keys()):
|
||||||
|
for ep in still_existing_dirt_piles:
|
||||||
|
cleaned_dirt_piles_per_step.append([max(ep)-ep[idx] for idx, value in enumerate(ep)])
|
||||||
|
# Remove first element and add last element where all dirt piles have been collected
|
||||||
|
del cleaned_dirt_piles_per_step[-1][0]
|
||||||
|
cleaned_dirt_piles_per_step[-1].append(max(still_existing_dirt_piles[-1]))
|
||||||
|
|
||||||
|
# Add last entry to reached_flags
|
||||||
|
print(ep_steps)
|
||||||
|
print(reached_flags)
|
||||||
|
print(cleaned_dirt_piles_per_step)
|
||||||
|
|
||||||
|
if log:
|
||||||
|
if 'DirtPiles' in list(factory.state.entities.keys()):
|
||||||
|
metrics_data = {"cleaned_dirt_piles_per_step": cleaned_dirt_piles_per_step}
|
||||||
|
if 'Destinations' in list(factory.state.entities.keys()):
|
||||||
|
metrics_data = {"reached_flags": reached_flags}
|
||||||
|
with open(f"{results_path}/metrics", "wb") as pickle_file:
|
||||||
|
pickle.dump(metrics_data, pickle_file)
|
||||||
Reference in New Issue
Block a user