Experiments look good

This commit is contained in:
Steffen Illium
2022-01-15 12:37:58 +01:00
parent d29ccbbb71
commit 823aa075b9
14 changed files with 478 additions and 297 deletions

View File

@ -1,5 +1,4 @@
import time
from enum import Enum
from typing import List, Union, NamedTuple, Dict
import random
@ -12,7 +11,7 @@ from environments.helpers import EnvActions as BaseActions
from environments.helpers import Rewards as BaseRewards
from environments.factory.base.base_factory import BaseFactory
from environments.factory.base.objects import Agent, Action, Entity, Tile
from environments.factory.base.objects import Agent, Action, Entity, Floor
from environments.factory.base.registers import Entities, EntityRegister
from environments.factory.base.renderer import RenderEntity
@ -43,7 +42,6 @@ class DirtProperties(NamedTuple):
max_local_amount: int = 2 # Max dirt amount per tile.
max_global_amount: int = 20 # Max dirt amount in the whole environment.
dirt_smear_amount: float = 0.2 # Agents smear dirt, when not cleaning up in place.
agent_can_interact: bool = True # Whether the agents can interact with the dirt in this environment.
done_when_clean: bool = True
@ -89,7 +87,7 @@ class DirtRegister(EntityRegister):
self._dirt_properties: DirtProperties = dirt_properties
def spawn_dirt(self, then_dirty_tiles) -> bool:
if isinstance(then_dirty_tiles, Tile):
if isinstance(then_dirty_tiles, Floor):
then_dirty_tiles = [then_dirty_tiles]
for tile in then_dirty_tiles:
if not self.amount > self.dirt_properties.max_global_amount:
@ -128,15 +126,14 @@ r = Rewards
class DirtFactory(BaseFactory):
@property
def additional_actions(self) -> Union[Action, List[Action]]:
super_actions = super().additional_actions
if self.dirt_prop.agent_can_interact:
super_actions.append(Action(str_ident=a.CLEAN_UP))
def actions_hook(self) -> Union[Action, List[Action]]:
super_actions = super().actions_hook
super_actions.append(Action(str_ident=a.CLEAN_UP))
return super_actions
@property
def additional_entities(self) -> Dict[(Enum, Entities)]:
super_entities = super().additional_entities
def entities_hook(self) -> Dict[(str, Entities)]:
super_entities = super().entities_hook
dirt_register = DirtRegister(self.dirt_prop, self._level_shape)
super_entities.update(({c.DIRT: dirt_register}))
return super_entities
@ -148,10 +145,11 @@ class DirtFactory(BaseFactory):
self._dirt_rng = np.random.default_rng(env_seed)
self._dirt: DirtRegister
kwargs.update(env_seed=env_seed)
# TODO: Reset ---> document this
super().__init__(*args, **kwargs)
def render_additional_assets(self, mode='human'):
additional_assets = super().render_additional_assets()
def render_assets_hook(self, mode='human'):
additional_assets = super().render_assets_hook()
dirt = [RenderEntity('dirt', dirt.tile.pos, min(0.15 + dirt.amount, 1.5), 'scale')
for dirt in self[c.DIRT]]
additional_assets.extend(dirt)
@ -167,12 +165,12 @@ class DirtFactory(BaseFactory):
dirt.set_new_amount(max(new_dirt_amount, c.FREE_CELL.value))
valid = c.VALID
self.print(f'{agent.name} did just clean up some dirt at {agent.pos}.')
info_dict = {f'{agent.name}_{a.CLEAN_UP}_VALID': 1}
info_dict = {f'{agent.name}_{a.CLEAN_UP}_VALID': 1, 'cleanup_valid': 1}
reward = r.CLEAN_UP_VALID
else:
valid = c.NOT_VALID
self.print(f'{agent.name} just tried to clean up some dirt at {agent.pos}, but failed.')
info_dict = {f'{agent.name}_{a.CLEAN_UP}_FAIL': 1}
info_dict = {f'{agent.name}_{a.CLEAN_UP}_FAIL': 1, 'cleanup_fail': 1}
reward = r.CLEAN_UP_FAIL
if valid and self.dirt_prop.done_when_clean and (len(self[c.DIRT]) == 0):
@ -195,8 +193,8 @@ class DirtFactory(BaseFactory):
n_dirt_tiles = max(0, int(new_spawn * len(free_for_dirt)))
self[c.DIRT].spawn_dirt(free_for_dirt[:n_dirt_tiles])
def do_additional_step(self) -> (List[dict], dict):
super_reward_info = super().do_additional_step()
def step_hook(self) -> (List[dict], dict):
super_reward_info = super().step_hook()
if smear_amount := self.dirt_prop.dirt_smear_amount:
for agent in self[c.AGENT]:
if agent.temp_valid and agent.last_pos != c.NO_POS:
@ -229,8 +227,8 @@ class DirtFactory(BaseFactory):
else:
return action_result
def do_additional_reset(self) -> None:
super().do_additional_reset()
def reset_hook(self) -> None:
super().reset_hook()
self.trigger_dirt_spawn(initial_spawn=True)
self._next_dirt_spawn = self.dirt_prop.spawn_frequency if self.dirt_prop.spawn_frequency else -1
@ -242,13 +240,13 @@ class DirtFactory(BaseFactory):
return all_cleaned, super_dict
return super_done, super_dict
def _additional_observations(self) -> Dict[str, np.typing.ArrayLike]:
additional_observations = super()._additional_observations()
def observations_hook(self) -> Dict[str, np.typing.ArrayLike]:
additional_observations = super().observations_hook()
additional_observations.update({c.DIRT: self[c.DIRT].as_array()})
return additional_observations
def gather_additional_info(self, agent: Agent) -> dict:
event_reward_dict = super().additional_per_agent_reward(agent)
event_reward_dict = super().per_agent_reward_hook(agent)
info_dict = dict()
dirt = [dirt.amount for dirt in self[c.DIRT]]
@ -280,8 +278,7 @@ if __name__ == '__main__':
max_local_amount=1,
spawn_frequency=0,
max_spawn_ratio=0.05,
dirt_smear_amount=0.0,
agent_can_interact=True
dirt_smear_amount=0.0
)
obs_props = ObservationProperties(render_agents=aro.COMBINED, omit_agent_self=True,
@ -294,13 +291,13 @@ if __name__ == '__main__':
global_timings = []
for i in range(10):
factory = DirtFactory(n_agents=1, done_at_collision=False,
factory = DirtFactory(n_agents=10, done_at_collision=False,
level_name='rooms', max_steps=1000,
doors_have_area=False,
obs_prop=obs_props, parse_doors=True,
verbose=True,
mv_prop=move_props, dirt_prop=dirt_props,
inject_agents=[TSPDirtAgent],
# inject_agents=[TSPDirtAgent],
)
# noinspection DuplicatedCode
@ -318,11 +315,11 @@ if __name__ == '__main__':
env_state = factory.reset()
if render:
factory.render()
tsp_agent = factory.get_injected_agents()[0]
# tsp_agent = factory.get_injected_agents()[0]
rwrd = 0
for agent_i_action in random_actions:
agent_i_action = tsp_agent.predict()
# agent_i_action = tsp_agent.predict()
env_state, step_rwrd, done_bool, info_obj = factory.step(agent_i_action)
rwrd += step_rwrd
if render: