Rework for performance

This commit is contained in:
Steffen Illium
2022-01-10 15:54:22 +01:00
parent 78bf19f7f4
commit 435056f373
10 changed files with 525 additions and 469 deletions

View File

@ -6,18 +6,32 @@ import numpy as np
import random
from environments.factory.base.base_factory import BaseFactory
from environments.helpers import Constants as c, Constants
from environments import helpers as h
from environments.helpers import Constants as BaseConstants
from environments.helpers import EnvActions as BaseActions
from environments.helpers import Rewards as BaseRewards
from environments.factory.base.objects import Agent, Entity, Action
from environments.factory.base.registers import Entities, EntityRegister
from environments.factory.base.renderer import RenderEntity
class Constants(BaseConstants):
# Destination Env
DEST = 'Destination'
DESTINATION = 1
DESTINATION_DONE = 0.5
DEST_REACHED = 'ReachedDestination'
DESTINATION = 1
DESTINATION_DONE = 0.5
class Actions(BaseActions):
WAIT_ON_DEST = 'WAIT'
class Rewards(BaseRewards):
WAIT_VALID = 0.1
WAIT_FAIL = -0.1
DEST_REACHED = 5.0
class Destination(Entity):
@ -30,20 +44,16 @@ class Destination(Entity):
def currently_dwelling_names(self):
return self._per_agent_times.keys()
@property
def can_collide(self):
return False
@property
def encoding(self):
return DESTINATION
return c.DESTINATION
def __init__(self, *args, dwell_time: int = 0, **kwargs):
super(Destination, self).__init__(*args, **kwargs)
self.dwell_time = dwell_time
self._per_agent_times = defaultdict(lambda: dwell_time)
def wait(self, agent: Agent):
def do_wait_action(self, agent: Agent):
self._per_agent_times[agent.name] -= 1
return c.VALID
@ -52,7 +62,7 @@ class Destination(Entity):
@property
def is_considered_reached(self):
agent_at_position = any(c.AGENT.name.lower() in x.name.lower() for x in self.tile.guests_that_can_collide)
agent_at_position = any(c.AGENT.lower() in x.name.lower() for x in self.tile.guests_that_can_collide)
return (agent_at_position and not self.dwell_time) or any(x == 0 for x in self._per_agent_times.values())
def agent_is_dwelling(self, agent: Agent):
@ -67,15 +77,19 @@ class Destination(Entity):
class Destinations(EntityRegister):
_accepted_objects = Destination
_light_blocking = False
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.is_blocking_light = False
self.can_be_shadowed = False
def as_array(self):
self._array[:] = c.FREE_CELL.value
self._array[:] = c.FREE_CELL
# ToDo: Switch to new Style Array Put
# indices = list(zip(range(len(cls)), *zip(*[x.pos for x in cls])))
# np.put(cls._array, [np.ravel_multi_index(x, cls._array.shape) for x in indices], cls.encodings)
for item in self:
if item.pos != c.NO_POS.value:
if item.pos != c.NO_POS:
self._array[0, item.x, item.y] = item.encoding
return self._array
@ -85,10 +99,11 @@ class Destinations(EntityRegister):
class ReachedDestinations(Destinations):
_accepted_objects = Destination
_light_blocking = False
def __init__(self, *args, **kwargs):
super(ReachedDestinations, self).__init__(*args, **kwargs)
self.can_be_shadowed = False
self.is_blocking_light = False
def summarize_states(self, n_steps=None):
return {}
@ -102,7 +117,7 @@ class DestModeOptions(object):
class DestProperties(NamedTuple):
n_dests: int = 1 # How many destinations are there
dwell_time: int = 0 # How long does the agent need to "wait" on a destination
dwell_time: int = 0 # How long does the agent need to "do_wait_action" on a destination
spawn_frequency: int = 0
spawn_in_other_zone: bool = True #
spawn_mode: str = DestModeOptions.DONE
@ -113,6 +128,11 @@ class DestProperties(NamedTuple):
assert (spawn_mode == DestModeOptions.DONE) != bool(spawn_frequency)
c = Constants
a = Actions
r = Rewards
# noinspection PyAttributeOutsideInit, PyAbstractClass
class DestFactory(BaseFactory):
# noinspection PyMissingConstructor
@ -131,7 +151,7 @@ class DestFactory(BaseFactory):
# noinspection PyUnresolvedReferences
super_actions = super().additional_actions
if self.dest_prop.dwell_time:
super_actions.append(Action(enum_ident=h.EnvActions.WAIT_ON_DEST))
super_actions.append(Action(enum_ident=a.WAIT_ON_DEST))
return super_actions
@property
@ -147,27 +167,32 @@ class DestFactory(BaseFactory):
)
reached_destinations = ReachedDestinations(level_shape=self._level_shape)
super_entities.update({c.DESTINATION: destinations, c.REACHEDDESTINATION: reached_destinations})
super_entities.update({c.DEST: destinations, c.DEST_REACHED: reached_destinations})
return super_entities
def wait(self, agent: Agent):
if destiantion := self[c.DESTINATION].by_pos(agent.pos):
valid = destiantion.wait(agent)
return valid
def do_wait_action(self, agent: Agent) -> (dict, dict):
if destination := self[c.DEST].by_pos(agent.pos):
valid = destination.do_wait_action(agent)
self.print(f'{agent.name} just waited at {agent.pos}')
info_dict = {f'{agent.name}_{a.WAIT_ON_DEST}_VALID': 1}
else:
return c.NOT_VALID
valid = c.NOT_VALID
self.print(f'{agent.name} just tried to do_wait_action do_wait_action at {agent.pos} but failed')
info_dict = {f'{agent.name}_{a.WAIT_ON_DEST}_FAIL': 1}
reward = dict(value=r.WAIT_VALID if valid else r.WAIT_FAIL, reason=a.WAIT_ON_DEST, info=info_dict)
return valid, reward
def do_additional_actions(self, agent: Agent, action: Action) -> Union[None, c]:
def do_additional_actions(self, agent: Agent, action: Action) -> (dict, dict):
# noinspection PyUnresolvedReferences
valid = super().do_additional_actions(agent, action)
if valid is None:
if action == h.EnvActions.WAIT_ON_DEST:
valid = self.wait(agent)
return valid
super_action_result = super().do_additional_actions(agent, action)
if super_action_result is None:
if action == a.WAIT_ON_DEST:
action_result = self.do_wait_action(agent)
return action_result
else:
return None
else:
return valid
return super_action_result
def do_additional_reset(self) -> None:
# noinspection PyUnresolvedReferences
@ -180,14 +205,14 @@ class DestFactory(BaseFactory):
if destinations_to_spawn:
n_dest_to_spawn = len(destinations_to_spawn)
if self.dest_prop.spawn_mode != DestModeOptions.GROUPED:
destinations = [Destination(tile) for tile in self[c.FLOOR].empty_tiles[:n_dest_to_spawn]]
self[c.DESTINATION].register_additional_items(destinations)
destinations = [Destination(tile, c.DEST) for tile in self[c.FLOOR].empty_tiles[:n_dest_to_spawn]]
self[c.DEST].register_additional_items(destinations)
for dest in destinations_to_spawn:
del self._dest_spawn_timer[dest]
self.print(f'{n_dest_to_spawn} new destinations have been spawned')
elif self.dest_prop.spawn_mode == DestModeOptions.GROUPED and n_dest_to_spawn == self.dest_prop.n_dests:
destinations = [Destination(tile) for tile in self[c.FLOOR].empty_tiles[:n_dest_to_spawn]]
self[c.DESTINATION].register_additional_items(destinations)
destinations = [Destination(tile, self[c.DEST]) for tile in self[c.FLOOR].empty_tiles[:n_dest_to_spawn]]
self[c.DEST].register_additional_items(destinations)
for dest in destinations_to_spawn:
del self._dest_spawn_timer[dest]
self.print(f'{n_dest_to_spawn} new destinations have been spawned')
@ -197,15 +222,14 @@ class DestFactory(BaseFactory):
else:
self.print('No Items are spawning, limit is reached.')
def do_additional_step(self) -> dict:
def do_additional_step(self) -> (List[dict], dict):
# noinspection PyUnresolvedReferences
info_dict = super().do_additional_step()
super_reward_info = super().do_additional_step()
for key, val in self._dest_spawn_timer.items():
self._dest_spawn_timer[key] = min(self.dest_prop.spawn_frequency, self._dest_spawn_timer[key] + 1)
for dest in list(self[c.DESTINATION].values()):
for dest in list(self[c.DEST].values()):
if dest.is_considered_reached:
self[c.REACHEDDESTINATION].register_item(dest)
self[c.DESTINATION].delete_env_object(dest)
dest.change_register(self[c.DEST])
self._dest_spawn_timer[dest.name] = 0
self.print(f'{dest.name} is reached now, removing...')
else:
@ -218,41 +242,29 @@ class DestFactory(BaseFactory):
dest.leave(agent)
self.print(f'{agent.name} left the destination early.')
self.trigger_destination_spawn()
return info_dict
return super_reward_info
def _additional_observations(self) -> Dict[Constants, np.typing.ArrayLike]:
def _additional_observations(self) -> Dict[str, np.typing.ArrayLike]:
additional_observations = super()._additional_observations()
additional_observations.update({c.DESTINATION: self[c.DESTINATION].as_array()})
additional_observations.update({c.DEST: self[c.DEST].as_array()})
return additional_observations
def calculate_additional_reward(self, agent: Agent) -> (int, dict):
def additional_per_agent_reward(self, agent: Agent) -> Dict[str, dict]:
# noinspection PyUnresolvedReferences
reward, info_dict = super().calculate_additional_reward(agent)
if h.EnvActions.WAIT_ON_DEST == agent.temp_action:
if agent.temp_valid:
info_dict.update({f'{agent.name}_waiting_at_dest': 1})
info_dict.update(agent_waiting_at_dest=1)
self.print(f'{agent.name} just waited at {agent.pos}')
reward += 0.1
else:
info_dict.update({f'{agent.name}_tried_failed': 1})
info_dict.update(agent_waiting_failed=1)
self.print(f'{agent.name} just tried to wait wait at {agent.pos} but failed')
reward -= 0.1
if len(self[c.REACHEDDESTINATION]):
for reached_dest in list(self[c.REACHEDDESTINATION]):
reward_event_dict = super().additional_per_agent_reward(agent)
if len(self[c.DEST_REACHED]):
for reached_dest in list(self[c.DEST_REACHED]):
if agent.pos == reached_dest.pos:
info_dict.update({f'{agent.name}_reached_destination': 1})
info_dict.update(agent_reached_destination=1)
self.print(f'{agent.name} just reached destination at {agent.pos}')
reward += 0.5
self[c.REACHEDDESTINATION].delete_env_object(reached_dest)
return reward, info_dict
self[c.DEST_REACHED].delete_env_object(reached_dest)
info_dict = {f'{agent.name}_{c.DEST_REACHED}': 1}
reward_event_dict.update({c.DEST_REACHED: {'reward': r.DEST_REACHED, 'info': info_dict}})
return reward_event_dict
def render_additional_assets(self, mode='human'):
# noinspection PyUnresolvedReferences
additional_assets = super().render_additional_assets()
destinations = [RenderEntity(c.DESTINATION.value, dest.pos) for dest in self[c.DESTINATION]]
destinations = [RenderEntity(c.DEST, dest.pos) for dest in self[c.DEST]]
additional_assets.extend(destinations)
return additional_assets