Experiments look good

This commit is contained in:
Steffen Illium
2022-01-15 12:37:58 +01:00
parent d29ccbbb71
commit 823aa075b9
14 changed files with 478 additions and 297 deletions

View File

@ -147,17 +147,17 @@ class DestFactory(BaseFactory):
super().__init__(*args, **kwargs)
@property
def additional_actions(self) -> Union[Action, List[Action]]:
def actions_hook(self) -> Union[Action, List[Action]]:
# noinspection PyUnresolvedReferences
super_actions = super().additional_actions
super_actions = super().actions_hook
if self.dest_prop.dwell_time:
super_actions.append(Action(enum_ident=a.WAIT_ON_DEST))
return super_actions
@property
def additional_entities(self) -> Dict[(Enum, Entities)]:
def entities_hook(self) -> Dict[(Enum, Entities)]:
# noinspection PyUnresolvedReferences
super_entities = super().additional_entities
super_entities = super().entities_hook
empty_tiles = self[c.FLOOR].empty_tiles[:self.dest_prop.n_dests]
destinations = Destinations.from_tiles(
@ -194,9 +194,9 @@ class DestFactory(BaseFactory):
else:
return super_action_result
def do_additional_reset(self) -> None:
def reset_hook(self) -> None:
# noinspection PyUnresolvedReferences
super().do_additional_reset()
super().reset_hook()
self._dest_spawn_timer = dict()
def trigger_destination_spawn(self):
@ -222,9 +222,9 @@ class DestFactory(BaseFactory):
else:
self.print('No Items are spawning, limit is reached.')
def do_additional_step(self) -> (List[dict], dict):
def step_hook(self) -> (List[dict], dict):
# noinspection PyUnresolvedReferences
super_reward_info = super().do_additional_step()
super_reward_info = super().step_hook()
for key, val in self._dest_spawn_timer.items():
self._dest_spawn_timer[key] = min(self.dest_prop.spawn_frequency, self._dest_spawn_timer[key] + 1)
for dest in list(self[c.DEST].values()):
@ -244,14 +244,14 @@ class DestFactory(BaseFactory):
self.trigger_destination_spawn()
return super_reward_info
def _additional_observations(self) -> Dict[str, np.typing.ArrayLike]:
additional_observations = super()._additional_observations()
def observations_hook(self) -> Dict[str, np.typing.ArrayLike]:
additional_observations = super().observations_hook()
additional_observations.update({c.DEST: self[c.DEST].as_array()})
return additional_observations
def additional_per_agent_reward(self, agent: Agent) -> Dict[str, dict]:
def per_agent_reward_hook(self, agent: Agent) -> Dict[str, dict]:
# noinspection PyUnresolvedReferences
reward_event_dict = super().additional_per_agent_reward(agent)
reward_event_dict = super().per_agent_reward_hook(agent)
if len(self[c.DEST_REACHED]):
for reached_dest in list(self[c.DEST_REACHED]):
if agent.pos == reached_dest.pos:
@ -261,9 +261,9 @@ class DestFactory(BaseFactory):
reward_event_dict.update({c.DEST_REACHED: {'reward': r.DEST_REACHED, 'info': info_dict}})
return reward_event_dict
def render_additional_assets(self, mode='human'):
def render_assets_hook(self, mode='human'):
# noinspection PyUnresolvedReferences
additional_assets = super().render_additional_assets()
additional_assets = super().render_assets_hook()
destinations = [RenderEntity(c.DEST, dest.pos) for dest in self[c.DEST]]
additional_assets.extend(destinations)
return additional_assets