Rewards can now be set as parameter
This commit is contained in:
@ -8,7 +8,6 @@ import random
|
||||
from environments.factory.base.base_factory import BaseFactory
|
||||
from environments.helpers import Constants as BaseConstants
|
||||
from environments.helpers import EnvActions as BaseActions
|
||||
from environments.helpers import Rewards as BaseRewards
|
||||
from environments.factory.base.objects import Agent, Entity, Action
|
||||
from environments.factory.base.registers import Entities, EntityRegister
|
||||
|
||||
@ -27,11 +26,11 @@ class Actions(BaseActions):
|
||||
WAIT_ON_DEST = 'WAIT'
|
||||
|
||||
|
||||
class Rewards(BaseRewards):
|
||||
class RewardsDest(NamedTuple):
|
||||
|
||||
WAIT_VALID = 0.1
|
||||
WAIT_FAIL = -0.1
|
||||
DEST_REACHED = 5.0
|
||||
WAIT_VALID: float = 0.1
|
||||
WAIT_FAIL: float = -0.1
|
||||
DEST_REACHED: float = 5.0
|
||||
|
||||
|
||||
class Destination(Entity):
|
||||
@ -117,7 +116,7 @@ class DestModeOptions(object):
|
||||
|
||||
class DestProperties(NamedTuple):
|
||||
n_dests: int = 1 # How many destinations are there
|
||||
dwell_time: int = 0 # How long does the agent need to "do_wait_action" on a destination
|
||||
dwell_time: int = 0 # How long does the agent need to "wait" on a destination
|
||||
spawn_frequency: int = 0
|
||||
spawn_in_other_zone: bool = True #
|
||||
spawn_mode: str = DestModeOptions.DONE
|
||||
@ -130,18 +129,20 @@ class DestProperties(NamedTuple):
|
||||
|
||||
c = Constants
|
||||
a = Actions
|
||||
r = Rewards
|
||||
|
||||
|
||||
# noinspection PyAttributeOutsideInit, PyAbstractClass
|
||||
class DestFactory(BaseFactory):
|
||||
# noinspection PyMissingConstructor
|
||||
|
||||
def __init__(self, *args, dest_prop: DestProperties = DestProperties(),
|
||||
def __init__(self, *args, dest_prop: DestProperties = DestProperties(), rewards_dest: RewardsDest = RewardsDest(),
|
||||
env_seed=time.time_ns(), **kwargs):
|
||||
if isinstance(dest_prop, dict):
|
||||
dest_prop = DestProperties(**dest_prop)
|
||||
if isinstance(rewards_dest, dict):
|
||||
rewards_dest = RewardsDest(**rewards_dest)
|
||||
self.dest_prop = dest_prop
|
||||
self.rewards_dest = rewards_dest
|
||||
kwargs.update(env_seed=env_seed)
|
||||
self._dest_rng = np.random.default_rng(env_seed)
|
||||
super().__init__(*args, **kwargs)
|
||||
@ -179,7 +180,8 @@ class DestFactory(BaseFactory):
|
||||
valid = c.NOT_VALID
|
||||
self.print(f'{agent.name} just tried to do_wait_action do_wait_action at {agent.pos} but failed')
|
||||
info_dict = {f'{agent.name}_{a.WAIT_ON_DEST}_FAIL': 1}
|
||||
reward = dict(value=r.WAIT_VALID if valid else r.WAIT_FAIL, reason=a.WAIT_ON_DEST, info=info_dict)
|
||||
reward = dict(value=self.rewards_dest.WAIT_VALID if valid else self.rewards_dest.WAIT_FAIL,
|
||||
reason=a.WAIT_ON_DEST, info=info_dict)
|
||||
return valid, reward
|
||||
|
||||
def do_additional_actions(self, agent: Agent, action: Action) -> (dict, dict):
|
||||
@ -258,7 +260,8 @@ class DestFactory(BaseFactory):
|
||||
self.print(f'{agent.name} just reached destination at {agent.pos}')
|
||||
self[c.DEST_REACHED].delete_env_object(reached_dest)
|
||||
info_dict = {f'{agent.name}_{c.DEST_REACHED}': 1}
|
||||
reward_event_dict.update({c.DEST_REACHED: {'reward': r.DEST_REACHED, 'info': info_dict}})
|
||||
reward_event_dict.update({c.DEST_REACHED: {'reward': self.rewards_dest.DEST_REACHED,
|
||||
'info': info_dict}})
|
||||
return reward_event_dict
|
||||
|
||||
def render_assets_hook(self, mode='human'):
|
||||
@ -270,13 +273,13 @@ class DestFactory(BaseFactory):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from environments.utility_classes import AgentRenderOptions as ARO, ObservationProperties
|
||||
from environments.utility_classes import AgentRenderOptions as aro, ObservationProperties
|
||||
|
||||
render = True
|
||||
|
||||
dest_probs = DestProperties(n_dests=2, spawn_frequency=5, spawn_mode=DestModeOptions.GROUPED)
|
||||
|
||||
obs_props = ObservationProperties(render_agents=ARO.LEVEL, omit_agent_self=True, pomdp_r=2)
|
||||
obs_props = ObservationProperties(render_agents=aro.LEVEL, omit_agent_self=True, pomdp_r=2)
|
||||
|
||||
move_props = {'allow_square_movement': True,
|
||||
'allow_diagonal_movement': False,
|
||||
|
Reference in New Issue
Block a user