Rewards can now be set as parameter
This commit is contained in:
@ -7,7 +7,6 @@ import random
|
||||
from environments.factory.base.base_factory import BaseFactory
|
||||
from environments.helpers import Constants as BaseConstants
|
||||
from environments.helpers import EnvActions as BaseActions
|
||||
from environments.helpers import Rewards as BaseRewards
|
||||
from environments import helpers as h
|
||||
from environments.factory.base.objects import Agent, Entity, Action, Floor
|
||||
from environments.factory.base.registers import Entities, EntityRegister, BoundEnvObjRegister, ObjectRegister
|
||||
@ -28,11 +27,11 @@ class Actions(BaseActions):
|
||||
ITEM_ACTION = 'ITEMACTION'
|
||||
|
||||
|
||||
class Rewards(BaseRewards):
|
||||
DROP_OFF_VALID = 0.1
|
||||
DROP_OFF_FAIL = -0.1
|
||||
PICK_UP_FAIL = -0.1
|
||||
PICK_UP_VALID = 0.1
|
||||
class RewardsItem(NamedTuple):
|
||||
DROP_OFF_VALID: float = 0.1
|
||||
DROP_OFF_FAIL: float = -0.1
|
||||
PICK_UP_FAIL: float = -0.1
|
||||
PICK_UP_VALID: float = 0.1
|
||||
|
||||
|
||||
class Item(Entity):
|
||||
@ -177,16 +176,19 @@ class ItemProperties(NamedTuple):
|
||||
|
||||
c = Constants
|
||||
a = Actions
|
||||
r = Rewards
|
||||
|
||||
|
||||
# noinspection PyAttributeOutsideInit, PyAbstractClass
|
||||
class ItemFactory(BaseFactory):
|
||||
# noinspection PyMissingConstructor
|
||||
def __init__(self, *args, item_prop: ItemProperties = ItemProperties(), env_seed=time.time_ns(), **kwargs):
|
||||
def __init__(self, *args, item_prop: ItemProperties = ItemProperties(), env_seed=time.time_ns(),
|
||||
rewards_item: RewardsItem = RewardsItem(), **kwargs):
|
||||
if isinstance(item_prop, dict):
|
||||
item_prop = ItemProperties(**item_prop)
|
||||
if isinstance(rewards_item, dict):
|
||||
rewards_item = RewardsItem(**rewards_item)
|
||||
self.item_prop = item_prop
|
||||
self.rewards_item = rewards_item
|
||||
kwargs.update(env_seed=env_seed)
|
||||
self._item_rng = np.random.default_rng(env_seed)
|
||||
assert (item_prop.n_items <= ((1 + kwargs.get('_pomdp_r', 0) * 2) ** 2)) or not kwargs.get('_pomdp_r', 0)
|
||||
@ -244,18 +246,19 @@ class ItemFactory(BaseFactory):
|
||||
else:
|
||||
self.print(f'{agent.name} just tried to drop off at {agent.pos}, but failed.')
|
||||
info_dict = {f'{agent.name}_DROPOFF_FAIL': 1, 'DROPOFF_FAIL': 1}
|
||||
reward = dict(value=r.DROP_OFF_VALID if valid else r.DROP_OFF_FAIL, reason=a.ITEM_ACTION, info=info_dict)
|
||||
reward = dict(value=self.rewards_item.DROP_OFF_VALID if valid else self.rewards_item.DROP_OFF_FAIL,
|
||||
reason=a.ITEM_ACTION, info=info_dict)
|
||||
return valid, reward
|
||||
elif item := self[c.ITEM].by_pos(agent.pos):
|
||||
item.change_register(inventory)
|
||||
item.set_tile_to(self._NO_POS_TILE)
|
||||
self.print(f'{agent.name} just picked up an item at {agent.pos}')
|
||||
info_dict = {f'{agent.name}_{a.ITEM_ACTION}_VALID': 1, f'{a.ITEM_ACTION}_VALID': 1}
|
||||
return c.VALID, dict(value=r.PICK_UP_VALID, reason=a.ITEM_ACTION, info=info_dict)
|
||||
return c.VALID, dict(value=self.rewards_item.PICK_UP_VALID, reason=a.ITEM_ACTION, info=info_dict)
|
||||
else:
|
||||
self.print(f'{agent.name} just tried to pick up an item at {agent.pos}, but failed.')
|
||||
info_dict = {f'{agent.name}_{a.ITEM_ACTION}_FAIL': 1, f'{a.ITEM_ACTION}_FAIL': 1}
|
||||
return c.NOT_VALID, dict(value=r.PICK_UP_FAIL, reason=a.ITEM_ACTION, info=info_dict)
|
||||
return c.NOT_VALID, dict(value=self.rewards_item.PICK_UP_FAIL, reason=a.ITEM_ACTION, info=info_dict)
|
||||
|
||||
def do_additional_actions(self, agent: Agent, action: Action) -> (dict, dict):
|
||||
# noinspection PyUnresolvedReferences
|
||||
|
Reference in New Issue
Block a user