2022-04-11 16:15:44 +02:00

314 lines
12 KiB
Python

import time
from collections import defaultdict
from enum import Enum
from typing import List, Union, NamedTuple, Dict
import numpy as np
import random
from environments.factory.base.base_factory import BaseFactory
from environments.helpers import Constants as BaseConstants
from environments.helpers import EnvActions as BaseActions
from environments.factory.base.objects import Agent, Entity, Action
from environments.factory.base.registers import Entities, EntityCollection
from environments.factory.base.renderer import RenderEntity
class Constants(BaseConstants):
# Destination Env
DEST = 'Destination'
DESTINATION = 1
DESTINATION_DONE = 0.5
DEST_REACHED = 'ReachedDestination'
class Actions(BaseActions):
WAIT_ON_DEST = 'WAIT'
class RewardsDest(NamedTuple):
WAIT_VALID: float = 0.1
WAIT_FAIL: float = -0.1
DEST_REACHED: float = 5.0
class Destination(Entity):
@property
def any_agent_has_dwelled(self):
return bool(len(self._per_agent_times))
@property
def currently_dwelling_names(self):
return self._per_agent_times.keys()
@property
def encoding(self):
return c.DESTINATION
def __init__(self, *args, dwell_time: int = 0, **kwargs):
super(Destination, self).__init__(*args, **kwargs)
self.dwell_time = dwell_time
self._per_agent_times = defaultdict(lambda: dwell_time)
def do_wait_action(self, agent: Agent):
self._per_agent_times[agent.name] -= 1
return c.VALID
def leave(self, agent: Agent):
del self._per_agent_times[agent.name]
@property
def is_considered_reached(self):
agent_at_position = any(c.AGENT.lower() in x.name.lower() for x in self.tile.guests_that_can_collide)
return (agent_at_position and not self.dwell_time) or any(x == 0 for x in self._per_agent_times.values())
def agent_is_dwelling(self, agent: Agent):
return self._per_agent_times[agent.name] < self.dwell_time
def summarize_state(self, n_steps=None) -> dict:
state_summary = super().summarize_state(n_steps=n_steps)
state_summary.update(per_agent_times=self._per_agent_times)
return state_summary
class Destinations(EntityCollection):
_accepted_objects = Destination
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.is_blocking_light = False
self.can_be_shadowed = False
def as_array(self):
self._array[:] = c.FREE_CELL
# ToDo: Switch to new Style Array Put
# indices = list(zip(range(len(cls)), *zip(*[x.pos for x in cls])))
# np.put(cls._array, [np.ravel_multi_index(x, cls._array.shape) for x in indices], cls.encodings)
for item in self:
if item.pos != c.NO_POS:
self._array[0, item.x, item.y] = item.encoding
return self._array
def __repr__(self):
super(Destinations, self).__repr__()
class ReachedDestinations(Destinations):
_accepted_objects = Destination
def __init__(self, *args, **kwargs):
super(ReachedDestinations, self).__init__(*args, **kwargs)
self.can_be_shadowed = False
self.is_blocking_light = False
def summarize_states(self, n_steps=None):
return {}
class DestModeOptions(object):
DONE = 'DONE'
GROUPED = 'GROUPED'
PER_DEST = 'PER_DEST'
class DestProperties(NamedTuple):
n_dests: int = 1 # How many destinations are there
dwell_time: int = 0 # How long does the agent need to "wait" on a destination
spawn_frequency: int = 0
spawn_in_other_zone: bool = True #
spawn_mode: str = DestModeOptions.DONE
assert dwell_time >= 0, 'dwell_time cannot be < 0!'
assert spawn_frequency >= 0, 'spawn_frequency cannot be < 0!'
assert n_dests >= 0, 'n_destinations cannot be < 0!'
assert (spawn_mode == DestModeOptions.DONE) != bool(spawn_frequency)
c = Constants
a = Actions
# noinspection PyAttributeOutsideInit, PyAbstractClass
class DestFactory(BaseFactory):
# noinspection PyMissingConstructor
def __init__(self, *args, dest_prop: DestProperties = DestProperties(), rewards_dest: RewardsDest = RewardsDest(),
env_seed=time.time_ns(), **kwargs):
if isinstance(dest_prop, dict):
dest_prop = DestProperties(**dest_prop)
if isinstance(rewards_dest, dict):
rewards_dest = RewardsDest(**rewards_dest)
self.dest_prop = dest_prop
self.rewards_dest = rewards_dest
kwargs.update(env_seed=env_seed)
self._dest_rng = np.random.default_rng(env_seed)
super().__init__(*args, **kwargs)
@property
def actions_hook(self) -> Union[Action, List[Action]]:
# noinspection PyUnresolvedReferences
super_actions = super().actions_hook
if self.dest_prop.dwell_time:
super_actions.append(Action(enum_ident=a.WAIT_ON_DEST))
return super_actions
@property
def entities_hook(self) -> Dict[(Enum, Entities)]:
# noinspection PyUnresolvedReferences
super_entities = super().entities_hook
empty_tiles = self[c.FLOOR].empty_tiles[:self.dest_prop.n_dests]
destinations = Destinations.from_tiles(
empty_tiles, self._level_shape,
entity_kwargs=dict(
dwell_time=self.dest_prop.dwell_time)
)
reached_destinations = ReachedDestinations(level_shape=self._level_shape)
super_entities.update({c.DEST: destinations, c.DEST_REACHED: reached_destinations})
return super_entities
def do_wait_action(self, agent: Agent) -> (dict, dict):
if destination := self[c.DEST].by_pos(agent.pos):
valid = destination.do_wait_action(agent)
self.print(f'{agent.name} just waited at {agent.pos}')
info_dict = {f'{agent.name}_{a.WAIT_ON_DEST}_VALID': 1}
else:
valid = c.NOT_VALID
self.print(f'{agent.name} just tried to do_wait_action do_wait_action at {agent.pos} but failed')
info_dict = {f'{agent.name}_{a.WAIT_ON_DEST}_FAIL': 1}
reward = dict(value=self.rewards_dest.WAIT_VALID if valid else self.rewards_dest.WAIT_FAIL,
reason=a.WAIT_ON_DEST, info=info_dict)
return valid, reward
def do_additional_actions(self, agent: Agent, action: Action) -> (dict, dict):
# noinspection PyUnresolvedReferences
super_action_result = super().do_additional_actions(agent, action)
if super_action_result is None:
if action == a.WAIT_ON_DEST:
action_result = self.do_wait_action(agent)
return action_result
else:
return None
else:
return super_action_result
def reset_hook(self) -> None:
# noinspection PyUnresolvedReferences
super().reset_hook()
self._dest_spawn_timer = dict()
def trigger_destination_spawn(self):
destinations_to_spawn = [key for key, val in self._dest_spawn_timer.items()
if val == self.dest_prop.spawn_frequency]
if destinations_to_spawn:
n_dest_to_spawn = len(destinations_to_spawn)
if self.dest_prop.spawn_mode != DestModeOptions.GROUPED:
destinations = [Destination(tile, c.DEST) for tile in self[c.FLOOR].empty_tiles[:n_dest_to_spawn]]
self[c.DEST].add_additional_items(destinations)
for dest in destinations_to_spawn:
del self._dest_spawn_timer[dest]
self.print(f'{n_dest_to_spawn} new destinations have been spawned')
elif self.dest_prop.spawn_mode == DestModeOptions.GROUPED and n_dest_to_spawn == self.dest_prop.n_dests:
destinations = [Destination(tile, self[c.DEST]) for tile in self[c.FLOOR].empty_tiles[:n_dest_to_spawn]]
self[c.DEST].add_additional_items(destinations)
for dest in destinations_to_spawn:
del self._dest_spawn_timer[dest]
self.print(f'{n_dest_to_spawn} new destinations have been spawned')
else:
self.print(f'{n_dest_to_spawn} new destinations could be spawned, but waiting for all.')
pass
else:
self.print('No Items are spawning, limit is reached.')
def step_hook(self) -> (List[dict], dict):
# noinspection PyUnresolvedReferences
super_reward_info = super().step_hook()
for key, val in self._dest_spawn_timer.items():
self._dest_spawn_timer[key] = min(self.dest_prop.spawn_frequency, self._dest_spawn_timer[key] + 1)
for dest in list(self[c.DEST].values()):
if dest.is_considered_reached:
dest.change_parent_collection(self[c.DEST])
self._dest_spawn_timer[dest.name] = 0
self.print(f'{dest.name} is reached now, removing...')
else:
for agent_name in dest.currently_dwelling_names:
agent = self[c.AGENT].by_name(agent_name)
if agent.pos == dest.pos:
self.print(f'{agent.name} is still waiting.')
pass
else:
dest.leave(agent)
self.print(f'{agent.name} left the destination early.')
self.trigger_destination_spawn()
return super_reward_info
def observations_hook(self) -> Dict[str, np.typing.ArrayLike]:
additional_observations = super().observations_hook()
additional_observations.update({c.DEST: self[c.DEST].as_array()})
return additional_observations
def per_agent_reward_hook(self, agent: Agent) -> Dict[str, dict]:
# noinspection PyUnresolvedReferences
reward_event_dict = super().per_agent_reward_hook(agent)
if len(self[c.DEST_REACHED]):
for reached_dest in list(self[c.DEST_REACHED]):
if agent.pos == reached_dest.pos:
self.print(f'{agent.name} just reached destination at {agent.pos}')
self[c.DEST_REACHED].delete_env_object(reached_dest)
info_dict = {f'{agent.name}_{c.DEST_REACHED}': 1}
reward_event_dict.update({c.DEST_REACHED: {'reward': self.rewards_dest.DEST_REACHED,
'info': info_dict}})
return reward_event_dict
def render_assets_hook(self, mode='human'):
# noinspection PyUnresolvedReferences
additional_assets = super().render_assets_hook()
destinations = [RenderEntity(c.DEST, dest.pos) for dest in self[c.DEST]]
additional_assets.extend(destinations)
return additional_assets
if __name__ == '__main__':
from environments.utility_classes import AgentRenderOptions as aro, ObservationProperties
render = True
dest_probs = DestProperties(n_dests=2, spawn_frequency=5, spawn_mode=DestModeOptions.GROUPED)
obs_props = ObservationProperties(render_agents=aro.LEVEL, omit_agent_self=True, pomdp_r=2)
move_props = {'allow_square_movement': True,
'allow_diagonal_movement': False,
'allow_no_op': False}
factory = DestFactory(n_agents=10, done_at_collision=False,
level_name='rooms', max_steps=400,
obs_prop=obs_props, parse_doors=True,
verbose=True,
mv_prop=move_props, dest_prop=dest_probs
)
# noinspection DuplicatedCode
n_actions = factory.action_space.n - 1
_ = factory.observation_space
for epoch in range(4):
random_actions = [[random.randint(0, n_actions) for _
in range(factory.n_agents)] for _
in range(factory.max_steps + 1)]
env_state = factory.reset()
r = 0
for agent_i_action in random_actions:
env_state, step_r, done_bool, info_obj = factory.step(agent_i_action)
r += step_r
if render:
factory.render()
if done_bool:
break
print(f'Factory run {epoch} done, reward is:\n {r}')
pass