Debugging

This commit is contained in:
Steffen Illium
2022-01-11 10:54:02 +01:00
parent 435056f373
commit 3150757347
6 changed files with 67 additions and 58 deletions

View File

@ -35,7 +35,7 @@ class BaseFactory(gym.Env):
@property
def named_action_space(self):
return {x.identifier.value: idx for idx, x in enumerate(self._actions.values())}
return {x.identifier: idx for idx, x in enumerate(self._actions.values())}
@property
def observation_space(self):
@ -287,7 +287,7 @@ class BaseFactory(gym.Env):
doors.tick_doors()
# Finalize
reward, reward_info = self.build_reward_result()
reward, reward_info = self.build_reward_result(rewards)
info.update(reward_info)
if self._steps >= self.max_steps:
@ -313,8 +313,8 @@ class BaseFactory(gym.Env):
if door is not None:
door.use()
valid = c.VALID
self.print(f'{agent.name} just used a door {door.name}')
info_dict = {f'{agent.name}_door_use_{door.name}': 1}
self.print(f'{agent.name} just used a {door.name} at {door.pos}')
info_dict = {f'{agent.name}_door_use': 1}
# When he doesn't...
else:
valid = c.NOT_VALID
@ -478,8 +478,7 @@ class BaseFactory(gym.Env):
return oobs
def get_all_tiles_with_collisions(self) -> List[Tile]:
tiles = [x.tile for y in self._entities for x in y if
y.can_collide and not isinstance(y, WallTiles) and x.can_collide and len(x.tile.guests) > 1]
tiles = [x for x in self[c.FLOOR] if len(x.guests_that_can_collide) > 1]
if False:
tiles_with_collisions = list()
for tile in self[c.FLOOR]:
@ -503,11 +502,11 @@ class BaseFactory(gym.Env):
else:
valid = c.NOT_VALID
self.print(f'{agent.name} just hit the wall at {agent.pos}.')
info_dict.update({f'{agent.pos}_wall_collide': 1})
info_dict.update({f'{agent.name}_wall_collide': 1})
else:
# Agent seems to be trying to Leave the level
self.print(f'{agent.name} tried to leave the level {agent.pos}.')
info_dict.update({f'{agent.pos}_wall_collide': 1})
info_dict.update({f'{agent.name}_wall_collide': 1})
reward_value = r.MOVEMENTS_VALID if valid else r.MOVEMENTS_FAIL
reward = {'value': reward_value, 'reason': action.identifier, 'info': info_dict}
return valid, reward
@ -554,7 +553,7 @@ class BaseFactory(gym.Env):
def additional_per_agent_rewards(self, agent) -> List[dict]:
return []
def build_reward_result(self) -> (int, dict):
def build_reward_result(self, global_env_rewards: list) -> (int, dict):
# Returns: Reward, Info
info = defaultdict(lambda: 0.0)
@ -584,12 +583,14 @@ class BaseFactory(gym.Env):
combined_info_dict = dict(combined_info_dict)
combined_info_dict.update(info)
global_reward_sum = sum(global_env_rewards)
if self.individual_rewards:
self.print(f"rewards are {comb_rewards}")
reward = list(comb_rewards.values())
reward = [x + global_reward_sum for x in reward]
return reward, combined_info_dict
else:
reward = sum(comb_rewards.values())
reward = sum(comb_rewards.values()) + global_reward_sum
self.print(f"reward is {reward}")
return reward, combined_info_dict