Rewards can now be set as parameter
This commit is contained in:
@ -4,11 +4,9 @@ import random
|
||||
|
||||
import numpy as np
|
||||
|
||||
# from algorithms.TSP_dirt_agent import TSPDirtAgent
|
||||
from algorithms.TSP_dirt_agent import TSPDirtAgent
|
||||
from environments.helpers import Constants as BaseConstants
|
||||
from environments.helpers import EnvActions as BaseActions
|
||||
from environments.helpers import Rewards as BaseRewards
|
||||
|
||||
from environments.factory.base.base_factory import BaseFactory
|
||||
from environments.factory.base.objects import Agent, Action, Entity, Floor
|
||||
@ -26,10 +24,10 @@ class Actions(BaseActions):
|
||||
CLEAN_UP = 'do_cleanup_action'
|
||||
|
||||
|
||||
class Rewards(BaseRewards):
|
||||
CLEAN_UP_VALID = 0.5
|
||||
CLEAN_UP_FAIL = -0.1
|
||||
CLEAN_UP_LAST_PIECE = 4.5
|
||||
class RewardsDirt(NamedTuple):
|
||||
CLEAN_UP_VALID: float = 0.5
|
||||
CLEAN_UP_FAIL: float = -0.1
|
||||
CLEAN_UP_LAST_PIECE: float = 4.5
|
||||
|
||||
|
||||
class DirtProperties(NamedTuple):
|
||||
@ -119,7 +117,6 @@ def entropy(x):
|
||||
|
||||
c = Constants
|
||||
a = Actions
|
||||
r = Rewards
|
||||
|
||||
|
||||
# noinspection PyAttributeOutsideInit, PyAbstractClass
|
||||
@ -138,10 +135,15 @@ class DirtFactory(BaseFactory):
|
||||
super_entities.update(({c.DIRT: dirt_register}))
|
||||
return super_entities
|
||||
|
||||
def __init__(self, *args, dirt_prop: DirtProperties = DirtProperties(), env_seed=time.time_ns(), **kwargs):
|
||||
def __init__(self, *args,
|
||||
dirt_prop: DirtProperties = DirtProperties(), rewards_dirt: RewardsDirt = RewardsDirt(),
|
||||
env_seed=time.time_ns(), **kwargs):
|
||||
if isinstance(dirt_prop, dict):
|
||||
dirt_prop = DirtProperties(**dirt_prop)
|
||||
if isinstance(rewards_dirt, dict):
|
||||
rewards_dirt = RewardsDirt(**rewards_dirt)
|
||||
self.dirt_prop = dirt_prop
|
||||
self.rewards_dirt = rewards_dirt
|
||||
self._dirt_rng = np.random.default_rng(env_seed)
|
||||
self._dirt: DirtRegister
|
||||
kwargs.update(env_seed=env_seed)
|
||||
@ -166,15 +168,15 @@ class DirtFactory(BaseFactory):
|
||||
valid = c.VALID
|
||||
self.print(f'{agent.name} did just clean up some dirt at {agent.pos}.')
|
||||
info_dict = {f'{agent.name}_{a.CLEAN_UP}_VALID': 1, 'cleanup_valid': 1}
|
||||
reward = r.CLEAN_UP_VALID
|
||||
reward = self.rewards_dirt.CLEAN_UP_VALID
|
||||
else:
|
||||
valid = c.NOT_VALID
|
||||
self.print(f'{agent.name} just tried to clean up some dirt at {agent.pos}, but failed.')
|
||||
info_dict = {f'{agent.name}_{a.CLEAN_UP}_FAIL': 1, 'cleanup_fail': 1}
|
||||
reward = r.CLEAN_UP_FAIL
|
||||
reward = self.rewards_dirt.CLEAN_UP_FAIL
|
||||
|
||||
if valid and self.dirt_prop.done_when_clean and (len(self[c.DIRT]) == 0):
|
||||
reward += r.CLEAN_UP_LAST_PIECE
|
||||
reward += self.rewards_dirt.CLEAN_UP_LAST_PIECE
|
||||
self.print(f'{agent.name} picked up the last piece of dirt!')
|
||||
info_dict = {f'{agent.name}_{a.CLEAN_UP}_LAST_PIECE': 1}
|
||||
return valid, dict(value=reward, reason=a.CLEAN_UP, info=info_dict)
|
||||
|
Reference in New Issue
Block a user