mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-07-08 02:21:36 +02:00
Reword TSP runner
This commit is contained in:
@ -1,4 +1,5 @@
|
||||
import os
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
|
||||
from tqdm import trange
|
||||
@ -8,14 +9,14 @@ from marl_factory_grid.algorithms.tsp.contortions import get_dirt_quadrant_tsp_a
|
||||
|
||||
|
||||
def dirt_quadrant_multi_agent_tsp_eval(emergent_phenomenon):
|
||||
run_tsp_setting("dirt_quadrant", emergent_phenomenon)
|
||||
run_tsp_setting("dirt_quadrant", emergent_phenomenon, log=False)
|
||||
|
||||
|
||||
def two_rooms_multi_agent_tsp_eval(emergent_phenomenon):
|
||||
run_tsp_setting("two_rooms", emergent_phenomenon)
|
||||
run_tsp_setting("two_rooms", emergent_phenomenon, log=False)
|
||||
|
||||
|
||||
def run_tsp_setting(config_name, emergent_phenomenon, n_episodes=1):
|
||||
def run_tsp_setting(config_name, emergent_phenomenon, n_episodes=1, log=False):
|
||||
# Render at each step?
|
||||
render = True
|
||||
|
||||
@ -35,8 +36,13 @@ def run_tsp_setting(config_name, emergent_phenomenon, n_episodes=1):
|
||||
with open(f"{results_path}/env_config.txt", "w") as txt_file:
|
||||
txt_file.write(str(factory.conf))
|
||||
|
||||
still_existing_dirt_piles = []
|
||||
reached_flags = []
|
||||
|
||||
for episode in trange(n_episodes):
|
||||
_ = factory.reset()
|
||||
still_existing_dirt_piles.append([])
|
||||
reached_flags.append([])
|
||||
done = False
|
||||
if render:
|
||||
factory.render()
|
||||
@ -48,6 +54,7 @@ def run_tsp_setting(config_name, emergent_phenomenon, n_episodes=1):
|
||||
else:
|
||||
print("Config name does not exist. Abort...")
|
||||
break
|
||||
ep_steps = 0
|
||||
while not done:
|
||||
a = [x.predict() for x in agents]
|
||||
# Have this condition, to terminate as soon as all dirt piles are collected. This ensures that the implementation
|
||||
@ -55,7 +62,33 @@ def run_tsp_setting(config_name, emergent_phenomenon, n_episodes=1):
|
||||
if 'DirtPiles' in list(factory.state.entities.keys()) and factory.state.entities['DirtPiles'].global_amount == 0.0:
|
||||
break
|
||||
obs_type, _, _, done, info = factory.step(a)
|
||||
if 'DirtPiles' in list(factory.state.entities.keys()):
|
||||
still_existing_dirt_piles[-1].append(len(factory.state.entities['DirtPiles']))
|
||||
if 'Destinations' in list(factory.state.entities.keys()):
|
||||
reached_flags[-1].append(sum([1 for ele in [x.was_reached() for x in factory.state['Destinations']] if ele]))
|
||||
ep_steps += 1
|
||||
if render:
|
||||
factory.render()
|
||||
if done:
|
||||
break
|
||||
break
|
||||
|
||||
cleaned_dirt_piles_per_step = []
|
||||
if 'DirtPiles' in list(factory.state.entities.keys()):
|
||||
for ep in still_existing_dirt_piles:
|
||||
cleaned_dirt_piles_per_step.append([max(ep)-ep[idx] for idx, value in enumerate(ep)])
|
||||
# Remove first element and add last element where all dirt piles have been collected
|
||||
del cleaned_dirt_piles_per_step[-1][0]
|
||||
cleaned_dirt_piles_per_step[-1].append(max(still_existing_dirt_piles[-1]))
|
||||
|
||||
# Add last entry to reached_flags
|
||||
print(ep_steps)
|
||||
print(reached_flags)
|
||||
print(cleaned_dirt_piles_per_step)
|
||||
|
||||
if log:
|
||||
if 'DirtPiles' in list(factory.state.entities.keys()):
|
||||
metrics_data = {"cleaned_dirt_piles_per_step": cleaned_dirt_piles_per_step}
|
||||
if 'Destinations' in list(factory.state.entities.keys()):
|
||||
metrics_data = {"reached_flags": reached_flags}
|
||||
with open(f"{results_path}/metrics", "wb") as pickle_file:
|
||||
pickle.dump(metrics_data, pickle_file)
|
Reference in New Issue
Block a user