Rewards can now be set as parameter
This commit is contained in:
@ -8,7 +8,6 @@ from environments.factory.base.registers import EntityRegister, EnvObjectRegiste
|
||||
from environments.factory.base.renderer import RenderEntity
|
||||
from environments.helpers import Constants as BaseConstants
|
||||
from environments.helpers import EnvActions as BaseActions
|
||||
from environments.helpers import Rewards as BaseRewards
|
||||
|
||||
from environments import helpers as h
|
||||
|
||||
@ -25,10 +24,10 @@ class Actions(BaseActions):
|
||||
CHARGE = 'do_charge_action'
|
||||
|
||||
|
||||
class Rewards(BaseRewards):
|
||||
CHARGE_VALID = 0.1
|
||||
CHARGE_FAIL = -0.1
|
||||
BATTERY_DISCHARGED = -1.0
|
||||
class RewardsBtry(NamedTuple):
|
||||
CHARGE_VALID: float = 0.1
|
||||
CHARGE_FAIL: float = -0.1
|
||||
BATTERY_DISCHARGED: float = -1.0
|
||||
|
||||
|
||||
class BatteryProperties(NamedTuple):
|
||||
@ -42,7 +41,6 @@ class BatteryProperties(NamedTuple):
|
||||
|
||||
c = Constants
|
||||
a = Actions
|
||||
r = Rewards
|
||||
|
||||
|
||||
class Battery(BoundingMixin, EnvObject):
|
||||
@ -62,9 +60,9 @@ class Battery(BoundingMixin, EnvObject):
|
||||
if self.charge_level < 1:
|
||||
# noinspection PyTypeChecker
|
||||
self.charge_level = min(1, amount + self.charge_level)
|
||||
return dict(valid=c.VALID, action=a.CHARGE, reward=r.CHARGE_VALID)
|
||||
return c.VALID
|
||||
else:
|
||||
return dict(valid=c.NOT_VALID, action=a.CHARGE, reward=r.CHARGE_FAIL)
|
||||
return c.NOT_VALID
|
||||
|
||||
def decharge(self, amount) -> c:
|
||||
if self.charge_level != 0:
|
||||
@ -133,8 +131,8 @@ class ChargePod(Entity):
|
||||
return c.NOT_VALID
|
||||
if sum(guest for guest in self.tile.guests if 'agent' in guest.name.lower()) > 1:
|
||||
return c.NOT_VALID
|
||||
battery.do_charge_action(self.charge_rate)
|
||||
return c.VALID
|
||||
valid = battery.do_charge_action(self.charge_rate)
|
||||
return valid
|
||||
|
||||
def summarize_state(self, n_steps=None) -> dict:
|
||||
if n_steps == h.STEPS_START:
|
||||
@ -152,10 +150,14 @@ class ChargePods(EntityRegister):
|
||||
|
||||
class BatteryFactory(BaseFactory):
|
||||
|
||||
def __init__(self, *args, btry_prop=BatteryProperties(), **kwargs):
|
||||
def __init__(self, *args, btry_prop=BatteryProperties(), rewards_dest: RewardsBtry = RewardsBtry(),
|
||||
**kwargs):
|
||||
if isinstance(btry_prop, dict):
|
||||
btry_prop = BatteryProperties(**btry_prop)
|
||||
if isinstance(rewards_dest, dict):
|
||||
rewards_dest = RewardsBtry(**rewards_dest)
|
||||
self.btry_prop = btry_prop
|
||||
self.rewards_dest = rewards_dest
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def per_agent_raw_observations_hook(self, agent) -> Dict[str, np.typing.ArrayLike]:
|
||||
@ -215,7 +217,8 @@ class BatteryFactory(BaseFactory):
|
||||
info_dict = {f'{agent.name}_{a.CHARGE}_FAIL': 1}
|
||||
# info_dict = {f'{agent.name}_no_charger': 1}
|
||||
self.print(f'{agent.name} failed to charged batteries at {agent.pos}.')
|
||||
reward = dict(value=r.CHARGE_VALID if valid else r.CHARGE_FAIL, reason=a.CHARGE, info=info_dict)
|
||||
reward = dict(value=self.rewards_dest.CHARGE_VALID if valid else self.rewards_dest.CHARGE_FAIL,
|
||||
reason=a.CHARGE, info=info_dict)
|
||||
return valid, reward
|
||||
|
||||
def do_additional_actions(self, agent: Agent, action: Action) -> (bool, dict):
|
||||
@ -254,7 +257,9 @@ class BatteryFactory(BaseFactory):
|
||||
if self[c.BATTERIES].by_entity(agent).is_discharged:
|
||||
self.print(f'{agent.name} Battery is discharged!')
|
||||
info_dict = {f'{agent.name}_{c.BATTERY_DISCHARGED}': 1}
|
||||
reward_event_dict.update({c.BATTERY_DISCHARGED: {'reward': r.BATTERY_DISCHARGED, 'info': info_dict}})
|
||||
reward_event_dict.update({c.BATTERY_DISCHARGED: {'reward': self.rewards_dest.BATTERY_DISCHARGED,
|
||||
'info': info_dict}}
|
||||
)
|
||||
else:
|
||||
# All Fine
|
||||
pass
|
||||
|
Reference in New Issue
Block a user