Experiments look good
This commit is contained in:
@ -1,5 +1,4 @@
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import List, Union, NamedTuple, Dict
|
||||
import random
|
||||
|
||||
@ -12,7 +11,7 @@ from environments.helpers import EnvActions as BaseActions
|
||||
from environments.helpers import Rewards as BaseRewards
|
||||
|
||||
from environments.factory.base.base_factory import BaseFactory
|
||||
from environments.factory.base.objects import Agent, Action, Entity, Tile
|
||||
from environments.factory.base.objects import Agent, Action, Entity, Floor
|
||||
from environments.factory.base.registers import Entities, EntityRegister
|
||||
|
||||
from environments.factory.base.renderer import RenderEntity
|
||||
@ -43,7 +42,6 @@ class DirtProperties(NamedTuple):
|
||||
max_local_amount: int = 2 # Max dirt amount per tile.
|
||||
max_global_amount: int = 20 # Max dirt amount in the whole environment.
|
||||
dirt_smear_amount: float = 0.2 # Agents smear dirt, when not cleaning up in place.
|
||||
agent_can_interact: bool = True # Whether the agents can interact with the dirt in this environment.
|
||||
done_when_clean: bool = True
|
||||
|
||||
|
||||
@ -89,7 +87,7 @@ class DirtRegister(EntityRegister):
|
||||
self._dirt_properties: DirtProperties = dirt_properties
|
||||
|
||||
def spawn_dirt(self, then_dirty_tiles) -> bool:
|
||||
if isinstance(then_dirty_tiles, Tile):
|
||||
if isinstance(then_dirty_tiles, Floor):
|
||||
then_dirty_tiles = [then_dirty_tiles]
|
||||
for tile in then_dirty_tiles:
|
||||
if not self.amount > self.dirt_properties.max_global_amount:
|
||||
@ -128,15 +126,14 @@ r = Rewards
|
||||
class DirtFactory(BaseFactory):
|
||||
|
||||
@property
|
||||
def additional_actions(self) -> Union[Action, List[Action]]:
|
||||
super_actions = super().additional_actions
|
||||
if self.dirt_prop.agent_can_interact:
|
||||
super_actions.append(Action(str_ident=a.CLEAN_UP))
|
||||
def actions_hook(self) -> Union[Action, List[Action]]:
|
||||
super_actions = super().actions_hook
|
||||
super_actions.append(Action(str_ident=a.CLEAN_UP))
|
||||
return super_actions
|
||||
|
||||
@property
|
||||
def additional_entities(self) -> Dict[(Enum, Entities)]:
|
||||
super_entities = super().additional_entities
|
||||
def entities_hook(self) -> Dict[(str, Entities)]:
|
||||
super_entities = super().entities_hook
|
||||
dirt_register = DirtRegister(self.dirt_prop, self._level_shape)
|
||||
super_entities.update(({c.DIRT: dirt_register}))
|
||||
return super_entities
|
||||
@ -148,10 +145,11 @@ class DirtFactory(BaseFactory):
|
||||
self._dirt_rng = np.random.default_rng(env_seed)
|
||||
self._dirt: DirtRegister
|
||||
kwargs.update(env_seed=env_seed)
|
||||
# TODO: Reset ---> document this
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def render_additional_assets(self, mode='human'):
|
||||
additional_assets = super().render_additional_assets()
|
||||
def render_assets_hook(self, mode='human'):
|
||||
additional_assets = super().render_assets_hook()
|
||||
dirt = [RenderEntity('dirt', dirt.tile.pos, min(0.15 + dirt.amount, 1.5), 'scale')
|
||||
for dirt in self[c.DIRT]]
|
||||
additional_assets.extend(dirt)
|
||||
@ -167,12 +165,12 @@ class DirtFactory(BaseFactory):
|
||||
dirt.set_new_amount(max(new_dirt_amount, c.FREE_CELL.value))
|
||||
valid = c.VALID
|
||||
self.print(f'{agent.name} did just clean up some dirt at {agent.pos}.')
|
||||
info_dict = {f'{agent.name}_{a.CLEAN_UP}_VALID': 1}
|
||||
info_dict = {f'{agent.name}_{a.CLEAN_UP}_VALID': 1, 'cleanup_valid': 1}
|
||||
reward = r.CLEAN_UP_VALID
|
||||
else:
|
||||
valid = c.NOT_VALID
|
||||
self.print(f'{agent.name} just tried to clean up some dirt at {agent.pos}, but failed.')
|
||||
info_dict = {f'{agent.name}_{a.CLEAN_UP}_FAIL': 1}
|
||||
info_dict = {f'{agent.name}_{a.CLEAN_UP}_FAIL': 1, 'cleanup_fail': 1}
|
||||
reward = r.CLEAN_UP_FAIL
|
||||
|
||||
if valid and self.dirt_prop.done_when_clean and (len(self[c.DIRT]) == 0):
|
||||
@ -195,8 +193,8 @@ class DirtFactory(BaseFactory):
|
||||
n_dirt_tiles = max(0, int(new_spawn * len(free_for_dirt)))
|
||||
self[c.DIRT].spawn_dirt(free_for_dirt[:n_dirt_tiles])
|
||||
|
||||
def do_additional_step(self) -> (List[dict], dict):
|
||||
super_reward_info = super().do_additional_step()
|
||||
def step_hook(self) -> (List[dict], dict):
|
||||
super_reward_info = super().step_hook()
|
||||
if smear_amount := self.dirt_prop.dirt_smear_amount:
|
||||
for agent in self[c.AGENT]:
|
||||
if agent.temp_valid and agent.last_pos != c.NO_POS:
|
||||
@ -229,8 +227,8 @@ class DirtFactory(BaseFactory):
|
||||
else:
|
||||
return action_result
|
||||
|
||||
def do_additional_reset(self) -> None:
|
||||
super().do_additional_reset()
|
||||
def reset_hook(self) -> None:
|
||||
super().reset_hook()
|
||||
self.trigger_dirt_spawn(initial_spawn=True)
|
||||
self._next_dirt_spawn = self.dirt_prop.spawn_frequency if self.dirt_prop.spawn_frequency else -1
|
||||
|
||||
@ -242,13 +240,13 @@ class DirtFactory(BaseFactory):
|
||||
return all_cleaned, super_dict
|
||||
return super_done, super_dict
|
||||
|
||||
def _additional_observations(self) -> Dict[str, np.typing.ArrayLike]:
|
||||
additional_observations = super()._additional_observations()
|
||||
def observations_hook(self) -> Dict[str, np.typing.ArrayLike]:
|
||||
additional_observations = super().observations_hook()
|
||||
additional_observations.update({c.DIRT: self[c.DIRT].as_array()})
|
||||
return additional_observations
|
||||
|
||||
def gather_additional_info(self, agent: Agent) -> dict:
|
||||
event_reward_dict = super().additional_per_agent_reward(agent)
|
||||
event_reward_dict = super().per_agent_reward_hook(agent)
|
||||
info_dict = dict()
|
||||
|
||||
dirt = [dirt.amount for dirt in self[c.DIRT]]
|
||||
@ -280,8 +278,7 @@ if __name__ == '__main__':
|
||||
max_local_amount=1,
|
||||
spawn_frequency=0,
|
||||
max_spawn_ratio=0.05,
|
||||
dirt_smear_amount=0.0,
|
||||
agent_can_interact=True
|
||||
dirt_smear_amount=0.0
|
||||
)
|
||||
|
||||
obs_props = ObservationProperties(render_agents=aro.COMBINED, omit_agent_self=True,
|
||||
@ -294,13 +291,13 @@ if __name__ == '__main__':
|
||||
global_timings = []
|
||||
for i in range(10):
|
||||
|
||||
factory = DirtFactory(n_agents=1, done_at_collision=False,
|
||||
factory = DirtFactory(n_agents=10, done_at_collision=False,
|
||||
level_name='rooms', max_steps=1000,
|
||||
doors_have_area=False,
|
||||
obs_prop=obs_props, parse_doors=True,
|
||||
verbose=True,
|
||||
mv_prop=move_props, dirt_prop=dirt_props,
|
||||
inject_agents=[TSPDirtAgent],
|
||||
# inject_agents=[TSPDirtAgent],
|
||||
)
|
||||
|
||||
# noinspection DuplicatedCode
|
||||
@ -318,11 +315,11 @@ if __name__ == '__main__':
|
||||
env_state = factory.reset()
|
||||
if render:
|
||||
factory.render()
|
||||
tsp_agent = factory.get_injected_agents()[0]
|
||||
# tsp_agent = factory.get_injected_agents()[0]
|
||||
|
||||
rwrd = 0
|
||||
for agent_i_action in random_actions:
|
||||
agent_i_action = tsp_agent.predict()
|
||||
# agent_i_action = tsp_agent.predict()
|
||||
env_state, step_rwrd, done_bool, info_obj = factory.step(agent_i_action)
|
||||
rwrd += step_rwrd
|
||||
if render:
|
||||
|
Reference in New Issue
Block a user