Experiments look good

This commit is contained in:
Steffen Illium
2022-01-15 12:37:58 +01:00
parent d29ccbbb71
commit 823aa075b9
14 changed files with 478 additions and 297 deletions

View File

@ -9,7 +9,7 @@ from environments.helpers import Constants as BaseConstants
from environments.helpers import EnvActions as BaseActions
from environments.helpers import Rewards as BaseRewards
from environments import helpers as h
from environments.factory.base.objects import Agent, Entity, Action, Tile
from environments.factory.base.objects import Agent, Entity, Action, Floor
from environments.factory.base.registers import Entities, EntityRegister, BoundEnvObjRegister, ObjectRegister
from environments.factory.base.renderer import RenderEntity
@ -25,7 +25,7 @@ class Constants(BaseConstants):
class Actions(BaseActions):
ITEM_ACTION = 'item_action'
ITEM_ACTION = 'ITEMACTION'
class Rewards(BaseRewards):
@ -62,7 +62,7 @@ class ItemRegister(EntityRegister):
_accepted_objects = Item
def spawn_items(self, tiles: List[Tile]):
def spawn_items(self, tiles: List[Floor]):
items = [Item(tile, self) for tile in tiles]
self.register_additional_items(items)
@ -193,16 +193,16 @@ class ItemFactory(BaseFactory):
super().__init__(*args, **kwargs)
@property
def additional_actions(self) -> Union[Action, List[Action]]:
def actions_hook(self) -> Union[Action, List[Action]]:
# noinspection PyUnresolvedReferences
super_actions = super().additional_actions
super_actions = super().actions_hook
super_actions.append(Action(str_ident=a.ITEM_ACTION))
return super_actions
@property
def additional_entities(self) -> Dict[(str, Entities)]:
def entities_hook(self) -> Dict[(str, Entities)]:
# noinspection PyUnresolvedReferences
super_entities = super().additional_entities
super_entities = super().entities_hook
empty_tiles = self[c.FLOOR].empty_tiles[:self.item_prop.n_drop_off_locations]
drop_offs = DropOffLocations.from_tiles(
@ -220,13 +220,13 @@ class ItemFactory(BaseFactory):
super_entities.update({c.DROP_OFF: drop_offs, c.ITEM: item_register, c.INVENTORY: inventories})
return super_entities
def _additional_per_agent_raw_observations(self, agent) -> Dict[str, np.typing.ArrayLike]:
additional_raw_observations = super()._additional_per_agent_raw_observations(agent)
def per_agent_raw_observations_hook(self, agent) -> Dict[str, np.typing.ArrayLike]:
additional_raw_observations = super().per_agent_raw_observations_hook(agent)
additional_raw_observations.update({c.INVENTORY: self[c.INVENTORY].by_entity(agent).as_array()})
return additional_raw_observations
def _additional_observations(self) -> Dict[str, np.typing.ArrayLike]:
additional_observations = super()._additional_observations()
def observations_hook(self) -> Dict[str, np.typing.ArrayLike]:
additional_observations = super().observations_hook()
additional_observations.update({c.ITEM: self[c.ITEM].as_array()})
additional_observations.update({c.DROP_OFF: self[c.DROP_OFF].as_array()})
return additional_observations
@ -240,21 +240,21 @@ class ItemFactory(BaseFactory):
valid = c.NOT_VALID
if valid:
self.print(f'{agent.name} just dropped of an item at {drop_off.pos}.')
info_dict = {f'{agent.name}_DROPOFF_VALID': 1}
info_dict = {f'{agent.name}_DROPOFF_VALID': 1, 'DROPOFF_VALID': 1}
else:
self.print(f'{agent.name} just tried to drop off at {agent.pos}, but failed.')
info_dict = {f'{agent.name}_DROPOFF_FAIL': 1}
info_dict = {f'{agent.name}_DROPOFF_FAIL': 1, 'DROPOFF_FAIL': 1}
reward = dict(value=r.DROP_OFF_VALID if valid else r.DROP_OFF_FAIL, reason=a.ITEM_ACTION, info=info_dict)
return valid, reward
elif item := self[c.ITEM].by_pos(agent.pos):
item.change_register(inventory)
item.set_tile_to(self._NO_POS_TILE)
self.print(f'{agent.name} just picked up an item at {agent.pos}')
info_dict = {f'{agent.name}_{a.ITEM_ACTION}_VALID': 1}
info_dict = {f'{agent.name}_{a.ITEM_ACTION}_VALID': 1, f'{a.ITEM_ACTION}_VALID': 1}
return c.VALID, dict(value=r.PICK_UP_VALID, reason=a.ITEM_ACTION, info=info_dict)
else:
self.print(f'{agent.name} just tried to pick up an item at {agent.pos}, but failed.')
info_dict = {f'{agent.name}_{a.ITEM_ACTION}_FAIL': 1}
info_dict = {f'{agent.name}_{a.ITEM_ACTION}_FAIL': 1, f'{a.ITEM_ACTION}_FAIL': 1}
return c.NOT_VALID, dict(value=r.PICK_UP_FAIL, reason=a.ITEM_ACTION, info=info_dict)
def do_additional_actions(self, agent: Agent, action: Action) -> (dict, dict):
@ -269,9 +269,9 @@ class ItemFactory(BaseFactory):
else:
return action_result
def do_additional_reset(self) -> None:
def reset_hook(self) -> None:
# noinspection PyUnresolvedReferences
super().do_additional_reset()
super().reset_hook()
self._next_item_spawn = self.item_prop.spawn_frequency
self.trigger_item_spawn()
@ -284,9 +284,9 @@ class ItemFactory(BaseFactory):
else:
self.print('No Items are spawning, limit is reached.')
def do_additional_step(self) -> (List[dict], dict):
def step_hook(self) -> (List[dict], dict):
# noinspection PyUnresolvedReferences
super_reward_info = super().do_additional_step()
super_reward_info = super().step_hook()
for item in list(self[c.ITEM].values()):
if item.auto_despawn >= 1:
item.set_auto_despawn(item.auto_despawn-1)
@ -301,9 +301,9 @@ class ItemFactory(BaseFactory):
self._next_item_spawn = max(0, self._next_item_spawn-1)
return super_reward_info
def render_additional_assets(self, mode='human'):
def render_assets_hook(self, mode='human'):
# noinspection PyUnresolvedReferences
additional_assets = super().render_additional_assets()
additional_assets = super().render_assets_hook()
items = [RenderEntity(c.ITEM, item.tile.pos) for item in self[c.ITEM] if item.tile != self._NO_POS_TILE]
additional_assets.extend(items)
drop_offs = [RenderEntity(c.DROP_OFF, drop_off.tile.pos) for drop_off in self[c.DROP_OFF]]
@ -314,7 +314,7 @@ class ItemFactory(BaseFactory):
if __name__ == '__main__':
from environments.utility_classes import AgentRenderOptions as aro, ObservationProperties
render = False
render = True
item_probs = ItemProperties(n_items=30, n_drop_off_locations=6)
@ -336,18 +336,18 @@ if __name__ == '__main__':
obs_space = factory.observation_space
obs_space_named = factory.named_observation_space
for epoch in range(4):
for epoch in range(400):
random_actions = [[random.randint(0, n_actions) for _
in range(factory.n_agents)] for _
in range(factory.max_steps + 1)]
env_state = factory.reset()
r = 0
rwrd = 0
for agent_i_action in random_actions:
env_state, step_r, done_bool, info_obj = factory.step(agent_i_action)
r += step_r
rwrd += step_r
if render:
factory.render()
if done_bool:
break
print(f'Factory run {epoch} done, reward is:\n {r}')
print(f'Factory run {epoch} done, reward is:\n {rwrd}')
pass