Rewards can now be set as parameter
This commit is contained in:
parent
823aa075b9
commit
3ce6302e8a
@ -14,7 +14,7 @@ from environments.factory.base.shadow_casting import Map
|
|||||||
from environments import helpers as h
|
from environments import helpers as h
|
||||||
from environments.helpers import Constants as c
|
from environments.helpers import Constants as c
|
||||||
from environments.helpers import EnvActions as a
|
from environments.helpers import EnvActions as a
|
||||||
from environments.helpers import Rewards as r
|
from environments.helpers import RewardsBase
|
||||||
from environments.factory.base.objects import Agent, Floor, Action
|
from environments.factory.base.objects import Agent, Floor, Action
|
||||||
from environments.factory.base.registers import Actions, Entities, Agents, Doors, Floors, Walls, PlaceHolders, \
|
from environments.factory.base.registers import Actions, Entities, Agents, Doors, Floors, Walls, PlaceHolders, \
|
||||||
GlobalPositions
|
GlobalPositions
|
||||||
@ -80,6 +80,7 @@ class BaseFactory(gym.Env):
|
|||||||
def __init__(self, level_name='simple', n_agents=1, max_steps=int(5e2),
|
def __init__(self, level_name='simple', n_agents=1, max_steps=int(5e2),
|
||||||
mv_prop: MovementProperties = MovementProperties(),
|
mv_prop: MovementProperties = MovementProperties(),
|
||||||
obs_prop: ObservationProperties = ObservationProperties(),
|
obs_prop: ObservationProperties = ObservationProperties(),
|
||||||
|
rewards_base: RewardsBase = RewardsBase(),
|
||||||
parse_doors=False, done_at_collision=False, inject_agents: Union[None, List] = None,
|
parse_doors=False, done_at_collision=False, inject_agents: Union[None, List] = None,
|
||||||
verbose=False, doors_have_area=True, env_seed=time.time_ns(), individual_rewards=False,
|
verbose=False, doors_have_area=True, env_seed=time.time_ns(), individual_rewards=False,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
@ -88,6 +89,8 @@ class BaseFactory(gym.Env):
|
|||||||
mv_prop = MovementProperties(**mv_prop)
|
mv_prop = MovementProperties(**mv_prop)
|
||||||
if isinstance(obs_prop, dict):
|
if isinstance(obs_prop, dict):
|
||||||
obs_prop = ObservationProperties(**obs_prop)
|
obs_prop = ObservationProperties(**obs_prop)
|
||||||
|
if isinstance(rewards_base, dict):
|
||||||
|
rewards_base = RewardsBase(**rewards_base)
|
||||||
|
|
||||||
assert obs_prop.frames_to_stack != 1 and \
|
assert obs_prop.frames_to_stack != 1 and \
|
||||||
obs_prop.frames_to_stack >= 0, "'frames_to_stack' cannot be negative or 1."
|
obs_prop.frames_to_stack >= 0, "'frames_to_stack' cannot be negative or 1."
|
||||||
@ -100,6 +103,7 @@ class BaseFactory(gym.Env):
|
|||||||
self._base_rng = np.random.default_rng(self.env_seed)
|
self._base_rng = np.random.default_rng(self.env_seed)
|
||||||
self.mv_prop = mv_prop
|
self.mv_prop = mv_prop
|
||||||
self.obs_prop = obs_prop
|
self.obs_prop = obs_prop
|
||||||
|
self.rewards_base = rewards_base
|
||||||
self.level_name = level_name
|
self.level_name = level_name
|
||||||
self._level_shape = None
|
self._level_shape = None
|
||||||
self._obs_shape = None
|
self._obs_shape = None
|
||||||
@ -244,7 +248,7 @@ class BaseFactory(gym.Env):
|
|||||||
action_valid, reward = self._do_move_action(agent, action_obj)
|
action_valid, reward = self._do_move_action(agent, action_obj)
|
||||||
elif a.NOOP == action_obj:
|
elif a.NOOP == action_obj:
|
||||||
action_valid = c.VALID
|
action_valid = c.VALID
|
||||||
reward = dict(value=r.NOOP, reason=a.NOOP, info={f'{agent.name}_NOOP': 1, 'NOOP': 1})
|
reward = dict(value=self.rewards_base.NOOP, reason=a.NOOP, info={f'{agent.name}_NOOP': 1, 'NOOP': 1})
|
||||||
elif a.USE_DOOR == action_obj:
|
elif a.USE_DOOR == action_obj:
|
||||||
action_valid, reward = self._handle_door_interaction(agent)
|
action_valid, reward = self._handle_door_interaction(agent)
|
||||||
else:
|
else:
|
||||||
@ -323,7 +327,7 @@ class BaseFactory(gym.Env):
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
raise RuntimeError('This should not happen, since the door action should not be available.')
|
raise RuntimeError('This should not happen, since the door action should not be available.')
|
||||||
reward = dict(value=r.USE_DOOR_VALID if valid else r.USE_DOOR_FAIL,
|
reward = dict(value=self.rewards_base.USE_DOOR_VALID if valid else self.rewards_base.USE_DOOR_FAIL,
|
||||||
reason=a.USE_DOOR, info=info_dict)
|
reason=a.USE_DOOR, info=info_dict)
|
||||||
|
|
||||||
return valid, reward
|
return valid, reward
|
||||||
@ -518,7 +522,7 @@ class BaseFactory(gym.Env):
|
|||||||
# Agent seems to be trying to Leave the level
|
# Agent seems to be trying to Leave the level
|
||||||
self.print(f'{agent.name} tried to leave the level {agent.pos}. ({action.identifier})')
|
self.print(f'{agent.name} tried to leave the level {agent.pos}. ({action.identifier})')
|
||||||
info_dict.update({f'{agent.name}_wall_collide': 1, 'wall_collide': 1})
|
info_dict.update({f'{agent.name}_wall_collide': 1, 'wall_collide': 1})
|
||||||
reward_value = r.MOVEMENTS_VALID if valid else r.MOVEMENTS_FAIL
|
reward_value = self.rewards_base.MOVEMENTS_VALID if valid else self.rewards_base.MOVEMENTS_FAIL
|
||||||
reward = {'value': reward_value, 'reason': action.identifier, 'info': info_dict}
|
reward = {'value': reward_value, 'reason': action.identifier, 'info': info_dict}
|
||||||
return valid, reward
|
return valid, reward
|
||||||
|
|
||||||
@ -573,7 +577,9 @@ class BaseFactory(gym.Env):
|
|||||||
if collisions := agent.step_result['collisions']:
|
if collisions := agent.step_result['collisions']:
|
||||||
self.print(f't = {self._steps}\t{agent.name} has collisions with {collisions}')
|
self.print(f't = {self._steps}\t{agent.name} has collisions with {collisions}')
|
||||||
info[c.COLLISION] += 1
|
info[c.COLLISION] += 1
|
||||||
reward = {'value': r.COLLISION, 'reason': c.COLLISION, 'info': {f'{agent.name}_{c.COLLISION}': 1}}
|
reward = {'value': self.rewards_base.COLLISION,
|
||||||
|
'reason': c.COLLISION,
|
||||||
|
'info': {f'{agent.name}_{c.COLLISION}': 1}}
|
||||||
agent.step_result['rewards'].append(reward)
|
agent.step_result['rewards'].append(reward)
|
||||||
else:
|
else:
|
||||||
# No Collisions, nothing to do
|
# No Collisions, nothing to do
|
||||||
|
@ -8,7 +8,6 @@ from environments.factory.base.registers import EntityRegister, EnvObjectRegiste
|
|||||||
from environments.factory.base.renderer import RenderEntity
|
from environments.factory.base.renderer import RenderEntity
|
||||||
from environments.helpers import Constants as BaseConstants
|
from environments.helpers import Constants as BaseConstants
|
||||||
from environments.helpers import EnvActions as BaseActions
|
from environments.helpers import EnvActions as BaseActions
|
||||||
from environments.helpers import Rewards as BaseRewards
|
|
||||||
|
|
||||||
from environments import helpers as h
|
from environments import helpers as h
|
||||||
|
|
||||||
@ -25,10 +24,10 @@ class Actions(BaseActions):
|
|||||||
CHARGE = 'do_charge_action'
|
CHARGE = 'do_charge_action'
|
||||||
|
|
||||||
|
|
||||||
class Rewards(BaseRewards):
|
class RewardsBtry(NamedTuple):
|
||||||
CHARGE_VALID = 0.1
|
CHARGE_VALID: float = 0.1
|
||||||
CHARGE_FAIL = -0.1
|
CHARGE_FAIL: float = -0.1
|
||||||
BATTERY_DISCHARGED = -1.0
|
BATTERY_DISCHARGED: float = -1.0
|
||||||
|
|
||||||
|
|
||||||
class BatteryProperties(NamedTuple):
|
class BatteryProperties(NamedTuple):
|
||||||
@ -42,7 +41,6 @@ class BatteryProperties(NamedTuple):
|
|||||||
|
|
||||||
c = Constants
|
c = Constants
|
||||||
a = Actions
|
a = Actions
|
||||||
r = Rewards
|
|
||||||
|
|
||||||
|
|
||||||
class Battery(BoundingMixin, EnvObject):
|
class Battery(BoundingMixin, EnvObject):
|
||||||
@ -62,9 +60,9 @@ class Battery(BoundingMixin, EnvObject):
|
|||||||
if self.charge_level < 1:
|
if self.charge_level < 1:
|
||||||
# noinspection PyTypeChecker
|
# noinspection PyTypeChecker
|
||||||
self.charge_level = min(1, amount + self.charge_level)
|
self.charge_level = min(1, amount + self.charge_level)
|
||||||
return dict(valid=c.VALID, action=a.CHARGE, reward=r.CHARGE_VALID)
|
return c.VALID
|
||||||
else:
|
else:
|
||||||
return dict(valid=c.NOT_VALID, action=a.CHARGE, reward=r.CHARGE_FAIL)
|
return c.NOT_VALID
|
||||||
|
|
||||||
def decharge(self, amount) -> c:
|
def decharge(self, amount) -> c:
|
||||||
if self.charge_level != 0:
|
if self.charge_level != 0:
|
||||||
@ -133,8 +131,8 @@ class ChargePod(Entity):
|
|||||||
return c.NOT_VALID
|
return c.NOT_VALID
|
||||||
if sum(guest for guest in self.tile.guests if 'agent' in guest.name.lower()) > 1:
|
if sum(guest for guest in self.tile.guests if 'agent' in guest.name.lower()) > 1:
|
||||||
return c.NOT_VALID
|
return c.NOT_VALID
|
||||||
battery.do_charge_action(self.charge_rate)
|
valid = battery.do_charge_action(self.charge_rate)
|
||||||
return c.VALID
|
return valid
|
||||||
|
|
||||||
def summarize_state(self, n_steps=None) -> dict:
|
def summarize_state(self, n_steps=None) -> dict:
|
||||||
if n_steps == h.STEPS_START:
|
if n_steps == h.STEPS_START:
|
||||||
@ -152,10 +150,14 @@ class ChargePods(EntityRegister):
|
|||||||
|
|
||||||
class BatteryFactory(BaseFactory):
|
class BatteryFactory(BaseFactory):
|
||||||
|
|
||||||
def __init__(self, *args, btry_prop=BatteryProperties(), **kwargs):
|
def __init__(self, *args, btry_prop=BatteryProperties(), rewards_dest: RewardsBtry = RewardsBtry(),
|
||||||
|
**kwargs):
|
||||||
if isinstance(btry_prop, dict):
|
if isinstance(btry_prop, dict):
|
||||||
btry_prop = BatteryProperties(**btry_prop)
|
btry_prop = BatteryProperties(**btry_prop)
|
||||||
|
if isinstance(rewards_dest, dict):
|
||||||
|
rewards_dest = RewardsBtry(**rewards_dest)
|
||||||
self.btry_prop = btry_prop
|
self.btry_prop = btry_prop
|
||||||
|
self.rewards_dest = rewards_dest
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
def per_agent_raw_observations_hook(self, agent) -> Dict[str, np.typing.ArrayLike]:
|
def per_agent_raw_observations_hook(self, agent) -> Dict[str, np.typing.ArrayLike]:
|
||||||
@ -215,7 +217,8 @@ class BatteryFactory(BaseFactory):
|
|||||||
info_dict = {f'{agent.name}_{a.CHARGE}_FAIL': 1}
|
info_dict = {f'{agent.name}_{a.CHARGE}_FAIL': 1}
|
||||||
# info_dict = {f'{agent.name}_no_charger': 1}
|
# info_dict = {f'{agent.name}_no_charger': 1}
|
||||||
self.print(f'{agent.name} failed to charged batteries at {agent.pos}.')
|
self.print(f'{agent.name} failed to charged batteries at {agent.pos}.')
|
||||||
reward = dict(value=r.CHARGE_VALID if valid else r.CHARGE_FAIL, reason=a.CHARGE, info=info_dict)
|
reward = dict(value=self.rewards_dest.CHARGE_VALID if valid else self.rewards_dest.CHARGE_FAIL,
|
||||||
|
reason=a.CHARGE, info=info_dict)
|
||||||
return valid, reward
|
return valid, reward
|
||||||
|
|
||||||
def do_additional_actions(self, agent: Agent, action: Action) -> (bool, dict):
|
def do_additional_actions(self, agent: Agent, action: Action) -> (bool, dict):
|
||||||
@ -254,7 +257,9 @@ class BatteryFactory(BaseFactory):
|
|||||||
if self[c.BATTERIES].by_entity(agent).is_discharged:
|
if self[c.BATTERIES].by_entity(agent).is_discharged:
|
||||||
self.print(f'{agent.name} Battery is discharged!')
|
self.print(f'{agent.name} Battery is discharged!')
|
||||||
info_dict = {f'{agent.name}_{c.BATTERY_DISCHARGED}': 1}
|
info_dict = {f'{agent.name}_{c.BATTERY_DISCHARGED}': 1}
|
||||||
reward_event_dict.update({c.BATTERY_DISCHARGED: {'reward': r.BATTERY_DISCHARGED, 'info': info_dict}})
|
reward_event_dict.update({c.BATTERY_DISCHARGED: {'reward': self.rewards_dest.BATTERY_DISCHARGED,
|
||||||
|
'info': info_dict}}
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# All Fine
|
# All Fine
|
||||||
pass
|
pass
|
||||||
|
@ -8,7 +8,6 @@ import random
|
|||||||
from environments.factory.base.base_factory import BaseFactory
|
from environments.factory.base.base_factory import BaseFactory
|
||||||
from environments.helpers import Constants as BaseConstants
|
from environments.helpers import Constants as BaseConstants
|
||||||
from environments.helpers import EnvActions as BaseActions
|
from environments.helpers import EnvActions as BaseActions
|
||||||
from environments.helpers import Rewards as BaseRewards
|
|
||||||
from environments.factory.base.objects import Agent, Entity, Action
|
from environments.factory.base.objects import Agent, Entity, Action
|
||||||
from environments.factory.base.registers import Entities, EntityRegister
|
from environments.factory.base.registers import Entities, EntityRegister
|
||||||
|
|
||||||
@ -27,11 +26,11 @@ class Actions(BaseActions):
|
|||||||
WAIT_ON_DEST = 'WAIT'
|
WAIT_ON_DEST = 'WAIT'
|
||||||
|
|
||||||
|
|
||||||
class Rewards(BaseRewards):
|
class RewardsDest(NamedTuple):
|
||||||
|
|
||||||
WAIT_VALID = 0.1
|
WAIT_VALID: float = 0.1
|
||||||
WAIT_FAIL = -0.1
|
WAIT_FAIL: float = -0.1
|
||||||
DEST_REACHED = 5.0
|
DEST_REACHED: float = 5.0
|
||||||
|
|
||||||
|
|
||||||
class Destination(Entity):
|
class Destination(Entity):
|
||||||
@ -117,7 +116,7 @@ class DestModeOptions(object):
|
|||||||
|
|
||||||
class DestProperties(NamedTuple):
|
class DestProperties(NamedTuple):
|
||||||
n_dests: int = 1 # How many destinations are there
|
n_dests: int = 1 # How many destinations are there
|
||||||
dwell_time: int = 0 # How long does the agent need to "do_wait_action" on a destination
|
dwell_time: int = 0 # How long does the agent need to "wait" on a destination
|
||||||
spawn_frequency: int = 0
|
spawn_frequency: int = 0
|
||||||
spawn_in_other_zone: bool = True #
|
spawn_in_other_zone: bool = True #
|
||||||
spawn_mode: str = DestModeOptions.DONE
|
spawn_mode: str = DestModeOptions.DONE
|
||||||
@ -130,18 +129,20 @@ class DestProperties(NamedTuple):
|
|||||||
|
|
||||||
c = Constants
|
c = Constants
|
||||||
a = Actions
|
a = Actions
|
||||||
r = Rewards
|
|
||||||
|
|
||||||
|
|
||||||
# noinspection PyAttributeOutsideInit, PyAbstractClass
|
# noinspection PyAttributeOutsideInit, PyAbstractClass
|
||||||
class DestFactory(BaseFactory):
|
class DestFactory(BaseFactory):
|
||||||
# noinspection PyMissingConstructor
|
# noinspection PyMissingConstructor
|
||||||
|
|
||||||
def __init__(self, *args, dest_prop: DestProperties = DestProperties(),
|
def __init__(self, *args, dest_prop: DestProperties = DestProperties(), rewards_dest: RewardsDest = RewardsDest(),
|
||||||
env_seed=time.time_ns(), **kwargs):
|
env_seed=time.time_ns(), **kwargs):
|
||||||
if isinstance(dest_prop, dict):
|
if isinstance(dest_prop, dict):
|
||||||
dest_prop = DestProperties(**dest_prop)
|
dest_prop = DestProperties(**dest_prop)
|
||||||
|
if isinstance(rewards_dest, dict):
|
||||||
|
rewards_dest = RewardsDest(**rewards_dest)
|
||||||
self.dest_prop = dest_prop
|
self.dest_prop = dest_prop
|
||||||
|
self.rewards_dest = rewards_dest
|
||||||
kwargs.update(env_seed=env_seed)
|
kwargs.update(env_seed=env_seed)
|
||||||
self._dest_rng = np.random.default_rng(env_seed)
|
self._dest_rng = np.random.default_rng(env_seed)
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
@ -179,7 +180,8 @@ class DestFactory(BaseFactory):
|
|||||||
valid = c.NOT_VALID
|
valid = c.NOT_VALID
|
||||||
self.print(f'{agent.name} just tried to do_wait_action do_wait_action at {agent.pos} but failed')
|
self.print(f'{agent.name} just tried to do_wait_action do_wait_action at {agent.pos} but failed')
|
||||||
info_dict = {f'{agent.name}_{a.WAIT_ON_DEST}_FAIL': 1}
|
info_dict = {f'{agent.name}_{a.WAIT_ON_DEST}_FAIL': 1}
|
||||||
reward = dict(value=r.WAIT_VALID if valid else r.WAIT_FAIL, reason=a.WAIT_ON_DEST, info=info_dict)
|
reward = dict(value=self.rewards_dest.WAIT_VALID if valid else self.rewards_dest.WAIT_FAIL,
|
||||||
|
reason=a.WAIT_ON_DEST, info=info_dict)
|
||||||
return valid, reward
|
return valid, reward
|
||||||
|
|
||||||
def do_additional_actions(self, agent: Agent, action: Action) -> (dict, dict):
|
def do_additional_actions(self, agent: Agent, action: Action) -> (dict, dict):
|
||||||
@ -258,7 +260,8 @@ class DestFactory(BaseFactory):
|
|||||||
self.print(f'{agent.name} just reached destination at {agent.pos}')
|
self.print(f'{agent.name} just reached destination at {agent.pos}')
|
||||||
self[c.DEST_REACHED].delete_env_object(reached_dest)
|
self[c.DEST_REACHED].delete_env_object(reached_dest)
|
||||||
info_dict = {f'{agent.name}_{c.DEST_REACHED}': 1}
|
info_dict = {f'{agent.name}_{c.DEST_REACHED}': 1}
|
||||||
reward_event_dict.update({c.DEST_REACHED: {'reward': r.DEST_REACHED, 'info': info_dict}})
|
reward_event_dict.update({c.DEST_REACHED: {'reward': self.rewards_dest.DEST_REACHED,
|
||||||
|
'info': info_dict}})
|
||||||
return reward_event_dict
|
return reward_event_dict
|
||||||
|
|
||||||
def render_assets_hook(self, mode='human'):
|
def render_assets_hook(self, mode='human'):
|
||||||
@ -270,13 +273,13 @@ class DestFactory(BaseFactory):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
from environments.utility_classes import AgentRenderOptions as ARO, ObservationProperties
|
from environments.utility_classes import AgentRenderOptions as aro, ObservationProperties
|
||||||
|
|
||||||
render = True
|
render = True
|
||||||
|
|
||||||
dest_probs = DestProperties(n_dests=2, spawn_frequency=5, spawn_mode=DestModeOptions.GROUPED)
|
dest_probs = DestProperties(n_dests=2, spawn_frequency=5, spawn_mode=DestModeOptions.GROUPED)
|
||||||
|
|
||||||
obs_props = ObservationProperties(render_agents=ARO.LEVEL, omit_agent_self=True, pomdp_r=2)
|
obs_props = ObservationProperties(render_agents=aro.LEVEL, omit_agent_self=True, pomdp_r=2)
|
||||||
|
|
||||||
move_props = {'allow_square_movement': True,
|
move_props = {'allow_square_movement': True,
|
||||||
'allow_diagonal_movement': False,
|
'allow_diagonal_movement': False,
|
||||||
|
@ -4,11 +4,9 @@ import random
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
# from algorithms.TSP_dirt_agent import TSPDirtAgent
|
|
||||||
from algorithms.TSP_dirt_agent import TSPDirtAgent
|
from algorithms.TSP_dirt_agent import TSPDirtAgent
|
||||||
from environments.helpers import Constants as BaseConstants
|
from environments.helpers import Constants as BaseConstants
|
||||||
from environments.helpers import EnvActions as BaseActions
|
from environments.helpers import EnvActions as BaseActions
|
||||||
from environments.helpers import Rewards as BaseRewards
|
|
||||||
|
|
||||||
from environments.factory.base.base_factory import BaseFactory
|
from environments.factory.base.base_factory import BaseFactory
|
||||||
from environments.factory.base.objects import Agent, Action, Entity, Floor
|
from environments.factory.base.objects import Agent, Action, Entity, Floor
|
||||||
@ -26,10 +24,10 @@ class Actions(BaseActions):
|
|||||||
CLEAN_UP = 'do_cleanup_action'
|
CLEAN_UP = 'do_cleanup_action'
|
||||||
|
|
||||||
|
|
||||||
class Rewards(BaseRewards):
|
class RewardsDirt(NamedTuple):
|
||||||
CLEAN_UP_VALID = 0.5
|
CLEAN_UP_VALID: float = 0.5
|
||||||
CLEAN_UP_FAIL = -0.1
|
CLEAN_UP_FAIL: float = -0.1
|
||||||
CLEAN_UP_LAST_PIECE = 4.5
|
CLEAN_UP_LAST_PIECE: float = 4.5
|
||||||
|
|
||||||
|
|
||||||
class DirtProperties(NamedTuple):
|
class DirtProperties(NamedTuple):
|
||||||
@ -119,7 +117,6 @@ def entropy(x):
|
|||||||
|
|
||||||
c = Constants
|
c = Constants
|
||||||
a = Actions
|
a = Actions
|
||||||
r = Rewards
|
|
||||||
|
|
||||||
|
|
||||||
# noinspection PyAttributeOutsideInit, PyAbstractClass
|
# noinspection PyAttributeOutsideInit, PyAbstractClass
|
||||||
@ -138,10 +135,15 @@ class DirtFactory(BaseFactory):
|
|||||||
super_entities.update(({c.DIRT: dirt_register}))
|
super_entities.update(({c.DIRT: dirt_register}))
|
||||||
return super_entities
|
return super_entities
|
||||||
|
|
||||||
def __init__(self, *args, dirt_prop: DirtProperties = DirtProperties(), env_seed=time.time_ns(), **kwargs):
|
def __init__(self, *args,
|
||||||
|
dirt_prop: DirtProperties = DirtProperties(), rewards_dirt: RewardsDirt = RewardsDirt(),
|
||||||
|
env_seed=time.time_ns(), **kwargs):
|
||||||
if isinstance(dirt_prop, dict):
|
if isinstance(dirt_prop, dict):
|
||||||
dirt_prop = DirtProperties(**dirt_prop)
|
dirt_prop = DirtProperties(**dirt_prop)
|
||||||
|
if isinstance(rewards_dirt, dict):
|
||||||
|
rewards_dirt = RewardsDirt(**rewards_dirt)
|
||||||
self.dirt_prop = dirt_prop
|
self.dirt_prop = dirt_prop
|
||||||
|
self.rewards_dirt = rewards_dirt
|
||||||
self._dirt_rng = np.random.default_rng(env_seed)
|
self._dirt_rng = np.random.default_rng(env_seed)
|
||||||
self._dirt: DirtRegister
|
self._dirt: DirtRegister
|
||||||
kwargs.update(env_seed=env_seed)
|
kwargs.update(env_seed=env_seed)
|
||||||
@ -166,15 +168,15 @@ class DirtFactory(BaseFactory):
|
|||||||
valid = c.VALID
|
valid = c.VALID
|
||||||
self.print(f'{agent.name} did just clean up some dirt at {agent.pos}.')
|
self.print(f'{agent.name} did just clean up some dirt at {agent.pos}.')
|
||||||
info_dict = {f'{agent.name}_{a.CLEAN_UP}_VALID': 1, 'cleanup_valid': 1}
|
info_dict = {f'{agent.name}_{a.CLEAN_UP}_VALID': 1, 'cleanup_valid': 1}
|
||||||
reward = r.CLEAN_UP_VALID
|
reward = self.rewards_dirt.CLEAN_UP_VALID
|
||||||
else:
|
else:
|
||||||
valid = c.NOT_VALID
|
valid = c.NOT_VALID
|
||||||
self.print(f'{agent.name} just tried to clean up some dirt at {agent.pos}, but failed.')
|
self.print(f'{agent.name} just tried to clean up some dirt at {agent.pos}, but failed.')
|
||||||
info_dict = {f'{agent.name}_{a.CLEAN_UP}_FAIL': 1, 'cleanup_fail': 1}
|
info_dict = {f'{agent.name}_{a.CLEAN_UP}_FAIL': 1, 'cleanup_fail': 1}
|
||||||
reward = r.CLEAN_UP_FAIL
|
reward = self.rewards_dirt.CLEAN_UP_FAIL
|
||||||
|
|
||||||
if valid and self.dirt_prop.done_when_clean and (len(self[c.DIRT]) == 0):
|
if valid and self.dirt_prop.done_when_clean and (len(self[c.DIRT]) == 0):
|
||||||
reward += r.CLEAN_UP_LAST_PIECE
|
reward += self.rewards_dirt.CLEAN_UP_LAST_PIECE
|
||||||
self.print(f'{agent.name} picked up the last piece of dirt!')
|
self.print(f'{agent.name} picked up the last piece of dirt!')
|
||||||
info_dict = {f'{agent.name}_{a.CLEAN_UP}_LAST_PIECE': 1}
|
info_dict = {f'{agent.name}_{a.CLEAN_UP}_LAST_PIECE': 1}
|
||||||
return valid, dict(value=reward, reason=a.CLEAN_UP, info=info_dict)
|
return valid, dict(value=reward, reason=a.CLEAN_UP, info=info_dict)
|
||||||
|
@ -7,7 +7,6 @@ import random
|
|||||||
from environments.factory.base.base_factory import BaseFactory
|
from environments.factory.base.base_factory import BaseFactory
|
||||||
from environments.helpers import Constants as BaseConstants
|
from environments.helpers import Constants as BaseConstants
|
||||||
from environments.helpers import EnvActions as BaseActions
|
from environments.helpers import EnvActions as BaseActions
|
||||||
from environments.helpers import Rewards as BaseRewards
|
|
||||||
from environments import helpers as h
|
from environments import helpers as h
|
||||||
from environments.factory.base.objects import Agent, Entity, Action, Floor
|
from environments.factory.base.objects import Agent, Entity, Action, Floor
|
||||||
from environments.factory.base.registers import Entities, EntityRegister, BoundEnvObjRegister, ObjectRegister
|
from environments.factory.base.registers import Entities, EntityRegister, BoundEnvObjRegister, ObjectRegister
|
||||||
@ -28,11 +27,11 @@ class Actions(BaseActions):
|
|||||||
ITEM_ACTION = 'ITEMACTION'
|
ITEM_ACTION = 'ITEMACTION'
|
||||||
|
|
||||||
|
|
||||||
class Rewards(BaseRewards):
|
class RewardsItem(NamedTuple):
|
||||||
DROP_OFF_VALID = 0.1
|
DROP_OFF_VALID: float = 0.1
|
||||||
DROP_OFF_FAIL = -0.1
|
DROP_OFF_FAIL: float = -0.1
|
||||||
PICK_UP_FAIL = -0.1
|
PICK_UP_FAIL: float = -0.1
|
||||||
PICK_UP_VALID = 0.1
|
PICK_UP_VALID: float = 0.1
|
||||||
|
|
||||||
|
|
||||||
class Item(Entity):
|
class Item(Entity):
|
||||||
@ -177,16 +176,19 @@ class ItemProperties(NamedTuple):
|
|||||||
|
|
||||||
c = Constants
|
c = Constants
|
||||||
a = Actions
|
a = Actions
|
||||||
r = Rewards
|
|
||||||
|
|
||||||
|
|
||||||
# noinspection PyAttributeOutsideInit, PyAbstractClass
|
# noinspection PyAttributeOutsideInit, PyAbstractClass
|
||||||
class ItemFactory(BaseFactory):
|
class ItemFactory(BaseFactory):
|
||||||
# noinspection PyMissingConstructor
|
# noinspection PyMissingConstructor
|
||||||
def __init__(self, *args, item_prop: ItemProperties = ItemProperties(), env_seed=time.time_ns(), **kwargs):
|
def __init__(self, *args, item_prop: ItemProperties = ItemProperties(), env_seed=time.time_ns(),
|
||||||
|
rewards_item: RewardsItem = RewardsItem(), **kwargs):
|
||||||
if isinstance(item_prop, dict):
|
if isinstance(item_prop, dict):
|
||||||
item_prop = ItemProperties(**item_prop)
|
item_prop = ItemProperties(**item_prop)
|
||||||
|
if isinstance(rewards_item, dict):
|
||||||
|
rewards_item = RewardsItem(**rewards_item)
|
||||||
self.item_prop = item_prop
|
self.item_prop = item_prop
|
||||||
|
self.rewards_item = rewards_item
|
||||||
kwargs.update(env_seed=env_seed)
|
kwargs.update(env_seed=env_seed)
|
||||||
self._item_rng = np.random.default_rng(env_seed)
|
self._item_rng = np.random.default_rng(env_seed)
|
||||||
assert (item_prop.n_items <= ((1 + kwargs.get('_pomdp_r', 0) * 2) ** 2)) or not kwargs.get('_pomdp_r', 0)
|
assert (item_prop.n_items <= ((1 + kwargs.get('_pomdp_r', 0) * 2) ** 2)) or not kwargs.get('_pomdp_r', 0)
|
||||||
@ -244,18 +246,19 @@ class ItemFactory(BaseFactory):
|
|||||||
else:
|
else:
|
||||||
self.print(f'{agent.name} just tried to drop off at {agent.pos}, but failed.')
|
self.print(f'{agent.name} just tried to drop off at {agent.pos}, but failed.')
|
||||||
info_dict = {f'{agent.name}_DROPOFF_FAIL': 1, 'DROPOFF_FAIL': 1}
|
info_dict = {f'{agent.name}_DROPOFF_FAIL': 1, 'DROPOFF_FAIL': 1}
|
||||||
reward = dict(value=r.DROP_OFF_VALID if valid else r.DROP_OFF_FAIL, reason=a.ITEM_ACTION, info=info_dict)
|
reward = dict(value=self.rewards_item.DROP_OFF_VALID if valid else self.rewards_item.DROP_OFF_FAIL,
|
||||||
|
reason=a.ITEM_ACTION, info=info_dict)
|
||||||
return valid, reward
|
return valid, reward
|
||||||
elif item := self[c.ITEM].by_pos(agent.pos):
|
elif item := self[c.ITEM].by_pos(agent.pos):
|
||||||
item.change_register(inventory)
|
item.change_register(inventory)
|
||||||
item.set_tile_to(self._NO_POS_TILE)
|
item.set_tile_to(self._NO_POS_TILE)
|
||||||
self.print(f'{agent.name} just picked up an item at {agent.pos}')
|
self.print(f'{agent.name} just picked up an item at {agent.pos}')
|
||||||
info_dict = {f'{agent.name}_{a.ITEM_ACTION}_VALID': 1, f'{a.ITEM_ACTION}_VALID': 1}
|
info_dict = {f'{agent.name}_{a.ITEM_ACTION}_VALID': 1, f'{a.ITEM_ACTION}_VALID': 1}
|
||||||
return c.VALID, dict(value=r.PICK_UP_VALID, reason=a.ITEM_ACTION, info=info_dict)
|
return c.VALID, dict(value=self.rewards_item.PICK_UP_VALID, reason=a.ITEM_ACTION, info=info_dict)
|
||||||
else:
|
else:
|
||||||
self.print(f'{agent.name} just tried to pick up an item at {agent.pos}, but failed.')
|
self.print(f'{agent.name} just tried to pick up an item at {agent.pos}, but failed.')
|
||||||
info_dict = {f'{agent.name}_{a.ITEM_ACTION}_FAIL': 1, f'{a.ITEM_ACTION}_FAIL': 1}
|
info_dict = {f'{agent.name}_{a.ITEM_ACTION}_FAIL': 1, f'{a.ITEM_ACTION}_FAIL': 1}
|
||||||
return c.NOT_VALID, dict(value=r.PICK_UP_FAIL, reason=a.ITEM_ACTION, info=info_dict)
|
return c.NOT_VALID, dict(value=self.rewards_item.PICK_UP_FAIL, reason=a.ITEM_ACTION, info=info_dict)
|
||||||
|
|
||||||
def do_additional_actions(self, agent: Agent, action: Action) -> (dict, dict):
|
def do_additional_actions(self, agent: Agent, action: Action) -> (dict, dict):
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
|
@ -76,19 +76,18 @@ class EnvActions:
|
|||||||
return list(itertools.chain(cls.square_move(), cls.diagonal_move()))
|
return list(itertools.chain(cls.square_move(), cls.diagonal_move()))
|
||||||
|
|
||||||
|
|
||||||
class Rewards:
|
class RewardsBase(NamedTuple):
|
||||||
|
MOVEMENTS_VALID: float = -0.001
|
||||||
MOVEMENTS_VALID = -0.00
|
MOVEMENTS_FAIL: float = -0.05
|
||||||
MOVEMENTS_FAIL = -0.10
|
NOOP: float = -0.01
|
||||||
NOOP = -0.01
|
USE_DOOR_VALID: float = -0.00
|
||||||
USE_DOOR_VALID = -0.00
|
USE_DOOR_FAIL: float = -0.01
|
||||||
USE_DOOR_FAIL = -0.10
|
COLLISION: float = -0.5
|
||||||
COLLISION = -0.5
|
|
||||||
|
|
||||||
|
|
||||||
m = EnvActions
|
m = EnvActions
|
||||||
c = Constants
|
c = Constants
|
||||||
r = Rewards
|
r = RewardsBase
|
||||||
|
|
||||||
ACTIONMAP = defaultdict(lambda: (0, 0),
|
ACTIONMAP = defaultdict(lambda: (0, 0),
|
||||||
{m.NORTH: (-1, 0), m.NORTHEAST: (-1, 1),
|
{m.NORTH: (-1, 0), m.NORTHEAST: (-1, 1),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user