Rewards can now be set as parameter

This commit is contained in:
Steffen Illium
2022-01-17 11:21:07 +01:00
parent 823aa075b9
commit 3ce6302e8a
6 changed files with 79 additions and 61 deletions

View File

@ -8,7 +8,6 @@ import random
from environments.factory.base.base_factory import BaseFactory
from environments.helpers import Constants as BaseConstants
from environments.helpers import EnvActions as BaseActions
from environments.helpers import Rewards as BaseRewards
from environments.factory.base.objects import Agent, Entity, Action
from environments.factory.base.registers import Entities, EntityRegister
@ -27,11 +26,11 @@ class Actions(BaseActions):
WAIT_ON_DEST = 'WAIT'
class Rewards(BaseRewards):
class RewardsDest(NamedTuple):
WAIT_VALID = 0.1
WAIT_FAIL = -0.1
DEST_REACHED = 5.0
WAIT_VALID: float = 0.1
WAIT_FAIL: float = -0.1
DEST_REACHED: float = 5.0
class Destination(Entity):
@ -117,7 +116,7 @@ class DestModeOptions(object):
class DestProperties(NamedTuple):
n_dests: int = 1 # How many destinations are there
dwell_time: int = 0 # How long does the agent need to "do_wait_action" on a destination
dwell_time: int = 0 # How long does the agent need to "wait" on a destination
spawn_frequency: int = 0
spawn_in_other_zone: bool = True #
spawn_mode: str = DestModeOptions.DONE
@ -130,18 +129,20 @@ class DestProperties(NamedTuple):
c = Constants
a = Actions
r = Rewards
# noinspection PyAttributeOutsideInit, PyAbstractClass
class DestFactory(BaseFactory):
# noinspection PyMissingConstructor
def __init__(self, *args, dest_prop: DestProperties = DestProperties(),
def __init__(self, *args, dest_prop: DestProperties = DestProperties(), rewards_dest: RewardsDest = RewardsDest(),
env_seed=time.time_ns(), **kwargs):
if isinstance(dest_prop, dict):
dest_prop = DestProperties(**dest_prop)
if isinstance(rewards_dest, dict):
rewards_dest = RewardsDest(**rewards_dest)
self.dest_prop = dest_prop
self.rewards_dest = rewards_dest
kwargs.update(env_seed=env_seed)
self._dest_rng = np.random.default_rng(env_seed)
super().__init__(*args, **kwargs)
@ -179,7 +180,8 @@ class DestFactory(BaseFactory):
valid = c.NOT_VALID
self.print(f'{agent.name} just tried to do_wait_action do_wait_action at {agent.pos} but failed')
info_dict = {f'{agent.name}_{a.WAIT_ON_DEST}_FAIL': 1}
reward = dict(value=r.WAIT_VALID if valid else r.WAIT_FAIL, reason=a.WAIT_ON_DEST, info=info_dict)
reward = dict(value=self.rewards_dest.WAIT_VALID if valid else self.rewards_dest.WAIT_FAIL,
reason=a.WAIT_ON_DEST, info=info_dict)
return valid, reward
def do_additional_actions(self, agent: Agent, action: Action) -> (dict, dict):
@ -258,7 +260,8 @@ class DestFactory(BaseFactory):
self.print(f'{agent.name} just reached destination at {agent.pos}')
self[c.DEST_REACHED].delete_env_object(reached_dest)
info_dict = {f'{agent.name}_{c.DEST_REACHED}': 1}
reward_event_dict.update({c.DEST_REACHED: {'reward': r.DEST_REACHED, 'info': info_dict}})
reward_event_dict.update({c.DEST_REACHED: {'reward': self.rewards_dest.DEST_REACHED,
'info': info_dict}})
return reward_event_dict
def render_assets_hook(self, mode='human'):
@ -270,13 +273,13 @@ class DestFactory(BaseFactory):
if __name__ == '__main__':
from environments.utility_classes import AgentRenderOptions as ARO, ObservationProperties
from environments.utility_classes import AgentRenderOptions as aro, ObservationProperties
render = True
dest_probs = DestProperties(n_dests=2, spawn_frequency=5, spawn_mode=DestModeOptions.GROUPED)
obs_props = ObservationProperties(render_agents=ARO.LEVEL, omit_agent_self=True, pomdp_r=2)
obs_props = ObservationProperties(render_agents=aro.LEVEL, omit_agent_self=True, pomdp_r=2)
move_props = {'allow_square_movement': True,
'allow_diagonal_movement': False,