mirror of
				https://github.com/illiumst/marl-factory-grid.git
				synced 2025-10-31 04:37:25 +01:00 
			
		
		
		
	Experiments look good
This commit is contained in:
		| @@ -15,8 +15,8 @@ from environments import helpers as h | ||||
| from environments.helpers import Constants as c | ||||
| from environments.helpers import EnvActions as a | ||||
| from environments.helpers import Rewards as r | ||||
| from environments.factory.base.objects import Agent, Tile, Action | ||||
| from environments.factory.base.registers import Actions, Entities, Agents, Doors, FloorTiles, WallTiles, PlaceHolders, \ | ||||
| from environments.factory.base.objects import Agent, Floor, Action | ||||
| from environments.factory.base.registers import Actions, Entities, Agents, Doors, Floors, Walls, PlaceHolders, \ | ||||
|     GlobalPositions | ||||
| from environments.utility_classes import MovementProperties, ObservationProperties, MarlFrameStack | ||||
| from environments.utility_classes import AgentRenderOptions as a_obs | ||||
| @@ -121,7 +121,7 @@ class BaseFactory(gym.Env): | ||||
|         self.doors_have_area = doors_have_area | ||||
|         self.individual_rewards = individual_rewards | ||||
|  | ||||
|         # Reset | ||||
|         # TODO: Reset ---> document this | ||||
|         self.reset() | ||||
|  | ||||
|     def __getitem__(self, item): | ||||
| @@ -141,21 +141,21 @@ class BaseFactory(gym.Env): | ||||
|         self._obs_shape = self._level_shape if not self.obs_prop.pomdp_r else (self.pomdp_diameter, ) * 2 | ||||
|  | ||||
|         # Walls | ||||
|         walls = WallTiles.from_argwhere_coordinates( | ||||
|         walls = Walls.from_argwhere_coordinates( | ||||
|             np.argwhere(level_array == c.OCCUPIED_CELL), | ||||
|             self._level_shape | ||||
|         ) | ||||
|         self._entities.register_additional_items({c.WALLS: walls}) | ||||
|  | ||||
|         # Floor | ||||
|         floor = FloorTiles.from_argwhere_coordinates( | ||||
|         floor = Floors.from_argwhere_coordinates( | ||||
|             np.argwhere(level_array == c.FREE_CELL), | ||||
|             self._level_shape | ||||
|         ) | ||||
|         self._entities.register_additional_items({c.FLOOR: floor}) | ||||
|  | ||||
|         # NOPOS | ||||
|         self._NO_POS_TILE = Tile(c.NO_POS, None) | ||||
|         self._NO_POS_TILE = Floor(c.NO_POS, None) | ||||
|  | ||||
|         # Doors | ||||
|         if self.parse_doors: | ||||
| @@ -170,7 +170,7 @@ class BaseFactory(gym.Env): | ||||
|  | ||||
|         # Actions | ||||
|         self._actions = Actions(self.mv_prop, can_use_doors=self.parse_doors) | ||||
|         if additional_actions := self.additional_actions: | ||||
|         if additional_actions := self.actions_hook: | ||||
|             self._actions.register_additional_items(additional_actions) | ||||
|  | ||||
|         # Agents | ||||
| @@ -202,7 +202,7 @@ class BaseFactory(gym.Env): | ||||
|             self._entities.register_additional_items({c.AGENT_PLACEHOLDER: placeholder}) | ||||
|  | ||||
|         # Additional Entitites from SubEnvs | ||||
|         if additional_entities := self.additional_entities: | ||||
|         if additional_entities := self.entities_hook: | ||||
|             self._entities.register_additional_items(additional_entities) | ||||
|  | ||||
|         if self.obs_prop.show_global_position_info: | ||||
| @@ -217,7 +217,7 @@ class BaseFactory(gym.Env): | ||||
|  | ||||
|     def reset(self) -> (np.typing.ArrayLike, int, bool, dict): | ||||
|         _ = self._base_init_env() | ||||
|         self.do_additional_reset() | ||||
|         self.reset_hook() | ||||
|  | ||||
|         self._steps = 0 | ||||
|  | ||||
| @@ -233,7 +233,7 @@ class BaseFactory(gym.Env): | ||||
|         self._steps += 1 | ||||
|  | ||||
|         # Pre step Hook for later use | ||||
|         self.hook_pre_step() | ||||
|         self.pre_step_hook() | ||||
|  | ||||
|         for action, agent in zip(actions, self[c.AGENT]): | ||||
|             agent.clear_temp_state() | ||||
| @@ -244,7 +244,7 @@ class BaseFactory(gym.Env): | ||||
|                 action_valid, reward = self._do_move_action(agent, action_obj) | ||||
|             elif a.NOOP == action_obj: | ||||
|                 action_valid = c.VALID | ||||
|                 reward = dict(value=r.NOOP, reason=a.NOOP, info={f'{agent.pos}_NOOP': 1}) | ||||
|                 reward = dict(value=r.NOOP, reason=a.NOOP, info={f'{agent.name}_NOOP': 1, 'NOOP': 1}) | ||||
|             elif a.USE_DOOR == action_obj: | ||||
|                 action_valid, reward = self._handle_door_interaction(agent) | ||||
|             else: | ||||
| @@ -258,7 +258,7 @@ class BaseFactory(gym.Env): | ||||
|             agent.step_result = step_result | ||||
|  | ||||
|         # Additional step and Reward, Info Init | ||||
|         rewards, info = self.do_additional_step() | ||||
|         rewards, info = self.step_hook() | ||||
|         # Todo: Make this faster, so that only tiles of entities that can collide are searched. | ||||
|         tiles_with_collisions = self.get_all_tiles_with_collisions() | ||||
|         for tile in tiles_with_collisions: | ||||
| @@ -297,7 +297,7 @@ class BaseFactory(gym.Env): | ||||
|             info.update(self._summarize_state()) | ||||
|  | ||||
|         # Post step Hook for later use | ||||
|         info.update(self.hook_post_step()) | ||||
|         info.update(self.post_step_hook()) | ||||
|  | ||||
|         obs, _ = self._build_observations() | ||||
|  | ||||
| @@ -314,11 +314,11 @@ class BaseFactory(gym.Env): | ||||
|                 door.use() | ||||
|                 valid = c.VALID | ||||
|                 self.print(f'{agent.name} just used a {door.name} at {door.pos}') | ||||
|                 info_dict = {f'{agent.name}_door_use': 1} | ||||
|                 info_dict = {f'{agent.name}_door_use': 1, f'door_use': 1} | ||||
|             # When he doesn't... | ||||
|             else: | ||||
|                 valid = c.NOT_VALID | ||||
|                 info_dict = {f'{agent.name}_failed_door_use': 1} | ||||
|                 info_dict = {f'{agent.name}_failed_door_use': 1, 'failed_door_use': 1} | ||||
|                 self.print(f'{agent.name} just tried to use a door at {agent.pos}, but there is none.') | ||||
|  | ||||
|         else: | ||||
| @@ -334,7 +334,7 @@ class BaseFactory(gym.Env): | ||||
|         per_agent_obsn = dict() | ||||
|         # Generel Observations | ||||
|         lvl_obs = self[c.WALLS].as_array() | ||||
|         door_obs = self[c.DOORS].as_array() | ||||
|         door_obs = self[c.DOORS].as_array() if self.parse_doors else None | ||||
|         if self.obs_prop.render_agents == a_obs.NOT: | ||||
|             global_agent_obs = None | ||||
|         elif self.obs_prop.omit_agent_self and self.n_agents == 1: | ||||
| @@ -342,7 +342,7 @@ class BaseFactory(gym.Env): | ||||
|         else: | ||||
|             global_agent_obs = self[c.AGENT].as_array().copy() | ||||
|         placeholder_obs = self[c.AGENT_PLACEHOLDER].as_array() if self[c.AGENT_PLACEHOLDER] else None | ||||
|         add_obs_dict = self._additional_observations() | ||||
|         add_obs_dict = self.observations_hook() | ||||
|  | ||||
|         for agent_idx, agent in enumerate(self[c.AGENT]): | ||||
|             obs_dict = dict() | ||||
| @@ -367,17 +367,17 @@ class BaseFactory(gym.Env): | ||||
|  | ||||
|             obs_dict[c.WALLS] = lvl_obs | ||||
|             if self.obs_prop.render_agents in [a_obs.SEPERATE, a_obs.COMBINED] and agent_obs is not None: | ||||
|                 obs_dict[c.AGENT] = agent_obs | ||||
|                 obs_dict[c.AGENT] = agent_obs[:] | ||||
|             if self[c.AGENT_PLACEHOLDER] and placeholder_obs is not None: | ||||
|                 obs_dict[c.AGENT_PLACEHOLDER] = placeholder_obs | ||||
|             if self.parse_doors and door_obs is not None: | ||||
|                 obs_dict[c.DOORS] = door_obs | ||||
|                 obs_dict[c.DOORS] = door_obs[:] | ||||
|             obs_dict.update(add_obs_dict) | ||||
|             obsn = np.vstack(list(obs_dict.values())) | ||||
|             if self.obs_prop.pomdp_r: | ||||
|                 obsn = self._do_pomdp_cutout(agent, obsn) | ||||
|  | ||||
|             raw_obs = self._additional_per_agent_raw_observations(agent) | ||||
|             raw_obs = self.per_agent_raw_observations_hook(agent) | ||||
|             raw_obs = {key: np.expand_dims(val, 0) if val.ndim != 3 else val for key, val in raw_obs.items()} | ||||
|             obsn = np.vstack((obsn, *raw_obs.values())) | ||||
|  | ||||
| @@ -387,6 +387,12 @@ class BaseFactory(gym.Env): | ||||
|                                               zip(keys, idxs, list(idxs[1:]) + [idxs[-1]+1, ])} | ||||
|  | ||||
|             # Shadow Casting | ||||
|             if agent.step_result is not None: | ||||
|                 pass | ||||
|             else: | ||||
|                 assert self._steps == 0 | ||||
|                 agent.step_result = {'action_name': a.NOOP, 'action_valid': True, | ||||
|                                      'collisions': [], 'lightmap': None} | ||||
|             if self.obs_prop.cast_shadows: | ||||
|                 try: | ||||
|                     light_block_obs = [obs_idx for key, obs_idx in per_agent_expl_idx[agent.name].items() | ||||
| @@ -430,17 +436,15 @@ class BaseFactory(gym.Env): | ||||
|                 if door_shadowing: | ||||
|                     # noinspection PyUnboundLocalVariable | ||||
|                     light_block_map[xs, ys] = 0 | ||||
|                 if agent.step_result: | ||||
|                     agent.step_result['lightmap'] = light_block_map | ||||
|                     pass | ||||
|                 else: | ||||
|                     assert self._steps == 0 | ||||
|                     agent.step_result = {'action_name': a.NOOP, 'action_valid': True, | ||||
|                                          'collisions': [], 'lightmap': light_block_map} | ||||
|  | ||||
|                 agent.step_result['lightmap'] = light_block_map | ||||
|  | ||||
|                 obsn[shadowed_obs] = ((obsn[shadowed_obs] * light_block_map) + 0.) - (1 - light_block_map) | ||||
|             else: | ||||
|                 pass | ||||
|                 if self._pomdp_r: | ||||
|                     agent.step_result['lightmap'] = np.ones(self._obs_shape) | ||||
|                 else: | ||||
|                     agent.step_result['lightmap'] = None | ||||
|  | ||||
|             per_agent_obsn[agent.name] = obsn | ||||
|  | ||||
| @@ -484,7 +488,7 @@ class BaseFactory(gym.Env): | ||||
|             oobs = np.pad(oobs, ((0, 0), (x0_pad, x1_pad), (y0_pad, y1_pad)), 'constant') | ||||
|         return oobs | ||||
|  | ||||
|     def get_all_tiles_with_collisions(self) -> List[Tile]: | ||||
|     def get_all_tiles_with_collisions(self) -> List[Floor]: | ||||
|         tiles = [x for x in self[c.FLOOR] if len(x.guests_that_can_collide) > 1] | ||||
|         if False: | ||||
|             tiles_with_collisions = list() | ||||
| @@ -503,22 +507,22 @@ class BaseFactory(gym.Env): | ||||
|             valid = agent.move(new_tile) | ||||
|             if valid: | ||||
|                 # This will spam your logs, beware! | ||||
|                 # self.print(f'{agent.name} just moved from {agent.last_pos} to {agent.pos}.') | ||||
|                 # info_dict.update({f'{agent.pos}_move': 1}) | ||||
|                 self.print(f'{agent.name} just moved {action.identifier} from {agent.last_pos} to {agent.pos}.') | ||||
|                 info_dict.update({f'{agent.name}_move': 1, 'move': 1}) | ||||
|                 pass | ||||
|             else: | ||||
|                 valid = c.NOT_VALID | ||||
|                 self.print(f'{agent.name} just hit the wall at {agent.pos}.') | ||||
|                 info_dict.update({f'{agent.name}_wall_collide': 1}) | ||||
|                 self.print(f'{agent.name} just hit the wall at {agent.pos}. ({action.identifier})') | ||||
|                 info_dict.update({f'{agent.name}_wall_collide': 1, 'wall_collide': 1}) | ||||
|         else: | ||||
|             # Agent seems to be trying to Leave the level | ||||
|             self.print(f'{agent.name} tried to leave the level {agent.pos}.') | ||||
|             info_dict.update({f'{agent.name}_wall_collide': 1}) | ||||
|             self.print(f'{agent.name} tried to leave the level {agent.pos}. ({action.identifier})') | ||||
|             info_dict.update({f'{agent.name}_wall_collide': 1, 'wall_collide': 1}) | ||||
|         reward_value = r.MOVEMENTS_VALID if valid else r.MOVEMENTS_FAIL | ||||
|         reward = {'value': reward_value, 'reason': action.identifier, 'info': info_dict} | ||||
|         return valid, reward | ||||
|  | ||||
|     def _check_agent_move(self, agent, action: Action) -> (Tile, bool): | ||||
|     def _check_agent_move(self, agent, action: Action) -> (Floor, bool): | ||||
|         # Actions | ||||
|         x_diff, y_diff = h.ACTIONMAP[action.identifier] | ||||
|         x_new = agent.x + x_diff | ||||
| @@ -556,10 +560,6 @@ class BaseFactory(gym.Env): | ||||
|  | ||||
|         return new_tile, valid | ||||
|  | ||||
|     @abc.abstractmethod | ||||
|     def additional_per_agent_rewards(self, agent) -> List[dict]: | ||||
|         return [] | ||||
|  | ||||
|     def build_reward_result(self, global_env_rewards: list) -> (int, dict): | ||||
|         # Returns: Reward, Info | ||||
|         info = defaultdict(lambda: 0.0) | ||||
| @@ -567,7 +567,7 @@ class BaseFactory(gym.Env): | ||||
|         # Gather additional sub-env rewards and calculate collisions | ||||
|         for agent in self[c.AGENT]: | ||||
|  | ||||
|             rewards = self.additional_per_agent_rewards(agent) | ||||
|             rewards = self.per_agent_reward_hook(agent) | ||||
|             for reward in rewards: | ||||
|                 agent.step_result['rewards'].append(reward) | ||||
|             if collisions := agent.step_result['collisions']: | ||||
| @@ -601,6 +601,12 @@ class BaseFactory(gym.Env): | ||||
|             self.print(f"reward is {reward}") | ||||
|         return reward, combined_info_dict | ||||
|  | ||||
|     def start_recording(self): | ||||
|         self._record_episodes = True | ||||
|  | ||||
|     def stop_recording(self): | ||||
|         self._record_episodes = False | ||||
|  | ||||
|     # noinspection PyGlobalUndefined | ||||
|     def render(self, mode='human'): | ||||
|         if not self._renderer:  # lazy init | ||||
| @@ -621,7 +627,7 @@ class BaseFactory(gym.Env): | ||||
|             for i, door in enumerate(self[c.DOORS]): | ||||
|                 name, state = 'door_open' if door.is_open else 'door_closed', 'blank' | ||||
|                 doors.append(RenderEntity(name, door.pos, 1, 'none', state, i + 1)) | ||||
|         additional_assets = self.render_additional_assets() | ||||
|         additional_assets = self.render_assets_hook() | ||||
|  | ||||
|         return self._renderer.render(walls + doors + additional_assets + agents) | ||||
|  | ||||
| @@ -652,7 +658,8 @@ class BaseFactory(gym.Env): | ||||
|  | ||||
|     # Properties which are called by the base class to extend beyond attributes of the base class | ||||
|     @property | ||||
|     def additional_actions(self) -> Union[Action, List[Action]]: | ||||
|     @abc.abstractmethod | ||||
|     def actions_hook(self) -> Union[Action, List[Action]]: | ||||
|         """ | ||||
|         When heriting from this Base Class, you musst implement this methode!!! | ||||
|  | ||||
| @@ -662,7 +669,8 @@ class BaseFactory(gym.Env): | ||||
|         return [] | ||||
|  | ||||
|     @property | ||||
|     def additional_entities(self) -> Dict[(str, Entities)]: | ||||
|     @abc.abstractmethod | ||||
|     def entities_hook(self) -> Dict[(str, Entities)]: | ||||
|         """ | ||||
|         When heriting from this Base Class, you musst implement this methode!!! | ||||
|  | ||||
| @@ -674,27 +682,39 @@ class BaseFactory(gym.Env): | ||||
|     # Functions which provide additions to functions of the base class | ||||
|     #  Always call super!!!!!! | ||||
|     @abc.abstractmethod | ||||
|     def do_additional_reset(self) -> None: | ||||
|     def reset_hook(self) -> None: | ||||
|         pass | ||||
|  | ||||
|     @abc.abstractmethod | ||||
|     def do_additional_step(self) -> (List[dict], dict): | ||||
|         return [], {} | ||||
|     def pre_step_hook(self) -> None: | ||||
|         pass | ||||
|  | ||||
|     @abc.abstractmethod | ||||
|     def do_additional_actions(self, agent: Agent, action: Action) -> (bool, dict): | ||||
|         return None | ||||
|  | ||||
|     @abc.abstractmethod | ||||
|     def step_hook(self) -> (List[dict], dict): | ||||
|         return [], {} | ||||
|  | ||||
|     @abc.abstractmethod | ||||
|     def check_additional_done(self) -> (bool, dict): | ||||
|         return False, {} | ||||
|  | ||||
|     @abc.abstractmethod | ||||
|     def _additional_observations(self) -> Dict[str, np.typing.ArrayLike]: | ||||
|     def observations_hook(self) -> Dict[str, np.typing.ArrayLike]: | ||||
|         return {} | ||||
|  | ||||
|     @abc.abstractmethod | ||||
|     def _additional_per_agent_raw_observations(self, agent) -> Dict[str, np.typing.ArrayLike]: | ||||
|     def per_agent_reward_hook(self, agent: Agent) -> Dict[str, dict]: | ||||
|         return {} | ||||
|  | ||||
|     @abc.abstractmethod | ||||
|     def post_step_hook(self) -> dict: | ||||
|         return {} | ||||
|  | ||||
|     @abc.abstractmethod | ||||
|     def per_agent_raw_observations_hook(self, agent) -> Dict[str, np.typing.ArrayLike]: | ||||
|         additional_raw_observations = {} | ||||
|         if self.obs_prop.show_global_position_info: | ||||
|             global_pos_obs = np.zeros(self._obs_shape) | ||||
| @@ -703,19 +723,5 @@ class BaseFactory(gym.Env): | ||||
|         return additional_raw_observations | ||||
|  | ||||
|     @abc.abstractmethod | ||||
|     def additional_per_agent_reward(self, agent: Agent) -> Dict[str, dict]: | ||||
|         return {} | ||||
|  | ||||
|     @abc.abstractmethod | ||||
|     def render_additional_assets(self): | ||||
|     def render_assets_hook(self): | ||||
|         return [] | ||||
|  | ||||
|     # Hooks for in between operations. | ||||
|     #  Always call super!!!!!! | ||||
|     @abc.abstractmethod | ||||
|     def hook_pre_step(self) -> None: | ||||
|         pass | ||||
|  | ||||
|     @abc.abstractmethod | ||||
|     def hook_post_step(self) -> dict: | ||||
|         return {} | ||||
|   | ||||
| @@ -9,10 +9,11 @@ from environments.helpers import Constants as c | ||||
| import itertools | ||||
|  | ||||
| ########################################################################## | ||||
| # ##################### Base Object Definition ######################### # | ||||
| # ##################### Base Object Building Blocks ######################### # | ||||
| ########################################################################## | ||||
|  | ||||
|  | ||||
| # TODO: Missing Documentation | ||||
| class Object: | ||||
|  | ||||
|     """Generell Objects for Organisation and Maintanance such as Actions etc...""" | ||||
| @@ -53,8 +54,10 @@ class Object: | ||||
|  | ||||
|     def __eq__(self, other) -> bool: | ||||
|         return other == self.identifier | ||||
| # Base | ||||
|  | ||||
|  | ||||
| # TODO: Missing Documentation | ||||
| class EnvObject(Object): | ||||
|  | ||||
|     """Objects that hold Information that are observable, but have no position on the env grid. Inventories etc...""" | ||||
| @@ -78,27 +81,10 @@ class EnvObject(Object): | ||||
|         self._register.delete_env_object(self) | ||||
|         self._register = register | ||||
|         return self._register == register | ||||
| # With Rendering | ||||
|  | ||||
|  | ||||
| class BoundingMixin(Object): | ||||
|  | ||||
|     @property | ||||
|     def bound_entity(self): | ||||
|         return self._bound_entity | ||||
|  | ||||
|     def __init__(self,entity_to_be_bound, *args, **kwargs): | ||||
|         super(BoundingMixin, self).__init__(*args, **kwargs) | ||||
|         assert entity_to_be_bound is not None | ||||
|         self._bound_entity = entity_to_be_bound | ||||
|  | ||||
|     @property | ||||
|     def name(self): | ||||
|         return f'{super(BoundingMixin, self).name}({self._bound_entity.name})' | ||||
|  | ||||
|     def belongs_to_entity(self, entity): | ||||
|         return entity == self.bound_entity | ||||
|  | ||||
|  | ||||
| # TODO: Missing Documentation | ||||
| class Entity(EnvObject): | ||||
|     """Full Env Entity that lives on the env Grid. Doors, Items, Dirt etc...""" | ||||
|  | ||||
| @@ -133,8 +119,10 @@ class Entity(EnvObject): | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return super(Entity, self).__repr__() + f'(@{self.pos})' | ||||
| # With Position in Env | ||||
|  | ||||
|  | ||||
| # TODO: Missing Documentation | ||||
| class MoveableEntity(Entity): | ||||
|  | ||||
|     @property | ||||
| @@ -169,6 +157,27 @@ class MoveableEntity(Entity): | ||||
|             return c.VALID | ||||
|         else: | ||||
|             return c.NOT_VALID | ||||
| # Can Move | ||||
|  | ||||
|  | ||||
| # TODO: Missing Documentation | ||||
| class BoundingMixin(Object): | ||||
|  | ||||
|     @property | ||||
|     def bound_entity(self): | ||||
|         return self._bound_entity | ||||
|  | ||||
|     def __init__(self,entity_to_be_bound, *args, **kwargs): | ||||
|         super(BoundingMixin, self).__init__(*args, **kwargs) | ||||
|         assert entity_to_be_bound is not None | ||||
|         self._bound_entity = entity_to_be_bound | ||||
|  | ||||
|     @property | ||||
|     def name(self): | ||||
|         return f'{super(BoundingMixin, self).name}({self._bound_entity.name})' | ||||
|  | ||||
|     def belongs_to_entity(self, entity): | ||||
|         return entity == self.bound_entity | ||||
|  | ||||
|  | ||||
| ########################################################################## | ||||
| @@ -216,7 +225,7 @@ class GlobalPosition(BoundingMixin, EnvObject): | ||||
|         self._normalized = normalized | ||||
|  | ||||
|  | ||||
| class Tile(EnvObject): | ||||
| class Floor(EnvObject): | ||||
|  | ||||
|     @property | ||||
|     def encoding(self): | ||||
| @@ -243,7 +252,7 @@ class Tile(EnvObject): | ||||
|         return self._pos | ||||
|  | ||||
|     def __init__(self, pos, *args, **kwargs): | ||||
|         super(Tile, self).__init__(*args, **kwargs) | ||||
|         super(Floor, self).__init__(*args, **kwargs) | ||||
|         self._guests = dict() | ||||
|         self._pos = tuple(pos) | ||||
|  | ||||
| @@ -277,7 +286,7 @@ class Tile(EnvObject): | ||||
|         return dict(name=self.name, x=int(self.x), y=int(self.y)) | ||||
|  | ||||
|  | ||||
| class Wall(Tile): | ||||
| class Wall(Floor): | ||||
|  | ||||
|     @property | ||||
|     def can_collide(self): | ||||
| @@ -302,7 +311,7 @@ class Door(Entity): | ||||
|     @property | ||||
|     def encoding(self): | ||||
|         # This is important as it shadow is checked by occupation value | ||||
|         return c.OCCUPIED_CELL if self.is_closed else 2 | ||||
|         return c.OCCUPIED_CELL if self.is_closed else 0.5 | ||||
|  | ||||
|     @property | ||||
|     def str_state(self): | ||||
| @@ -396,5 +405,5 @@ class Agent(MoveableEntity): | ||||
|  | ||||
|     def summarize_state(self, **kwargs): | ||||
|         state_dict = super().summarize_state(**kwargs) | ||||
|         state_dict.update(valid=bool(self.temp_action_result['valid']), action=str(self.temp_action_result['action'])) | ||||
|         state_dict.update(valid=bool(self.step_result['action_valid']), action=str(self.step_result['action_name'])) | ||||
|         return state_dict | ||||
|   | ||||
| @@ -6,7 +6,7 @@ from typing import List, Union, Dict, Tuple | ||||
| import numpy as np | ||||
| import six | ||||
|  | ||||
| from environments.factory.base.objects import Entity, Tile, Agent, Door, Action, Wall, PlaceHolder, GlobalPosition, \ | ||||
| from environments.factory.base.objects import Entity, Floor, Agent, Door, Action, Wall, PlaceHolder, GlobalPosition, \ | ||||
|     Object, EnvObject | ||||
| from environments.utility_classes import MovementProperties | ||||
| from environments import helpers as h | ||||
| @@ -271,12 +271,9 @@ class GlobalPositions(EnvObjectRegister): | ||||
|  | ||||
|     _accepted_objects = GlobalPosition | ||||
|  | ||||
|     is_blocking_light = False | ||||
|     can_be_shadowed = False | ||||
|     can_collide = False | ||||
|  | ||||
|     def __init__(self, *args, **kwargs): | ||||
|         super(GlobalPositions, self).__init__(*args, is_per_agent=True, individual_slices=True, **kwargs) | ||||
|         super(GlobalPositions, self).__init__(*args, is_per_agent=True, individual_slices=True, is_blocking_light = False, | ||||
|                                               can_be_shadowed = False, can_collide = False, **kwargs) | ||||
|  | ||||
|     def as_array(self): | ||||
|         # FIXME DEBUG!!! make this lazy? | ||||
| @@ -377,7 +374,7 @@ class Entities(ObjectRegister): | ||||
|         return found_entities | ||||
|  | ||||
|  | ||||
| class WallTiles(EntityRegister): | ||||
| class Walls(EntityRegister): | ||||
|     _accepted_objects = Wall | ||||
|  | ||||
|     def as_array(self): | ||||
| @@ -390,9 +387,9 @@ class WallTiles(EntityRegister): | ||||
|         return self._array | ||||
|  | ||||
|     def __init__(self, *args, is_blocking_light=True, **kwargs): | ||||
|         super(WallTiles, self).__init__(*args, individual_slices=False, | ||||
|                                         can_collide=True, | ||||
|                                         is_blocking_light=is_blocking_light, **kwargs) | ||||
|         super(Walls, self).__init__(*args, individual_slices=False, | ||||
|                                     can_collide=True, | ||||
|                                     is_blocking_light=is_blocking_light, **kwargs) | ||||
|         self._value = c.OCCUPIED_CELL | ||||
|  | ||||
|     @classmethod | ||||
| @@ -411,16 +408,16 @@ class WallTiles(EntityRegister): | ||||
|  | ||||
|     def summarize_states(self, n_steps=None): | ||||
|         if n_steps == h.STEPS_START: | ||||
|             return super(WallTiles, self).summarize_states(n_steps=n_steps) | ||||
|             return super(Walls, self).summarize_states(n_steps=n_steps) | ||||
|         else: | ||||
|             return {} | ||||
|  | ||||
|  | ||||
| class FloorTiles(WallTiles): | ||||
|     _accepted_objects = Tile | ||||
| class Floors(Walls): | ||||
|     _accepted_objects = Floor | ||||
|  | ||||
|     def __init__(self, *args, is_blocking_light=False, **kwargs): | ||||
|         super(FloorTiles, self).__init__(*args, is_blocking_light=is_blocking_light, **kwargs) | ||||
|         super(Floors, self).__init__(*args, is_blocking_light=is_blocking_light, **kwargs) | ||||
|         self._value = c.FREE_CELL | ||||
|  | ||||
|     @property | ||||
| @@ -430,7 +427,7 @@ class FloorTiles(WallTiles): | ||||
|         return tiles | ||||
|  | ||||
|     @property | ||||
|     def empty_tiles(self) -> List[Tile]: | ||||
|     def empty_tiles(self) -> List[Floor]: | ||||
|         tiles = [tile for tile in self if tile.is_empty()] | ||||
|         random.shuffle(tiles) | ||||
|         return tiles | ||||
|   | ||||
| @@ -158,19 +158,19 @@ class BatteryFactory(BaseFactory): | ||||
|         self.btry_prop = btry_prop | ||||
|         super().__init__(*args, **kwargs) | ||||
|  | ||||
|     def _additional_per_agent_raw_observations(self, agent) -> Dict[str, np.typing.ArrayLike]: | ||||
|         additional_raw_observations = super()._additional_per_agent_raw_observations(agent) | ||||
|     def per_agent_raw_observations_hook(self, agent) -> Dict[str, np.typing.ArrayLike]: | ||||
|         additional_raw_observations = super().per_agent_raw_observations_hook(agent) | ||||
|         additional_raw_observations.update({c.BATTERIES: self[c.BATTERIES].as_array_by_entity(agent)}) | ||||
|         return additional_raw_observations | ||||
|  | ||||
|     def _additional_observations(self) -> Dict[str, np.typing.ArrayLike]: | ||||
|         additional_observations = super()._additional_observations() | ||||
|     def observations_hook(self) -> Dict[str, np.typing.ArrayLike]: | ||||
|         additional_observations = super().observations_hook() | ||||
|         additional_observations.update({c.CHARGE_PODS: self[c.CHARGE_PODS].as_array()}) | ||||
|         return additional_observations | ||||
|  | ||||
|     @property | ||||
|     def additional_entities(self): | ||||
|         super_entities = super().additional_entities | ||||
|     def entities_hook(self): | ||||
|         super_entities = super().entities_hook | ||||
|  | ||||
|         empty_tiles = self[c.FLOOR].empty_tiles[:self.btry_prop.charge_locations] | ||||
|         charge_pods = ChargePods.from_tiles( | ||||
| @@ -185,8 +185,8 @@ class BatteryFactory(BaseFactory): | ||||
|         super_entities.update({c.BATTERIES: batteries, c.CHARGE_PODS: charge_pods}) | ||||
|         return super_entities | ||||
|  | ||||
|     def do_additional_step(self) -> (List[dict], dict): | ||||
|         super_reward_info = super(BatteryFactory, self).do_additional_step() | ||||
|     def step_hook(self) -> (List[dict], dict): | ||||
|         super_reward_info = super(BatteryFactory, self).step_hook() | ||||
|  | ||||
|         # Decharge | ||||
|         batteries = self[c.BATTERIES] | ||||
| @@ -230,7 +230,7 @@ class BatteryFactory(BaseFactory): | ||||
|             return action_result | ||||
|         pass | ||||
|  | ||||
|     def do_additional_reset(self) -> None: | ||||
|     def reset_hook(self) -> None: | ||||
|         # There is Nothing to reset. | ||||
|         pass | ||||
|  | ||||
| @@ -249,8 +249,8 @@ class BatteryFactory(BaseFactory): | ||||
|                 pass | ||||
|         pass | ||||
|  | ||||
|     def additional_per_agent_reward(self, agent: Agent) -> Dict[str, dict]: | ||||
|         reward_event_dict = super(BatteryFactory, self).additional_per_agent_reward(agent) | ||||
|     def per_agent_reward_hook(self, agent: Agent) -> Dict[str, dict]: | ||||
|         reward_event_dict = super(BatteryFactory, self).per_agent_reward_hook(agent) | ||||
|         if self[c.BATTERIES].by_entity(agent).is_discharged: | ||||
|             self.print(f'{agent.name} Battery is discharged!') | ||||
|             info_dict = {f'{agent.name}_{c.BATTERY_DISCHARGED}': 1} | ||||
| @@ -260,9 +260,9 @@ class BatteryFactory(BaseFactory): | ||||
|             pass | ||||
|         return reward_event_dict | ||||
|  | ||||
|     def render_additional_assets(self): | ||||
|     def render_assets_hook(self): | ||||
|         # noinspection PyUnresolvedReferences | ||||
|         additional_assets = super().render_additional_assets() | ||||
|         additional_assets = super().render_assets_hook() | ||||
|         charge_pods = [RenderEntity(c.CHARGE_PODS, charge_pod.tile.pos) for charge_pod in self[c.CHARGE_PODS]] | ||||
|         additional_assets.extend(charge_pods) | ||||
|         return additional_assets | ||||
|   | ||||
| @@ -147,17 +147,17 @@ class DestFactory(BaseFactory): | ||||
|         super().__init__(*args, **kwargs) | ||||
|  | ||||
|     @property | ||||
|     def additional_actions(self) -> Union[Action, List[Action]]: | ||||
|     def actions_hook(self) -> Union[Action, List[Action]]: | ||||
|         # noinspection PyUnresolvedReferences | ||||
|         super_actions = super().additional_actions | ||||
|         super_actions = super().actions_hook | ||||
|         if self.dest_prop.dwell_time: | ||||
|             super_actions.append(Action(enum_ident=a.WAIT_ON_DEST)) | ||||
|         return super_actions | ||||
|  | ||||
|     @property | ||||
|     def additional_entities(self) -> Dict[(Enum, Entities)]: | ||||
|     def entities_hook(self) -> Dict[(Enum, Entities)]: | ||||
|         # noinspection PyUnresolvedReferences | ||||
|         super_entities = super().additional_entities | ||||
|         super_entities = super().entities_hook | ||||
|  | ||||
|         empty_tiles = self[c.FLOOR].empty_tiles[:self.dest_prop.n_dests] | ||||
|         destinations = Destinations.from_tiles( | ||||
| @@ -194,9 +194,9 @@ class DestFactory(BaseFactory): | ||||
|         else: | ||||
|             return super_action_result | ||||
|  | ||||
|     def do_additional_reset(self) -> None: | ||||
|     def reset_hook(self) -> None: | ||||
|         # noinspection PyUnresolvedReferences | ||||
|         super().do_additional_reset() | ||||
|         super().reset_hook() | ||||
|         self._dest_spawn_timer = dict() | ||||
|  | ||||
|     def trigger_destination_spawn(self): | ||||
| @@ -222,9 +222,9 @@ class DestFactory(BaseFactory): | ||||
|         else: | ||||
|             self.print('No Items are spawning, limit is reached.') | ||||
|  | ||||
|     def do_additional_step(self) -> (List[dict], dict): | ||||
|     def step_hook(self) -> (List[dict], dict): | ||||
|         # noinspection PyUnresolvedReferences | ||||
|         super_reward_info = super().do_additional_step() | ||||
|         super_reward_info = super().step_hook() | ||||
|         for key, val in self._dest_spawn_timer.items(): | ||||
|             self._dest_spawn_timer[key] = min(self.dest_prop.spawn_frequency, self._dest_spawn_timer[key] + 1) | ||||
|         for dest in list(self[c.DEST].values()): | ||||
| @@ -244,14 +244,14 @@ class DestFactory(BaseFactory): | ||||
|         self.trigger_destination_spawn() | ||||
|         return super_reward_info | ||||
|  | ||||
|     def _additional_observations(self) -> Dict[str, np.typing.ArrayLike]: | ||||
|         additional_observations = super()._additional_observations() | ||||
|     def observations_hook(self) -> Dict[str, np.typing.ArrayLike]: | ||||
|         additional_observations = super().observations_hook() | ||||
|         additional_observations.update({c.DEST: self[c.DEST].as_array()}) | ||||
|         return additional_observations | ||||
|  | ||||
|     def additional_per_agent_reward(self, agent: Agent) -> Dict[str, dict]: | ||||
|     def per_agent_reward_hook(self, agent: Agent) -> Dict[str, dict]: | ||||
|         # noinspection PyUnresolvedReferences | ||||
|         reward_event_dict = super().additional_per_agent_reward(agent) | ||||
|         reward_event_dict = super().per_agent_reward_hook(agent) | ||||
|         if len(self[c.DEST_REACHED]): | ||||
|             for reached_dest in list(self[c.DEST_REACHED]): | ||||
|                 if agent.pos == reached_dest.pos: | ||||
| @@ -261,9 +261,9 @@ class DestFactory(BaseFactory): | ||||
|                     reward_event_dict.update({c.DEST_REACHED: {'reward': r.DEST_REACHED, 'info': info_dict}}) | ||||
|         return reward_event_dict | ||||
|  | ||||
|     def render_additional_assets(self, mode='human'): | ||||
|     def render_assets_hook(self, mode='human'): | ||||
|         # noinspection PyUnresolvedReferences | ||||
|         additional_assets = super().render_additional_assets() | ||||
|         additional_assets = super().render_assets_hook() | ||||
|         destinations = [RenderEntity(c.DEST, dest.pos) for dest in self[c.DEST]] | ||||
|         additional_assets.extend(destinations) | ||||
|         return additional_assets | ||||
|   | ||||
| @@ -1,5 +1,4 @@ | ||||
| import time | ||||
| from enum import Enum | ||||
| from typing import List, Union, NamedTuple, Dict | ||||
| import random | ||||
|  | ||||
| @@ -12,7 +11,7 @@ from environments.helpers import EnvActions as BaseActions | ||||
| from environments.helpers import Rewards as BaseRewards | ||||
|  | ||||
| from environments.factory.base.base_factory import BaseFactory | ||||
| from environments.factory.base.objects import Agent, Action, Entity, Tile | ||||
| from environments.factory.base.objects import Agent, Action, Entity, Floor | ||||
| from environments.factory.base.registers import Entities, EntityRegister | ||||
|  | ||||
| from environments.factory.base.renderer import RenderEntity | ||||
| @@ -43,7 +42,6 @@ class DirtProperties(NamedTuple): | ||||
|     max_local_amount: int = 2               # Max dirt amount per tile. | ||||
|     max_global_amount: int = 20             # Max dirt amount in the whole environment. | ||||
|     dirt_smear_amount: float = 0.2          # Agents smear dirt, when not cleaning up in place. | ||||
|     agent_can_interact: bool = True         # Whether the agents can interact with the dirt in this environment. | ||||
|     done_when_clean: bool = True | ||||
|  | ||||
|  | ||||
| @@ -89,7 +87,7 @@ class DirtRegister(EntityRegister): | ||||
|         self._dirt_properties: DirtProperties = dirt_properties | ||||
|  | ||||
|     def spawn_dirt(self, then_dirty_tiles) -> bool: | ||||
|         if isinstance(then_dirty_tiles, Tile): | ||||
|         if isinstance(then_dirty_tiles, Floor): | ||||
|             then_dirty_tiles = [then_dirty_tiles] | ||||
|         for tile in then_dirty_tiles: | ||||
|             if not self.amount > self.dirt_properties.max_global_amount: | ||||
| @@ -128,15 +126,14 @@ r = Rewards | ||||
| class DirtFactory(BaseFactory): | ||||
|  | ||||
|     @property | ||||
|     def additional_actions(self) -> Union[Action, List[Action]]: | ||||
|         super_actions = super().additional_actions | ||||
|         if self.dirt_prop.agent_can_interact: | ||||
|             super_actions.append(Action(str_ident=a.CLEAN_UP)) | ||||
|     def actions_hook(self) -> Union[Action, List[Action]]: | ||||
|         super_actions = super().actions_hook | ||||
|         super_actions.append(Action(str_ident=a.CLEAN_UP)) | ||||
|         return super_actions | ||||
|  | ||||
|     @property | ||||
|     def additional_entities(self) -> Dict[(Enum, Entities)]: | ||||
|         super_entities = super().additional_entities | ||||
|     def entities_hook(self) -> Dict[(str, Entities)]: | ||||
|         super_entities = super().entities_hook | ||||
|         dirt_register = DirtRegister(self.dirt_prop, self._level_shape) | ||||
|         super_entities.update(({c.DIRT: dirt_register})) | ||||
|         return super_entities | ||||
| @@ -148,10 +145,11 @@ class DirtFactory(BaseFactory): | ||||
|         self._dirt_rng = np.random.default_rng(env_seed) | ||||
|         self._dirt: DirtRegister | ||||
|         kwargs.update(env_seed=env_seed) | ||||
|         # TODO: Reset ---> document this | ||||
|         super().__init__(*args, **kwargs) | ||||
|  | ||||
|     def render_additional_assets(self, mode='human'): | ||||
|         additional_assets = super().render_additional_assets() | ||||
|     def render_assets_hook(self, mode='human'): | ||||
|         additional_assets = super().render_assets_hook() | ||||
|         dirt = [RenderEntity('dirt', dirt.tile.pos, min(0.15 + dirt.amount, 1.5), 'scale') | ||||
|                 for dirt in self[c.DIRT]] | ||||
|         additional_assets.extend(dirt) | ||||
| @@ -167,12 +165,12 @@ class DirtFactory(BaseFactory): | ||||
|                 dirt.set_new_amount(max(new_dirt_amount, c.FREE_CELL.value)) | ||||
|             valid = c.VALID | ||||
|             self.print(f'{agent.name} did just clean up some dirt at {agent.pos}.') | ||||
|             info_dict = {f'{agent.name}_{a.CLEAN_UP}_VALID': 1} | ||||
|             info_dict = {f'{agent.name}_{a.CLEAN_UP}_VALID': 1, 'cleanup_valid': 1} | ||||
|             reward = r.CLEAN_UP_VALID | ||||
|         else: | ||||
|             valid = c.NOT_VALID | ||||
|             self.print(f'{agent.name} just tried to clean up some dirt at {agent.pos}, but failed.') | ||||
|             info_dict = {f'{agent.name}_{a.CLEAN_UP}_FAIL': 1} | ||||
|             info_dict = {f'{agent.name}_{a.CLEAN_UP}_FAIL': 1, 'cleanup_fail': 1} | ||||
|             reward = r.CLEAN_UP_FAIL | ||||
|  | ||||
|         if valid and self.dirt_prop.done_when_clean and (len(self[c.DIRT]) == 0): | ||||
| @@ -195,8 +193,8 @@ class DirtFactory(BaseFactory): | ||||
|         n_dirt_tiles = max(0, int(new_spawn * len(free_for_dirt))) | ||||
|         self[c.DIRT].spawn_dirt(free_for_dirt[:n_dirt_tiles]) | ||||
|  | ||||
|     def do_additional_step(self) -> (List[dict], dict): | ||||
|         super_reward_info = super().do_additional_step() | ||||
|     def step_hook(self) -> (List[dict], dict): | ||||
|         super_reward_info = super().step_hook() | ||||
|         if smear_amount := self.dirt_prop.dirt_smear_amount: | ||||
|             for agent in self[c.AGENT]: | ||||
|                 if agent.temp_valid and agent.last_pos != c.NO_POS: | ||||
| @@ -229,8 +227,8 @@ class DirtFactory(BaseFactory): | ||||
|         else: | ||||
|             return action_result | ||||
|  | ||||
|     def do_additional_reset(self) -> None: | ||||
|         super().do_additional_reset() | ||||
|     def reset_hook(self) -> None: | ||||
|         super().reset_hook() | ||||
|         self.trigger_dirt_spawn(initial_spawn=True) | ||||
|         self._next_dirt_spawn = self.dirt_prop.spawn_frequency if self.dirt_prop.spawn_frequency else -1 | ||||
|  | ||||
| @@ -242,13 +240,13 @@ class DirtFactory(BaseFactory): | ||||
|                 return all_cleaned, super_dict | ||||
|         return super_done, super_dict | ||||
|  | ||||
|     def _additional_observations(self) -> Dict[str, np.typing.ArrayLike]: | ||||
|         additional_observations = super()._additional_observations() | ||||
|     def observations_hook(self) -> Dict[str, np.typing.ArrayLike]: | ||||
|         additional_observations = super().observations_hook() | ||||
|         additional_observations.update({c.DIRT: self[c.DIRT].as_array()}) | ||||
|         return additional_observations | ||||
|  | ||||
|     def gather_additional_info(self, agent: Agent) -> dict: | ||||
|         event_reward_dict = super().additional_per_agent_reward(agent) | ||||
|         event_reward_dict = super().per_agent_reward_hook(agent) | ||||
|         info_dict = dict() | ||||
|  | ||||
|         dirt = [dirt.amount for dirt in self[c.DIRT]] | ||||
| @@ -280,8 +278,7 @@ if __name__ == '__main__': | ||||
|         max_local_amount=1, | ||||
|         spawn_frequency=0, | ||||
|         max_spawn_ratio=0.05, | ||||
|         dirt_smear_amount=0.0, | ||||
|         agent_can_interact=True | ||||
|         dirt_smear_amount=0.0 | ||||
|     ) | ||||
|  | ||||
|     obs_props = ObservationProperties(render_agents=aro.COMBINED, omit_agent_self=True, | ||||
| @@ -294,13 +291,13 @@ if __name__ == '__main__': | ||||
|     global_timings = [] | ||||
|     for i in range(10): | ||||
|  | ||||
|         factory = DirtFactory(n_agents=1, done_at_collision=False, | ||||
|         factory = DirtFactory(n_agents=10, done_at_collision=False, | ||||
|                               level_name='rooms', max_steps=1000, | ||||
|                               doors_have_area=False, | ||||
|                               obs_prop=obs_props, parse_doors=True, | ||||
|                               verbose=True, | ||||
|                               mv_prop=move_props, dirt_prop=dirt_props, | ||||
|                               inject_agents=[TSPDirtAgent], | ||||
|                               # inject_agents=[TSPDirtAgent], | ||||
|                               ) | ||||
|  | ||||
|         # noinspection DuplicatedCode | ||||
| @@ -318,11 +315,11 @@ if __name__ == '__main__': | ||||
|             env_state = factory.reset() | ||||
|             if render: | ||||
|                 factory.render() | ||||
|             tsp_agent = factory.get_injected_agents()[0] | ||||
|             # tsp_agent = factory.get_injected_agents()[0] | ||||
|  | ||||
|             rwrd = 0 | ||||
|             for agent_i_action in random_actions: | ||||
|                 agent_i_action = tsp_agent.predict() | ||||
|                 # agent_i_action = tsp_agent.predict() | ||||
|                 env_state, step_rwrd, done_bool, info_obj = factory.step(agent_i_action) | ||||
|                 rwrd += step_rwrd | ||||
|                 if render: | ||||
|   | ||||
							
								
								
									
										58
									
								
								environments/factory/factory_dirt_stationary_machines.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										58
									
								
								environments/factory/factory_dirt_stationary_machines.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,58 @@ | ||||
| from typing import Dict, List, Union | ||||
|  | ||||
| import numpy as np | ||||
|  | ||||
| from environments.factory.base.objects import Agent, Entity, Action | ||||
| from environments.factory.factory_dirt import Dirt, DirtRegister, DirtFactory | ||||
| from environments.factory.base.objects import Floor | ||||
| from environments.factory.base.registers import Floors, Entities, EntityRegister | ||||
|  | ||||
|  | ||||
| class Machines(EntityRegister): | ||||
|  | ||||
|     def __init__(self, *args, **kwargs): | ||||
|         super().__init__(*args, **kwargs) | ||||
|  | ||||
|  | ||||
| class Machine(Entity): | ||||
|  | ||||
|     def __init__(self, *args, **kwargs): | ||||
|         super().__init__(*args, **kwargs) | ||||
|  | ||||
|  | ||||
| class StationaryMachinesDirtFactory(DirtFactory): | ||||
|  | ||||
|     def __init__(self, *args, **kwargs): | ||||
|         self._machine_coords = [(6, 6), (12, 13)] | ||||
|         super().__init__(*args, **kwargs) | ||||
|  | ||||
|     def entities_hook(self) -> Dict[(str, Entities)]: | ||||
|         super_entities = super().entities_hook() | ||||
|  | ||||
|         return super_entities | ||||
|  | ||||
|     def reset_hook(self) -> None: | ||||
|                 pass | ||||
|  | ||||
|     def observations_hook(self) -> Dict[str, np.typing.ArrayLike]: | ||||
|         pass | ||||
|  | ||||
|     def actions_hook(self) -> Union[Action, List[Action]]: | ||||
|         pass | ||||
|  | ||||
|     def step_hook(self) -> (List[dict], dict): | ||||
|  | ||||
|         pass | ||||
|  | ||||
|     def per_agent_raw_observations_hook(self, agent) -> Dict[str, np.typing.ArrayLike]: | ||||
|         super_per_agent_raw_observations = super().per_agent_raw_observations_hook(agent) | ||||
|         return super_per_agent_raw_observations | ||||
|  | ||||
|     def per_agent_reward_hook(self, agent: Agent) -> Dict[str, dict]: | ||||
|         pass | ||||
|  | ||||
|     def pre_step_hook(self) -> None: | ||||
|         pass | ||||
|  | ||||
|     def post_step_hook(self) -> dict: | ||||
|         pass | ||||
| @@ -9,7 +9,7 @@ from environments.helpers import Constants as BaseConstants | ||||
| from environments.helpers import EnvActions as BaseActions | ||||
| from environments.helpers import Rewards as BaseRewards | ||||
| from environments import helpers as h | ||||
| from environments.factory.base.objects import Agent, Entity, Action, Tile | ||||
| from environments.factory.base.objects import Agent, Entity, Action, Floor | ||||
| from environments.factory.base.registers import Entities, EntityRegister, BoundEnvObjRegister, ObjectRegister | ||||
|  | ||||
| from environments.factory.base.renderer import RenderEntity | ||||
| @@ -25,7 +25,7 @@ class Constants(BaseConstants): | ||||
|  | ||||
|  | ||||
| class Actions(BaseActions): | ||||
|     ITEM_ACTION     = 'item_action' | ||||
|     ITEM_ACTION     = 'ITEMACTION' | ||||
|  | ||||
|  | ||||
| class Rewards(BaseRewards): | ||||
| @@ -62,7 +62,7 @@ class ItemRegister(EntityRegister): | ||||
|  | ||||
|     _accepted_objects = Item | ||||
|  | ||||
|     def spawn_items(self, tiles: List[Tile]): | ||||
|     def spawn_items(self, tiles: List[Floor]): | ||||
|         items = [Item(tile, self) for tile in tiles] | ||||
|         self.register_additional_items(items) | ||||
|  | ||||
| @@ -193,16 +193,16 @@ class ItemFactory(BaseFactory): | ||||
|         super().__init__(*args, **kwargs) | ||||
|  | ||||
|     @property | ||||
|     def additional_actions(self) -> Union[Action, List[Action]]: | ||||
|     def actions_hook(self) -> Union[Action, List[Action]]: | ||||
|         # noinspection PyUnresolvedReferences | ||||
|         super_actions = super().additional_actions | ||||
|         super_actions = super().actions_hook | ||||
|         super_actions.append(Action(str_ident=a.ITEM_ACTION)) | ||||
|         return super_actions | ||||
|  | ||||
|     @property | ||||
|     def additional_entities(self) -> Dict[(str, Entities)]: | ||||
|     def entities_hook(self) -> Dict[(str, Entities)]: | ||||
|         # noinspection PyUnresolvedReferences | ||||
|         super_entities = super().additional_entities | ||||
|         super_entities = super().entities_hook | ||||
|  | ||||
|         empty_tiles = self[c.FLOOR].empty_tiles[:self.item_prop.n_drop_off_locations] | ||||
|         drop_offs = DropOffLocations.from_tiles( | ||||
| @@ -220,13 +220,13 @@ class ItemFactory(BaseFactory): | ||||
|         super_entities.update({c.DROP_OFF: drop_offs, c.ITEM: item_register, c.INVENTORY: inventories}) | ||||
|         return super_entities | ||||
|  | ||||
|     def _additional_per_agent_raw_observations(self, agent) -> Dict[str, np.typing.ArrayLike]: | ||||
|         additional_raw_observations = super()._additional_per_agent_raw_observations(agent) | ||||
|     def per_agent_raw_observations_hook(self, agent) -> Dict[str, np.typing.ArrayLike]: | ||||
|         additional_raw_observations = super().per_agent_raw_observations_hook(agent) | ||||
|         additional_raw_observations.update({c.INVENTORY: self[c.INVENTORY].by_entity(agent).as_array()}) | ||||
|         return additional_raw_observations | ||||
|  | ||||
|     def _additional_observations(self) -> Dict[str, np.typing.ArrayLike]: | ||||
|         additional_observations = super()._additional_observations() | ||||
|     def observations_hook(self) -> Dict[str, np.typing.ArrayLike]: | ||||
|         additional_observations = super().observations_hook() | ||||
|         additional_observations.update({c.ITEM: self[c.ITEM].as_array()}) | ||||
|         additional_observations.update({c.DROP_OFF: self[c.DROP_OFF].as_array()}) | ||||
|         return additional_observations | ||||
| @@ -240,21 +240,21 @@ class ItemFactory(BaseFactory): | ||||
|                 valid = c.NOT_VALID | ||||
|             if valid: | ||||
|                 self.print(f'{agent.name} just dropped of an item at {drop_off.pos}.') | ||||
|                 info_dict = {f'{agent.name}_DROPOFF_VALID': 1} | ||||
|                 info_dict = {f'{agent.name}_DROPOFF_VALID': 1, 'DROPOFF_VALID': 1} | ||||
|             else: | ||||
|                 self.print(f'{agent.name} just tried to drop off at {agent.pos}, but failed.') | ||||
|                 info_dict = {f'{agent.name}_DROPOFF_FAIL': 1} | ||||
|                 info_dict = {f'{agent.name}_DROPOFF_FAIL': 1, 'DROPOFF_FAIL': 1} | ||||
|             reward = dict(value=r.DROP_OFF_VALID if valid else r.DROP_OFF_FAIL, reason=a.ITEM_ACTION, info=info_dict) | ||||
|             return valid, reward | ||||
|         elif item := self[c.ITEM].by_pos(agent.pos): | ||||
|             item.change_register(inventory) | ||||
|             item.set_tile_to(self._NO_POS_TILE) | ||||
|             self.print(f'{agent.name} just picked up an item at {agent.pos}') | ||||
|             info_dict = {f'{agent.name}_{a.ITEM_ACTION}_VALID': 1} | ||||
|             info_dict = {f'{agent.name}_{a.ITEM_ACTION}_VALID': 1, f'{a.ITEM_ACTION}_VALID': 1} | ||||
|             return c.VALID, dict(value=r.PICK_UP_VALID, reason=a.ITEM_ACTION, info=info_dict) | ||||
|         else: | ||||
|             self.print(f'{agent.name} just tried to pick up an item at {agent.pos}, but failed.') | ||||
|             info_dict = {f'{agent.name}_{a.ITEM_ACTION}_FAIL': 1} | ||||
|             info_dict = {f'{agent.name}_{a.ITEM_ACTION}_FAIL': 1, f'{a.ITEM_ACTION}_FAIL': 1} | ||||
|             return c.NOT_VALID, dict(value=r.PICK_UP_FAIL, reason=a.ITEM_ACTION, info=info_dict) | ||||
|  | ||||
|     def do_additional_actions(self, agent: Agent, action: Action) -> (dict, dict): | ||||
| @@ -269,9 +269,9 @@ class ItemFactory(BaseFactory): | ||||
|         else: | ||||
|             return action_result | ||||
|  | ||||
|     def do_additional_reset(self) -> None: | ||||
|     def reset_hook(self) -> None: | ||||
|         # noinspection PyUnresolvedReferences | ||||
|         super().do_additional_reset() | ||||
|         super().reset_hook() | ||||
|         self._next_item_spawn = self.item_prop.spawn_frequency | ||||
|         self.trigger_item_spawn() | ||||
|  | ||||
| @@ -284,9 +284,9 @@ class ItemFactory(BaseFactory): | ||||
|         else: | ||||
|             self.print('No Items are spawning, limit is reached.') | ||||
|  | ||||
|     def do_additional_step(self) -> (List[dict], dict): | ||||
|     def step_hook(self) -> (List[dict], dict): | ||||
|         # noinspection PyUnresolvedReferences | ||||
|         super_reward_info = super().do_additional_step() | ||||
|         super_reward_info = super().step_hook() | ||||
|         for item in list(self[c.ITEM].values()): | ||||
|             if item.auto_despawn >= 1: | ||||
|                 item.set_auto_despawn(item.auto_despawn-1) | ||||
| @@ -301,9 +301,9 @@ class ItemFactory(BaseFactory): | ||||
|             self._next_item_spawn = max(0, self._next_item_spawn-1) | ||||
|         return super_reward_info | ||||
|  | ||||
|     def render_additional_assets(self, mode='human'): | ||||
|     def render_assets_hook(self, mode='human'): | ||||
|         # noinspection PyUnresolvedReferences | ||||
|         additional_assets = super().render_additional_assets() | ||||
|         additional_assets = super().render_assets_hook() | ||||
|         items = [RenderEntity(c.ITEM, item.tile.pos) for item in self[c.ITEM] if item.tile != self._NO_POS_TILE] | ||||
|         additional_assets.extend(items) | ||||
|         drop_offs = [RenderEntity(c.DROP_OFF, drop_off.tile.pos) for drop_off in self[c.DROP_OFF]] | ||||
| @@ -314,7 +314,7 @@ class ItemFactory(BaseFactory): | ||||
| if __name__ == '__main__': | ||||
|     from environments.utility_classes import AgentRenderOptions as aro, ObservationProperties | ||||
|  | ||||
|     render = False | ||||
|     render = True | ||||
|  | ||||
|     item_probs = ItemProperties(n_items=30, n_drop_off_locations=6) | ||||
|  | ||||
| @@ -336,18 +336,18 @@ if __name__ == '__main__': | ||||
|     obs_space = factory.observation_space | ||||
|     obs_space_named = factory.named_observation_space | ||||
|  | ||||
|     for epoch in range(4): | ||||
|     for epoch in range(400): | ||||
|         random_actions = [[random.randint(0, n_actions) for _ | ||||
|                            in range(factory.n_agents)] for _ | ||||
|                           in range(factory.max_steps + 1)] | ||||
|         env_state = factory.reset() | ||||
|         r = 0 | ||||
|         rwrd = 0 | ||||
|         for agent_i_action in random_actions: | ||||
|             env_state, step_r, done_bool, info_obj = factory.step(agent_i_action) | ||||
|             r += step_r | ||||
|             rwrd += step_r | ||||
|             if render: | ||||
|                 factory.render() | ||||
|             if done_bool: | ||||
|                 break | ||||
|         print(f'Factory run {epoch} done, reward is:\n    {r}') | ||||
|         print(f'Factory run {epoch} done, reward is:\n    {rwrd}') | ||||
| pass | ||||
|   | ||||
| @@ -1,5 +1,6 @@ | ||||
| import pickle | ||||
| from collections import defaultdict | ||||
| from os import PathLike | ||||
| from pathlib import Path | ||||
| from typing import List, Dict, Union | ||||
|  | ||||
| @@ -9,14 +10,17 @@ from environments.helpers import IGNORED_DF_COLUMNS | ||||
|  | ||||
| import pandas as pd | ||||
|  | ||||
| from plotting.compare_runs import plot_single_run | ||||
|  | ||||
|  | ||||
| class EnvMonitor(BaseCallback): | ||||
|  | ||||
|     ext = 'png' | ||||
|  | ||||
|     def __init__(self, env): | ||||
|     def __init__(self, env, filepath: Union[str, PathLike] = None): | ||||
|         super(EnvMonitor, self).__init__() | ||||
|         self.unwrapped = env | ||||
|         self._filepath = filepath | ||||
|         self._monitor_df = pd.DataFrame() | ||||
|         self._monitor_dicts = defaultdict(dict) | ||||
|  | ||||
| @@ -67,8 +71,10 @@ class EnvMonitor(BaseCallback): | ||||
|             pass | ||||
|         return | ||||
|  | ||||
|     def save_run(self, filepath: Union[Path, str]): | ||||
|     def save_run(self, filepath: Union[Path, str], auto_plotting_keys=None): | ||||
|         filepath = Path(filepath) | ||||
|         filepath.parent.mkdir(exist_ok=True, parents=True) | ||||
|         with filepath.open('wb') as f: | ||||
|             pickle.dump(self._monitor_df.reset_index(), f, protocol=pickle.HIGHEST_PROTOCOL) | ||||
|         if auto_plotting_keys: | ||||
|             plot_single_run(filepath, column_keys=auto_plotting_keys) | ||||
|   | ||||
| @@ -24,14 +24,12 @@ class EnvRecorder(BaseCallback): | ||||
|                 self._entities = [entities] | ||||
|         else: | ||||
|             self._entities = entities | ||||
|         self.started = False | ||||
|         self.closed = False | ||||
|  | ||||
|     def __getattr__(self, item): | ||||
|         return getattr(self.unwrapped, item) | ||||
|  | ||||
|     def reset(self): | ||||
|         self.unwrapped._record_episodes = True | ||||
|         self.unwrapped.start_recording() | ||||
|         return self.unwrapped.reset() | ||||
|  | ||||
|     def _on_training_start(self) -> None: | ||||
| @@ -57,6 +55,14 @@ class EnvRecorder(BaseCallback): | ||||
|         else: | ||||
|             pass | ||||
|  | ||||
|     def step(self, actions): | ||||
|         step_result = self.unwrapped.step(actions) | ||||
|         # 0, 1,     2    ,     3    =    idx | ||||
|         # _, _, done_bool, info_obj = step_result | ||||
|         self._read_info(0, step_result[3]) | ||||
|         self._read_done(0, step_result[2]) | ||||
|         return step_result | ||||
|  | ||||
|     def save_records(self, filepath: Union[Path, str], save_occupation_map=False, save_trajectory_map=False): | ||||
|         filepath = Path(filepath) | ||||
|         filepath.parent.mkdir(exist_ok=True, parents=True) | ||||
|   | ||||
| @@ -10,6 +10,45 @@ from environments.helpers import IGNORED_DF_COLUMNS, MODEL_MAP | ||||
| from plotting.plotting import prepare_plot | ||||
|  | ||||
|  | ||||
| def plot_single_run(run_path: Union[str, PathLike], use_tex: bool = False, column_keys=None): | ||||
|     run_path = Path(run_path) | ||||
|     df_list = list() | ||||
|     if run_path.is_dir(): | ||||
|         monitor_file = next(run_path.glob('*monitor*.pick')) | ||||
|     elif run_path.exists() and run_path.is_file(): | ||||
|         monitor_file = run_path | ||||
|     else: | ||||
|         raise ValueError | ||||
|  | ||||
|     with monitor_file.open('rb') as f: | ||||
|         monitor_df = pickle.load(f) | ||||
|  | ||||
|         monitor_df = monitor_df.fillna(0) | ||||
|         df_list.append(monitor_df) | ||||
|  | ||||
|     df = pd.concat(df_list,  ignore_index=True) | ||||
|     df = df.fillna(0).rename(columns={'episode': 'Episode'}).sort_values(['Episode']) | ||||
|     if column_keys is not None: | ||||
|         columns = [col for col in column_keys if col in df.columns] | ||||
|     else: | ||||
|         columns = [col for col in df.columns if col not in IGNORED_DF_COLUMNS] | ||||
|  | ||||
|     roll_n = 50 | ||||
|  | ||||
|     non_overlapp_window = df.groupby(['Episode']).rolling(roll_n, min_periods=1).mean() | ||||
|  | ||||
|     df_melted = df[columns + ['Episode']].reset_index().melt(id_vars=['Episode'], | ||||
|                                                                 value_vars=columns, var_name="Measurement", | ||||
|                                                                 value_name="Score") | ||||
|  | ||||
|     if df_melted['Episode'].max() > 800: | ||||
|         skip_n = round(df_melted['Episode'].max() * 0.02) | ||||
|         df_melted = df_melted[df_melted['Episode'] % skip_n == 0] | ||||
|  | ||||
|     prepare_plot(run_path.parent / f'{run_path.parent.name}_monitor_lineplot.png', df_melted, use_tex=use_tex) | ||||
|     print('Plotting done.') | ||||
|  | ||||
|  | ||||
| def compare_seed_runs(run_path: Union[str, PathLike], use_tex: bool = False): | ||||
|     run_path = Path(run_path) | ||||
|     df_list = list() | ||||
| @@ -37,7 +76,10 @@ def compare_seed_runs(run_path: Union[str, PathLike], use_tex: bool = False): | ||||
|         skip_n = round(df_melted['Episode'].max() * 0.02) | ||||
|         df_melted = df_melted[df_melted['Episode'] % skip_n == 0] | ||||
|  | ||||
|     prepare_plot(run_path / f'{run_path.name}_monitor_lineplot.png', df_melted, use_tex=use_tex) | ||||
|     if run_path.is_dir(): | ||||
|         prepare_plot(run_path / f'{run_path}_monitor_lineplot.png', df_melted, use_tex=use_tex) | ||||
|     elif run_path.exists() and run_path.is_file(): | ||||
|         prepare_plot(run_path.parent / f'{run_path.parent}_monitor_lineplot.png', df_melted, use_tex=use_tex) | ||||
|     print('Plotting done.') | ||||
|  | ||||
|  | ||||
|   | ||||
| @@ -1,4 +1,5 @@ | ||||
| import seaborn as sns | ||||
| import matplotlib as mpl | ||||
| from matplotlib import pyplot as plt | ||||
|  | ||||
| PALETTE = 10 * ( | ||||
| @@ -21,7 +22,14 @@ PALETTE = 10 * ( | ||||
| def plot(filepath, ext='png'): | ||||
|     plt.tight_layout() | ||||
|     figure = plt.gcf() | ||||
|     figure.savefig(str(filepath), format=ext) | ||||
|     ax = plt.gca() | ||||
|     legends = [c for c in ax.get_children() if isinstance(c, mpl.legend.Legend)] | ||||
|  | ||||
|     if legends: | ||||
|         figure.savefig(str(filepath), format=ext,  bbox_extra_artists=(*legends,), bbox_inches='tight') | ||||
|     else: | ||||
|         figure.savefig(str(filepath), format=ext) | ||||
|  | ||||
|     plt.show() | ||||
|     plt.clf() | ||||
|  | ||||
| @@ -30,7 +38,7 @@ def prepare_tex(df, hue, style, hue_order): | ||||
|     sns.set(rc={'text.usetex': True}, style='whitegrid') | ||||
|     lineplot = sns.lineplot(data=df, x='Episode', y='Score', ci=95, palette=PALETTE, | ||||
|                             hue_order=hue_order, hue=hue, style=style) | ||||
|     # lineplot.set_title(f'{sorted(list(df["Measurement"].unique()))}') | ||||
|     lineplot.set_title(f'{sorted(list(df["Measurement"].unique()))}') | ||||
|     plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0) | ||||
|     plt.tight_layout() | ||||
|     return lineplot | ||||
| @@ -48,6 +56,19 @@ def prepare_plt(df, hue, style, hue_order): | ||||
|     return lineplot | ||||
|  | ||||
|  | ||||
| def prepare_center_double_column_legend(df, hue, style, hue_order): | ||||
|     print('Struggling to plot Figure using LaTeX - going back to normal.') | ||||
|     plt.close('all') | ||||
|     sns.set(rc={'text.usetex': False}, style='whitegrid') | ||||
|     fig = plt.figure(figsize=(10, 11)) | ||||
|     lineplot = sns.lineplot(data=df, x='Episode', y='Score', hue=hue, style=style, | ||||
|                             ci=95, palette=PALETTE, hue_order=hue_order, legend=False) | ||||
|     # plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0) | ||||
|     lineplot.legend(hue_order, ncol=3, loc='lower center', title='Parameter Combinations', bbox_to_anchor=(0.5, -0.43)) | ||||
|     plt.tight_layout() | ||||
|     return lineplot | ||||
|  | ||||
|  | ||||
| def prepare_plot(filepath, results_df, ext='png', hue='Measurement', style=None, use_tex=False): | ||||
|     df = results_df.copy() | ||||
|     df[hue] = df[hue].str.replace('_', '-') | ||||
|   | ||||
| @@ -4,7 +4,10 @@ from pathlib import Path | ||||
| import yaml | ||||
| from stable_baselines3 import A2C, PPO, DQN | ||||
|  | ||||
| from environments.factory.factory_dirt import Constants as c | ||||
|  | ||||
| from environments.factory.factory_dirt import DirtFactory | ||||
| from environments.logging.envmonitor import EnvMonitor | ||||
| from environments.logging.recorder import EnvRecorder | ||||
|  | ||||
| warnings.filterwarnings('ignore', category=FutureWarning) | ||||
| @@ -16,32 +19,35 @@ if __name__ == '__main__': | ||||
|     determin = False | ||||
|     render = True | ||||
|     record = False | ||||
|     seed = 67 | ||||
|     verbose = True | ||||
|     seed = 13 | ||||
|     n_agents = 1 | ||||
|     # out_path = Path('study_out/e_1_new_reward/no_obs/dirt/A2C_new_reward/0_A2C_new_reward') | ||||
|     out_path = Path('study_out/test/dirt') | ||||
|     out_path = Path('study_out/reload') | ||||
|     model_path = out_path | ||||
|  | ||||
|     with (out_path / f'env_params.json').open('r') as f: | ||||
|         env_kwargs = yaml.load(f, Loader=yaml.FullLoader) | ||||
|         env_kwargs.update(additional_agent_placeholder=None, n_agents=n_agents, max_steps=150) | ||||
|         if gain_amount := env_kwargs.get('dirt_prop', {}).get('gain_amount', None): | ||||
|             env_kwargs['dirt_prop']['max_spawn_amount'] = gain_amount | ||||
|             del env_kwargs['dirt_prop']['gain_amount'] | ||||
|  | ||||
|         env_kwargs.update(record_episodes=record, done_at_collision=True) | ||||
|         env_kwargs.update(n_agents=n_agents, done_at_collision=False, verbose=verbose) | ||||
|  | ||||
|     this_model = out_path / 'model.zip' | ||||
|  | ||||
|     model_cls = PPO  # next(val for key, val in h.MODEL_MAP.items() if key in out_path.parent.name) | ||||
|     models = [model_cls.load(this_model)] | ||||
|     try: | ||||
|         # Legacy Cleanups | ||||
|         del env_kwargs['dirt_prop']['agent_can_interact'] | ||||
|         env_kwargs['verbose'] = True | ||||
|     except KeyError: | ||||
|         pass | ||||
|  | ||||
|     # Init Env | ||||
|     with DirtFactory(**env_kwargs) as env: | ||||
|         env = EnvRecorder(env) | ||||
|         env = EnvMonitor(env) | ||||
|         env = EnvRecorder(env) if record else env | ||||
|         obs_shape = env.observation_space.shape | ||||
|         # Evaluation Loop for i in range(n Episodes) | ||||
|         for episode in range(50): | ||||
|         for episode in range(500): | ||||
|             env_state = env.reset() | ||||
|             rew, done_bool = 0, False | ||||
|             while not done_bool: | ||||
| @@ -55,7 +61,17 @@ if __name__ == '__main__': | ||||
|                 rew += step_r | ||||
|                 if render: | ||||
|                     env.render() | ||||
|                 try: | ||||
|                     door = next(x for x in env.unwrapped.unwrapped[c.DOORS] if x.is_open) | ||||
|                     print('openDoor found') | ||||
|                 except StopIteration: | ||||
|                     pass | ||||
|  | ||||
|                 if done_bool: | ||||
|                     break | ||||
|             print(f'Factory run {episode} done, reward is:\n    {rew}') | ||||
|             print(f'Factory run {episode} done, steps taken {env.unwrapped.unwrapped._steps}, reward is:\n    {rew}') | ||||
|         env.save_run(out_path / 'reload_monitor.pick', | ||||
|                      auto_plotting_keys=['step_reward', 'cleanup_valid', 'cleanup_fail']) | ||||
|         if record: | ||||
|             env.save_records(out_path / 'reload_recorder.pick', save_occupation_map=True) | ||||
|     print('all done') | ||||
|   | ||||
| @@ -1,3 +1,4 @@ | ||||
| import itertools | ||||
| import sys | ||||
| from pathlib import Path | ||||
|  | ||||
| @@ -65,8 +66,8 @@ def load_model_run_baseline(policy_path, env_to_run): | ||||
|                 if done_bool: | ||||
|                     break | ||||
|             print(f'Factory run {episode} done, reward is:\n    {rew}') | ||||
|         recorded_env_factory.save_run(filepath=policy_path / f'monitor.pick') | ||||
|         recorded_env_factory.save_records(filepath=policy_path / f'recorder.json') | ||||
|         recorded_env_factory.save_run(filepath=policy_path / f'baseline_monitor.pick') | ||||
|         recorded_env_factory.save_records(filepath=policy_path / f'baseline_recorder.json') | ||||
|  | ||||
|  | ||||
| def load_model_run_combined(root_path, env_to_run, env_kwargs): | ||||
| @@ -89,134 +90,156 @@ def load_model_run_combined(root_path, env_to_run, env_kwargs): | ||||
|                                                        env_factory.named_observation_space, | ||||
|                                                        *[x.named_observation_space for x in models]) | ||||
|  | ||||
|         monitored_env_factory = EnvMonitor(env_factory) | ||||
|         recorded_env_factory = EnvRecorder(monitored_env_factory) | ||||
|         env = EnvMonitor(env_factory) | ||||
|         # Evaluation Loop for i in range(n Episodes) | ||||
|         for episode in range(5): | ||||
|             env_state = recorded_env_factory.reset() | ||||
|             env_state = env.reset() | ||||
|             rew, done_bool = 0, False | ||||
|             while not done_bool: | ||||
|                 translated_observations = observation_translator(env_state) | ||||
|                 actions = [model.predict(translated_observations[model_idx], deterministic=True)[0] | ||||
|                            for model_idx, model in enumerate(models)] | ||||
|                 translated_actions = action_translator(actions) | ||||
|                 env_state, step_r, done_bool, info_obj = recorded_env_factory.step(translated_actions) | ||||
|                 env_state, step_r, done_bool, info_obj = env.step(translated_actions) | ||||
|                 rew += step_r | ||||
|                 if done_bool: | ||||
|                     break | ||||
|             print(f'Factory run {episode} done, reward is:\n    {rew}') | ||||
|         recorded_env_factory.save_run(filepath=root_path / f'monitor.pick') | ||||
|         recorded_env_factory.save_records(filepath=root_path / f'recorder.json') | ||||
|         env.save_run(filepath=root_path / f'monitor_combined.pick') | ||||
|         # env.save_records(filepath=root_path / f'recorder_combined.json') | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     # What to do: | ||||
|     train = True | ||||
|     individual_run = True | ||||
|     individual_run = False | ||||
|     combined_run = False | ||||
|     multi_env = False | ||||
|  | ||||
|     train_steps = 2e6 | ||||
|     train_steps = 1e6 | ||||
|     frames_to_stack = 3 | ||||
|  | ||||
|     # Define a global studi save path | ||||
|     study_root_path = Path(__file__).parent.parent / 'study_out' / f'{Path(__file__).stem}' | ||||
|     paremters_of_interest = dict( | ||||
|         show_global_position_info=[True, False], | ||||
|         pomdp_r=[3], | ||||
|         cast_shadows=[True, False], | ||||
|         allow_diagonal_movement=[True], | ||||
|         parse_doors=[True, False], | ||||
|         doors_have_area=[True, False], | ||||
|         done_at_collision=[True, False] | ||||
|     ) | ||||
|     keys, vals = zip(*paremters_of_interest.items()) | ||||
|  | ||||
|     def policy_model_kwargs(): | ||||
|         return dict() | ||||
|     # Then we find all permutations for those values | ||||
|     p = list(itertools.product(*vals)) | ||||
|  | ||||
|     # Define Global Env Parameters | ||||
|     # Define properties object parameters | ||||
|     obs_props = ObservationProperties(render_agents=AgentRenderOptions.NOT, | ||||
|                                       additional_agent_placeholder=None, | ||||
|                                       omit_agent_self=True, | ||||
|                                       frames_to_stack=frames_to_stack, | ||||
|                                       pomdp_r=2, cast_shadows=True) | ||||
|     move_props = MovementProperties(allow_diagonal_movement=True, | ||||
|                                     allow_square_movement=True, | ||||
|                                     allow_no_op=False) | ||||
|     dirt_props = DirtProperties(initial_dirt_ratio=0.35, initial_dirt_spawn_r_var=0.1, | ||||
|                                 clean_amount=0.34, | ||||
|                                 max_spawn_amount=0.1, max_global_amount=20, | ||||
|                                 max_local_amount=1, spawn_frequency=0, max_spawn_ratio=0.05, | ||||
|                                 dirt_smear_amount=0.0, agent_can_interact=True) | ||||
|     item_props = ItemProperties(n_items=10, spawn_frequency=30, n_drop_off_locations=2, | ||||
|                                 max_agent_inventory_capacity=15) | ||||
|     dest_props = DestProperties(n_dests=4, spawn_mode=DestModeOptions.GROUPED, spawn_frequency=1) | ||||
|     factory_kwargs = dict(n_agents=1, max_steps=500, parse_doors=True, | ||||
|                           level_name='rooms', doors_have_area=True, | ||||
|                           verbose=False, | ||||
|                           mv_prop=move_props, | ||||
|                           obs_prop=obs_props, | ||||
|                           done_at_collision=False | ||||
|                           ) | ||||
|     # Finally we can create out list of dicts | ||||
|     result = [{keys[index]: entry[index] for index in range(len(entry))} for entry in p] | ||||
|  | ||||
|     # Bundle both environments with global kwargs and parameters | ||||
|     env_map = {} | ||||
|     env_map.update({'dirt': (DirtFactory, dict(dirt_prop=dirt_props, | ||||
|                                                **factory_kwargs.copy()))}) | ||||
|     env_map.update({'item': (ItemFactory, dict(item_prop=item_props, | ||||
|                                                **factory_kwargs.copy()))}) | ||||
|     # env_map.update({'dest': (DestFactory, dict(dest_prop=dest_props, | ||||
|     #                                           **factory_kwargs.copy()))}) | ||||
|     env_map.update({'combined': (DirtDestItemFactory, dict(dest_prop=dest_props, | ||||
|                                                            item_prop=item_props, | ||||
|                                                            dirt_prop=dirt_props, | ||||
|                                                            **factory_kwargs.copy()))}) | ||||
|     env_names = list(env_map.keys()) | ||||
|     for u in result: | ||||
|         file_name = '_'.join('_'.join([str(y)[0] for y in x]) for x in u.items()) | ||||
|         study_root_path = Path(__file__).parent.parent / 'study_out' / file_name | ||||
|  | ||||
|     # Train starts here ############################################################ | ||||
|     # Build Major Loop  parameters, parameter versions, Env Classes and models | ||||
|     if train: | ||||
|         for env_key in (env_key for env_key in env_map if 'combined' != env_key): | ||||
|             model_cls = h.MODEL_MAP['PPO'] | ||||
|             combination_path = study_root_path / env_key | ||||
|             env_class, env_kwargs = env_map[env_key] | ||||
|         # Model Kwargs | ||||
|         policy_model_kwargs = dict(ent_coef=0.01) | ||||
|  | ||||
|             # Output folder | ||||
|             if (combination_path / 'monitor.pick').exists(): | ||||
|                 continue | ||||
|             combination_path.mkdir(parents=True, exist_ok=True) | ||||
|         # Define Global Env Parameters | ||||
|         # Define properties object parameters | ||||
|         obs_props = ObservationProperties(render_agents=AgentRenderOptions.NOT, | ||||
|                                           additional_agent_placeholder=None, | ||||
|                                           omit_agent_self=True, | ||||
|                                           frames_to_stack=frames_to_stack, | ||||
|                                           pomdp_r=u['pomdp_r'], cast_shadows=u['cast_shadows'], | ||||
|                                           show_global_position_info=u['show_global_position_info']) | ||||
|         move_props = MovementProperties(allow_diagonal_movement=u['allow_diagonal_movement'], | ||||
|                                         allow_square_movement=True, | ||||
|                                         allow_no_op=False) | ||||
|         dirt_props = DirtProperties(initial_dirt_ratio=0.35, initial_dirt_spawn_r_var=0.1, | ||||
|                                     clean_amount=0.34, | ||||
|                                     max_spawn_amount=0.1, max_global_amount=20, | ||||
|                                     max_local_amount=1, spawn_frequency=0, max_spawn_ratio=0.05, | ||||
|                                     dirt_smear_amount=0.0) | ||||
|         item_props = ItemProperties(n_items=10, spawn_frequency=30, n_drop_off_locations=2, | ||||
|                                     max_agent_inventory_capacity=15) | ||||
|         dest_props = DestProperties(n_dests=4, spawn_mode=DestModeOptions.GROUPED, spawn_frequency=1) | ||||
|         factory_kwargs = dict(n_agents=1, max_steps=500, parse_doors=u['parse_doors'], | ||||
|                               level_name='rooms', doors_have_area=u['doors_have_area'], | ||||
|                               verbose=False, | ||||
|                               mv_prop=move_props, | ||||
|                               obs_prop=obs_props, | ||||
|                               done_at_collision=u['done_at_collision'] | ||||
|                               ) | ||||
|  | ||||
|             if not multi_env: | ||||
|                 env_factory = encapsule_env_factory(env_class, env_kwargs)() | ||||
|             else: | ||||
|                 env_factory = SubprocVecEnv([encapsule_env_factory(env_class, env_kwargs) | ||||
|                                              for _ in range(6)], start_method="spawn") | ||||
|         # Bundle both environments with global kwargs and parameters | ||||
|         env_map = {} | ||||
|         env_map.update({'dirt': (DirtFactory, dict(dirt_prop=dirt_props, | ||||
|                                                    **factory_kwargs.copy()), | ||||
|                                  ['cleanup_valid', 'cleanup_fail'])}) | ||||
|         # env_map.update({'item': (ItemFactory, dict(item_prop=item_props, | ||||
|         #                                            **factory_kwargs.copy()), | ||||
|         #                          ['DROPOFF_FAIL', 'ITEMACTION_FAIL', 'DROPOFF_VALID', 'ITEMACTION_VALID'])}) | ||||
|         # env_map.update({'dest': (DestFactory, dict(dest_prop=dest_props, | ||||
|         #                                           **factory_kwargs.copy()))}) | ||||
|         env_map.update({'combined': (DirtDestItemFactory, dict(dest_prop=dest_props, | ||||
|                                                                item_prop=item_props, | ||||
|                                                                dirt_prop=dirt_props, | ||||
|                                                                **factory_kwargs.copy()))}) | ||||
|         env_names = list(env_map.keys()) | ||||
|  | ||||
|             param_path = combination_path / f'env_params.json' | ||||
|             try: | ||||
|                 env_factory.env_method('save_params', param_path) | ||||
|             except AttributeError: | ||||
|                 env_factory.save_params(param_path) | ||||
|         # Train starts here ############################################################ | ||||
|         # Build Major Loop  parameters, parameter versions, Env Classes and models | ||||
|         if train: | ||||
|             for env_key in (env_key for env_key in env_map if 'combined' != env_key): | ||||
|                 model_cls = h.MODEL_MAP['PPO'] | ||||
|                 combination_path = study_root_path / env_key | ||||
|                 env_class, env_kwargs, env_plot_keys = env_map[env_key] | ||||
|  | ||||
|             # EnvMonitor Init | ||||
|             callbacks = [EnvMonitor(env_factory)] | ||||
|                 # Output folder | ||||
|                 if (combination_path / 'monitor.pick').exists(): | ||||
|                     continue | ||||
|                 combination_path.mkdir(parents=True, exist_ok=True) | ||||
|  | ||||
|             # Model Init | ||||
|             model = model_cls("MlpPolicy", env_factory, **policy_model_kwargs(), | ||||
|                               verbose=1, seed=69, device='cpu') | ||||
|                 if not multi_env: | ||||
|                     env_factory = encapsule_env_factory(env_class, env_kwargs)() | ||||
|                 else: | ||||
|                     env_factory = SubprocVecEnv([encapsule_env_factory(env_class, env_kwargs) | ||||
|                                                  for _ in range(6)], start_method="spawn") | ||||
|  | ||||
|             # Model train | ||||
|             model.learn(total_timesteps=int(train_steps), callback=callbacks) | ||||
|                 param_path = combination_path / f'env_params.json' | ||||
|                 try: | ||||
|                     env_factory.env_method('save_params', param_path) | ||||
|                 except AttributeError: | ||||
|                     env_factory.save_params(param_path) | ||||
|  | ||||
|             # Model save | ||||
|             try: | ||||
|                 model.named_action_space = env_factory.unwrapped.named_action_space | ||||
|                 model.named_observation_space = env_factory.unwrapped.named_observation_space | ||||
|             except AttributeError: | ||||
|                 model.named_action_space = env_factory.get_attr("named_action_space")[0] | ||||
|                 model.named_observation_space = env_factory.get_attr("named_observation_space")[0] | ||||
|             save_path = combination_path / f'model.zip' | ||||
|             model.save(save_path) | ||||
|                 # EnvMonitor Init | ||||
|                 callbacks = [EnvMonitor(env_factory)] | ||||
|  | ||||
|             # Monitor Save | ||||
|             callbacks[0].save_run(combination_path / 'monitor.pick') | ||||
|                 # Model Init | ||||
|                 model = model_cls("MlpPolicy", env_factory, **policy_model_kwargs, | ||||
|                                   verbose=1, seed=69, device='cpu') | ||||
|  | ||||
|             # Better be save then sorry: Clean up! | ||||
|             del env_factory, model | ||||
|             import gc | ||||
|             gc.collect() | ||||
|                 # Model train | ||||
|                 model.learn(total_timesteps=int(train_steps), callback=callbacks) | ||||
|  | ||||
|                 # Model save | ||||
|                 try: | ||||
|                     model.named_action_space = env_factory.unwrapped.named_action_space | ||||
|                     model.named_observation_space = env_factory.unwrapped.named_observation_space | ||||
|                 except AttributeError: | ||||
|                     model.named_action_space = env_factory.get_attr("named_action_space")[0] | ||||
|                     model.named_observation_space = env_factory.get_attr("named_observation_space")[0] | ||||
|                 save_path = combination_path / f'model.zip' | ||||
|                 model.save(save_path) | ||||
|  | ||||
|                 # Monitor Save | ||||
|                 callbacks[0].save_run(combination_path / 'monitor.pick', | ||||
|                                       auto_plotting_keys=['step_reward', 'collision'] + env_plot_keys) | ||||
|  | ||||
|                 # Better be save then sorry: Clean up! | ||||
|                 del env_factory, model | ||||
|                 import gc | ||||
|                 gc.collect() | ||||
|  | ||||
|     # Train ends here ############################################################ | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Steffen Illium
					Steffen Illium