Rework for performance
This commit is contained in:
@ -6,18 +6,32 @@ import numpy as np
|
||||
import random
|
||||
|
||||
from environments.factory.base.base_factory import BaseFactory
|
||||
from environments.helpers import Constants as c, Constants
|
||||
from environments import helpers as h
|
||||
from environments.helpers import Constants as BaseConstants
|
||||
from environments.helpers import EnvActions as BaseActions
|
||||
from environments.helpers import Rewards as BaseRewards
|
||||
from environments.factory.base.objects import Agent, Entity, Action
|
||||
from environments.factory.base.registers import Entities, EntityRegister
|
||||
|
||||
from environments.factory.base.renderer import RenderEntity
|
||||
|
||||
|
||||
class Constants(BaseConstants):
|
||||
# Destination Env
|
||||
DEST = 'Destination'
|
||||
DESTINATION = 1
|
||||
DESTINATION_DONE = 0.5
|
||||
DEST_REACHED = 'ReachedDestination'
|
||||
|
||||
|
||||
DESTINATION = 1
|
||||
DESTINATION_DONE = 0.5
|
||||
class Actions(BaseActions):
|
||||
WAIT_ON_DEST = 'WAIT'
|
||||
|
||||
|
||||
class Rewards(BaseRewards):
|
||||
|
||||
WAIT_VALID = 0.1
|
||||
WAIT_FAIL = -0.1
|
||||
DEST_REACHED = 5.0
|
||||
|
||||
|
||||
class Destination(Entity):
|
||||
@ -30,20 +44,16 @@ class Destination(Entity):
|
||||
def currently_dwelling_names(self):
|
||||
return self._per_agent_times.keys()
|
||||
|
||||
@property
|
||||
def can_collide(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
return DESTINATION
|
||||
return c.DESTINATION
|
||||
|
||||
def __init__(self, *args, dwell_time: int = 0, **kwargs):
|
||||
super(Destination, self).__init__(*args, **kwargs)
|
||||
self.dwell_time = dwell_time
|
||||
self._per_agent_times = defaultdict(lambda: dwell_time)
|
||||
|
||||
def wait(self, agent: Agent):
|
||||
def do_wait_action(self, agent: Agent):
|
||||
self._per_agent_times[agent.name] -= 1
|
||||
return c.VALID
|
||||
|
||||
@ -52,7 +62,7 @@ class Destination(Entity):
|
||||
|
||||
@property
|
||||
def is_considered_reached(self):
|
||||
agent_at_position = any(c.AGENT.name.lower() in x.name.lower() for x in self.tile.guests_that_can_collide)
|
||||
agent_at_position = any(c.AGENT.lower() in x.name.lower() for x in self.tile.guests_that_can_collide)
|
||||
return (agent_at_position and not self.dwell_time) or any(x == 0 for x in self._per_agent_times.values())
|
||||
|
||||
def agent_is_dwelling(self, agent: Agent):
|
||||
@ -67,15 +77,19 @@ class Destination(Entity):
|
||||
class Destinations(EntityRegister):
|
||||
|
||||
_accepted_objects = Destination
|
||||
_light_blocking = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.is_blocking_light = False
|
||||
self.can_be_shadowed = False
|
||||
|
||||
def as_array(self):
|
||||
self._array[:] = c.FREE_CELL.value
|
||||
self._array[:] = c.FREE_CELL
|
||||
# ToDo: Switch to new Style Array Put
|
||||
# indices = list(zip(range(len(cls)), *zip(*[x.pos for x in cls])))
|
||||
# np.put(cls._array, [np.ravel_multi_index(x, cls._array.shape) for x in indices], cls.encodings)
|
||||
for item in self:
|
||||
if item.pos != c.NO_POS.value:
|
||||
if item.pos != c.NO_POS:
|
||||
self._array[0, item.x, item.y] = item.encoding
|
||||
return self._array
|
||||
|
||||
@ -85,10 +99,11 @@ class Destinations(EntityRegister):
|
||||
|
||||
class ReachedDestinations(Destinations):
|
||||
_accepted_objects = Destination
|
||||
_light_blocking = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(ReachedDestinations, self).__init__(*args, **kwargs)
|
||||
self.can_be_shadowed = False
|
||||
self.is_blocking_light = False
|
||||
|
||||
def summarize_states(self, n_steps=None):
|
||||
return {}
|
||||
@ -102,7 +117,7 @@ class DestModeOptions(object):
|
||||
|
||||
class DestProperties(NamedTuple):
|
||||
n_dests: int = 1 # How many destinations are there
|
||||
dwell_time: int = 0 # How long does the agent need to "wait" on a destination
|
||||
dwell_time: int = 0 # How long does the agent need to "do_wait_action" on a destination
|
||||
spawn_frequency: int = 0
|
||||
spawn_in_other_zone: bool = True #
|
||||
spawn_mode: str = DestModeOptions.DONE
|
||||
@ -113,6 +128,11 @@ class DestProperties(NamedTuple):
|
||||
assert (spawn_mode == DestModeOptions.DONE) != bool(spawn_frequency)
|
||||
|
||||
|
||||
c = Constants
|
||||
a = Actions
|
||||
r = Rewards
|
||||
|
||||
|
||||
# noinspection PyAttributeOutsideInit, PyAbstractClass
|
||||
class DestFactory(BaseFactory):
|
||||
# noinspection PyMissingConstructor
|
||||
@ -131,7 +151,7 @@ class DestFactory(BaseFactory):
|
||||
# noinspection PyUnresolvedReferences
|
||||
super_actions = super().additional_actions
|
||||
if self.dest_prop.dwell_time:
|
||||
super_actions.append(Action(enum_ident=h.EnvActions.WAIT_ON_DEST))
|
||||
super_actions.append(Action(enum_ident=a.WAIT_ON_DEST))
|
||||
return super_actions
|
||||
|
||||
@property
|
||||
@ -147,27 +167,32 @@ class DestFactory(BaseFactory):
|
||||
)
|
||||
reached_destinations = ReachedDestinations(level_shape=self._level_shape)
|
||||
|
||||
super_entities.update({c.DESTINATION: destinations, c.REACHEDDESTINATION: reached_destinations})
|
||||
super_entities.update({c.DEST: destinations, c.DEST_REACHED: reached_destinations})
|
||||
return super_entities
|
||||
|
||||
def wait(self, agent: Agent):
|
||||
if destiantion := self[c.DESTINATION].by_pos(agent.pos):
|
||||
valid = destiantion.wait(agent)
|
||||
return valid
|
||||
def do_wait_action(self, agent: Agent) -> (dict, dict):
|
||||
if destination := self[c.DEST].by_pos(agent.pos):
|
||||
valid = destination.do_wait_action(agent)
|
||||
self.print(f'{agent.name} just waited at {agent.pos}')
|
||||
info_dict = {f'{agent.name}_{a.WAIT_ON_DEST}_VALID': 1}
|
||||
else:
|
||||
return c.NOT_VALID
|
||||
valid = c.NOT_VALID
|
||||
self.print(f'{agent.name} just tried to do_wait_action do_wait_action at {agent.pos} but failed')
|
||||
info_dict = {f'{agent.name}_{a.WAIT_ON_DEST}_FAIL': 1}
|
||||
reward = dict(value=r.WAIT_VALID if valid else r.WAIT_FAIL, reason=a.WAIT_ON_DEST, info=info_dict)
|
||||
return valid, reward
|
||||
|
||||
def do_additional_actions(self, agent: Agent, action: Action) -> Union[None, c]:
|
||||
def do_additional_actions(self, agent: Agent, action: Action) -> (dict, dict):
|
||||
# noinspection PyUnresolvedReferences
|
||||
valid = super().do_additional_actions(agent, action)
|
||||
if valid is None:
|
||||
if action == h.EnvActions.WAIT_ON_DEST:
|
||||
valid = self.wait(agent)
|
||||
return valid
|
||||
super_action_result = super().do_additional_actions(agent, action)
|
||||
if super_action_result is None:
|
||||
if action == a.WAIT_ON_DEST:
|
||||
action_result = self.do_wait_action(agent)
|
||||
return action_result
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
return valid
|
||||
return super_action_result
|
||||
|
||||
def do_additional_reset(self) -> None:
|
||||
# noinspection PyUnresolvedReferences
|
||||
@ -180,14 +205,14 @@ class DestFactory(BaseFactory):
|
||||
if destinations_to_spawn:
|
||||
n_dest_to_spawn = len(destinations_to_spawn)
|
||||
if self.dest_prop.spawn_mode != DestModeOptions.GROUPED:
|
||||
destinations = [Destination(tile) for tile in self[c.FLOOR].empty_tiles[:n_dest_to_spawn]]
|
||||
self[c.DESTINATION].register_additional_items(destinations)
|
||||
destinations = [Destination(tile, c.DEST) for tile in self[c.FLOOR].empty_tiles[:n_dest_to_spawn]]
|
||||
self[c.DEST].register_additional_items(destinations)
|
||||
for dest in destinations_to_spawn:
|
||||
del self._dest_spawn_timer[dest]
|
||||
self.print(f'{n_dest_to_spawn} new destinations have been spawned')
|
||||
elif self.dest_prop.spawn_mode == DestModeOptions.GROUPED and n_dest_to_spawn == self.dest_prop.n_dests:
|
||||
destinations = [Destination(tile) for tile in self[c.FLOOR].empty_tiles[:n_dest_to_spawn]]
|
||||
self[c.DESTINATION].register_additional_items(destinations)
|
||||
destinations = [Destination(tile, self[c.DEST]) for tile in self[c.FLOOR].empty_tiles[:n_dest_to_spawn]]
|
||||
self[c.DEST].register_additional_items(destinations)
|
||||
for dest in destinations_to_spawn:
|
||||
del self._dest_spawn_timer[dest]
|
||||
self.print(f'{n_dest_to_spawn} new destinations have been spawned')
|
||||
@ -197,15 +222,14 @@ class DestFactory(BaseFactory):
|
||||
else:
|
||||
self.print('No Items are spawning, limit is reached.')
|
||||
|
||||
def do_additional_step(self) -> dict:
|
||||
def do_additional_step(self) -> (List[dict], dict):
|
||||
# noinspection PyUnresolvedReferences
|
||||
info_dict = super().do_additional_step()
|
||||
super_reward_info = super().do_additional_step()
|
||||
for key, val in self._dest_spawn_timer.items():
|
||||
self._dest_spawn_timer[key] = min(self.dest_prop.spawn_frequency, self._dest_spawn_timer[key] + 1)
|
||||
for dest in list(self[c.DESTINATION].values()):
|
||||
for dest in list(self[c.DEST].values()):
|
||||
if dest.is_considered_reached:
|
||||
self[c.REACHEDDESTINATION].register_item(dest)
|
||||
self[c.DESTINATION].delete_env_object(dest)
|
||||
dest.change_register(self[c.DEST])
|
||||
self._dest_spawn_timer[dest.name] = 0
|
||||
self.print(f'{dest.name} is reached now, removing...')
|
||||
else:
|
||||
@ -218,41 +242,29 @@ class DestFactory(BaseFactory):
|
||||
dest.leave(agent)
|
||||
self.print(f'{agent.name} left the destination early.')
|
||||
self.trigger_destination_spawn()
|
||||
return info_dict
|
||||
return super_reward_info
|
||||
|
||||
def _additional_observations(self) -> Dict[Constants, np.typing.ArrayLike]:
|
||||
def _additional_observations(self) -> Dict[str, np.typing.ArrayLike]:
|
||||
additional_observations = super()._additional_observations()
|
||||
additional_observations.update({c.DESTINATION: self[c.DESTINATION].as_array()})
|
||||
additional_observations.update({c.DEST: self[c.DEST].as_array()})
|
||||
return additional_observations
|
||||
|
||||
def calculate_additional_reward(self, agent: Agent) -> (int, dict):
|
||||
def additional_per_agent_reward(self, agent: Agent) -> Dict[str, dict]:
|
||||
# noinspection PyUnresolvedReferences
|
||||
reward, info_dict = super().calculate_additional_reward(agent)
|
||||
if h.EnvActions.WAIT_ON_DEST == agent.temp_action:
|
||||
if agent.temp_valid:
|
||||
info_dict.update({f'{agent.name}_waiting_at_dest': 1})
|
||||
info_dict.update(agent_waiting_at_dest=1)
|
||||
self.print(f'{agent.name} just waited at {agent.pos}')
|
||||
reward += 0.1
|
||||
else:
|
||||
info_dict.update({f'{agent.name}_tried_failed': 1})
|
||||
info_dict.update(agent_waiting_failed=1)
|
||||
self.print(f'{agent.name} just tried to wait wait at {agent.pos} but failed')
|
||||
reward -= 0.1
|
||||
if len(self[c.REACHEDDESTINATION]):
|
||||
for reached_dest in list(self[c.REACHEDDESTINATION]):
|
||||
reward_event_dict = super().additional_per_agent_reward(agent)
|
||||
if len(self[c.DEST_REACHED]):
|
||||
for reached_dest in list(self[c.DEST_REACHED]):
|
||||
if agent.pos == reached_dest.pos:
|
||||
info_dict.update({f'{agent.name}_reached_destination': 1})
|
||||
info_dict.update(agent_reached_destination=1)
|
||||
self.print(f'{agent.name} just reached destination at {agent.pos}')
|
||||
reward += 0.5
|
||||
self[c.REACHEDDESTINATION].delete_env_object(reached_dest)
|
||||
return reward, info_dict
|
||||
self[c.DEST_REACHED].delete_env_object(reached_dest)
|
||||
info_dict = {f'{agent.name}_{c.DEST_REACHED}': 1}
|
||||
reward_event_dict.update({c.DEST_REACHED: {'reward': r.DEST_REACHED, 'info': info_dict}})
|
||||
return reward_event_dict
|
||||
|
||||
def render_additional_assets(self, mode='human'):
|
||||
# noinspection PyUnresolvedReferences
|
||||
additional_assets = super().render_additional_assets()
|
||||
destinations = [RenderEntity(c.DESTINATION.value, dest.pos) for dest in self[c.DESTINATION]]
|
||||
destinations = [RenderEntity(c.DEST, dest.pos) for dest in self[c.DEST]]
|
||||
additional_assets.extend(destinations)
|
||||
return additional_assets
|
||||
|
||||
|
Reference in New Issue
Block a user