mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-12-13 18:50:38 +01:00
Equalize rendering for TSP and RL agents
This commit is contained in:
@@ -286,6 +286,7 @@ class A2C:
|
|||||||
updated_indices = []
|
updated_indices = []
|
||||||
if len(affected_agents[door_positions[0]]) == 0:
|
if len(affected_agents[door_positions[0]]) == 0:
|
||||||
# Remove auxiliary piles for all agents
|
# Remove auxiliary piles for all agents
|
||||||
|
# (In config, we defined every pile with an even numbered index to be an auxiliary pile)
|
||||||
updated_indices = [[ele for ele in lst if ele % 2 != 0] for lst in indices]
|
updated_indices = [[ele for ele in lst if ele % 2 != 0] for lst in indices]
|
||||||
else:
|
else:
|
||||||
for distance, agent_indices in affected_agents[door_positions[0]].items():
|
for distance, agent_indices in affected_agents[door_positions[0]].items():
|
||||||
@@ -430,6 +431,10 @@ class A2C:
|
|||||||
reward[idx] += 50 # 1
|
reward[idx] += 50 # 1
|
||||||
cleaned_dirt_piles[idx][pos] = True
|
cleaned_dirt_piles[idx][pos] = True
|
||||||
|
|
||||||
|
# Indicate that renderer can hide dirt pile
|
||||||
|
dirt_at_position = env.state['DirtPiles'].by_pos(pos)
|
||||||
|
dirt_at_position[0].set_new_amount(0)
|
||||||
|
|
||||||
if self.cfg[nms.ALGORITHM]["pile_all_done"] in ["all", "distributed"]:
|
if self.cfg[nms.ALGORITHM]["pile_all_done"] in ["all", "distributed"]:
|
||||||
if all([all(cleaned_dirt_piles[i].values()) for i in range(self.n_agents)]):
|
if all([all(cleaned_dirt_piles[i].values()) for i in range(self.n_agents)]):
|
||||||
done = True
|
done = True
|
||||||
@@ -603,12 +608,17 @@ class A2C:
|
|||||||
|
|
||||||
while episode < n_episodes:
|
while episode < n_episodes:
|
||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
|
self.set_agent_spawnpoint(env)
|
||||||
if self.cfg[nms.ENV][nms.EVAL_RENDER]:
|
if self.cfg[nms.ENV][nms.EVAL_RENDER]:
|
||||||
if self.cfg[nms.ENV]["save_and_log"] and self.cfg[nms.ENV]["record"]:
|
if self.cfg[nms.ENV]["save_and_log"] and self.cfg[nms.ENV]["record"]:
|
||||||
env.set_recorder(self.recorder)
|
env.set_recorder(self.recorder)
|
||||||
|
if self.cfg[nms.ALGORITHM]["auxiliary_piles"]:
|
||||||
|
# Don't render auxiliary piles
|
||||||
|
auxiliary_piles = [pile for idx, pile in enumerate(env.state.entities['DirtPiles']) if idx % 2 == 0]
|
||||||
|
for pile in auxiliary_piles:
|
||||||
|
pile.set_new_amount(0)
|
||||||
env.render()
|
env.render()
|
||||||
env._renderer.fps = 5
|
env._renderer.fps = 5
|
||||||
self.set_agent_spawnpoint(env)
|
|
||||||
"""obs = list(obs.values())"""
|
"""obs = list(obs.values())"""
|
||||||
# Reset current target pile at episode begin if all piles have to be cleaned in one episode
|
# Reset current target pile at episode begin if all piles have to be cleaned in one episode
|
||||||
if self.cfg[nms.ALGORITHM]["pile_all_done"] in ["all", "distributed", "shared"]:
|
if self.cfg[nms.ALGORITHM]["pile_all_done"] in ["all", "distributed", "shared"]:
|
||||||
|
|||||||
@@ -24,9 +24,9 @@ class TSPDirtAgent(TSPBaseAgent):
|
|||||||
"""
|
"""
|
||||||
dirt_at_position = self._env.state[di.DIRT].by_pos(self.state.pos)
|
dirt_at_position = self._env.state[di.DIRT].by_pos(self.state.pos)
|
||||||
if dirt_at_position:
|
if dirt_at_position:
|
||||||
# Translate the action_object to an integer to have the same output as any other model
|
# Remove dirt from map
|
||||||
action = di.CLEAN_UP
|
self._env.state[di.DIRT].delete_env_object(dirt_at_position[0])
|
||||||
elif door := self._door_is_close(self._env.state):
|
if door := self._door_is_close(self._env.state):
|
||||||
action = self._use_door_or_move(door, di.DIRT)
|
action = self._use_door_or_move(door, di.DIRT)
|
||||||
else:
|
else:
|
||||||
action = self._predict_move(di.DIRT)
|
action = self._predict_move(di.DIRT)
|
||||||
|
|||||||
@@ -272,7 +272,21 @@ class Factory(gym.Env):
|
|||||||
global Renderer
|
global Renderer
|
||||||
self._renderer = Renderer(self.map.level_shape, view_radius=self.conf.pomdp_r, fps=10)
|
self._renderer = Renderer(self.map.level_shape, view_radius=self.conf.pomdp_r, fps=10)
|
||||||
|
|
||||||
|
# Hide dirt piles where all dirt was cleaned
|
||||||
render_entities = self.state.entities.render()
|
render_entities = self.state.entities.render()
|
||||||
|
if 'DirtPiles' in list(self.state.entities.keys()):
|
||||||
|
for pile in self.state.entities['DirtPiles']:
|
||||||
|
if pile.amount <= 0:
|
||||||
|
render_entities = [entity for entity in render_entities if not (entity.name == 'DirtPiles' and entity.pos == pile.pos)]
|
||||||
|
|
||||||
|
# Mask dirt piles as Destinations (relevant for RL-agents) # TODO
|
||||||
|
if self.conf['General']['level_name'] == 'two_rooms':
|
||||||
|
if 'DirtPiles' in list(self.state.entities.keys()):
|
||||||
|
for entity in render_entities:
|
||||||
|
if entity.name == 'DirtPiles':
|
||||||
|
entity.name = 'Destinations'
|
||||||
|
entity.value = 1
|
||||||
|
|
||||||
if self.conf.pomdp_r:
|
if self.conf.pomdp_r:
|
||||||
for render_entity in render_entities:
|
for render_entity in render_entities:
|
||||||
if render_entity.name == c.AGENT:
|
if render_entity.name == c.AGENT:
|
||||||
|
|||||||
@@ -80,11 +80,14 @@ def run_tsp_setting(config_name, emergent_phenomenon):
|
|||||||
break
|
break
|
||||||
while not done:
|
while not done:
|
||||||
a = [x.predict() for x in agents]
|
a = [x.predict() for x in agents]
|
||||||
|
# Have this condition, to terminate as soon as all dirt piles are collected. This ensures that the implementation
|
||||||
|
# of the TSP agent is equivalent to that of the RL agent
|
||||||
|
if 'DirtPiles' in list(factory.state.entities.keys()) and factory.state.entities['DirtPiles'].global_amount == 0.0:
|
||||||
|
break
|
||||||
obs_type, _, _, done, info = factory.step(a)
|
obs_type, _, _, done, info = factory.step(a)
|
||||||
if render:
|
if render:
|
||||||
factory.render()
|
factory.render()
|
||||||
if done:
|
if done:
|
||||||
print(f'Episode {episode} done...')
|
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user