mirror of
				https://github.com/illiumst/marl-factory-grid.git
				synced 2025-10-31 12:37:27 +01:00 
			
		
		
		
	Rework for performance
This commit is contained in:
		| @@ -13,6 +13,8 @@ from gym.wrappers import FrameStack | |||||||
| from environments.factory.base.shadow_casting import Map | from environments.factory.base.shadow_casting import Map | ||||||
| from environments import helpers as h | from environments import helpers as h | ||||||
| from environments.helpers import Constants as c | from environments.helpers import Constants as c | ||||||
|  | from environments.helpers import EnvActions as a | ||||||
|  | from environments.helpers import Rewards as r | ||||||
| from environments.factory.base.objects import Agent, Tile, Action | from environments.factory.base.objects import Agent, Tile, Action | ||||||
| from environments.factory.base.registers import Actions, Entities, Agents, Doors, FloorTiles, WallTiles, PlaceHolders, \ | from environments.factory.base.registers import Actions, Entities, Agents, Doors, FloorTiles, WallTiles, PlaceHolders, \ | ||||||
|     GlobalPositions |     GlobalPositions | ||||||
| @@ -205,8 +207,9 @@ class BaseFactory(gym.Env): | |||||||
|  |  | ||||||
|         if self.obs_prop.show_global_position_info: |         if self.obs_prop.show_global_position_info: | ||||||
|             global_positions = GlobalPositions(self._level_shape) |             global_positions = GlobalPositions(self._level_shape) | ||||||
|             obs_shape_2d = self._level_shape if not self._pomdp_r else ((self.pomdp_diameter,) * 2) |             # This moved into the GlobalPosition object | ||||||
|             global_positions.spawn_global_position_objects(obs_shape_2d, self[c.AGENT]) |             # obs_shape_2d = self._level_shape if not self._pomdp_r else ((self.pomdp_diameter,) * 2) | ||||||
|  |             global_positions.spawn_global_position_objects(self[c.AGENT]) | ||||||
|             self._entities.register_additional_items({c.GLOBAL_POSITION: global_positions}) |             self._entities.register_additional_items({c.GLOBAL_POSITION: global_positions}) | ||||||
|  |  | ||||||
|         # Return |         # Return | ||||||
| @@ -232,37 +235,51 @@ class BaseFactory(gym.Env): | |||||||
|         # Pre step Hook for later use |         # Pre step Hook for later use | ||||||
|         self.hook_pre_step() |         self.hook_pre_step() | ||||||
|  |  | ||||||
|         # Move this in a seperate function? |  | ||||||
|         for action, agent in zip(actions, self[c.AGENT]): |         for action, agent in zip(actions, self[c.AGENT]): | ||||||
|             agent.clear_temp_state() |             agent.clear_temp_state() | ||||||
|             action_obj = self._actions[int(action)] |             action_obj = self._actions[int(action)] | ||||||
|  |             step_result = dict(collisions=[], rewards=[], info={}, action_name='', action_valid=False) | ||||||
|             # cls.print(f'Action #{action} has been resolved to: {action_obj}') |             # cls.print(f'Action #{action} has been resolved to: {action_obj}') | ||||||
|             if h.EnvActions.is_move(action_obj): |             if a.is_move(action_obj): | ||||||
|                 valid = self._move_or_colide(agent, action_obj) |                 action_valid, reward = self._do_move_action(agent, action_obj) | ||||||
|             elif h.EnvActions.NOOP == agent.temp_action: |             elif a.NOOP == action_obj: | ||||||
|                 valid = c.VALID |                 action_valid = c.VALID | ||||||
|             elif h.EnvActions.USE_DOOR == action_obj: |                 reward = dict(value=r.NOOP, reason=a.NOOP, info={f'{agent.pos}_NOOP': 1}) | ||||||
|                 valid = self._handle_door_interaction(agent) |             elif a.USE_DOOR == action_obj: | ||||||
|  |                 action_valid, reward = self._handle_door_interaction(agent) | ||||||
|             else: |             else: | ||||||
|                 valid = self.do_additional_actions(agent, action_obj) |                 # noinspection PyTupleAssignmentBalance | ||||||
|             assert valid is not None, 'This should not happen, every Action musst be detected correctly!' |                 action_valid, reward = self.do_additional_actions(agent, action_obj) | ||||||
|             agent.temp_action = action_obj |                 # Not needed any more sice the tuple assignment above will fail in case of a failing action resolvement. | ||||||
|             agent.temp_valid = valid |                 # assert step_result is not None, 'This should not happen, every Action musst be detected correctly!' | ||||||
|  |             step_result['action_name'] = action_obj.identifier | ||||||
|         # In-between step Hook for later use |             step_result['action_valid'] = action_valid | ||||||
|         info = self.do_additional_step() |             step_result['rewards'].append(reward) | ||||||
|  |             agent.step_result = step_result | ||||||
|  |  | ||||||
|  |         # Additional step and Reward, Info Init | ||||||
|  |         rewards, info = self.do_additional_step() | ||||||
|  |         # Todo: Make this faster, so that only tiles of entities that can collide are searched. | ||||||
|         tiles_with_collisions = self.get_all_tiles_with_collisions() |         tiles_with_collisions = self.get_all_tiles_with_collisions() | ||||||
|         for tile in tiles_with_collisions: |         for tile in tiles_with_collisions: | ||||||
|             guests = tile.guests_that_can_collide |             guests = tile.guests_that_can_collide | ||||||
|             for i, guest in enumerate(guests): |             for i, guest in enumerate(guests): | ||||||
|  |                 # This does make a copy, but is faster than.copy() | ||||||
|                 this_collisions = guests[:] |                 this_collisions = guests[:] | ||||||
|                 del this_collisions[i] |                 del this_collisions[i] | ||||||
|                 guest.temp_collisions = this_collisions |                 assert hasattr(guest, 'step_result') | ||||||
|  |                 for collision in this_collisions: | ||||||
|  |                     guest.step_result['collisions'].append(collision) | ||||||
|  |  | ||||||
|         done = self.done_at_collision and tiles_with_collisions |         done = False | ||||||
|  |         if self.done_at_collision: | ||||||
|  |             if done_at_col := bool(tiles_with_collisions): | ||||||
|  |                 done = done_at_col | ||||||
|  |                 info.update(COLLISION_DONE=done_at_col) | ||||||
|  |  | ||||||
|         done = done or self.check_additional_done() |         additional_done, additional_done_info = self.check_additional_done() | ||||||
|  |         done = done or additional_done | ||||||
|  |         info.update(additional_done_info) | ||||||
|  |  | ||||||
|         # Step the door close intervall |         # Step the door close intervall | ||||||
|         if self.parse_doors: |         if self.parse_doors: | ||||||
| @@ -270,7 +287,8 @@ class BaseFactory(gym.Env): | |||||||
|                 doors.tick_doors() |                 doors.tick_doors() | ||||||
|  |  | ||||||
|         # Finalize |         # Finalize | ||||||
|         reward, reward_info = self.calculate_reward() |         reward, reward_info = self.build_reward_result() | ||||||
|  |  | ||||||
|         info.update(reward_info) |         info.update(reward_info) | ||||||
|         if self._steps >= self.max_steps: |         if self._steps >= self.max_steps: | ||||||
|             done = True |             done = True | ||||||
| @@ -285,7 +303,7 @@ class BaseFactory(gym.Env): | |||||||
|  |  | ||||||
|         return obs, reward, done, info |         return obs, reward, done, info | ||||||
|  |  | ||||||
|     def _handle_door_interaction(self, agent) -> c: |     def _handle_door_interaction(self, agent) -> (bool, dict): | ||||||
|         if doors := self[c.DOORS]: |         if doors := self[c.DOORS]: | ||||||
|             # Check if agent really is standing on a door: |             # Check if agent really is standing on a door: | ||||||
|             if self.doors_have_area: |             if self.doors_have_area: | ||||||
| @@ -294,12 +312,21 @@ class BaseFactory(gym.Env): | |||||||
|                 door = doors.by_pos(agent.pos) |                 door = doors.by_pos(agent.pos) | ||||||
|             if door is not None: |             if door is not None: | ||||||
|                 door.use() |                 door.use() | ||||||
|                 return c.VALID |                 valid = c.VALID | ||||||
|  |                 self.print(f'{agent.name} just used a door {door.name}') | ||||||
|  |                 info_dict = {f'{agent.name}_door_use_{door.name}': 1} | ||||||
|             # When he doesn't... |             # When he doesn't... | ||||||
|             else: |             else: | ||||||
|                 return c.NOT_VALID |                 valid = c.NOT_VALID | ||||||
|  |                 info_dict = {f'{agent.name}_failed_door_use': 1} | ||||||
|  |                 self.print(f'{agent.name} just tried to use a door at {agent.pos}, but there is none.') | ||||||
|  |  | ||||||
|         else: |         else: | ||||||
|             return c.NOT_VALID |             raise RuntimeError('This should not happen, since the door action should not be available.') | ||||||
|  |         reward = dict(value=r.USE_DOOR_VALID if valid else r.USE_DOOR_FAIL, | ||||||
|  |                       reason=a.USE_DOOR, info=info_dict) | ||||||
|  |  | ||||||
|  |         return valid, reward | ||||||
|  |  | ||||||
|     def _build_observations(self) -> np.typing.ArrayLike: |     def _build_observations(self) -> np.typing.ArrayLike: | ||||||
|         # Observation dict: |         # Observation dict: | ||||||
| @@ -308,7 +335,7 @@ class BaseFactory(gym.Env): | |||||||
|         # Generel Observations |         # Generel Observations | ||||||
|         lvl_obs = self[c.WALLS].as_array() |         lvl_obs = self[c.WALLS].as_array() | ||||||
|         door_obs = self[c.DOORS].as_array() |         door_obs = self[c.DOORS].as_array() | ||||||
|         agent_obs = self[c.AGENT].as_array() if self.obs_prop.render_agents != a_obs.NOT else None |         global_agent_obs = self[c.AGENT].as_array() if self.obs_prop.render_agents != a_obs.NOT else None | ||||||
|         placeholder_obs = self[c.AGENT_PLACEHOLDER].as_array() if self[c.AGENT_PLACEHOLDER] else None |         placeholder_obs = self[c.AGENT_PLACEHOLDER].as_array() if self[c.AGENT_PLACEHOLDER] else None | ||||||
|         add_obs_dict = self._additional_observations() |         add_obs_dict = self._additional_observations() | ||||||
|  |  | ||||||
| @@ -318,15 +345,20 @@ class BaseFactory(gym.Env): | |||||||
|             if self.obs_prop.render_agents != a_obs.NOT: |             if self.obs_prop.render_agents != a_obs.NOT: | ||||||
|                 if self.obs_prop.omit_agent_self: |                 if self.obs_prop.omit_agent_self: | ||||||
|                     if self.obs_prop.render_agents == a_obs.SEPERATE: |                     if self.obs_prop.render_agents == a_obs.SEPERATE: | ||||||
|                         agent_obs = np.take(agent_obs, [x for x in range(self.n_agents) if x != agent_idx], axis=0) |                         other_agent_obs_idx = [x for x in range(self.n_agents) if x != agent_idx] | ||||||
|  |                         agent_obs = np.take(global_agent_obs, other_agent_obs_idx, axis=0) | ||||||
|                     else: |                     else: | ||||||
|                         agent_obs = agent_obs.copy() |                         agent_obs = global_agent_obs.copy() | ||||||
|                         agent_obs[(0, *agent.pos)] -= agent.encoding |                         agent_obs[(0, *agent.pos)] -= agent.encoding | ||||||
|  |                 else: | ||||||
|  |                     agent_obs = global_agent_obs | ||||||
|  |             else: | ||||||
|  |                 agent_obs = global_agent_obs | ||||||
|  |  | ||||||
|             # Build Level Observations |             # Build Level Observations | ||||||
|             if self.obs_prop.render_agents == a_obs.LEVEL: |             if self.obs_prop.render_agents == a_obs.LEVEL: | ||||||
|                 lvl_obs = lvl_obs.copy() |                 lvl_obs = lvl_obs.copy() | ||||||
|                 lvl_obs += agent_obs |                 lvl_obs += global_agent_obs | ||||||
|  |  | ||||||
|             obs_dict[c.WALLS] = lvl_obs |             obs_dict[c.WALLS] = lvl_obs | ||||||
|             if self.obs_prop.render_agents in [a_obs.SEPERATE, a_obs.COMBINED]: |             if self.obs_prop.render_agents in [a_obs.SEPERATE, a_obs.COMBINED]: | ||||||
| @@ -340,11 +372,12 @@ class BaseFactory(gym.Env): | |||||||
|                 obsn = self._do_pomdp_cutout(agent, obsn) |                 obsn = self._do_pomdp_cutout(agent, obsn) | ||||||
|  |  | ||||||
|             raw_obs = self._additional_per_agent_raw_observations(agent) |             raw_obs = self._additional_per_agent_raw_observations(agent) | ||||||
|             obsn = np.vstack((obsn, *list(raw_obs.values()))) |             raw_obs = {key: np.expand_dims(val, 0) if val.ndim != 3 else val for key, val in raw_obs.items()} | ||||||
|  |             obsn = np.vstack((obsn, *raw_obs.values())) | ||||||
|  |  | ||||||
|             keys = list(chain(obs_dict.keys(), raw_obs.keys())) |             keys = list(chain(obs_dict.keys(), raw_obs.keys())) | ||||||
|             idxs = np.cumsum([x.shape[0] for x in chain(obs_dict.values(), raw_obs.values())]) - 1 |             idxs = np.cumsum([x.shape[0] for x in chain(obs_dict.values(), raw_obs.values())]) - 1 | ||||||
|             per_agent_expl_idx[agent.name] = {key: list(range(a, b)) for key, a, b in |             per_agent_expl_idx[agent.name] = {key: list(range(d, b)) for key, d, b in | ||||||
|                                               zip(keys, idxs, list(idxs[1:]) + [idxs[-1]+1, ])} |                                               zip(keys, idxs, list(idxs[1:]) + [idxs[-1]+1, ])} | ||||||
|  |  | ||||||
|             # Shadow Casting |             # Shadow Casting | ||||||
| @@ -390,7 +423,13 @@ class BaseFactory(gym.Env): | |||||||
|                 if door_shadowing: |                 if door_shadowing: | ||||||
|                     # noinspection PyUnboundLocalVariable |                     # noinspection PyUnboundLocalVariable | ||||||
|                     light_block_map[xs, ys] = 0 |                     light_block_map[xs, ys] = 0 | ||||||
|                 agent.temp_light_map = light_block_map.copy() |                 if agent.step_result: | ||||||
|  |                     agent.step_result['lightmap'] = light_block_map | ||||||
|  |                     pass | ||||||
|  |                 else: | ||||||
|  |                     assert self._steps == 0 | ||||||
|  |                     agent.step_result = {'action_name': a.NOOP, 'action_valid': True, | ||||||
|  |                                          'collisions': [], 'lightmap': light_block_map} | ||||||
|  |  | ||||||
|                 obsn[shadowed_obs] = ((obsn[shadowed_obs] * light_block_map) + 0.) - (1 - light_block_map) |                 obsn[shadowed_obs] = ((obsn[shadowed_obs] * light_block_map) + 0.) - (1 - light_block_map) | ||||||
|             else: |             else: | ||||||
| @@ -410,27 +449,27 @@ class BaseFactory(gym.Env): | |||||||
|  |  | ||||||
|     def _do_pomdp_cutout(self, agent, obs_to_be_padded): |     def _do_pomdp_cutout(self, agent, obs_to_be_padded): | ||||||
|         assert obs_to_be_padded.ndim == 3 |         assert obs_to_be_padded.ndim == 3 | ||||||
|         r, d = self._pomdp_r, self.pomdp_diameter |         ra, d = self._pomdp_r, self.pomdp_diameter | ||||||
|         x0, x1 = max(0, agent.x - r), min(agent.x + r + 1, self._level_shape[0]) |         x0, x1 = max(0, agent.x - ra), min(agent.x + ra + 1, self._level_shape[0]) | ||||||
|         y0, y1 = max(0, agent.y - r), min(agent.y + r + 1, self._level_shape[1]) |         y0, y1 = max(0, agent.y - ra), min(agent.y + ra + 1, self._level_shape[1]) | ||||||
|         oobs = obs_to_be_padded[:, x0:x1, y0:y1] |         oobs = obs_to_be_padded[:, x0:x1, y0:y1] | ||||||
|         if oobs.shape[1:] != (d, d): |         if oobs.shape[1:] != (d, d): | ||||||
|             if xd := oobs.shape[1] % d: |             if xd := oobs.shape[1] % d: | ||||||
|                 if agent.x > r: |                 if agent.x > ra: | ||||||
|                     x0_pad = 0 |                     x0_pad = 0 | ||||||
|                     x1_pad = (d - xd) |                     x1_pad = (d - xd) | ||||||
|                 else: |                 else: | ||||||
|                     x0_pad = r - agent.x |                     x0_pad = ra - agent.x | ||||||
|                     x1_pad = 0 |                     x1_pad = 0 | ||||||
|             else: |             else: | ||||||
|                 x0_pad, x1_pad = 0, 0 |                 x0_pad, x1_pad = 0, 0 | ||||||
|  |  | ||||||
|             if yd := oobs.shape[2] % d: |             if yd := oobs.shape[2] % d: | ||||||
|                 if agent.y > r: |                 if agent.y > ra: | ||||||
|                     y0_pad = 0 |                     y0_pad = 0 | ||||||
|                     y1_pad = (d - yd) |                     y1_pad = (d - yd) | ||||||
|                 else: |                 else: | ||||||
|                     y0_pad = r - agent.y |                     y0_pad = ra - agent.y | ||||||
|                     y1_pad = 0 |                     y1_pad = 0 | ||||||
|             else: |             else: | ||||||
|                 y0_pad, y1_pad = 0, 0 |                 y0_pad, y1_pad = 0, 0 | ||||||
| @@ -439,22 +478,39 @@ class BaseFactory(gym.Env): | |||||||
|         return oobs |         return oobs | ||||||
|  |  | ||||||
|     def get_all_tiles_with_collisions(self) -> List[Tile]: |     def get_all_tiles_with_collisions(self) -> List[Tile]: | ||||||
|         tiles_with_collisions = list() |         tiles = [x.tile for y in self._entities for x in y if | ||||||
|         for tile in self[c.FLOOR]: |                  y.can_collide and not isinstance(y, WallTiles) and x.can_collide and len(x.tile.guests) > 1] | ||||||
|             if tile.is_occupied(): |         if False: | ||||||
|                 guests = tile.guests_that_can_collide |             tiles_with_collisions = list() | ||||||
|                 if len(guests) >= 2: |             for tile in self[c.FLOOR]: | ||||||
|                     tiles_with_collisions.append(tile) |                 if tile.is_occupied(): | ||||||
|         return tiles_with_collisions |                     guests = tile.guests_that_can_collide | ||||||
|  |                     if len(guests) >= 2: | ||||||
|  |                         tiles_with_collisions.append(tile) | ||||||
|  |         return tiles | ||||||
|  |  | ||||||
|     def _move_or_colide(self, agent: Agent, action: Action) -> bool: |     def _do_move_action(self, agent: Agent, action: Action) -> (dict, dict): | ||||||
|  |         info_dict = dict() | ||||||
|         new_tile, valid = self._check_agent_move(agent, action) |         new_tile, valid = self._check_agent_move(agent, action) | ||||||
|         if valid: |         if valid: | ||||||
|             # Does not collide width level boundaries |             # Does not collide width level boundaries | ||||||
|             return agent.move(new_tile) |             valid = agent.move(new_tile) | ||||||
|  |             if valid: | ||||||
|  |                 # This will spam your logs, beware! | ||||||
|  |                 # self.print(f'{agent.name} just moved from {agent.last_pos} to {agent.pos}.') | ||||||
|  |                 # info_dict.update({f'{agent.pos}_move': 1}) | ||||||
|  |                 pass | ||||||
|  |             else: | ||||||
|  |                 valid = c.NOT_VALID | ||||||
|  |                 self.print(f'{agent.name} just hit the wall at {agent.pos}.') | ||||||
|  |                 info_dict.update({f'{agent.pos}_wall_collide': 1}) | ||||||
|         else: |         else: | ||||||
|             # Agent seems to be trying to collide in this step |             # Agent seems to be trying to Leave the level | ||||||
|             return c.NOT_VALID |             self.print(f'{agent.name} tried to leave the level {agent.pos}.') | ||||||
|  |             info_dict.update({f'{agent.pos}_wall_collide': 1}) | ||||||
|  |         reward_value = r.MOVEMENTS_VALID if valid else r.MOVEMENTS_FAIL | ||||||
|  |         reward = {'value': reward_value, 'reason': action.identifier, 'info': info_dict} | ||||||
|  |         return valid, reward | ||||||
|  |  | ||||||
|     def _check_agent_move(self, agent, action: Action) -> (Tile, bool): |     def _check_agent_move(self, agent, action: Action) -> (Tile, bool): | ||||||
|         # Actions |         # Actions | ||||||
| @@ -474,7 +530,7 @@ class BaseFactory(gym.Env): | |||||||
|             if doors := self[c.DOORS]: |             if doors := self[c.DOORS]: | ||||||
|                 if self.doors_have_area: |                 if self.doors_have_area: | ||||||
|                     if door := doors.by_pos(new_tile.pos): |                     if door := doors.by_pos(new_tile.pos): | ||||||
|                         if door.is_open: |                         if door.is_closed: | ||||||
|                             return agent.tile, c.NOT_VALID |                             return agent.tile, c.NOT_VALID | ||||||
|                         else:  # door.is_closed: |                         else:  # door.is_closed: | ||||||
|                             pass |                             pass | ||||||
| @@ -494,69 +550,46 @@ class BaseFactory(gym.Env): | |||||||
|  |  | ||||||
|         return new_tile, valid |         return new_tile, valid | ||||||
|  |  | ||||||
|     def calculate_reward(self) -> (int, dict): |     @abc.abstractmethod | ||||||
|  |     def additional_per_agent_rewards(self, agent) -> List[dict]: | ||||||
|  |         return [] | ||||||
|  |  | ||||||
|  |     def build_reward_result(self) -> (int, dict): | ||||||
|         # Returns: Reward, Info |         # Returns: Reward, Info | ||||||
|         per_agent_info_dict = defaultdict(dict) |         info = defaultdict(lambda: 0.0) | ||||||
|         reward = {} |  | ||||||
|  |  | ||||||
|  |         # Gather additional sub-env rewards and calculate collisions | ||||||
|         for agent in self[c.AGENT]: |         for agent in self[c.AGENT]: | ||||||
|             per_agent_reward = 0 |  | ||||||
|             if self._actions.is_moving_action(agent.temp_action): |  | ||||||
|                 if agent.temp_valid: |  | ||||||
|                     # info_dict.update(movement=1) |  | ||||||
|                     per_agent_reward -= 0.001 |  | ||||||
|                     pass |  | ||||||
|                 else: |  | ||||||
|                     per_agent_reward -= 0.05 |  | ||||||
|                     self.print(f'{agent.name} just hit the wall at {agent.pos}.') |  | ||||||
|                     per_agent_info_dict[agent.name].update({f'{agent.name}_vs_LEVEL': 1}) |  | ||||||
|  |  | ||||||
|             elif h.EnvActions.USE_DOOR == agent.temp_action: |             rewards = self.additional_per_agent_rewards(agent) | ||||||
|                 if agent.temp_valid: |             for reward in rewards: | ||||||
|                     # per_agent_reward += 0.00 |                 agent.step_result['rewards'].append(reward) | ||||||
|                     self.print(f'{agent.name} did just use the door at {agent.pos}.') |             if collisions := agent.step_result['collisions']: | ||||||
|                     per_agent_info_dict[agent.name].update(door_used=1) |                 self.print(f't = {self._steps}\t{agent.name} has collisions with {collisions}') | ||||||
|                 else: |                 info[c.COLLISION] += 1 | ||||||
|                     # per_agent_reward -= 0.00 |                 reward = {'value': r.COLLISION, 'reason': c.COLLISION, 'info': {f'{agent.name}_{c.COLLISION}': 1}} | ||||||
|                     self.print(f'{agent.name} just tried to use a door at {agent.pos}, but failed.') |                 agent.step_result['rewards'].append(reward) | ||||||
|                     per_agent_info_dict[agent.name].update({f'{agent.name}_failed_door_open': 1}) |  | ||||||
|             elif h.EnvActions.NOOP == agent.temp_action: |  | ||||||
|                 per_agent_info_dict[agent.name].update(no_op=1) |  | ||||||
|                 # per_agent_reward -= 0.00 |  | ||||||
|  |  | ||||||
|             # EnvMonitor Notes |  | ||||||
|             if agent.temp_valid: |  | ||||||
|                 per_agent_info_dict[agent.name].update(valid_action=1) |  | ||||||
|                 per_agent_info_dict[agent.name].update({f'{agent.name}_valid_action': 1}) |  | ||||||
|             else: |             else: | ||||||
|                 per_agent_info_dict[agent.name].update(failed_action=1) |                 # No Collisions, nothing to do | ||||||
|                 per_agent_info_dict[agent.name].update({f'{agent.name}_failed_action': 1}) |                 pass | ||||||
|  |  | ||||||
|             additional_reward, additional_info_dict = self.calculate_additional_reward(agent) |         comb_rewards = {agent.name: sum(x['value'] for x in agent.step_result['rewards']) for agent in self[c.AGENT]} | ||||||
|             per_agent_reward += additional_reward |  | ||||||
|             per_agent_info_dict[agent.name].update(additional_info_dict) |  | ||||||
|  |  | ||||||
|             if agent.temp_collisions: |  | ||||||
|                 self.print(f't = {self._steps}\t{agent.name} has collisions with {agent.temp_collisions}') |  | ||||||
|                 per_agent_info_dict[agent.name].update(collisions=1) |  | ||||||
|  |  | ||||||
|                 for other_agent in agent.temp_collisions: |  | ||||||
|                     per_agent_info_dict[agent.name].update({f'{agent.name}_vs_{other_agent.name}': 1}) |  | ||||||
|             reward[agent.name] = per_agent_reward |  | ||||||
|  |  | ||||||
|         # Combine the per_agent_info_dict: |         # Combine the per_agent_info_dict: | ||||||
|         combined_info_dict = defaultdict(lambda: 0) |         combined_info_dict = defaultdict(lambda: 0) | ||||||
|         for info_dict in per_agent_info_dict.values(): |         for agent in self[c.AGENT]: | ||||||
|             for key, value in info_dict.items(): |             for reward in agent.step_result['rewards']: | ||||||
|                 combined_info_dict[key] += value |                 combined_info_dict.update(reward['info']) | ||||||
|  |  | ||||||
|         combined_info_dict = dict(combined_info_dict) |         combined_info_dict = dict(combined_info_dict) | ||||||
|  |         combined_info_dict.update(info) | ||||||
|  |  | ||||||
|         if self.individual_rewards: |         if self.individual_rewards: | ||||||
|             self.print(f"rewards are {reward}") |             self.print(f"rewards are {comb_rewards}") | ||||||
|             reward = list(reward.values()) |             reward = list(comb_rewards.values()) | ||||||
|             return reward, combined_info_dict |             return reward, combined_info_dict | ||||||
|         else: |         else: | ||||||
|             reward = sum(reward.values()) |             reward = sum(comb_rewards.values()) | ||||||
|             self.print(f"reward is {reward}") |             self.print(f"reward is {reward}") | ||||||
|         return reward, combined_info_dict |         return reward, combined_info_dict | ||||||
|  |  | ||||||
| @@ -574,7 +607,7 @@ class BaseFactory(gym.Env): | |||||||
|         agents = [] |         agents = [] | ||||||
|         for i, agent in enumerate(self[c.AGENT]): |         for i, agent in enumerate(self[c.AGENT]): | ||||||
|             name, state = h.asset_str(agent) |             name, state = h.asset_str(agent) | ||||||
|             agents.append(RenderEntity(name, agent.pos, 1, 'none', state, i + 1, agent.temp_light_map)) |             agents.append(RenderEntity(name, agent.pos, 1, 'none', state, i + 1, agent.step_result['lightmap'])) | ||||||
|         doors = [] |         doors = [] | ||||||
|         if self.parse_doors: |         if self.parse_doors: | ||||||
|             for i, door in enumerate(self[c.DOORS]): |             for i, door in enumerate(self[c.DOORS]): | ||||||
| @@ -637,16 +670,16 @@ class BaseFactory(gym.Env): | |||||||
|         pass |         pass | ||||||
|  |  | ||||||
|     @abc.abstractmethod |     @abc.abstractmethod | ||||||
|     def do_additional_step(self) -> dict: |     def do_additional_step(self) -> (List[dict], dict): | ||||||
|         return {} |         return [], {} | ||||||
|  |  | ||||||
|     @abc.abstractmethod |     @abc.abstractmethod | ||||||
|     def do_additional_actions(self, agent: Agent, action: Action) -> Union[None, c]: |     def do_additional_actions(self, agent: Agent, action: Action) -> (bool, dict): | ||||||
|         return None |         return None | ||||||
|  |  | ||||||
|     @abc.abstractmethod |     @abc.abstractmethod | ||||||
|     def check_additional_done(self) -> bool: |     def check_additional_done(self) -> (bool, dict): | ||||||
|         return False |         return False, {} | ||||||
|  |  | ||||||
|     @abc.abstractmethod |     @abc.abstractmethod | ||||||
|     def _additional_observations(self) -> Dict[str, np.typing.ArrayLike]: |     def _additional_observations(self) -> Dict[str, np.typing.ArrayLike]: | ||||||
| @@ -660,8 +693,8 @@ class BaseFactory(gym.Env): | |||||||
|         return additional_raw_observations |         return additional_raw_observations | ||||||
|  |  | ||||||
|     @abc.abstractmethod |     @abc.abstractmethod | ||||||
|     def calculate_additional_reward(self, agent: Agent) -> (int, dict): |     def additional_per_agent_reward(self, agent: Agent) -> Dict[str, dict]: | ||||||
|         return 0, {} |         return {} | ||||||
|  |  | ||||||
|     @abc.abstractmethod |     @abc.abstractmethod | ||||||
|     def render_additional_assets(self): |     def render_additional_assets(self): | ||||||
|   | |||||||
| @@ -33,7 +33,7 @@ class Object: | |||||||
|         else: |         else: | ||||||
|             return self._name |             return self._name | ||||||
|  |  | ||||||
|     def __init__(self, str_ident: Union[str, None] = None, is_blocking_light=False, **kwargs): |     def __init__(self, str_ident: Union[str, None] = None, **kwargs): | ||||||
|  |  | ||||||
|         self._str_ident = str_ident |         self._str_ident = str_ident | ||||||
|  |  | ||||||
| @@ -45,7 +45,6 @@ class Object: | |||||||
|         else: |         else: | ||||||
|             raise ValueError('Please use either of the idents.') |             raise ValueError('Please use either of the idents.') | ||||||
|  |  | ||||||
|         self._is_blocking_light = is_blocking_light |  | ||||||
|         if kwargs: |         if kwargs: | ||||||
|             print(f'Following kwargs were passed, but ignored: {kwargs}') |             print(f'Following kwargs were passed, but ignored: {kwargs}') | ||||||
|  |  | ||||||
| @@ -62,6 +61,10 @@ class EnvObject(Object): | |||||||
|  |  | ||||||
|     _u_idx = defaultdict(lambda: 0) |     _u_idx = defaultdict(lambda: 0) | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def can_collide(self): | ||||||
|  |         return False | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def encoding(self): |     def encoding(self): | ||||||
|         return c.OCCUPIED_CELL |         return c.OCCUPIED_CELL | ||||||
| @@ -71,7 +74,10 @@ class EnvObject(Object): | |||||||
|         self._register = register |         self._register = register | ||||||
|  |  | ||||||
|     def change_register(self, register): |     def change_register(self, register): | ||||||
|  |         register.register_item(self) | ||||||
|  |         self._register.delete_env_object(self) | ||||||
|         self._register = register |         self._register = register | ||||||
|  |         return self._register == register | ||||||
|  |  | ||||||
|  |  | ||||||
| class BoundingMixin(Object): | class BoundingMixin(Object): | ||||||
| @@ -85,11 +91,6 @@ class BoundingMixin(Object): | |||||||
|         assert entity_to_be_bound is not None |         assert entity_to_be_bound is not None | ||||||
|         self._bound_entity = entity_to_be_bound |         self._bound_entity = entity_to_be_bound | ||||||
|  |  | ||||||
|     def __repr__(self): |  | ||||||
|         s = super(BoundingMixin, self).__repr__() |  | ||||||
|         i = s[:s.find('(')] |  | ||||||
|         return f'{s[:i]}[{self.bound_entity.name}]{s[i:]}' |  | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def name(self): |     def name(self): | ||||||
|         return f'{super(BoundingMixin, self).name}({self._bound_entity.name})' |         return f'{super(BoundingMixin, self).name}({self._bound_entity.name})' | ||||||
| @@ -101,13 +102,9 @@ class BoundingMixin(Object): | |||||||
| class Entity(EnvObject): | class Entity(EnvObject): | ||||||
|     """Full Env Entity that lives on the env Grid. Doors, Items, Dirt etc...""" |     """Full Env Entity that lives on the env Grid. Doors, Items, Dirt etc...""" | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def is_blocking_light(self): |  | ||||||
|         return self._is_blocking_light |  | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def can_collide(self): |     def can_collide(self): | ||||||
|         return True |         return False | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def x(self): |     def x(self): | ||||||
| @@ -125,10 +122,9 @@ class Entity(EnvObject): | |||||||
|     def tile(self): |     def tile(self): | ||||||
|         return self._tile |         return self._tile | ||||||
|  |  | ||||||
|     def __init__(self, tile, *args, is_blocking_light=True,  **kwargs): |     def __init__(self, tile, *args, **kwargs): | ||||||
|         super().__init__(*args, **kwargs) |         super().__init__(*args, **kwargs) | ||||||
|         self._tile = tile |         self._tile = tile | ||||||
|         self._is_blocking_light = is_blocking_light |  | ||||||
|         tile.enter(self) |         tile.enter(self) | ||||||
|  |  | ||||||
|     def summarize_state(self, **_) -> dict: |     def summarize_state(self, **_) -> dict: | ||||||
| @@ -170,9 +166,9 @@ class MoveableEntity(Entity): | |||||||
|             self._tile = next_tile |             self._tile = next_tile | ||||||
|             self._last_tile = curr_tile |             self._last_tile = curr_tile | ||||||
|             self._register.notify_change_to_value(self) |             self._register.notify_change_to_value(self) | ||||||
|             return True |             return c.VALID | ||||||
|         else: |         else: | ||||||
|             return False |             return c.NOT_VALID | ||||||
|  |  | ||||||
|  |  | ||||||
| ########################################################################## | ########################################################################## | ||||||
| @@ -284,6 +280,10 @@ class Tile(EnvObject): | |||||||
|  |  | ||||||
| class Wall(Tile): | class Wall(Tile): | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def can_collide(self): | ||||||
|  |         return True | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def encoding(self): |     def encoding(self): | ||||||
|         return c.OCCUPIED_CELL |         return c.OCCUPIED_CELL | ||||||
| @@ -381,6 +381,10 @@ class Door(Entity): | |||||||
|  |  | ||||||
| class Agent(MoveableEntity): | class Agent(MoveableEntity): | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def can_collide(self): | ||||||
|  |         return True | ||||||
|  |  | ||||||
|     def __init__(self, *args, **kwargs): |     def __init__(self, *args, **kwargs): | ||||||
|         super(Agent, self).__init__(*args, **kwargs) |         super(Agent, self).__init__(*args, **kwargs) | ||||||
|         self.clear_temp_state() |         self.clear_temp_state() | ||||||
| @@ -389,12 +393,9 @@ class Agent(MoveableEntity): | |||||||
|     def clear_temp_state(self): |     def clear_temp_state(self): | ||||||
|         # for attr in cls.__dict__: |         # for attr in cls.__dict__: | ||||||
|         #   if attr.startswith('temp'): |         #   if attr.startswith('temp'): | ||||||
|         self.temp_collisions = [] |         self.step_result = None | ||||||
|         self.temp_valid = None |  | ||||||
|         self.temp_action = None |  | ||||||
|         self.temp_light_map = None |  | ||||||
|  |  | ||||||
|     def summarize_state(self, **kwargs): |     def summarize_state(self, **kwargs): | ||||||
|         state_dict = super().summarize_state(**kwargs) |         state_dict = super().summarize_state(**kwargs) | ||||||
|         state_dict.update(valid=bool(self.temp_valid), action=str(self.temp_action)) |         state_dict.update(valid=bool(self.temp_action_result['valid']), action=str(self.temp_action_result['action'])) | ||||||
|         return state_dict |         return state_dict | ||||||
|   | |||||||
| @@ -85,19 +85,27 @@ class EnvObjectRegister(ObjectRegister): | |||||||
|     def encodings(self): |     def encodings(self): | ||||||
|         return [x.encoding for x in self] |         return [x.encoding for x in self] | ||||||
|  |  | ||||||
|     def __init__(self, obs_shape: (int, int), *args, individual_slices: bool = False, **kwargs): |     def __init__(self, obs_shape: (int, int), *args, | ||||||
|  |                  individual_slices: bool = False, | ||||||
|  |                  is_blocking_light: bool = False, | ||||||
|  |                  can_collide: bool = False, | ||||||
|  |                  can_be_shadowed: bool = True, **kwargs): | ||||||
|         super(EnvObjectRegister, self).__init__(*args, **kwargs) |         super(EnvObjectRegister, self).__init__(*args, **kwargs) | ||||||
|         self._shape = obs_shape |         self._shape = obs_shape | ||||||
|         self._array = None |         self._array = None | ||||||
|         self._individual_slices = individual_slices |         self._individual_slices = individual_slices | ||||||
|         self._lazy_eval_transforms = [] |         self._lazy_eval_transforms = [] | ||||||
|  |         self.is_blocking_light = is_blocking_light | ||||||
|  |         self.can_be_shadowed = can_be_shadowed | ||||||
|  |         self.can_collide = can_collide | ||||||
|  |  | ||||||
|     def register_item(self, other: EnvObject): |     def register_item(self, other: EnvObject): | ||||||
|         super(EnvObjectRegister, self).register_item(other) |         super(EnvObjectRegister, self).register_item(other) | ||||||
|         if self._array is None: |         if self._array is None: | ||||||
|             self._array = np.zeros((1, *self._shape)) |             self._array = np.zeros((1, *self._shape)) | ||||||
|         if self._individual_slices: |         else: | ||||||
|             self._array = np.vstack((self._array, np.zeros((1, *self._shape)))) |             if self._individual_slices: | ||||||
|  |                 self._array = np.vstack((self._array, np.zeros((1, *self._shape)))) | ||||||
|         self.notify_change_to_value(other) |         self.notify_change_to_value(other) | ||||||
|  |  | ||||||
|     def as_array(self): |     def as_array(self): | ||||||
| @@ -179,14 +187,9 @@ class EntityRegister(EnvObjectRegister, ABC): | |||||||
|     def tiles(self): |     def tiles(self): | ||||||
|         return [entity.tile for entity in self] |         return [entity.tile for entity in self] | ||||||
|  |  | ||||||
|     def __init__(self, level_shape, *args, |     def __init__(self, level_shape, *args, **kwargs): | ||||||
|                  is_blocking_light: bool = False, |  | ||||||
|                  can_be_shadowed: bool = True, |  | ||||||
|                  **kwargs): |  | ||||||
|         super(EntityRegister, self).__init__(level_shape, *args, **kwargs) |         super(EntityRegister, self).__init__(level_shape, *args, **kwargs) | ||||||
|         self._lazy_eval_transforms = [] |         self._lazy_eval_transforms = [] | ||||||
|         self.can_be_shadowed = can_be_shadowed |  | ||||||
|         self.is_blocking_light = is_blocking_light |  | ||||||
|  |  | ||||||
|     def __delitem__(self, name): |     def __delitem__(self, name): | ||||||
|         idx, obj = next((i, obj) for i, obj in enumerate(self) if obj.name == name) |         idx, obj = next((i, obj) for i, obj in enumerate(self) if obj.name == name) | ||||||
| @@ -220,7 +223,7 @@ class EntityRegister(EnvObjectRegister, ABC): | |||||||
|             return None |             return None | ||||||
|  |  | ||||||
|  |  | ||||||
| class BoundRegisterMixin(EnvObjectRegister, ABC): | class BoundEnvObjRegister(EnvObjectRegister, ABC): | ||||||
|  |  | ||||||
|     def __init__(self, entity_to_be_bound, *args, **kwargs): |     def __init__(self, entity_to_be_bound, *args, **kwargs): | ||||||
|         super().__init__(*args, **kwargs) |         super().__init__(*args, **kwargs) | ||||||
| @@ -229,6 +232,21 @@ class BoundRegisterMixin(EnvObjectRegister, ABC): | |||||||
|     def belongs_to_entity(self, entity): |     def belongs_to_entity(self, entity): | ||||||
|         return self._bound_entity == entity |         return self._bound_entity == entity | ||||||
|  |  | ||||||
|  |     def by_entity(self, entity): | ||||||
|  |         try: | ||||||
|  |             return next((x for x in self if x.belongs_to_entity(entity))) | ||||||
|  |         except StopIteration: | ||||||
|  |             return None | ||||||
|  |  | ||||||
|  |     def idx_by_entity(self, entity): | ||||||
|  |         try: | ||||||
|  |             return next((idx for idx, x in enumerate(self) if x.belongs_to_entity(entity))) | ||||||
|  |         except StopIteration: | ||||||
|  |             return None | ||||||
|  |  | ||||||
|  |     def as_array_by_entity(self, entity): | ||||||
|  |         return self._array[self.idx_by_entity(entity)] | ||||||
|  |  | ||||||
|  |  | ||||||
| class MovingEntityObjectRegister(EntityRegister, ABC): | class MovingEntityObjectRegister(EntityRegister, ABC): | ||||||
|  |  | ||||||
| @@ -255,6 +273,7 @@ class GlobalPositions(EnvObjectRegister): | |||||||
|  |  | ||||||
|     is_blocking_light = False |     is_blocking_light = False | ||||||
|     can_be_shadowed = False |     can_be_shadowed = False | ||||||
|  |     can_collide = False | ||||||
|  |  | ||||||
|     def __init__(self, *args, **kwargs): |     def __init__(self, *args, **kwargs): | ||||||
|         super(GlobalPositions, self).__init__(*args, is_per_agent=True, individual_slices=True, **kwargs) |         super(GlobalPositions, self).__init__(*args, is_per_agent=True, individual_slices=True, **kwargs) | ||||||
| @@ -360,7 +379,6 @@ class Entities(ObjectRegister): | |||||||
|  |  | ||||||
| class WallTiles(EntityRegister): | class WallTiles(EntityRegister): | ||||||
|     _accepted_objects = Wall |     _accepted_objects = Wall | ||||||
|     _light_blocking = True |  | ||||||
|  |  | ||||||
|     def as_array(self): |     def as_array(self): | ||||||
|         if not np.any(self._array): |         if not np.any(self._array): | ||||||
| @@ -371,9 +389,10 @@ class WallTiles(EntityRegister): | |||||||
|             self._array[0, x, y] = self._value |             self._array[0, x, y] = self._value | ||||||
|         return self._array |         return self._array | ||||||
|  |  | ||||||
|     def __init__(self, *args, **kwargs): |     def __init__(self, *args, is_blocking_light=True, **kwargs): | ||||||
|         super(WallTiles, self).__init__(*args, is_blocking_light=self._light_blocking, individual_slices=False, |         super(WallTiles, self).__init__(*args, individual_slices=False, | ||||||
|                                         **kwargs) |                                         can_collide=True, | ||||||
|  |                                         is_blocking_light=is_blocking_light, **kwargs) | ||||||
|         self._value = c.OCCUPIED_CELL |         self._value = c.OCCUPIED_CELL | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
| @@ -381,7 +400,7 @@ class WallTiles(EntityRegister): | |||||||
|         tiles = cls(*args, **kwargs) |         tiles = cls(*args, **kwargs) | ||||||
|         # noinspection PyTypeChecker |         # noinspection PyTypeChecker | ||||||
|         tiles.register_additional_items( |         tiles.register_additional_items( | ||||||
|             [cls._accepted_objects(pos, tiles, is_blocking_light=cls._light_blocking) |             [cls._accepted_objects(pos, tiles) | ||||||
|              for pos in argwhere_coordinates] |              for pos in argwhere_coordinates] | ||||||
|         ) |         ) | ||||||
|         return tiles |         return tiles | ||||||
| @@ -399,10 +418,9 @@ class WallTiles(EntityRegister): | |||||||
|  |  | ||||||
| class FloorTiles(WallTiles): | class FloorTiles(WallTiles): | ||||||
|     _accepted_objects = Tile |     _accepted_objects = Tile | ||||||
|     _light_blocking = False |  | ||||||
|  |  | ||||||
|     def __init__(self, *args, **kwargs): |     def __init__(self, *args, is_blocking_light=False, **kwargs): | ||||||
|         super(FloorTiles, self).__init__(*args, **kwargs) |         super(FloorTiles, self).__init__(*args, is_blocking_light=is_blocking_light, **kwargs) | ||||||
|         self._value = c.FREE_CELL |         self._value = c.FREE_CELL | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
| @@ -430,7 +448,7 @@ class Agents(MovingEntityObjectRegister): | |||||||
|     _accepted_objects = Agent |     _accepted_objects = Agent | ||||||
|  |  | ||||||
|     def __init__(self, *args, **kwargs): |     def __init__(self, *args, **kwargs): | ||||||
|         super().__init__(*args, **kwargs) |         super().__init__(*args, can_collide=True, **kwargs) | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def positions(self): |     def positions(self): | ||||||
| @@ -446,7 +464,7 @@ class Agents(MovingEntityObjectRegister): | |||||||
| class Doors(EntityRegister): | class Doors(EntityRegister): | ||||||
|  |  | ||||||
|     def __init__(self, *args, **kwargs): |     def __init__(self, *args, **kwargs): | ||||||
|         super(Doors, self).__init__(*args, is_blocking_light=True, **kwargs) |         super(Doors, self).__init__(*args, is_blocking_light=True, can_collide=True, **kwargs) | ||||||
|  |  | ||||||
|     _accepted_objects = Door |     _accepted_objects = Door | ||||||
|  |  | ||||||
|   | |||||||
| @@ -2,6 +2,7 @@ import numpy as np | |||||||
|  |  | ||||||
| from environments.helpers import Constants as c | from environments.helpers import Constants as c | ||||||
|  |  | ||||||
|  | # Multipliers for transforming coordinates to other octants: | ||||||
| mult_array = np.asarray([ | mult_array = np.asarray([ | ||||||
|     [1,  0,  0, -1, -1,  0,  0,  1], |     [1,  0,  0, -1, -1,  0,  0,  1], | ||||||
|     [0,  1, -1,  0,  0, -1,  1,  0], |     [0,  1, -1,  0,  0, -1,  1,  0], | ||||||
| @@ -11,8 +12,6 @@ mult_array = np.asarray([ | |||||||
|  |  | ||||||
|  |  | ||||||
| class Map(object): | class Map(object): | ||||||
|     # Multipliers for transforming coordinates to other octants: |  | ||||||
|  |  | ||||||
|     def __init__(self, map_array: np.typing.ArrayLike, diamond_slope: float = 0.9): |     def __init__(self, map_array: np.typing.ArrayLike, diamond_slope: float = 0.9): | ||||||
|         self.data = map_array |         self.data = map_array | ||||||
|         self.width, self.height = map_array.shape |         self.width, self.height = map_array.shape | ||||||
| @@ -33,7 +32,7 @@ class Map(object): | |||||||
|             self.light[x, y] = self.flag |             self.light[x, y] = self.flag | ||||||
|  |  | ||||||
|     def _cast_light(self, cx, cy, row, start, end, radius, xx, xy, yx, yy, id): |     def _cast_light(self, cx, cy, row, start, end, radius, xx, xy, yx, yy, id): | ||||||
|         "Recursive lightcasting function" |         """Recursive lightcasting function""" | ||||||
|         if start < end: |         if start < end: | ||||||
|             return |             return | ||||||
|         radius_squared = radius*radius |         radius_squared = radius*radius | ||||||
|   | |||||||
| @@ -1,4 +1,4 @@ | |||||||
| from typing import Union, NamedTuple, Dict | from typing import Union, NamedTuple, Dict, List | ||||||
|  |  | ||||||
| import numpy as np | import numpy as np | ||||||
|  |  | ||||||
| @@ -6,13 +6,29 @@ from environments.factory.base.base_factory import BaseFactory | |||||||
| from environments.factory.base.objects import Agent, Action, Entity, EnvObject, BoundingMixin | from environments.factory.base.objects import Agent, Action, Entity, EnvObject, BoundingMixin | ||||||
| from environments.factory.base.registers import EntityRegister, EnvObjectRegister | from environments.factory.base.registers import EntityRegister, EnvObjectRegister | ||||||
| from environments.factory.base.renderer import RenderEntity | from environments.factory.base.renderer import RenderEntity | ||||||
| from environments.helpers import Constants as c, Constants | from environments.helpers import Constants as BaseConstants | ||||||
|  | from environments.helpers import EnvActions as BaseActions | ||||||
|  | from environments.helpers import Rewards as BaseRewards | ||||||
|  |  | ||||||
| from environments import helpers as h | from environments import helpers as h | ||||||
|  |  | ||||||
|  |  | ||||||
| CHARGE_ACTION = h.EnvActions.CHARGE | class Constants(BaseConstants): | ||||||
| CHARGE_POD = 1 |     # Battery Env | ||||||
|  |     CHARGE_PODS          = 'Charge_Pod' | ||||||
|  |     BATTERIES            = 'BATTERIES' | ||||||
|  |     BATTERY_DISCHARGED   = 'DISCHARGED' | ||||||
|  |     CHARGE_POD           = 1 | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Actions(BaseActions): | ||||||
|  |     CHARGE              = 'do_charge_action' | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Rewards(BaseRewards): | ||||||
|  |     CHARGE_VALID        = 0.1 | ||||||
|  |     CHARGE_FAIL         = -0.1 | ||||||
|  |     BATTERY_DISCHARGED  = -1.0 | ||||||
|  |  | ||||||
|  |  | ||||||
| class BatteryProperties(NamedTuple): | class BatteryProperties(NamedTuple): | ||||||
| @@ -24,7 +40,12 @@ class BatteryProperties(NamedTuple): | |||||||
|     multi_charge: bool = False |     multi_charge: bool = False | ||||||
|  |  | ||||||
|  |  | ||||||
| class Battery(EnvObject, BoundingMixin): | c = Constants | ||||||
|  | a = Actions | ||||||
|  | r = Rewards | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Battery(BoundingMixin, EnvObject): | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def is_discharged(self): |     def is_discharged(self): | ||||||
| @@ -37,13 +58,13 @@ class Battery(EnvObject, BoundingMixin): | |||||||
|     def encoding(self): |     def encoding(self): | ||||||
|         return self.charge_level |         return self.charge_level | ||||||
|  |  | ||||||
|     def charge(self, amount) -> c: |     def do_charge_action(self, amount): | ||||||
|         if self.charge_level < 1: |         if self.charge_level < 1: | ||||||
|             # noinspection PyTypeChecker |             # noinspection PyTypeChecker | ||||||
|             self.charge_level = min(1, amount + self.charge_level) |             self.charge_level = min(1, amount + self.charge_level) | ||||||
|             return c.VALID |             return dict(valid=c.VALID, action=a.CHARGE, reward=r.CHARGE_VALID) | ||||||
|         else: |         else: | ||||||
|             return c.NOT_VALID |             return dict(valid=c.NOT_VALID, action=a.CHARGE, reward=r.CHARGE_FAIL) | ||||||
|  |  | ||||||
|     def decharge(self, amount) -> c: |     def decharge(self, amount) -> c: | ||||||
|         if self.charge_level != 0: |         if self.charge_level != 0: | ||||||
| @@ -54,7 +75,7 @@ class Battery(EnvObject, BoundingMixin): | |||||||
|         else: |         else: | ||||||
|             return c.NOT_VALID |             return c.NOT_VALID | ||||||
|  |  | ||||||
|     def summarize_state(self, **kwargs): |     def summarize_state(self, **_): | ||||||
|         attr_dict = {key: str(val) for key, val in self.__dict__.items() if not key.startswith('_') and key != 'data'} |         attr_dict = {key: str(val) for key, val in self.__dict__.items() if not key.startswith('_') and key != 'data'} | ||||||
|         attr_dict.update(dict(name=self.name)) |         attr_dict.update(dict(name=self.name)) | ||||||
|         return attr_dict |         return attr_dict | ||||||
| @@ -63,53 +84,43 @@ class Battery(EnvObject, BoundingMixin): | |||||||
| class BatteriesRegister(EnvObjectRegister): | class BatteriesRegister(EnvObjectRegister): | ||||||
|  |  | ||||||
|     _accepted_objects = Battery |     _accepted_objects = Battery | ||||||
|     is_blocking_light = False |  | ||||||
|     can_be_shadowed = False |  | ||||||
|  |  | ||||||
|     def __init__(self, *args, **kwargs): |     def __init__(self, *args, **kwargs): | ||||||
|         super(BatteriesRegister, self).__init__(*args, is_per_agent=True, individual_slices=True, **kwargs) |         super(BatteriesRegister, self).__init__(*args, individual_slices=True, | ||||||
|  |                                                 is_blocking_light=False, can_be_shadowed=False, **kwargs) | ||||||
|         self.is_observable = True |         self.is_observable = True | ||||||
|  |  | ||||||
|     def as_array(self): |     def spawn_batteries(self, agents, initial_charge_level): | ||||||
|         # ToDO: Make this Lazy |         batteries = [self._accepted_objects(initial_charge_level, agent, self) for _, agent in enumerate(agents)] | ||||||
|         self._array[:] = c.FREE_CELL.value |  | ||||||
|         for inv_idx, battery in enumerate(self): |  | ||||||
|             self._array[inv_idx] = battery.as_array() |  | ||||||
|         return self._array |  | ||||||
|  |  | ||||||
|     def spawn_batteries(self, agents, pomdp_r, initial_charge_level): |  | ||||||
|         batteries = [self._accepted_objects(pomdp_r, self._shape, agent, |  | ||||||
|                                             initial_charge_level) |  | ||||||
|                      for _, agent in enumerate(agents)] |  | ||||||
|         self.register_additional_items(batteries) |         self.register_additional_items(batteries) | ||||||
|  |  | ||||||
|     def idx_by_entity(self, entity): |  | ||||||
|         try: |  | ||||||
|             return next((idx for idx, bat in enumerate(self) if bat.belongs_to_entity(entity))) |  | ||||||
|         except StopIteration: |  | ||||||
|             return None |  | ||||||
|  |  | ||||||
|     def by_entity(self, entity): |  | ||||||
|         try: |  | ||||||
|             return next((bat for bat in self if bat.belongs_to_entity(entity))) |  | ||||||
|         except StopIteration: |  | ||||||
|             return None |  | ||||||
|  |  | ||||||
|     def summarize_states(self, n_steps=None): |     def summarize_states(self, n_steps=None): | ||||||
|         # as dict with additional nesting |         # as dict with additional nesting | ||||||
|         # return dict(items=super(Inventories, cls).summarize_states()) |         # return dict(items=super(Inventories, cls).summarize_states()) | ||||||
|         return super(BatteriesRegister, self).summarize_states(n_steps=n_steps) |         return super(BatteriesRegister, self).summarize_states(n_steps=n_steps) | ||||||
|  |  | ||||||
|  |     # Todo Move this to Mixin! | ||||||
|  |     def by_entity(self, entity): | ||||||
|  |         try: | ||||||
|  |             return next((x for x in self if x.belongs_to_entity(entity))) | ||||||
|  |         except StopIteration: | ||||||
|  |             return None | ||||||
|  |  | ||||||
|  |     def idx_by_entity(self, entity): | ||||||
|  |         try: | ||||||
|  |             return next((idx for idx, x in enumerate(self) if x.belongs_to_entity(entity))) | ||||||
|  |         except StopIteration: | ||||||
|  |             return None | ||||||
|  |  | ||||||
|  |     def as_array_by_entity(self, entity): | ||||||
|  |         return self._array[self.idx_by_entity(entity)] | ||||||
|  |  | ||||||
|  |  | ||||||
| class ChargePod(Entity): | class ChargePod(Entity): | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def can_collide(self): |  | ||||||
|         return False |  | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def encoding(self): |     def encoding(self): | ||||||
|         return CHARGE_POD |         return c.CHARGE_POD | ||||||
|  |  | ||||||
|     def __init__(self, *args, charge_rate: float = 0.4, |     def __init__(self, *args, charge_rate: float = 0.4, | ||||||
|                  multi_charge: bool = False, **kwargs): |                  multi_charge: bool = False, **kwargs): | ||||||
| @@ -120,9 +131,9 @@ class ChargePod(Entity): | |||||||
|     def charge_battery(self, battery: Battery): |     def charge_battery(self, battery: Battery): | ||||||
|         if battery.charge_level == 1.0: |         if battery.charge_level == 1.0: | ||||||
|             return c.NOT_VALID |             return c.NOT_VALID | ||||||
|         if sum(guest for guest in self.tile.guests if c.AGENT.name in guest.name) > 1: |         if sum(guest for guest in self.tile.guests if 'agent' in guest.name.lower()) > 1: | ||||||
|             return c.NOT_VALID |             return c.NOT_VALID | ||||||
|         battery.charge(self.charge_rate) |         battery.do_charge_action(self.charge_rate) | ||||||
|         return c.VALID |         return c.VALID | ||||||
|  |  | ||||||
|     def summarize_state(self, n_steps=None) -> dict: |     def summarize_state(self, n_steps=None) -> dict: | ||||||
| @@ -135,14 +146,6 @@ class ChargePods(EntityRegister): | |||||||
|  |  | ||||||
|     _accepted_objects = ChargePod |     _accepted_objects = ChargePod | ||||||
|  |  | ||||||
|     @DeprecationWarning |  | ||||||
|     def Xas_array(self): |  | ||||||
|         self._array[:] = c.FREE_CELL.value |  | ||||||
|         for item in self: |  | ||||||
|             if item.pos != c.NO_POS.value: |  | ||||||
|                 self._array[0, item.x, item.y] = item.encoding |  | ||||||
|         return self._array |  | ||||||
|  |  | ||||||
|     def __repr__(self): |     def __repr__(self): | ||||||
|         super(ChargePods, self).__repr__() |         super(ChargePods, self).__repr__() | ||||||
|  |  | ||||||
| @@ -155,14 +158,14 @@ class BatteryFactory(BaseFactory): | |||||||
|         self.btry_prop = btry_prop |         self.btry_prop = btry_prop | ||||||
|         super().__init__(*args, **kwargs) |         super().__init__(*args, **kwargs) | ||||||
|  |  | ||||||
|     def _additional_per_agent_raw_observations(self, agent) -> Dict[Constants, np.typing.ArrayLike]: |     def _additional_per_agent_raw_observations(self, agent) -> Dict[str, np.typing.ArrayLike]: | ||||||
|         additional_raw_observations = super()._additional_per_agent_raw_observations(agent) |         additional_raw_observations = super()._additional_per_agent_raw_observations(agent) | ||||||
|         additional_raw_observations.update({c.BATTERIES: self[c.BATTERIES].by_entity(agent).as_array()}) |         additional_raw_observations.update({c.BATTERIES: self[c.BATTERIES].as_array_by_entity(agent)}) | ||||||
|         return additional_raw_observations |         return additional_raw_observations | ||||||
|  |  | ||||||
|     def _additional_observations(self) -> Dict[Constants, np.typing.ArrayLike]: |     def _additional_observations(self) -> Dict[str, np.typing.ArrayLike]: | ||||||
|         additional_observations = super()._additional_observations() |         additional_observations = super()._additional_observations() | ||||||
|         additional_observations.update({c.CHARGE_POD: self[c.CHARGE_POD].as_array()}) |         additional_observations.update({c.CHARGE_PODS: self[c.CHARGE_PODS].as_array()}) | ||||||
|         return additional_observations |         return additional_observations | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
| @@ -178,12 +181,12 @@ class BatteryFactory(BaseFactory): | |||||||
|  |  | ||||||
|         batteries = BatteriesRegister(self._level_shape if not self._pomdp_r else ((self.pomdp_diameter,) * 2), |         batteries = BatteriesRegister(self._level_shape if not self._pomdp_r else ((self.pomdp_diameter,) * 2), | ||||||
|                                       ) |                                       ) | ||||||
|         batteries.spawn_batteries(self[c.AGENT], self._pomdp_r, self.btry_prop.initial_charge) |         batteries.spawn_batteries(self[c.AGENT], self.btry_prop.initial_charge) | ||||||
|         super_entities.update({c.BATTERIES: batteries, c.CHARGE_POD: charge_pods}) |         super_entities.update({c.BATTERIES: batteries, c.CHARGE_PODS: charge_pods}) | ||||||
|         return super_entities |         return super_entities | ||||||
|  |  | ||||||
|     def do_additional_step(self) -> dict: |     def do_additional_step(self) -> (List[dict], dict): | ||||||
|         info_dict = super(BatteryFactory, self).do_additional_step() |         super_reward_info = super(BatteryFactory, self).do_additional_step() | ||||||
|  |  | ||||||
|         # Decharge |         # Decharge | ||||||
|         batteries = self[c.BATTERIES] |         batteries = self[c.BATTERIES] | ||||||
| @@ -196,65 +199,70 @@ class BatteryFactory(BaseFactory): | |||||||
|  |  | ||||||
|             batteries.by_entity(agent).decharge(energy_consumption) |             batteries.by_entity(agent).decharge(energy_consumption) | ||||||
|  |  | ||||||
|         return info_dict |         return super_reward_info | ||||||
|  |  | ||||||
|     def do_charge(self, agent) -> c: |     def do_charge_action(self, agent) -> (dict, dict): | ||||||
|         if charge_pod := self[c.CHARGE_POD].by_pos(agent.pos): |         if charge_pod := self[c.CHARGE_PODS].by_pos(agent.pos): | ||||||
|             return charge_pod.charge_battery(self[c.BATTERIES].by_entity(agent)) |             valid = charge_pod.charge_battery(self[c.BATTERIES].by_entity(agent)) | ||||||
|  |             if valid: | ||||||
|  |                 info_dict = {f'{agent.name}_{a.CHARGE}_VALID': 1} | ||||||
|  |                 self.print(f'{agent.name} just charged batteries at {charge_pod.name}.') | ||||||
|  |             else: | ||||||
|  |                 info_dict = {f'{agent.name}_{a.CHARGE}_FAIL': 1} | ||||||
|  |                 self.print(f'{agent.name} failed to charged batteries at {charge_pod.name}.') | ||||||
|         else: |         else: | ||||||
|             return c.NOT_VALID |             valid = c.NOT_VALID | ||||||
|  |             info_dict = {f'{agent.name}_{a.CHARGE}_FAIL': 1} | ||||||
|  |             # info_dict = {f'{agent.name}_no_charger': 1} | ||||||
|  |             self.print(f'{agent.name} failed to charged batteries at {agent.pos}.') | ||||||
|  |         reward = dict(value=r.CHARGE_VALID if valid else r.CHARGE_FAIL, reason=a.CHARGE, info=info_dict) | ||||||
|  |         return valid, reward | ||||||
|  |  | ||||||
|     def do_additional_actions(self, agent: Agent, action: Action) -> Union[None, c]: |     def do_additional_actions(self, agent: Agent, action: Action) -> (bool, dict): | ||||||
|         valid = super().do_additional_actions(agent, action) |         action_result = super().do_additional_actions(agent, action) | ||||||
|         if valid is None: |         if action_result is None: | ||||||
|             if action == CHARGE_ACTION: |             if action == a.CHARGE: | ||||||
|                 valid = self.do_charge(agent) |                 action_result = self.do_charge_action(agent) | ||||||
|                 return valid |                 return action_result | ||||||
|             else: |             else: | ||||||
|                 return None |                 return None | ||||||
|         else: |         else: | ||||||
|             return valid |             return action_result | ||||||
|         pass |         pass | ||||||
|  |  | ||||||
|     def do_additional_reset(self) -> None: |     def do_additional_reset(self) -> None: | ||||||
|         # There is Nothing to reset. |         # There is Nothing to reset. | ||||||
|         pass |         pass | ||||||
|  |  | ||||||
|     def check_additional_done(self) -> bool: |     def check_additional_done(self) -> (bool, dict): | ||||||
|         super_done = super(BatteryFactory, self).check_additional_done() |         super_done, super_dict = super(BatteryFactory, self).check_additional_done() | ||||||
|         if super_done: |         if super_done: | ||||||
|             return super_done |             return super_done, super_dict | ||||||
|         else: |         else: | ||||||
|             return self.btry_prop.done_when_discharged and any(battery.is_discharged for battery in self[c.BATTERIES]) |             if self.btry_prop.done_when_discharged: | ||||||
|  |                 if btry_done := any(battery.is_discharged for battery in self[c.BATTERIES]): | ||||||
|  |                     super_dict.update(DISCHARGE_DONE=1) | ||||||
|  |                     return btry_done, super_dict | ||||||
|  |                 else: | ||||||
|  |                     pass | ||||||
|  |             else: | ||||||
|  |                 pass | ||||||
|         pass |         pass | ||||||
|  |  | ||||||
|     def calculate_additional_reward(self, agent: Agent) -> (int, dict): |     def additional_per_agent_reward(self, agent: Agent) -> Dict[str, dict]: | ||||||
|         reward, info_dict = super(BatteryFactory, self).calculate_additional_reward(agent) |         reward_event_dict = super(BatteryFactory, self).additional_per_agent_reward(agent) | ||||||
|         if h.EnvActions.CHARGE == agent.temp_action: |  | ||||||
|             if agent.temp_valid: |  | ||||||
|                 charge_pod = self[c.CHARGE_POD].by_pos(agent.pos) |  | ||||||
|                 info_dict.update({f'{agent.name}_charge': 1}) |  | ||||||
|                 info_dict.update(agent_charged=1) |  | ||||||
|                 self.print(f'{agent.name} just charged batteries at {charge_pod.pos}.') |  | ||||||
|                 reward += 0.1 |  | ||||||
|             else: |  | ||||||
|                 self[c.DROP_OFF].by_pos(agent.pos) |  | ||||||
|                 info_dict.update({f'{agent.name}_failed_charge': 1}) |  | ||||||
|                 info_dict.update(failed_charge=1) |  | ||||||
|                 self.print(f'{agent.name} just tried to charge at {agent.pos}, but failed.') |  | ||||||
|                 reward -= 0.1 |  | ||||||
|  |  | ||||||
|         if self[c.BATTERIES].by_entity(agent).is_discharged: |         if self[c.BATTERIES].by_entity(agent).is_discharged: | ||||||
|             info_dict.update({f'{agent.name}_discharged': 1}) |             self.print(f'{agent.name} Battery is discharged!') | ||||||
|             reward -= 1 |             info_dict = {f'{agent.name}_{c.BATTERY_DISCHARGED}': 1} | ||||||
|  |             reward_event_dict.update({c.BATTERY_DISCHARGED: {'reward': r.BATTERY_DISCHARGED, 'info': info_dict}}) | ||||||
|         else: |         else: | ||||||
|             info_dict.update({f'{agent.name}_battery_level': self[c.BATTERIES].by_entity(agent).charge_level}) |             # All Fine | ||||||
|         return reward, info_dict |             pass | ||||||
|  |         return reward_event_dict | ||||||
|  |  | ||||||
|     def render_additional_assets(self): |     def render_additional_assets(self): | ||||||
|         # noinspection PyUnresolvedReferences |         # noinspection PyUnresolvedReferences | ||||||
|         additional_assets = super().render_additional_assets() |         additional_assets = super().render_additional_assets() | ||||||
|         charge_pods = [RenderEntity(c.CHARGE_POD.value, charge_pod.tile.pos) for charge_pod in self[c.CHARGE_POD]] |         charge_pods = [RenderEntity(c.CHARGE_PODS, charge_pod.tile.pos) for charge_pod in self[c.CHARGE_PODS]] | ||||||
|         additional_assets.extend(charge_pods) |         additional_assets.extend(charge_pods) | ||||||
|         return additional_assets |         return additional_assets | ||||||
|  |  | ||||||
|   | |||||||
| @@ -6,18 +6,32 @@ import numpy as np | |||||||
| import random | import random | ||||||
|  |  | ||||||
| from environments.factory.base.base_factory import BaseFactory | from environments.factory.base.base_factory import BaseFactory | ||||||
| from environments.helpers import Constants as c, Constants | from environments.helpers import Constants as BaseConstants | ||||||
| from environments import helpers as h | from environments.helpers import EnvActions as BaseActions | ||||||
|  | from environments.helpers import Rewards as BaseRewards | ||||||
| from environments.factory.base.objects import Agent, Entity, Action | from environments.factory.base.objects import Agent, Entity, Action | ||||||
| from environments.factory.base.registers import Entities, EntityRegister | from environments.factory.base.registers import Entities, EntityRegister | ||||||
|  |  | ||||||
| from environments.factory.base.renderer import RenderEntity | from environments.factory.base.renderer import RenderEntity | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Constants(BaseConstants): | ||||||
|  |     # Destination Env | ||||||
|  |     DEST                    = 'Destination' | ||||||
|  |     DESTINATION             = 1 | ||||||
|  |     DESTINATION_DONE        = 0.5 | ||||||
|  |     DEST_REACHED            = 'ReachedDestination' | ||||||
|  |  | ||||||
|  |  | ||||||
| DESTINATION = 1 | class Actions(BaseActions): | ||||||
| DESTINATION_DONE = 0.5 |     WAIT_ON_DEST    = 'WAIT' | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Rewards(BaseRewards): | ||||||
|  |  | ||||||
|  |     WAIT_VALID      = 0.1 | ||||||
|  |     WAIT_FAIL      = -0.1 | ||||||
|  |     DEST_REACHED    = 5.0 | ||||||
|  |  | ||||||
|  |  | ||||||
| class Destination(Entity): | class Destination(Entity): | ||||||
| @@ -30,20 +44,16 @@ class Destination(Entity): | |||||||
|     def currently_dwelling_names(self): |     def currently_dwelling_names(self): | ||||||
|         return self._per_agent_times.keys() |         return self._per_agent_times.keys() | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def can_collide(self): |  | ||||||
|         return False |  | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def encoding(self): |     def encoding(self): | ||||||
|         return DESTINATION |         return c.DESTINATION | ||||||
|  |  | ||||||
|     def __init__(self, *args, dwell_time: int = 0, **kwargs): |     def __init__(self, *args, dwell_time: int = 0, **kwargs): | ||||||
|         super(Destination, self).__init__(*args, **kwargs) |         super(Destination, self).__init__(*args, **kwargs) | ||||||
|         self.dwell_time = dwell_time |         self.dwell_time = dwell_time | ||||||
|         self._per_agent_times = defaultdict(lambda: dwell_time) |         self._per_agent_times = defaultdict(lambda: dwell_time) | ||||||
|  |  | ||||||
|     def wait(self, agent: Agent): |     def do_wait_action(self, agent: Agent): | ||||||
|         self._per_agent_times[agent.name] -= 1 |         self._per_agent_times[agent.name] -= 1 | ||||||
|         return c.VALID |         return c.VALID | ||||||
|  |  | ||||||
| @@ -52,7 +62,7 @@ class Destination(Entity): | |||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def is_considered_reached(self): |     def is_considered_reached(self): | ||||||
|         agent_at_position = any(c.AGENT.name.lower() in x.name.lower() for x in self.tile.guests_that_can_collide) |         agent_at_position = any(c.AGENT.lower() in x.name.lower() for x in self.tile.guests_that_can_collide) | ||||||
|         return (agent_at_position and not self.dwell_time) or any(x == 0 for x in self._per_agent_times.values()) |         return (agent_at_position and not self.dwell_time) or any(x == 0 for x in self._per_agent_times.values()) | ||||||
|  |  | ||||||
|     def agent_is_dwelling(self, agent: Agent): |     def agent_is_dwelling(self, agent: Agent): | ||||||
| @@ -67,15 +77,19 @@ class Destination(Entity): | |||||||
| class Destinations(EntityRegister): | class Destinations(EntityRegister): | ||||||
|  |  | ||||||
|     _accepted_objects = Destination |     _accepted_objects = Destination | ||||||
|     _light_blocking = False |  | ||||||
|  |     def __init__(self, *args, **kwargs): | ||||||
|  |         super().__init__(*args, **kwargs) | ||||||
|  |         self.is_blocking_light = False | ||||||
|  |         self.can_be_shadowed = False | ||||||
|  |  | ||||||
|     def as_array(self): |     def as_array(self): | ||||||
|         self._array[:] = c.FREE_CELL.value |         self._array[:] = c.FREE_CELL | ||||||
|         # ToDo: Switch to new Style Array Put |         # ToDo: Switch to new Style Array Put | ||||||
|         # indices = list(zip(range(len(cls)), *zip(*[x.pos for x in cls]))) |         # indices = list(zip(range(len(cls)), *zip(*[x.pos for x in cls]))) | ||||||
|         # np.put(cls._array, [np.ravel_multi_index(x, cls._array.shape) for x in indices], cls.encodings) |         # np.put(cls._array, [np.ravel_multi_index(x, cls._array.shape) for x in indices], cls.encodings) | ||||||
|         for item in self: |         for item in self: | ||||||
|             if item.pos != c.NO_POS.value: |             if item.pos != c.NO_POS: | ||||||
|                 self._array[0, item.x, item.y] = item.encoding |                 self._array[0, item.x, item.y] = item.encoding | ||||||
|         return self._array |         return self._array | ||||||
|  |  | ||||||
| @@ -85,10 +99,11 @@ class Destinations(EntityRegister): | |||||||
|  |  | ||||||
| class ReachedDestinations(Destinations): | class ReachedDestinations(Destinations): | ||||||
|     _accepted_objects = Destination |     _accepted_objects = Destination | ||||||
|     _light_blocking = False |  | ||||||
|  |  | ||||||
|     def __init__(self, *args, **kwargs): |     def __init__(self, *args, **kwargs): | ||||||
|         super(ReachedDestinations, self).__init__(*args, **kwargs) |         super(ReachedDestinations, self).__init__(*args, **kwargs) | ||||||
|  |         self.can_be_shadowed = False | ||||||
|  |         self.is_blocking_light = False | ||||||
|  |  | ||||||
|     def summarize_states(self, n_steps=None): |     def summarize_states(self, n_steps=None): | ||||||
|         return {} |         return {} | ||||||
| @@ -102,7 +117,7 @@ class DestModeOptions(object): | |||||||
|  |  | ||||||
| class DestProperties(NamedTuple): | class DestProperties(NamedTuple): | ||||||
|     n_dests:                                     int = 1     # How many destinations are there |     n_dests:                                     int = 1     # How many destinations are there | ||||||
|     dwell_time:                                  int = 0     # How long does the agent need to "wait" on a destination |     dwell_time:                                  int = 0     # How long does the agent need to "do_wait_action" on a destination | ||||||
|     spawn_frequency:                             int = 0 |     spawn_frequency:                             int = 0 | ||||||
|     spawn_in_other_zone:                        bool = True  # |     spawn_in_other_zone:                        bool = True  # | ||||||
|     spawn_mode:                                  str = DestModeOptions.DONE |     spawn_mode:                                  str = DestModeOptions.DONE | ||||||
| @@ -113,6 +128,11 @@ class DestProperties(NamedTuple): | |||||||
|     assert (spawn_mode == DestModeOptions.DONE) != bool(spawn_frequency) |     assert (spawn_mode == DestModeOptions.DONE) != bool(spawn_frequency) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | c = Constants | ||||||
|  | a = Actions | ||||||
|  | r = Rewards | ||||||
|  |  | ||||||
|  |  | ||||||
| # noinspection PyAttributeOutsideInit, PyAbstractClass | # noinspection PyAttributeOutsideInit, PyAbstractClass | ||||||
| class DestFactory(BaseFactory): | class DestFactory(BaseFactory): | ||||||
|     # noinspection PyMissingConstructor |     # noinspection PyMissingConstructor | ||||||
| @@ -131,7 +151,7 @@ class DestFactory(BaseFactory): | |||||||
|         # noinspection PyUnresolvedReferences |         # noinspection PyUnresolvedReferences | ||||||
|         super_actions = super().additional_actions |         super_actions = super().additional_actions | ||||||
|         if self.dest_prop.dwell_time: |         if self.dest_prop.dwell_time: | ||||||
|             super_actions.append(Action(enum_ident=h.EnvActions.WAIT_ON_DEST)) |             super_actions.append(Action(enum_ident=a.WAIT_ON_DEST)) | ||||||
|         return super_actions |         return super_actions | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
| @@ -147,27 +167,32 @@ class DestFactory(BaseFactory): | |||||||
|         ) |         ) | ||||||
|         reached_destinations = ReachedDestinations(level_shape=self._level_shape) |         reached_destinations = ReachedDestinations(level_shape=self._level_shape) | ||||||
|  |  | ||||||
|         super_entities.update({c.DESTINATION: destinations, c.REACHEDDESTINATION: reached_destinations}) |         super_entities.update({c.DEST: destinations, c.DEST_REACHED: reached_destinations}) | ||||||
|         return super_entities |         return super_entities | ||||||
|  |  | ||||||
|     def wait(self, agent: Agent): |     def do_wait_action(self, agent: Agent) -> (dict, dict): | ||||||
|         if destiantion := self[c.DESTINATION].by_pos(agent.pos): |         if destination := self[c.DEST].by_pos(agent.pos): | ||||||
|             valid = destiantion.wait(agent) |             valid = destination.do_wait_action(agent) | ||||||
|             return valid |             self.print(f'{agent.name} just waited at {agent.pos}') | ||||||
|  |             info_dict = {f'{agent.name}_{a.WAIT_ON_DEST}_VALID': 1} | ||||||
|         else: |         else: | ||||||
|             return c.NOT_VALID |             valid = c.NOT_VALID | ||||||
|  |             self.print(f'{agent.name} just tried to do_wait_action do_wait_action at {agent.pos} but failed') | ||||||
|  |             info_dict = {f'{agent.name}_{a.WAIT_ON_DEST}_FAIL': 1} | ||||||
|  |         reward = dict(value=r.WAIT_VALID if valid else r.WAIT_FAIL, reason=a.WAIT_ON_DEST, info=info_dict) | ||||||
|  |         return valid, reward | ||||||
|  |  | ||||||
|     def do_additional_actions(self, agent: Agent, action: Action) -> Union[None, c]: |     def do_additional_actions(self, agent: Agent, action: Action) -> (dict, dict): | ||||||
|         # noinspection PyUnresolvedReferences |         # noinspection PyUnresolvedReferences | ||||||
|         valid = super().do_additional_actions(agent, action) |         super_action_result = super().do_additional_actions(agent, action) | ||||||
|         if valid is None: |         if super_action_result is None: | ||||||
|             if action == h.EnvActions.WAIT_ON_DEST: |             if action == a.WAIT_ON_DEST: | ||||||
|                 valid = self.wait(agent) |                 action_result = self.do_wait_action(agent) | ||||||
|                 return valid |                 return action_result | ||||||
|             else: |             else: | ||||||
|                 return None |                 return None | ||||||
|         else: |         else: | ||||||
|             return valid |             return super_action_result | ||||||
|  |  | ||||||
|     def do_additional_reset(self) -> None: |     def do_additional_reset(self) -> None: | ||||||
|         # noinspection PyUnresolvedReferences |         # noinspection PyUnresolvedReferences | ||||||
| @@ -180,14 +205,14 @@ class DestFactory(BaseFactory): | |||||||
|         if destinations_to_spawn: |         if destinations_to_spawn: | ||||||
|             n_dest_to_spawn = len(destinations_to_spawn) |             n_dest_to_spawn = len(destinations_to_spawn) | ||||||
|             if self.dest_prop.spawn_mode != DestModeOptions.GROUPED: |             if self.dest_prop.spawn_mode != DestModeOptions.GROUPED: | ||||||
|                 destinations = [Destination(tile) for tile in self[c.FLOOR].empty_tiles[:n_dest_to_spawn]] |                 destinations = [Destination(tile, c.DEST) for tile in self[c.FLOOR].empty_tiles[:n_dest_to_spawn]] | ||||||
|                 self[c.DESTINATION].register_additional_items(destinations) |                 self[c.DEST].register_additional_items(destinations) | ||||||
|                 for dest in destinations_to_spawn: |                 for dest in destinations_to_spawn: | ||||||
|                     del self._dest_spawn_timer[dest] |                     del self._dest_spawn_timer[dest] | ||||||
|                 self.print(f'{n_dest_to_spawn} new destinations have been spawned') |                 self.print(f'{n_dest_to_spawn} new destinations have been spawned') | ||||||
|             elif self.dest_prop.spawn_mode == DestModeOptions.GROUPED and n_dest_to_spawn == self.dest_prop.n_dests: |             elif self.dest_prop.spawn_mode == DestModeOptions.GROUPED and n_dest_to_spawn == self.dest_prop.n_dests: | ||||||
|                 destinations = [Destination(tile) for tile in self[c.FLOOR].empty_tiles[:n_dest_to_spawn]] |                 destinations = [Destination(tile, self[c.DEST]) for tile in self[c.FLOOR].empty_tiles[:n_dest_to_spawn]] | ||||||
|                 self[c.DESTINATION].register_additional_items(destinations) |                 self[c.DEST].register_additional_items(destinations) | ||||||
|                 for dest in destinations_to_spawn: |                 for dest in destinations_to_spawn: | ||||||
|                     del self._dest_spawn_timer[dest] |                     del self._dest_spawn_timer[dest] | ||||||
|                 self.print(f'{n_dest_to_spawn} new destinations have been spawned') |                 self.print(f'{n_dest_to_spawn} new destinations have been spawned') | ||||||
| @@ -197,15 +222,14 @@ class DestFactory(BaseFactory): | |||||||
|         else: |         else: | ||||||
|             self.print('No Items are spawning, limit is reached.') |             self.print('No Items are spawning, limit is reached.') | ||||||
|  |  | ||||||
|     def do_additional_step(self) -> dict: |     def do_additional_step(self) -> (List[dict], dict): | ||||||
|         # noinspection PyUnresolvedReferences |         # noinspection PyUnresolvedReferences | ||||||
|         info_dict = super().do_additional_step() |         super_reward_info = super().do_additional_step() | ||||||
|         for key, val in self._dest_spawn_timer.items(): |         for key, val in self._dest_spawn_timer.items(): | ||||||
|             self._dest_spawn_timer[key] = min(self.dest_prop.spawn_frequency, self._dest_spawn_timer[key] + 1) |             self._dest_spawn_timer[key] = min(self.dest_prop.spawn_frequency, self._dest_spawn_timer[key] + 1) | ||||||
|         for dest in list(self[c.DESTINATION].values()): |         for dest in list(self[c.DEST].values()): | ||||||
|             if dest.is_considered_reached: |             if dest.is_considered_reached: | ||||||
|                 self[c.REACHEDDESTINATION].register_item(dest) |                 dest.change_register(self[c.DEST]) | ||||||
|                 self[c.DESTINATION].delete_env_object(dest) |  | ||||||
|                 self._dest_spawn_timer[dest.name] = 0 |                 self._dest_spawn_timer[dest.name] = 0 | ||||||
|                 self.print(f'{dest.name} is reached now, removing...') |                 self.print(f'{dest.name} is reached now, removing...') | ||||||
|             else: |             else: | ||||||
| @@ -218,41 +242,29 @@ class DestFactory(BaseFactory): | |||||||
|                         dest.leave(agent) |                         dest.leave(agent) | ||||||
|                         self.print(f'{agent.name} left the destination early.') |                         self.print(f'{agent.name} left the destination early.') | ||||||
|         self.trigger_destination_spawn() |         self.trigger_destination_spawn() | ||||||
|         return info_dict |         return super_reward_info | ||||||
|  |  | ||||||
|     def _additional_observations(self) -> Dict[Constants, np.typing.ArrayLike]: |     def _additional_observations(self) -> Dict[str, np.typing.ArrayLike]: | ||||||
|         additional_observations = super()._additional_observations() |         additional_observations = super()._additional_observations() | ||||||
|         additional_observations.update({c.DESTINATION: self[c.DESTINATION].as_array()}) |         additional_observations.update({c.DEST: self[c.DEST].as_array()}) | ||||||
|         return additional_observations |         return additional_observations | ||||||
|  |  | ||||||
|     def calculate_additional_reward(self, agent: Agent) -> (int, dict): |     def additional_per_agent_reward(self, agent: Agent) -> Dict[str, dict]: | ||||||
|         # noinspection PyUnresolvedReferences |         # noinspection PyUnresolvedReferences | ||||||
|         reward, info_dict = super().calculate_additional_reward(agent) |         reward_event_dict = super().additional_per_agent_reward(agent) | ||||||
|         if h.EnvActions.WAIT_ON_DEST == agent.temp_action: |         if len(self[c.DEST_REACHED]): | ||||||
|             if agent.temp_valid: |             for reached_dest in list(self[c.DEST_REACHED]): | ||||||
|                 info_dict.update({f'{agent.name}_waiting_at_dest': 1}) |  | ||||||
|                 info_dict.update(agent_waiting_at_dest=1) |  | ||||||
|                 self.print(f'{agent.name} just waited at {agent.pos}') |  | ||||||
|                 reward += 0.1 |  | ||||||
|             else: |  | ||||||
|                 info_dict.update({f'{agent.name}_tried_failed': 1}) |  | ||||||
|                 info_dict.update(agent_waiting_failed=1) |  | ||||||
|                 self.print(f'{agent.name} just tried to wait wait at {agent.pos} but failed') |  | ||||||
|                 reward -= 0.1 |  | ||||||
|         if len(self[c.REACHEDDESTINATION]): |  | ||||||
|             for reached_dest in list(self[c.REACHEDDESTINATION]): |  | ||||||
|                 if agent.pos == reached_dest.pos: |                 if agent.pos == reached_dest.pos: | ||||||
|                     info_dict.update({f'{agent.name}_reached_destination': 1}) |  | ||||||
|                     info_dict.update(agent_reached_destination=1) |  | ||||||
|                     self.print(f'{agent.name} just reached destination at {agent.pos}') |                     self.print(f'{agent.name} just reached destination at {agent.pos}') | ||||||
|                     reward += 0.5 |                     self[c.DEST_REACHED].delete_env_object(reached_dest) | ||||||
|                     self[c.REACHEDDESTINATION].delete_env_object(reached_dest) |                     info_dict = {f'{agent.name}_{c.DEST_REACHED}': 1} | ||||||
|         return reward, info_dict |                     reward_event_dict.update({c.DEST_REACHED: {'reward': r.DEST_REACHED, 'info': info_dict}}) | ||||||
|  |         return reward_event_dict | ||||||
|  |  | ||||||
|     def render_additional_assets(self, mode='human'): |     def render_additional_assets(self, mode='human'): | ||||||
|         # noinspection PyUnresolvedReferences |         # noinspection PyUnresolvedReferences | ||||||
|         additional_assets = super().render_additional_assets() |         additional_assets = super().render_additional_assets() | ||||||
|         destinations = [RenderEntity(c.DESTINATION.value, dest.pos) for dest in self[c.DESTINATION]] |         destinations = [RenderEntity(c.DEST, dest.pos) for dest in self[c.DEST]] | ||||||
|         additional_assets.extend(destinations) |         additional_assets.extend(destinations) | ||||||
|         return additional_assets |         return additional_assets | ||||||
|  |  | ||||||
|   | |||||||
| @@ -8,6 +8,7 @@ import numpy as np | |||||||
| # from algorithms.TSP_dirt_agent import TSPDirtAgent | # from algorithms.TSP_dirt_agent import TSPDirtAgent | ||||||
| from environments.helpers import Constants as BaseConstants | from environments.helpers import Constants as BaseConstants | ||||||
| from environments.helpers import EnvActions as BaseActions | from environments.helpers import EnvActions as BaseActions | ||||||
|  | from environments.helpers import Rewards as BaseRewards | ||||||
|  |  | ||||||
| from environments.factory.base.base_factory import BaseFactory | from environments.factory.base.base_factory import BaseFactory | ||||||
| from environments.factory.base.objects import Agent, Action, Entity, Tile | from environments.factory.base.objects import Agent, Action, Entity, Tile | ||||||
| @@ -21,8 +22,14 @@ class Constants(BaseConstants): | |||||||
|     DIRT = 'Dirt' |     DIRT = 'Dirt' | ||||||
|  |  | ||||||
|  |  | ||||||
| class EnvActions(BaseActions): | class Actions(BaseActions): | ||||||
|     CLEAN_UP = 'clean_up' |     CLEAN_UP = 'do_cleanup_action' | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Rewards(BaseRewards): | ||||||
|  |     CLEAN_UP_VALID          = 0.5 | ||||||
|  |     CLEAN_UP_FAIL          = -0.1 | ||||||
|  |     CLEAN_UP_LAST_PIECE     = 4.5 | ||||||
|  |  | ||||||
|  |  | ||||||
| class DirtProperties(NamedTuple): | class DirtProperties(NamedTuple): | ||||||
| @@ -41,10 +48,6 @@ class DirtProperties(NamedTuple): | |||||||
|  |  | ||||||
| class Dirt(Entity): | class Dirt(Entity): | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def can_collide(self): |  | ||||||
|         return False |  | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def amount(self): |     def amount(self): | ||||||
|         return self._amount |         return self._amount | ||||||
| @@ -116,6 +119,8 @@ def entropy(x): | |||||||
|  |  | ||||||
|  |  | ||||||
| c = Constants | c = Constants | ||||||
|  | a = Actions | ||||||
|  | r = Rewards | ||||||
|  |  | ||||||
|  |  | ||||||
| # noinspection PyAttributeOutsideInit, PyAbstractClass | # noinspection PyAttributeOutsideInit, PyAbstractClass | ||||||
| @@ -125,7 +130,7 @@ class DirtFactory(BaseFactory): | |||||||
|     def additional_actions(self) -> Union[Action, List[Action]]: |     def additional_actions(self) -> Union[Action, List[Action]]: | ||||||
|         super_actions = super().additional_actions |         super_actions = super().additional_actions | ||||||
|         if self.dirt_prop.agent_can_interact: |         if self.dirt_prop.agent_can_interact: | ||||||
|             super_actions.append(Action(str_ident=EnvActions.CLEAN_UP)) |             super_actions.append(Action(str_ident=a.CLEAN_UP)) | ||||||
|         return super_actions |         return super_actions | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
| @@ -151,7 +156,7 @@ class DirtFactory(BaseFactory): | |||||||
|         additional_assets.extend(dirt) |         additional_assets.extend(dirt) | ||||||
|         return additional_assets |         return additional_assets | ||||||
|  |  | ||||||
|     def clean_up(self, agent: Agent) -> c: |     def do_cleanup_action(self, agent: Agent) -> (dict, dict): | ||||||
|         if dirt := self[c.DIRT].by_pos(agent.pos): |         if dirt := self[c.DIRT].by_pos(agent.pos): | ||||||
|             new_dirt_amount = dirt.amount - self.dirt_prop.clean_amount |             new_dirt_amount = dirt.amount - self.dirt_prop.clean_amount | ||||||
|  |  | ||||||
| @@ -159,9 +164,21 @@ class DirtFactory(BaseFactory): | |||||||
|                 self[c.DIRT].delete_env_object(dirt) |                 self[c.DIRT].delete_env_object(dirt) | ||||||
|             else: |             else: | ||||||
|                 dirt.set_new_amount(max(new_dirt_amount, c.FREE_CELL.value)) |                 dirt.set_new_amount(max(new_dirt_amount, c.FREE_CELL.value)) | ||||||
|             return c.VALID |             valid = c.VALID | ||||||
|  |             self.print(f'{agent.name} did just clean up some dirt at {agent.pos}.') | ||||||
|  |             info_dict = {f'{agent.name}_{a.CLEAN_UP}_VALID': 1} | ||||||
|  |             reward = r.CLEAN_UP_VALID | ||||||
|         else: |         else: | ||||||
|             return c.NOT_VALID |             valid = c.NOT_VALID | ||||||
|  |             self.print(f'{agent.name} just tried to clean up some dirt at {agent.pos}, but failed.') | ||||||
|  |             info_dict = {f'{agent.name}_{a.CLEAN_UP}_FAIL': 1} | ||||||
|  |             reward = r.CLEAN_UP_FAIL | ||||||
|  |  | ||||||
|  |         if valid and self.dirt_prop.done_when_clean and (len(self[c.DIRT]) == 0): | ||||||
|  |             reward += r.CLEAN_UP_LAST_PIECE | ||||||
|  |             self.print(f'{agent.name} picked up the last piece of dirt!') | ||||||
|  |             info_dict = {f'{agent.name}_{a.CLEAN_UP}_LAST_PIECE': 1} | ||||||
|  |         return valid, dict(value=reward, reason=a.CLEAN_UP, info=info_dict) | ||||||
|  |  | ||||||
|     def trigger_dirt_spawn(self, initial_spawn=False): |     def trigger_dirt_spawn(self, initial_spawn=False): | ||||||
|         dirt_rng = self._dirt_rng |         dirt_rng = self._dirt_rng | ||||||
| @@ -177,8 +194,8 @@ class DirtFactory(BaseFactory): | |||||||
|         n_dirt_tiles = max(0, int(new_spawn * len(free_for_dirt))) |         n_dirt_tiles = max(0, int(new_spawn * len(free_for_dirt))) | ||||||
|         self[c.DIRT].spawn_dirt(free_for_dirt[:n_dirt_tiles]) |         self[c.DIRT].spawn_dirt(free_for_dirt[:n_dirt_tiles]) | ||||||
|  |  | ||||||
|     def do_additional_step(self) -> dict: |     def do_additional_step(self) -> (List[dict], dict): | ||||||
|         info_dict = super().do_additional_step() |         super_reward_info = super().do_additional_step() | ||||||
|         if smear_amount := self.dirt_prop.dirt_smear_amount: |         if smear_amount := self.dirt_prop.dirt_smear_amount: | ||||||
|             for agent in self[c.AGENT]: |             for agent in self[c.AGENT]: | ||||||
|                 if agent.temp_valid and agent.last_pos != c.NO_POS: |                 if agent.temp_valid and agent.last_pos != c.NO_POS: | ||||||
| @@ -199,42 +216,44 @@ class DirtFactory(BaseFactory): | |||||||
|             self._next_dirt_spawn = self.dirt_prop.spawn_frequency |             self._next_dirt_spawn = self.dirt_prop.spawn_frequency | ||||||
|         else: |         else: | ||||||
|             self._next_dirt_spawn -= 1 |             self._next_dirt_spawn -= 1 | ||||||
|         return info_dict |         return super_reward_info | ||||||
|  |  | ||||||
|     def do_additional_actions(self, agent: Agent, action: Action) -> Union[None, c]: |     def do_additional_actions(self, agent: Agent, action: Action) -> (dict, dict): | ||||||
|         valid = super().do_additional_actions(agent, action) |         action_result = super().do_additional_actions(agent, action) | ||||||
|         if valid is None: |         if action_result is None: | ||||||
|             if action == EnvActions.CLEAN_UP: |             if action == a.CLEAN_UP: | ||||||
|                 if self.dirt_prop.agent_can_interact: |                 return self.do_cleanup_action(agent) | ||||||
|                     valid = self.clean_up(agent) |  | ||||||
|                     return valid |  | ||||||
|                 else: |  | ||||||
|                     return c.NOT_VALID |  | ||||||
|             else: |             else: | ||||||
|                 return None |                 return None | ||||||
|         else: |         else: | ||||||
|             return valid |             return action_result | ||||||
|  |  | ||||||
|     def do_additional_reset(self) -> None: |     def do_additional_reset(self) -> None: | ||||||
|         super().do_additional_reset() |         super().do_additional_reset() | ||||||
|         self.trigger_dirt_spawn(initial_spawn=True) |         self.trigger_dirt_spawn(initial_spawn=True) | ||||||
|         self._next_dirt_spawn = self.dirt_prop.spawn_frequency if self.dirt_prop.spawn_frequency else -1 |         self._next_dirt_spawn = self.dirt_prop.spawn_frequency if self.dirt_prop.spawn_frequency else -1 | ||||||
|  |  | ||||||
|     def check_additional_done(self): |     def check_additional_done(self) -> (bool, dict): | ||||||
|         super_done = super().check_additional_done() |         super_done, super_dict = super().check_additional_done() | ||||||
|         done = self.dirt_prop.done_when_clean and (len(self[c.DIRT]) == 0) |         if self.dirt_prop.done_when_clean: | ||||||
|         return super_done or done |             if all_cleaned := len(self[c.DIRT]) == 0: | ||||||
|  |                 super_dict.update(ALL_CLEAN_DONE=all_cleaned) | ||||||
|  |                 return all_cleaned, super_dict | ||||||
|  |         return super_done, super_dict | ||||||
|  |  | ||||||
|     def _additional_observations(self) -> Dict[str, np.typing.ArrayLike]: |     def _additional_observations(self) -> Dict[str, np.typing.ArrayLike]: | ||||||
|         additional_observations = super()._additional_observations() |         additional_observations = super()._additional_observations() | ||||||
|         additional_observations.update({c.DIRT: self[c.DIRT].as_array()}) |         additional_observations.update({c.DIRT: self[c.DIRT].as_array()}) | ||||||
|         return additional_observations |         return additional_observations | ||||||
|  |  | ||||||
|     def calculate_additional_reward(self, agent: Agent) -> (int, dict): |     def gather_additional_info(self, agent: Agent) -> dict: | ||||||
|         reward, info_dict = super().calculate_additional_reward(agent) |         event_reward_dict = super().additional_per_agent_reward(agent) | ||||||
|  |         info_dict = dict() | ||||||
|  |  | ||||||
|         dirt = [dirt.amount for dirt in self[c.DIRT]] |         dirt = [dirt.amount for dirt in self[c.DIRT]] | ||||||
|         current_dirt_amount = sum(dirt) |         current_dirt_amount = sum(dirt) | ||||||
|         dirty_tile_count = len(dirt) |         dirty_tile_count = len(dirt) | ||||||
|  |  | ||||||
|         # if dirty_tile_count: |         # if dirty_tile_count: | ||||||
|         #    dirt_distribution_score = entropy(softmax(np.asarray(dirt)) / dirty_tile_count) |         #    dirt_distribution_score = entropy(softmax(np.asarray(dirt)) / dirty_tile_count) | ||||||
|         # else: |         # else: | ||||||
| @@ -242,33 +261,14 @@ class DirtFactory(BaseFactory): | |||||||
|  |  | ||||||
|         info_dict.update(dirt_amount=current_dirt_amount) |         info_dict.update(dirt_amount=current_dirt_amount) | ||||||
|         info_dict.update(dirty_tile_count=dirty_tile_count) |         info_dict.update(dirty_tile_count=dirty_tile_count) | ||||||
|         # info_dict.update(dirt_distribution_score=dirt_distribution_score) |  | ||||||
|  |  | ||||||
|         if agent.temp_action == EnvActions.CLEAN_UP: |         event_reward_dict.update({'info': info_dict}) | ||||||
|             if agent.temp_valid: |         return event_reward_dict | ||||||
|                 # Reward if pickup succeds, |  | ||||||
|                 #  0.5 on every pickup |  | ||||||
|                 reward += 0.5 |  | ||||||
|                 info_dict.update(dirt_cleaned=1) |  | ||||||
|                 if self.dirt_prop.done_when_clean and (len(self[c.DIRT]) == 0): |  | ||||||
|                     #  0.5 additional reward for the very last pickup |  | ||||||
|                     reward += 4.5 |  | ||||||
|                     info_dict.update(done_clean=1) |  | ||||||
|                 self.print(f'{agent.name} did just clean up some dirt at {agent.pos}.') |  | ||||||
|             else: |  | ||||||
|                 reward -= 0.01 |  | ||||||
|                 self.print(f'{agent.name} just tried to clean up some dirt at {agent.pos}, but failed.') |  | ||||||
|                 info_dict.update({f'{agent.name}_failed_dirt_cleanup': 1}) |  | ||||||
|                 info_dict.update(failed_dirt_clean=1) |  | ||||||
|  |  | ||||||
|         # Potential based rewards -> |  | ||||||
|         #  track the last reward , minus the current reward = potential |  | ||||||
|         return reward, info_dict |  | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|     from environments.utility_classes import AgentRenderOptions as aro |     from environments.utility_classes import AgentRenderOptions as aro | ||||||
|     render = True |     render = False | ||||||
|  |  | ||||||
|     dirt_props = DirtProperties( |     dirt_props = DirtProperties( | ||||||
|         initial_dirt_ratio=0.35, |         initial_dirt_ratio=0.35, | ||||||
| @@ -289,14 +289,15 @@ if __name__ == '__main__': | |||||||
|     move_props = {'allow_square_movement': True, |     move_props = {'allow_square_movement': True, | ||||||
|                   'allow_diagonal_movement': False, |                   'allow_diagonal_movement': False, | ||||||
|                   'allow_no_op': False} |                   'allow_no_op': False} | ||||||
|  |     import time | ||||||
|     global_timings = [] |     global_timings = [] | ||||||
|     for i in range(20): |     for i in range(10): | ||||||
|  |  | ||||||
|         factory = DirtFactory(n_agents=2, done_at_collision=False, |         factory = DirtFactory(n_agents=2, done_at_collision=False, | ||||||
|                               level_name='rooms', max_steps=1000, |                               level_name='rooms', max_steps=1000, | ||||||
|                               doors_have_area=False, |                               doors_have_area=False, | ||||||
|                               obs_prop=obs_props, parse_doors=True, |                               obs_prop=obs_props, parse_doors=True, | ||||||
|                               record_episodes=True, verbose=True, |                               verbose=False, | ||||||
|                               mv_prop=move_props, dirt_prop=dirt_props, |                               mv_prop=move_props, dirt_prop=dirt_props, | ||||||
|                               # inject_agents=[TSPDirtAgent], |                               # inject_agents=[TSPDirtAgent], | ||||||
|                               ) |                               ) | ||||||
| @@ -307,7 +308,6 @@ if __name__ == '__main__': | |||||||
|         obs_space = factory.observation_space |         obs_space = factory.observation_space | ||||||
|         obs_space_named = factory.named_observation_space |         obs_space_named = factory.named_observation_space | ||||||
|         times = [] |         times = [] | ||||||
|         import time |  | ||||||
|         for epoch in range(10): |         for epoch in range(10): | ||||||
|             start_time = time.time() |             start_time = time.time() | ||||||
|             random_actions = [[random.randint(0, n_actions) for _ |             random_actions = [[random.randint(0, n_actions) for _ | ||||||
| @@ -318,18 +318,19 @@ if __name__ == '__main__': | |||||||
|                 factory.render() |                 factory.render() | ||||||
|             # tsp_agent = factory.get_injected_agents()[0] |             # tsp_agent = factory.get_injected_agents()[0] | ||||||
|  |  | ||||||
|             r = 0 |             rwrd = 0 | ||||||
|             for agent_i_action in random_actions: |             for agent_i_action in random_actions: | ||||||
|                 env_state, step_r, done_bool, info_obj = factory.step(agent_i_action) |                 env_state, step_rwrd, done_bool, info_obj = factory.step(agent_i_action) | ||||||
|                 r += step_r |                 rwrd += step_rwrd | ||||||
|                 if render: |                 if render: | ||||||
|                     factory.render() |                     factory.render() | ||||||
|                 if done_bool: |                 if done_bool: | ||||||
|                     break |                     break | ||||||
|             times.append(time.time() - start_time) |             times.append(time.time() - start_time) | ||||||
|             # print(f'Factory run {epoch} done, reward is:\n    {r}') |             # print(f'Factory run {epoch} done, reward is:\n    {r}') | ||||||
|         print('Time Taken: ', sum(times) / 10) |         print('Mean Time Taken: ', sum(times) / 10) | ||||||
|         global_timings.append(sum(times) / 10) |         global_timings.extend(times) | ||||||
|     print('Time Taken: ', sum(global_timings[10:]) / 10) |     print('Mean Time Taken: ', sum(global_timings) / len(global_timings)) | ||||||
|  |     print('Median Time Taken: ', global_timings[len(global_timings)//2]) | ||||||
|  |  | ||||||
| pass | pass | ||||||
|   | |||||||
| @@ -7,9 +7,10 @@ import random | |||||||
| from environments.factory.base.base_factory import BaseFactory | from environments.factory.base.base_factory import BaseFactory | ||||||
| from environments.helpers import Constants as BaseConstants | from environments.helpers import Constants as BaseConstants | ||||||
| from environments.helpers import EnvActions as BaseActions | from environments.helpers import EnvActions as BaseActions | ||||||
|  | from environments.helpers import Rewards as BaseRewards | ||||||
| from environments import helpers as h | from environments import helpers as h | ||||||
| from environments.factory.base.objects import Agent, Entity, Action, Tile | from environments.factory.base.objects import Agent, Entity, Action, Tile | ||||||
| from environments.factory.base.registers import Entities, EntityRegister, BoundRegisterMixin, ObjectRegister | from environments.factory.base.registers import Entities, EntityRegister, BoundEnvObjRegister, ObjectRegister | ||||||
|  |  | ||||||
| from environments.factory.base.renderer import RenderEntity | from environments.factory.base.renderer import RenderEntity | ||||||
|  |  | ||||||
| @@ -23,10 +24,17 @@ class Constants(BaseConstants): | |||||||
|     DROP_OFF            = 'Drop_Off' |     DROP_OFF            = 'Drop_Off' | ||||||
|  |  | ||||||
|  |  | ||||||
| class EnvActions(BaseActions): | class Actions(BaseActions): | ||||||
|     ITEM_ACTION     = 'item_action' |     ITEM_ACTION     = 'item_action' | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Rewards(BaseRewards): | ||||||
|  |     DROP_OFF_VALID = 0.1 | ||||||
|  |     DROP_OFF_FAIL = -0.1 | ||||||
|  |     PICK_UP_FAIL  = -0.1 | ||||||
|  |     PICK_UP_VALID  = 0.1 | ||||||
|  |  | ||||||
|  |  | ||||||
| class Item(Entity): | class Item(Entity): | ||||||
|  |  | ||||||
|     def __init__(self, *args, **kwargs): |     def __init__(self, *args, **kwargs): | ||||||
| @@ -37,10 +45,6 @@ class Item(Entity): | |||||||
|     def auto_despawn(self): |     def auto_despawn(self): | ||||||
|         return self._auto_despawn |         return self._auto_despawn | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def can_collide(self): |  | ||||||
|         return False |  | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def encoding(self): |     def encoding(self): | ||||||
|         # Edit this if you want items to be drawn in the ops differently |         # Edit this if you want items to be drawn in the ops differently | ||||||
| @@ -68,7 +72,7 @@ class ItemRegister(EntityRegister): | |||||||
|             del self[item] |             del self[item] | ||||||
|  |  | ||||||
|  |  | ||||||
| class Inventory(BoundRegisterMixin): | class Inventory(BoundEnvObjRegister): | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def name(self): |     def name(self): | ||||||
| @@ -131,10 +135,6 @@ class Inventories(ObjectRegister): | |||||||
|  |  | ||||||
| class DropOffLocation(Entity): | class DropOffLocation(Entity): | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def can_collide(self): |  | ||||||
|         return False |  | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def encoding(self): |     def encoding(self): | ||||||
|         return Constants.ITEM_DROP_OFF |         return Constants.ITEM_DROP_OFF | ||||||
| @@ -176,7 +176,8 @@ class ItemProperties(NamedTuple): | |||||||
|  |  | ||||||
|  |  | ||||||
| c = Constants | c = Constants | ||||||
| a = EnvActions | a = Actions | ||||||
|  | r = Rewards | ||||||
|  |  | ||||||
|  |  | ||||||
| # noinspection PyAttributeOutsideInit, PyAbstractClass | # noinspection PyAttributeOutsideInit, PyAbstractClass | ||||||
| @@ -230,37 +231,43 @@ class ItemFactory(BaseFactory): | |||||||
|         additional_observations.update({c.DROP_OFF: self[c.DROP_OFF].as_array()}) |         additional_observations.update({c.DROP_OFF: self[c.DROP_OFF].as_array()}) | ||||||
|         return additional_observations |         return additional_observations | ||||||
|  |  | ||||||
|     def do_item_action(self, agent: Agent): |     def do_item_action(self, agent: Agent) -> (dict, dict): | ||||||
|         inventory = self[c.INVENTORY].by_entity(agent) |         inventory = self[c.INVENTORY].by_entity(agent) | ||||||
|         if drop_off := self[c.DROP_OFF].by_pos(agent.pos): |         if drop_off := self[c.DROP_OFF].by_pos(agent.pos): | ||||||
|             if inventory: |             if inventory: | ||||||
|                 valid = drop_off.place_item(inventory.pop()) |                 valid = drop_off.place_item(inventory.pop()) | ||||||
|                 return valid |  | ||||||
|             else: |             else: | ||||||
|                 return c.NOT_VALID |                 valid = c.NOT_VALID | ||||||
|  |             if valid: | ||||||
|  |                 self.print(f'{agent.name} just dropped of an item at {drop_off.pos}.') | ||||||
|  |                 info_dict = {f'{agent.name}_DROPOFF_VALID': 1} | ||||||
|  |             else: | ||||||
|  |                 self.print(f'{agent.name} just tried to drop off at {agent.pos}, but failed.') | ||||||
|  |                 info_dict = {f'{agent.name}_DROPOFF_FAIL': 1} | ||||||
|  |             reward = dict(value=r.DROP_OFF_VALID if valid else r.DROP_OFF_FAIL, reason=a.ITEM_ACTION, info=info_dict) | ||||||
|  |             return valid, reward | ||||||
|         elif item := self[c.ITEM].by_pos(agent.pos): |         elif item := self[c.ITEM].by_pos(agent.pos): | ||||||
|             try: |             item.change_register(inventory) | ||||||
|                 inventory.register_item(item) |             item.set_tile_to(self._NO_POS_TILE) | ||||||
|                 item.change_register(inventory) |             self.print(f'{agent.name} just picked up an item at {agent.pos}') | ||||||
|                 self[c.ITEM].delete_env_object(item) |             info_dict = {f'{agent.name}_{a.ITEM_ACTION}_VALID': 1} | ||||||
|                 item.set_tile_to(self._NO_POS_TILE) |             return c.VALID, dict(value=r.PICK_UP_VALID, reason=a.ITEM_ACTION, info=info_dict) | ||||||
|                 return c.VALID |  | ||||||
|             except RuntimeError: |  | ||||||
|                 return c.NOT_VALID |  | ||||||
|         else: |         else: | ||||||
|             return c.NOT_VALID |             self.print(f'{agent.name} just tried to pick up an item at {agent.pos}, but failed.') | ||||||
|  |             info_dict = {f'{agent.name}_{a.ITEM_ACTION}_FAIL': 1} | ||||||
|  |             return c.NOT_VALID, dict(value=r.PICK_UP_FAIL, reason=a.ITEM_ACTION, info=info_dict) | ||||||
|  |  | ||||||
|     def do_additional_actions(self, agent: Agent, action: Action) -> Union[None, c]: |     def do_additional_actions(self, agent: Agent, action: Action) -> (dict, dict): | ||||||
|         # noinspection PyUnresolvedReferences |         # noinspection PyUnresolvedReferences | ||||||
|         valid = super().do_additional_actions(agent, action) |         action_result = super().do_additional_actions(agent, action) | ||||||
|         if valid is None: |         if action_result is None: | ||||||
|             if action == a.ITEM_ACTION: |             if action == a.ITEM_ACTION: | ||||||
|                 valid = self.do_item_action(agent) |                 action_result = self.do_item_action(agent) | ||||||
|                 return valid |                 return action_result | ||||||
|             else: |             else: | ||||||
|                 return None |                 return None | ||||||
|         else: |         else: | ||||||
|             return valid |             return action_result | ||||||
|  |  | ||||||
|     def do_additional_reset(self) -> None: |     def do_additional_reset(self) -> None: | ||||||
|         # noinspection PyUnresolvedReferences |         # noinspection PyUnresolvedReferences | ||||||
| @@ -277,9 +284,9 @@ class ItemFactory(BaseFactory): | |||||||
|         else: |         else: | ||||||
|             self.print('No Items are spawning, limit is reached.') |             self.print('No Items are spawning, limit is reached.') | ||||||
|  |  | ||||||
|     def do_additional_step(self) -> dict: |     def do_additional_step(self) -> (List[dict], dict): | ||||||
|         # noinspection PyUnresolvedReferences |         # noinspection PyUnresolvedReferences | ||||||
|         info_dict = super().do_additional_step() |         super_reward_info = super().do_additional_step() | ||||||
|         for item in list(self[c.ITEM].values()): |         for item in list(self[c.ITEM].values()): | ||||||
|             if item.auto_despawn >= 1: |             if item.auto_despawn >= 1: | ||||||
|                 item.set_auto_despawn(item.auto_despawn-1) |                 item.set_auto_despawn(item.auto_despawn-1) | ||||||
| @@ -292,35 +299,7 @@ class ItemFactory(BaseFactory): | |||||||
|             self.trigger_item_spawn() |             self.trigger_item_spawn() | ||||||
|         else: |         else: | ||||||
|             self._next_item_spawn = max(0, self._next_item_spawn-1) |             self._next_item_spawn = max(0, self._next_item_spawn-1) | ||||||
|         return info_dict |         return super_reward_info | ||||||
|  |  | ||||||
|     def calculate_additional_reward(self, agent: Agent) -> (int, dict): |  | ||||||
|         # noinspection PyUnresolvedReferences |  | ||||||
|         reward, info_dict = super().calculate_additional_reward(agent) |  | ||||||
|         if a.ITEM_ACTION == agent.temp_action: |  | ||||||
|             if agent.temp_valid: |  | ||||||
|                 if drop_off := self[c.DROP_OFF].by_pos(agent.pos): |  | ||||||
|                     info_dict.update({f'{agent.name}_item_drop_off': 1}) |  | ||||||
|                     info_dict.update(item_drop_off=1) |  | ||||||
|                     self.print(f'{agent.name} just dropped of an item at {drop_off.pos}.') |  | ||||||
|                     reward += 1 |  | ||||||
|                 else: |  | ||||||
|                     info_dict.update({f'{agent.name}_item_pickup': 1}) |  | ||||||
|                     info_dict.update(item_pickup=1) |  | ||||||
|                     self.print(f'{agent.name} just picked up an item at {agent.pos}') |  | ||||||
|                     reward += 0.2 |  | ||||||
|             else: |  | ||||||
|                 if self[c.DROP_OFF].by_pos(agent.pos): |  | ||||||
|                     info_dict.update({f'{agent.name}_failed_drop_off': 1}) |  | ||||||
|                     info_dict.update(failed_drop_off=1) |  | ||||||
|                     self.print(f'{agent.name} just tried to drop off at {agent.pos}, but failed.') |  | ||||||
|                     reward -= 0.1 |  | ||||||
|                 else: |  | ||||||
|                     info_dict.update({f'{agent.name}_failed_item_action': 1}) |  | ||||||
|                     info_dict.update(failed_pick_up=1) |  | ||||||
|                     self.print(f'{agent.name} just tried to pick up an item at {agent.pos}, but failed.') |  | ||||||
|                     reward -= 0.1 |  | ||||||
|         return reward, info_dict |  | ||||||
|  |  | ||||||
|     def render_additional_assets(self, mode='human'): |     def render_additional_assets(self, mode='human'): | ||||||
|         # noinspection PyUnresolvedReferences |         # noinspection PyUnresolvedReferences | ||||||
| @@ -335,9 +314,9 @@ class ItemFactory(BaseFactory): | |||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|     from environments.utility_classes import AgentRenderOptions as aro, ObservationProperties |     from environments.utility_classes import AgentRenderOptions as aro, ObservationProperties | ||||||
|  |  | ||||||
|     render = True |     render = False | ||||||
|  |  | ||||||
|     item_probs = ItemProperties(n_items=30) |     item_probs = ItemProperties(n_items=30, n_drop_off_locations=6) | ||||||
|  |  | ||||||
|     obs_props = ObservationProperties(render_agents=aro.SEPERATE, omit_agent_self=True, pomdp_r=2) |     obs_props = ObservationProperties(render_agents=aro.SEPERATE, omit_agent_self=True, pomdp_r=2) | ||||||
|  |  | ||||||
| @@ -345,7 +324,7 @@ if __name__ == '__main__': | |||||||
|                   'allow_diagonal_movement': True, |                   'allow_diagonal_movement': True, | ||||||
|                   'allow_no_op': False} |                   'allow_no_op': False} | ||||||
|  |  | ||||||
|     factory = ItemFactory(n_agents=2, done_at_collision=False, |     factory = ItemFactory(n_agents=6, done_at_collision=False, | ||||||
|                           level_name='rooms', max_steps=400, |                           level_name='rooms', max_steps=400, | ||||||
|                           obs_prop=obs_props, parse_doors=True, |                           obs_prop=obs_props, parse_doors=True, | ||||||
|                           record_episodes=True, verbose=True, |                           record_episodes=True, verbose=True, | ||||||
|   | |||||||
| @@ -1,6 +1,6 @@ | |||||||
| import itertools | import itertools | ||||||
| from collections import defaultdict | from collections import defaultdict | ||||||
| from typing import Tuple, Union, Dict, List | from typing import Tuple, Union, Dict, List, NamedTuple | ||||||
|  |  | ||||||
| import networkx as nx | import networkx as nx | ||||||
| import numpy as np | import numpy as np | ||||||
| @@ -38,37 +38,27 @@ class Constants: | |||||||
|     OPEN_DOOR           = 'open' |     OPEN_DOOR           = 'open' | ||||||
|  |  | ||||||
|     ACTION              = 'action' |     ACTION              = 'action' | ||||||
|     COLLISIONS          = 'collision' |     COLLISION          = 'collision' | ||||||
|     VALID               = 'valid' |     VALID               = True | ||||||
|     NOT_VALID           = 'not_valid' |     NOT_VALID           = False | ||||||
|  |  | ||||||
|     # Battery Env |  | ||||||
|     CHARGE_POD          = 'Charge_Pod' |  | ||||||
|     BATTERIES           = 'BATTERIES' |  | ||||||
|  |  | ||||||
|     # Destination Env |  | ||||||
|     DESTINATION         = 'Destination' |  | ||||||
|     REACHEDDESTINATION  = 'ReachedDestination' |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class EnvActions: | class EnvActions: | ||||||
|     # Movements |     # Movements | ||||||
|     NORTH = 'north' |     NORTH           = 'north' | ||||||
|     EAST = 'east' |     EAST            = 'east' | ||||||
|     SOUTH = 'south' |     SOUTH           = 'south' | ||||||
|     WEST = 'west' |     WEST            = 'west' | ||||||
|     NORTHEAST = 'north_east' |     NORTHEAST       = 'north_east' | ||||||
|     SOUTHEAST = 'south_east' |     SOUTHEAST       = 'south_east' | ||||||
|     SOUTHWEST = 'south_west' |     SOUTHWEST       = 'south_west' | ||||||
|     NORTHWEST = 'north_west' |     NORTHWEST       = 'north_west' | ||||||
|  |  | ||||||
|     # Other |     # Other | ||||||
|     NOOP = 'no_op' |     # MOVE            = 'move' | ||||||
|  |     NOOP            = 'no_op' | ||||||
|     USE_DOOR        = 'use_door' |     USE_DOOR        = 'use_door' | ||||||
|  |  | ||||||
|     CHARGE          = 'charge' |  | ||||||
|     WAIT_ON_DEST    = 'wait' |  | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     def is_move(cls, other): |     def is_move(cls, other): | ||||||
|         return any([other == direction for direction in cls.movement_actions()]) |         return any([other == direction for direction in cls.movement_actions()]) | ||||||
| @@ -86,8 +76,19 @@ class EnvActions: | |||||||
|         return list(itertools.chain(cls.square_move(), cls.diagonal_move())) |         return list(itertools.chain(cls.square_move(), cls.diagonal_move())) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Rewards: | ||||||
|  |  | ||||||
|  |     MOVEMENTS_VALID = -0.001 | ||||||
|  |     MOVEMENTS_FAIL  = -0.001 | ||||||
|  |     NOOP = -0.1 | ||||||
|  |     USE_DOOR_VALID = -0.001 | ||||||
|  |     USE_DOOR_FAIL  = -0.001 | ||||||
|  |     COLLISION      = -1 | ||||||
|  |  | ||||||
|  |  | ||||||
| m = EnvActions | m = EnvActions | ||||||
| c = Constants | c = Constants | ||||||
|  | r = Rewards | ||||||
|  |  | ||||||
| ACTIONMAP = defaultdict(lambda: (0, 0), {m.NORTH: (-1, 0), m.NORTHEAST: (-1, +1), | ACTIONMAP = defaultdict(lambda: (0, 0), {m.NORTH: (-1, 0), m.NORTHEAST: (-1, +1), | ||||||
|                                          m.EAST: (0, 1),   m.SOUTHEAST: (1, 1), |                                          m.EAST: (0, 1),   m.SOUTHEAST: (1, 1), | ||||||
| @@ -184,15 +185,20 @@ def asset_str(agent): | |||||||
|     # What does this abonimation do? |     # What does this abonimation do? | ||||||
|     # if any([x is None for x in [cls._slices[j] for j in agent.collisions]]): |     # if any([x is None for x in [cls._slices[j] for j in agent.collisions]]): | ||||||
|     #     print('error') |     #     print('error') | ||||||
|     col_names = [x.name for x in agent.temp_collisions] |     if step_result := agent.step_result: | ||||||
|     if any(c.AGENT in name for name in col_names): |         action = step_result['action_name'] | ||||||
|         return 'agent_collision', 'blank' |         valid = step_result['action_valid'] | ||||||
|     elif not agent.temp_valid or c.LEVEL in col_names or c.AGENT in col_names: |         col_names = [x.name for x in step_result['collisions']] | ||||||
|         return c.AGENT, 'invalid' |         if any(c.AGENT in name for name in col_names): | ||||||
|     elif agent.temp_valid and not EnvActions.is_move(agent.temp_action): |             return 'agent_collision', 'blank' | ||||||
|         return c.AGENT, 'valid' |         elif not valid or c.LEVEL in col_names or c.AGENT in col_names: | ||||||
|     elif agent.temp_valid and EnvActions.is_move(agent.temp_action): |             return c.AGENT, 'invalid' | ||||||
|         return c.AGENT, 'move' |         elif valid and not EnvActions.is_move(action): | ||||||
|  |             return c.AGENT, 'valid' | ||||||
|  |         elif valid and EnvActions.is_move(action): | ||||||
|  |             return c.AGENT, 'move' | ||||||
|  |         else: | ||||||
|  |             return c.AGENT, 'idle' | ||||||
|     else: |     else: | ||||||
|         return c.AGENT, 'idle' |         return c.AGENT, 'idle' | ||||||
|  |  | ||||||
|   | |||||||
| @@ -134,8 +134,7 @@ if __name__ == '__main__': | |||||||
|                                 max_spawn_amount=0.1, max_global_amount=20, |                                 max_spawn_amount=0.1, max_global_amount=20, | ||||||
|                                 max_local_amount=1, spawn_frequency=0, max_spawn_ratio=0.05, |                                 max_local_amount=1, spawn_frequency=0, max_spawn_ratio=0.05, | ||||||
|                                 dirt_smear_amount=0.0, agent_can_interact=True) |                                 dirt_smear_amount=0.0, agent_can_interact=True) | ||||||
|     item_props = ItemProperties(n_items=10, agent_can_interact=True, |     item_props = ItemProperties(n_items=10, spawn_frequency=30, n_drop_off_locations=2, | ||||||
|                                 spawn_frequency=30, n_drop_off_locations=2, |  | ||||||
|                                 max_agent_inventory_capacity=15) |                                 max_agent_inventory_capacity=15) | ||||||
|     dest_props = DestProperties(n_dests=4, spawn_mode=DestModeOptions.GROUPED, spawn_frequency=1) |     dest_props = DestProperties(n_dests=4, spawn_mode=DestModeOptions.GROUPED, spawn_frequency=1) | ||||||
|     factory_kwargs = dict(n_agents=1, max_steps=400, parse_doors=True, |     factory_kwargs = dict(n_agents=1, max_steps=400, parse_doors=True, | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Steffen Illium
					Steffen Illium