Rework for performance
This commit is contained in:
@ -7,9 +7,10 @@ import random
|
||||
from environments.factory.base.base_factory import BaseFactory
|
||||
from environments.helpers import Constants as BaseConstants
|
||||
from environments.helpers import EnvActions as BaseActions
|
||||
from environments.helpers import Rewards as BaseRewards
|
||||
from environments import helpers as h
|
||||
from environments.factory.base.objects import Agent, Entity, Action, Tile
|
||||
from environments.factory.base.registers import Entities, EntityRegister, BoundRegisterMixin, ObjectRegister
|
||||
from environments.factory.base.registers import Entities, EntityRegister, BoundEnvObjRegister, ObjectRegister
|
||||
|
||||
from environments.factory.base.renderer import RenderEntity
|
||||
|
||||
@ -23,10 +24,17 @@ class Constants(BaseConstants):
|
||||
DROP_OFF = 'Drop_Off'
|
||||
|
||||
|
||||
class EnvActions(BaseActions):
|
||||
class Actions(BaseActions):
|
||||
ITEM_ACTION = 'item_action'
|
||||
|
||||
|
||||
class Rewards(BaseRewards):
|
||||
DROP_OFF_VALID = 0.1
|
||||
DROP_OFF_FAIL = -0.1
|
||||
PICK_UP_FAIL = -0.1
|
||||
PICK_UP_VALID = 0.1
|
||||
|
||||
|
||||
class Item(Entity):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
@ -37,10 +45,6 @@ class Item(Entity):
|
||||
def auto_despawn(self):
|
||||
return self._auto_despawn
|
||||
|
||||
@property
|
||||
def can_collide(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
# Edit this if you want items to be drawn in the ops differently
|
||||
@ -68,7 +72,7 @@ class ItemRegister(EntityRegister):
|
||||
del self[item]
|
||||
|
||||
|
||||
class Inventory(BoundRegisterMixin):
|
||||
class Inventory(BoundEnvObjRegister):
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
@ -131,10 +135,6 @@ class Inventories(ObjectRegister):
|
||||
|
||||
class DropOffLocation(Entity):
|
||||
|
||||
@property
|
||||
def can_collide(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
return Constants.ITEM_DROP_OFF
|
||||
@ -176,7 +176,8 @@ class ItemProperties(NamedTuple):
|
||||
|
||||
|
||||
c = Constants
|
||||
a = EnvActions
|
||||
a = Actions
|
||||
r = Rewards
|
||||
|
||||
|
||||
# noinspection PyAttributeOutsideInit, PyAbstractClass
|
||||
@ -230,37 +231,43 @@ class ItemFactory(BaseFactory):
|
||||
additional_observations.update({c.DROP_OFF: self[c.DROP_OFF].as_array()})
|
||||
return additional_observations
|
||||
|
||||
def do_item_action(self, agent: Agent):
|
||||
def do_item_action(self, agent: Agent) -> (dict, dict):
|
||||
inventory = self[c.INVENTORY].by_entity(agent)
|
||||
if drop_off := self[c.DROP_OFF].by_pos(agent.pos):
|
||||
if inventory:
|
||||
valid = drop_off.place_item(inventory.pop())
|
||||
return valid
|
||||
else:
|
||||
return c.NOT_VALID
|
||||
valid = c.NOT_VALID
|
||||
if valid:
|
||||
self.print(f'{agent.name} just dropped of an item at {drop_off.pos}.')
|
||||
info_dict = {f'{agent.name}_DROPOFF_VALID': 1}
|
||||
else:
|
||||
self.print(f'{agent.name} just tried to drop off at {agent.pos}, but failed.')
|
||||
info_dict = {f'{agent.name}_DROPOFF_FAIL': 1}
|
||||
reward = dict(value=r.DROP_OFF_VALID if valid else r.DROP_OFF_FAIL, reason=a.ITEM_ACTION, info=info_dict)
|
||||
return valid, reward
|
||||
elif item := self[c.ITEM].by_pos(agent.pos):
|
||||
try:
|
||||
inventory.register_item(item)
|
||||
item.change_register(inventory)
|
||||
self[c.ITEM].delete_env_object(item)
|
||||
item.set_tile_to(self._NO_POS_TILE)
|
||||
return c.VALID
|
||||
except RuntimeError:
|
||||
return c.NOT_VALID
|
||||
item.change_register(inventory)
|
||||
item.set_tile_to(self._NO_POS_TILE)
|
||||
self.print(f'{agent.name} just picked up an item at {agent.pos}')
|
||||
info_dict = {f'{agent.name}_{a.ITEM_ACTION}_VALID': 1}
|
||||
return c.VALID, dict(value=r.PICK_UP_VALID, reason=a.ITEM_ACTION, info=info_dict)
|
||||
else:
|
||||
return c.NOT_VALID
|
||||
self.print(f'{agent.name} just tried to pick up an item at {agent.pos}, but failed.')
|
||||
info_dict = {f'{agent.name}_{a.ITEM_ACTION}_FAIL': 1}
|
||||
return c.NOT_VALID, dict(value=r.PICK_UP_FAIL, reason=a.ITEM_ACTION, info=info_dict)
|
||||
|
||||
def do_additional_actions(self, agent: Agent, action: Action) -> Union[None, c]:
|
||||
def do_additional_actions(self, agent: Agent, action: Action) -> (dict, dict):
|
||||
# noinspection PyUnresolvedReferences
|
||||
valid = super().do_additional_actions(agent, action)
|
||||
if valid is None:
|
||||
action_result = super().do_additional_actions(agent, action)
|
||||
if action_result is None:
|
||||
if action == a.ITEM_ACTION:
|
||||
valid = self.do_item_action(agent)
|
||||
return valid
|
||||
action_result = self.do_item_action(agent)
|
||||
return action_result
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
return valid
|
||||
return action_result
|
||||
|
||||
def do_additional_reset(self) -> None:
|
||||
# noinspection PyUnresolvedReferences
|
||||
@ -277,9 +284,9 @@ class ItemFactory(BaseFactory):
|
||||
else:
|
||||
self.print('No Items are spawning, limit is reached.')
|
||||
|
||||
def do_additional_step(self) -> dict:
|
||||
def do_additional_step(self) -> (List[dict], dict):
|
||||
# noinspection PyUnresolvedReferences
|
||||
info_dict = super().do_additional_step()
|
||||
super_reward_info = super().do_additional_step()
|
||||
for item in list(self[c.ITEM].values()):
|
||||
if item.auto_despawn >= 1:
|
||||
item.set_auto_despawn(item.auto_despawn-1)
|
||||
@ -292,35 +299,7 @@ class ItemFactory(BaseFactory):
|
||||
self.trigger_item_spawn()
|
||||
else:
|
||||
self._next_item_spawn = max(0, self._next_item_spawn-1)
|
||||
return info_dict
|
||||
|
||||
def calculate_additional_reward(self, agent: Agent) -> (int, dict):
|
||||
# noinspection PyUnresolvedReferences
|
||||
reward, info_dict = super().calculate_additional_reward(agent)
|
||||
if a.ITEM_ACTION == agent.temp_action:
|
||||
if agent.temp_valid:
|
||||
if drop_off := self[c.DROP_OFF].by_pos(agent.pos):
|
||||
info_dict.update({f'{agent.name}_item_drop_off': 1})
|
||||
info_dict.update(item_drop_off=1)
|
||||
self.print(f'{agent.name} just dropped of an item at {drop_off.pos}.')
|
||||
reward += 1
|
||||
else:
|
||||
info_dict.update({f'{agent.name}_item_pickup': 1})
|
||||
info_dict.update(item_pickup=1)
|
||||
self.print(f'{agent.name} just picked up an item at {agent.pos}')
|
||||
reward += 0.2
|
||||
else:
|
||||
if self[c.DROP_OFF].by_pos(agent.pos):
|
||||
info_dict.update({f'{agent.name}_failed_drop_off': 1})
|
||||
info_dict.update(failed_drop_off=1)
|
||||
self.print(f'{agent.name} just tried to drop off at {agent.pos}, but failed.')
|
||||
reward -= 0.1
|
||||
else:
|
||||
info_dict.update({f'{agent.name}_failed_item_action': 1})
|
||||
info_dict.update(failed_pick_up=1)
|
||||
self.print(f'{agent.name} just tried to pick up an item at {agent.pos}, but failed.')
|
||||
reward -= 0.1
|
||||
return reward, info_dict
|
||||
return super_reward_info
|
||||
|
||||
def render_additional_assets(self, mode='human'):
|
||||
# noinspection PyUnresolvedReferences
|
||||
@ -335,9 +314,9 @@ class ItemFactory(BaseFactory):
|
||||
if __name__ == '__main__':
|
||||
from environments.utility_classes import AgentRenderOptions as aro, ObservationProperties
|
||||
|
||||
render = True
|
||||
render = False
|
||||
|
||||
item_probs = ItemProperties(n_items=30)
|
||||
item_probs = ItemProperties(n_items=30, n_drop_off_locations=6)
|
||||
|
||||
obs_props = ObservationProperties(render_agents=aro.SEPERATE, omit_agent_self=True, pomdp_r=2)
|
||||
|
||||
@ -345,7 +324,7 @@ if __name__ == '__main__':
|
||||
'allow_diagonal_movement': True,
|
||||
'allow_no_op': False}
|
||||
|
||||
factory = ItemFactory(n_agents=2, done_at_collision=False,
|
||||
factory = ItemFactory(n_agents=6, done_at_collision=False,
|
||||
level_name='rooms', max_steps=400,
|
||||
obs_prop=obs_props, parse_doors=True,
|
||||
record_episodes=True, verbose=True,
|
||||
|
Reference in New Issue
Block a user