added target agent test and fixed tsp agents

This commit is contained in:
Chanumask
2024-01-18 13:32:30 +01:00
parent ecf53e7d64
commit 51612812b0
6 changed files with 100 additions and 27 deletions

View File

@ -5,6 +5,7 @@ from tqdm import trange
from marl_factory_grid.algorithms.static.TSP_dirt_agent import TSPDirtAgent
from marl_factory_grid.algorithms.static.TSP_item_agent import TSPItemAgent
from marl_factory_grid.algorithms.static.TSP_target_agent import TSPTargetAgent
from marl_factory_grid.environment.factory import Factory
if __name__ == '__main__':
@ -17,19 +18,18 @@ if __name__ == '__main__':
# Env Init
factory = Factory(path)
for episode in trange(5):
for episode in trange(1):
_ = factory.reset()
done = False
if render:
factory.render()
action_spaces = factory.action_space
agents = [TSPDirtAgent(factory, 0), TSPItemAgent(factory, 1)]
agents = [TSPDirtAgent(factory, 0), TSPItemAgent(factory, 1), TSPTargetAgent(factory, 2)]
while not done:
a = [randint(0, x.n - 1) for x in action_spaces]
a = [x.predict() for x in agents]
obs_type, _, _, done, info = factory.step(a)
if render:
factory.render()
if done:
print(f'Episode {episode} done...')
break