Rewards can now be set as parameter

This commit is contained in:
Steffen Illium
2022-01-17 11:21:07 +01:00
parent 823aa075b9
commit 3ce6302e8a
6 changed files with 79 additions and 61 deletions

View File

@ -8,7 +8,6 @@ from environments.factory.base.registers import EntityRegister, EnvObjectRegiste
from environments.factory.base.renderer import RenderEntity
from environments.helpers import Constants as BaseConstants
from environments.helpers import EnvActions as BaseActions
from environments.helpers import Rewards as BaseRewards
from environments import helpers as h
@ -25,10 +24,10 @@ class Actions(BaseActions):
CHARGE = 'do_charge_action'
class Rewards(BaseRewards):
CHARGE_VALID = 0.1
CHARGE_FAIL = -0.1
BATTERY_DISCHARGED = -1.0
class RewardsBtry(NamedTuple):
CHARGE_VALID: float = 0.1
CHARGE_FAIL: float = -0.1
BATTERY_DISCHARGED: float = -1.0
class BatteryProperties(NamedTuple):
@ -42,7 +41,6 @@ class BatteryProperties(NamedTuple):
c = Constants
a = Actions
r = Rewards
class Battery(BoundingMixin, EnvObject):
@ -62,9 +60,9 @@ class Battery(BoundingMixin, EnvObject):
if self.charge_level < 1:
# noinspection PyTypeChecker
self.charge_level = min(1, amount + self.charge_level)
return dict(valid=c.VALID, action=a.CHARGE, reward=r.CHARGE_VALID)
return c.VALID
else:
return dict(valid=c.NOT_VALID, action=a.CHARGE, reward=r.CHARGE_FAIL)
return c.NOT_VALID
def decharge(self, amount) -> c:
if self.charge_level != 0:
@ -133,8 +131,8 @@ class ChargePod(Entity):
return c.NOT_VALID
if sum(guest for guest in self.tile.guests if 'agent' in guest.name.lower()) > 1:
return c.NOT_VALID
battery.do_charge_action(self.charge_rate)
return c.VALID
valid = battery.do_charge_action(self.charge_rate)
return valid
def summarize_state(self, n_steps=None) -> dict:
if n_steps == h.STEPS_START:
@ -152,10 +150,14 @@ class ChargePods(EntityRegister):
class BatteryFactory(BaseFactory):
def __init__(self, *args, btry_prop=BatteryProperties(), **kwargs):
def __init__(self, *args, btry_prop=BatteryProperties(), rewards_dest: RewardsBtry = RewardsBtry(),
**kwargs):
if isinstance(btry_prop, dict):
btry_prop = BatteryProperties(**btry_prop)
if isinstance(rewards_dest, dict):
rewards_dest = RewardsBtry(**rewards_dest)
self.btry_prop = btry_prop
self.rewards_dest = rewards_dest
super().__init__(*args, **kwargs)
def per_agent_raw_observations_hook(self, agent) -> Dict[str, np.typing.ArrayLike]:
@ -215,7 +217,8 @@ class BatteryFactory(BaseFactory):
info_dict = {f'{agent.name}_{a.CHARGE}_FAIL': 1}
# info_dict = {f'{agent.name}_no_charger': 1}
self.print(f'{agent.name} failed to charged batteries at {agent.pos}.')
reward = dict(value=r.CHARGE_VALID if valid else r.CHARGE_FAIL, reason=a.CHARGE, info=info_dict)
reward = dict(value=self.rewards_dest.CHARGE_VALID if valid else self.rewards_dest.CHARGE_FAIL,
reason=a.CHARGE, info=info_dict)
return valid, reward
def do_additional_actions(self, agent: Agent, action: Action) -> (bool, dict):
@ -254,7 +257,9 @@ class BatteryFactory(BaseFactory):
if self[c.BATTERIES].by_entity(agent).is_discharged:
self.print(f'{agent.name} Battery is discharged!')
info_dict = {f'{agent.name}_{c.BATTERY_DISCHARGED}': 1}
reward_event_dict.update({c.BATTERY_DISCHARGED: {'reward': r.BATTERY_DISCHARGED, 'info': info_dict}})
reward_event_dict.update({c.BATTERY_DISCHARGED: {'reward': self.rewards_dest.BATTERY_DISCHARGED,
'info': info_dict}}
)
else:
# All Fine
pass