Rework for performance

This commit is contained in:
Steffen Illium
2022-01-10 15:54:22 +01:00
parent 78bf19f7f4
commit 435056f373
10 changed files with 525 additions and 469 deletions

View File

@ -7,9 +7,10 @@ import random
from environments.factory.base.base_factory import BaseFactory
from environments.helpers import Constants as BaseConstants
from environments.helpers import EnvActions as BaseActions
from environments.helpers import Rewards as BaseRewards
from environments import helpers as h
from environments.factory.base.objects import Agent, Entity, Action, Tile
from environments.factory.base.registers import Entities, EntityRegister, BoundRegisterMixin, ObjectRegister
from environments.factory.base.registers import Entities, EntityRegister, BoundEnvObjRegister, ObjectRegister
from environments.factory.base.renderer import RenderEntity
@ -23,10 +24,17 @@ class Constants(BaseConstants):
DROP_OFF = 'Drop_Off'
class EnvActions(BaseActions):
class Actions(BaseActions):
ITEM_ACTION = 'item_action'
class Rewards(BaseRewards):
DROP_OFF_VALID = 0.1
DROP_OFF_FAIL = -0.1
PICK_UP_FAIL = -0.1
PICK_UP_VALID = 0.1
class Item(Entity):
def __init__(self, *args, **kwargs):
@ -37,10 +45,6 @@ class Item(Entity):
def auto_despawn(self):
return self._auto_despawn
@property
def can_collide(self):
return False
@property
def encoding(self):
# Edit this if you want items to be drawn in the ops differently
@ -68,7 +72,7 @@ class ItemRegister(EntityRegister):
del self[item]
class Inventory(BoundRegisterMixin):
class Inventory(BoundEnvObjRegister):
@property
def name(self):
@ -131,10 +135,6 @@ class Inventories(ObjectRegister):
class DropOffLocation(Entity):
@property
def can_collide(self):
return False
@property
def encoding(self):
return Constants.ITEM_DROP_OFF
@ -176,7 +176,8 @@ class ItemProperties(NamedTuple):
c = Constants
a = EnvActions
a = Actions
r = Rewards
# noinspection PyAttributeOutsideInit, PyAbstractClass
@ -230,37 +231,43 @@ class ItemFactory(BaseFactory):
additional_observations.update({c.DROP_OFF: self[c.DROP_OFF].as_array()})
return additional_observations
def do_item_action(self, agent: Agent):
def do_item_action(self, agent: Agent) -> (dict, dict):
inventory = self[c.INVENTORY].by_entity(agent)
if drop_off := self[c.DROP_OFF].by_pos(agent.pos):
if inventory:
valid = drop_off.place_item(inventory.pop())
return valid
else:
return c.NOT_VALID
valid = c.NOT_VALID
if valid:
self.print(f'{agent.name} just dropped of an item at {drop_off.pos}.')
info_dict = {f'{agent.name}_DROPOFF_VALID': 1}
else:
self.print(f'{agent.name} just tried to drop off at {agent.pos}, but failed.')
info_dict = {f'{agent.name}_DROPOFF_FAIL': 1}
reward = dict(value=r.DROP_OFF_VALID if valid else r.DROP_OFF_FAIL, reason=a.ITEM_ACTION, info=info_dict)
return valid, reward
elif item := self[c.ITEM].by_pos(agent.pos):
try:
inventory.register_item(item)
item.change_register(inventory)
self[c.ITEM].delete_env_object(item)
item.set_tile_to(self._NO_POS_TILE)
return c.VALID
except RuntimeError:
return c.NOT_VALID
item.change_register(inventory)
item.set_tile_to(self._NO_POS_TILE)
self.print(f'{agent.name} just picked up an item at {agent.pos}')
info_dict = {f'{agent.name}_{a.ITEM_ACTION}_VALID': 1}
return c.VALID, dict(value=r.PICK_UP_VALID, reason=a.ITEM_ACTION, info=info_dict)
else:
return c.NOT_VALID
self.print(f'{agent.name} just tried to pick up an item at {agent.pos}, but failed.')
info_dict = {f'{agent.name}_{a.ITEM_ACTION}_FAIL': 1}
return c.NOT_VALID, dict(value=r.PICK_UP_FAIL, reason=a.ITEM_ACTION, info=info_dict)
def do_additional_actions(self, agent: Agent, action: Action) -> Union[None, c]:
def do_additional_actions(self, agent: Agent, action: Action) -> (dict, dict):
# noinspection PyUnresolvedReferences
valid = super().do_additional_actions(agent, action)
if valid is None:
action_result = super().do_additional_actions(agent, action)
if action_result is None:
if action == a.ITEM_ACTION:
valid = self.do_item_action(agent)
return valid
action_result = self.do_item_action(agent)
return action_result
else:
return None
else:
return valid
return action_result
def do_additional_reset(self) -> None:
# noinspection PyUnresolvedReferences
@ -277,9 +284,9 @@ class ItemFactory(BaseFactory):
else:
self.print('No Items are spawning, limit is reached.')
def do_additional_step(self) -> dict:
def do_additional_step(self) -> (List[dict], dict):
# noinspection PyUnresolvedReferences
info_dict = super().do_additional_step()
super_reward_info = super().do_additional_step()
for item in list(self[c.ITEM].values()):
if item.auto_despawn >= 1:
item.set_auto_despawn(item.auto_despawn-1)
@ -292,35 +299,7 @@ class ItemFactory(BaseFactory):
self.trigger_item_spawn()
else:
self._next_item_spawn = max(0, self._next_item_spawn-1)
return info_dict
def calculate_additional_reward(self, agent: Agent) -> (int, dict):
# noinspection PyUnresolvedReferences
reward, info_dict = super().calculate_additional_reward(agent)
if a.ITEM_ACTION == agent.temp_action:
if agent.temp_valid:
if drop_off := self[c.DROP_OFF].by_pos(agent.pos):
info_dict.update({f'{agent.name}_item_drop_off': 1})
info_dict.update(item_drop_off=1)
self.print(f'{agent.name} just dropped of an item at {drop_off.pos}.')
reward += 1
else:
info_dict.update({f'{agent.name}_item_pickup': 1})
info_dict.update(item_pickup=1)
self.print(f'{agent.name} just picked up an item at {agent.pos}')
reward += 0.2
else:
if self[c.DROP_OFF].by_pos(agent.pos):
info_dict.update({f'{agent.name}_failed_drop_off': 1})
info_dict.update(failed_drop_off=1)
self.print(f'{agent.name} just tried to drop off at {agent.pos}, but failed.')
reward -= 0.1
else:
info_dict.update({f'{agent.name}_failed_item_action': 1})
info_dict.update(failed_pick_up=1)
self.print(f'{agent.name} just tried to pick up an item at {agent.pos}, but failed.')
reward -= 0.1
return reward, info_dict
return super_reward_info
def render_additional_assets(self, mode='human'):
# noinspection PyUnresolvedReferences
@ -335,9 +314,9 @@ class ItemFactory(BaseFactory):
if __name__ == '__main__':
from environments.utility_classes import AgentRenderOptions as aro, ObservationProperties
render = True
render = False
item_probs = ItemProperties(n_items=30)
item_probs = ItemProperties(n_items=30, n_drop_off_locations=6)
obs_props = ObservationProperties(render_agents=aro.SEPERATE, omit_agent_self=True, pomdp_r=2)
@ -345,7 +324,7 @@ if __name__ == '__main__':
'allow_diagonal_movement': True,
'allow_no_op': False}
factory = ItemFactory(n_agents=2, done_at_collision=False,
factory = ItemFactory(n_agents=6, done_at_collision=False,
level_name='rooms', max_steps=400,
obs_prop=obs_props, parse_doors=True,
record_episodes=True, verbose=True,