Experiments look good
This commit is contained in:
@ -9,7 +9,7 @@ from environments.helpers import Constants as BaseConstants
|
||||
from environments.helpers import EnvActions as BaseActions
|
||||
from environments.helpers import Rewards as BaseRewards
|
||||
from environments import helpers as h
|
||||
from environments.factory.base.objects import Agent, Entity, Action, Tile
|
||||
from environments.factory.base.objects import Agent, Entity, Action, Floor
|
||||
from environments.factory.base.registers import Entities, EntityRegister, BoundEnvObjRegister, ObjectRegister
|
||||
|
||||
from environments.factory.base.renderer import RenderEntity
|
||||
@ -25,7 +25,7 @@ class Constants(BaseConstants):
|
||||
|
||||
|
||||
class Actions(BaseActions):
|
||||
ITEM_ACTION = 'item_action'
|
||||
ITEM_ACTION = 'ITEMACTION'
|
||||
|
||||
|
||||
class Rewards(BaseRewards):
|
||||
@ -62,7 +62,7 @@ class ItemRegister(EntityRegister):
|
||||
|
||||
_accepted_objects = Item
|
||||
|
||||
def spawn_items(self, tiles: List[Tile]):
|
||||
def spawn_items(self, tiles: List[Floor]):
|
||||
items = [Item(tile, self) for tile in tiles]
|
||||
self.register_additional_items(items)
|
||||
|
||||
@ -193,16 +193,16 @@ class ItemFactory(BaseFactory):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def additional_actions(self) -> Union[Action, List[Action]]:
|
||||
def actions_hook(self) -> Union[Action, List[Action]]:
|
||||
# noinspection PyUnresolvedReferences
|
||||
super_actions = super().additional_actions
|
||||
super_actions = super().actions_hook
|
||||
super_actions.append(Action(str_ident=a.ITEM_ACTION))
|
||||
return super_actions
|
||||
|
||||
@property
|
||||
def additional_entities(self) -> Dict[(str, Entities)]:
|
||||
def entities_hook(self) -> Dict[(str, Entities)]:
|
||||
# noinspection PyUnresolvedReferences
|
||||
super_entities = super().additional_entities
|
||||
super_entities = super().entities_hook
|
||||
|
||||
empty_tiles = self[c.FLOOR].empty_tiles[:self.item_prop.n_drop_off_locations]
|
||||
drop_offs = DropOffLocations.from_tiles(
|
||||
@ -220,13 +220,13 @@ class ItemFactory(BaseFactory):
|
||||
super_entities.update({c.DROP_OFF: drop_offs, c.ITEM: item_register, c.INVENTORY: inventories})
|
||||
return super_entities
|
||||
|
||||
def _additional_per_agent_raw_observations(self, agent) -> Dict[str, np.typing.ArrayLike]:
|
||||
additional_raw_observations = super()._additional_per_agent_raw_observations(agent)
|
||||
def per_agent_raw_observations_hook(self, agent) -> Dict[str, np.typing.ArrayLike]:
|
||||
additional_raw_observations = super().per_agent_raw_observations_hook(agent)
|
||||
additional_raw_observations.update({c.INVENTORY: self[c.INVENTORY].by_entity(agent).as_array()})
|
||||
return additional_raw_observations
|
||||
|
||||
def _additional_observations(self) -> Dict[str, np.typing.ArrayLike]:
|
||||
additional_observations = super()._additional_observations()
|
||||
def observations_hook(self) -> Dict[str, np.typing.ArrayLike]:
|
||||
additional_observations = super().observations_hook()
|
||||
additional_observations.update({c.ITEM: self[c.ITEM].as_array()})
|
||||
additional_observations.update({c.DROP_OFF: self[c.DROP_OFF].as_array()})
|
||||
return additional_observations
|
||||
@ -240,21 +240,21 @@ class ItemFactory(BaseFactory):
|
||||
valid = c.NOT_VALID
|
||||
if valid:
|
||||
self.print(f'{agent.name} just dropped of an item at {drop_off.pos}.')
|
||||
info_dict = {f'{agent.name}_DROPOFF_VALID': 1}
|
||||
info_dict = {f'{agent.name}_DROPOFF_VALID': 1, 'DROPOFF_VALID': 1}
|
||||
else:
|
||||
self.print(f'{agent.name} just tried to drop off at {agent.pos}, but failed.')
|
||||
info_dict = {f'{agent.name}_DROPOFF_FAIL': 1}
|
||||
info_dict = {f'{agent.name}_DROPOFF_FAIL': 1, 'DROPOFF_FAIL': 1}
|
||||
reward = dict(value=r.DROP_OFF_VALID if valid else r.DROP_OFF_FAIL, reason=a.ITEM_ACTION, info=info_dict)
|
||||
return valid, reward
|
||||
elif item := self[c.ITEM].by_pos(agent.pos):
|
||||
item.change_register(inventory)
|
||||
item.set_tile_to(self._NO_POS_TILE)
|
||||
self.print(f'{agent.name} just picked up an item at {agent.pos}')
|
||||
info_dict = {f'{agent.name}_{a.ITEM_ACTION}_VALID': 1}
|
||||
info_dict = {f'{agent.name}_{a.ITEM_ACTION}_VALID': 1, f'{a.ITEM_ACTION}_VALID': 1}
|
||||
return c.VALID, dict(value=r.PICK_UP_VALID, reason=a.ITEM_ACTION, info=info_dict)
|
||||
else:
|
||||
self.print(f'{agent.name} just tried to pick up an item at {agent.pos}, but failed.')
|
||||
info_dict = {f'{agent.name}_{a.ITEM_ACTION}_FAIL': 1}
|
||||
info_dict = {f'{agent.name}_{a.ITEM_ACTION}_FAIL': 1, f'{a.ITEM_ACTION}_FAIL': 1}
|
||||
return c.NOT_VALID, dict(value=r.PICK_UP_FAIL, reason=a.ITEM_ACTION, info=info_dict)
|
||||
|
||||
def do_additional_actions(self, agent: Agent, action: Action) -> (dict, dict):
|
||||
@ -269,9 +269,9 @@ class ItemFactory(BaseFactory):
|
||||
else:
|
||||
return action_result
|
||||
|
||||
def do_additional_reset(self) -> None:
|
||||
def reset_hook(self) -> None:
|
||||
# noinspection PyUnresolvedReferences
|
||||
super().do_additional_reset()
|
||||
super().reset_hook()
|
||||
self._next_item_spawn = self.item_prop.spawn_frequency
|
||||
self.trigger_item_spawn()
|
||||
|
||||
@ -284,9 +284,9 @@ class ItemFactory(BaseFactory):
|
||||
else:
|
||||
self.print('No Items are spawning, limit is reached.')
|
||||
|
||||
def do_additional_step(self) -> (List[dict], dict):
|
||||
def step_hook(self) -> (List[dict], dict):
|
||||
# noinspection PyUnresolvedReferences
|
||||
super_reward_info = super().do_additional_step()
|
||||
super_reward_info = super().step_hook()
|
||||
for item in list(self[c.ITEM].values()):
|
||||
if item.auto_despawn >= 1:
|
||||
item.set_auto_despawn(item.auto_despawn-1)
|
||||
@ -301,9 +301,9 @@ class ItemFactory(BaseFactory):
|
||||
self._next_item_spawn = max(0, self._next_item_spawn-1)
|
||||
return super_reward_info
|
||||
|
||||
def render_additional_assets(self, mode='human'):
|
||||
def render_assets_hook(self, mode='human'):
|
||||
# noinspection PyUnresolvedReferences
|
||||
additional_assets = super().render_additional_assets()
|
||||
additional_assets = super().render_assets_hook()
|
||||
items = [RenderEntity(c.ITEM, item.tile.pos) for item in self[c.ITEM] if item.tile != self._NO_POS_TILE]
|
||||
additional_assets.extend(items)
|
||||
drop_offs = [RenderEntity(c.DROP_OFF, drop_off.tile.pos) for drop_off in self[c.DROP_OFF]]
|
||||
@ -314,7 +314,7 @@ class ItemFactory(BaseFactory):
|
||||
if __name__ == '__main__':
|
||||
from environments.utility_classes import AgentRenderOptions as aro, ObservationProperties
|
||||
|
||||
render = False
|
||||
render = True
|
||||
|
||||
item_probs = ItemProperties(n_items=30, n_drop_off_locations=6)
|
||||
|
||||
@ -336,18 +336,18 @@ if __name__ == '__main__':
|
||||
obs_space = factory.observation_space
|
||||
obs_space_named = factory.named_observation_space
|
||||
|
||||
for epoch in range(4):
|
||||
for epoch in range(400):
|
||||
random_actions = [[random.randint(0, n_actions) for _
|
||||
in range(factory.n_agents)] for _
|
||||
in range(factory.max_steps + 1)]
|
||||
env_state = factory.reset()
|
||||
r = 0
|
||||
rwrd = 0
|
||||
for agent_i_action in random_actions:
|
||||
env_state, step_r, done_bool, info_obj = factory.step(agent_i_action)
|
||||
r += step_r
|
||||
rwrd += step_r
|
||||
if render:
|
||||
factory.render()
|
||||
if done_bool:
|
||||
break
|
||||
print(f'Factory run {epoch} done, reward is:\n {r}')
|
||||
print(f'Factory run {epoch} done, reward is:\n {rwrd}')
|
||||
pass
|
||||
|
Reference in New Issue
Block a user