mirror of
				https://github.com/illiumst/marl-factory-grid.git
				synced 2025-11-01 21:17:26 +01:00 
			
		
		
		
	Logging Monitor Callback
This commit is contained in:
		| @@ -56,6 +56,7 @@ class BaseFactory(gym.Env): | ||||
|         self.allow_vertical_movement = True | ||||
|         self.allow_horizontal_movement = True | ||||
|         self.allow_no_OP = True | ||||
|         self.done_at_collision = True | ||||
|         self._registered_actions = self.movement_actions + int(self.allow_no_OP) + self.register_additional_actions() | ||||
|         self.level = h.one_hot_level( | ||||
|             h.parse_level(Path(__file__).parent / h.LEVELS_DIR / f'{level}.txt') | ||||
| @@ -96,7 +97,7 @@ class BaseFactory(gym.Env): | ||||
|         self.steps += 1 | ||||
|  | ||||
|         # Move this in a seperate function? | ||||
|         states = list() | ||||
|         agent_states = list() | ||||
|         for agent_i, action in enumerate(actions): | ||||
|             agent_i_state = AgentState(agent_i, action) | ||||
|             if self._is_moving_action(action): | ||||
| @@ -107,13 +108,15 @@ class BaseFactory(gym.Env): | ||||
|                 pos, valid = self.additional_actions(agent_i, action) | ||||
|             # Update state accordingly | ||||
|             agent_i_state.update(pos=pos, action_valid=valid) | ||||
|             states.append(agent_i_state) | ||||
|             agent_states.append(agent_i_state) | ||||
|  | ||||
|         for i, collision_vec in enumerate(self.check_all_collisions(states, self.state.shape[0])): | ||||
|             states[i].update(collision_vector=collision_vec) | ||||
|         for i, collision_vec in enumerate(self.check_all_collisions(agent_states, self.state.shape[0])): | ||||
|             agent_states[i].update(collision_vector=collision_vec) | ||||
|             if self.done_at_collision and collision_vec.any(): | ||||
|                 self.done = True | ||||
|  | ||||
|         self.agent_states = states | ||||
|         reward, info = self.calculate_reward(states) | ||||
|         self.agent_states = agent_states | ||||
|         reward, info = self.calculate_reward(agent_states) | ||||
|  | ||||
|         if self.steps >= self.max_steps: | ||||
|             self.done = True | ||||
|   | ||||
| @@ -17,7 +17,7 @@ DIRT_INDEX = -1 | ||||
|  | ||||
| @dataclass | ||||
| class DirtProperties: | ||||
|     clean_amount = 0.25 | ||||
|     clean_amount = 10 | ||||
|     max_spawn_ratio = 0.1 | ||||
|     gain_amount = 0.1 | ||||
|     spawn_frequency = 5 | ||||
| @@ -31,8 +31,9 @@ class SimpleFactory(BaseFactory): | ||||
|     def _is_clean_up_action(self, action): | ||||
|         return self.action_space.n - 1 == action | ||||
|  | ||||
|     def __init__(self, *args, dirt_properties: DirtProperties, **kwargs): | ||||
|     def __init__(self, *args, dirt_properties: DirtProperties, verbose=False, **kwargs): | ||||
|         self._dirt_properties = dirt_properties | ||||
|         self.verbose = verbose | ||||
|         super(SimpleFactory, self).__init__(*args, **kwargs) | ||||
|         self.slice_strings.update({self.state.shape[0]-1: 'dirt'}) | ||||
|         self.renderer = None  # expensive - dont use it when not required ! | ||||
| @@ -81,6 +82,9 @@ class SimpleFactory(BaseFactory): | ||||
|             return pos, cleanup_was_sucessfull | ||||
|  | ||||
|     def step(self, actions): | ||||
|         if self.state[h.LEVEL_IDX][self.agent_i_position(0)] == h.IS_OCCUPIED_CELL: | ||||
|             print(f'fAgent placed on wall!!!!, step is :{self.steps}') | ||||
|             raise Exception('Agent placed on wall!!!!') | ||||
|         _, _, _, info = super(SimpleFactory, self).step(actions) | ||||
|         if not self.next_dirt_spawn: | ||||
|             self.spawn_dirt() | ||||
| @@ -94,12 +98,6 @@ class SimpleFactory(BaseFactory): | ||||
|             if self._is_clean_up_action(action): | ||||
|                 agent_i_pos = self.agent_i_position(agent_i) | ||||
|                 _, valid = self.clean_up(agent_i_pos) | ||||
|                 if valid: | ||||
|                     print(f'Agent {agent_i} did just clean up some dirt at {agent_i_pos}.') | ||||
|                     self.monitor.add('dirt_cleaned', self._dirt_properties.clean_amount) | ||||
|                 else: | ||||
|                     print(f'Agent {agent_i} just tried to clean up some dirt at {agent_i_pos}, but was unsucsessfull.') | ||||
|                     self.monitor.add('failed_cleanup_attempt', 1) | ||||
|                 return agent_i_pos, valid | ||||
|             else: | ||||
|                 raise RuntimeError('This should not happen!!!') | ||||
| @@ -120,25 +118,44 @@ class SimpleFactory(BaseFactory): | ||||
|         dirty_tiles = len(np.nonzero(self.state[DIRT_INDEX])) | ||||
|  | ||||
|         try: | ||||
|             this_step_reward = (dirty_tiles / current_dirt_amount) | ||||
|             # penalty = current_dirt_amount | ||||
|             penalty = 0 | ||||
|         except (ZeroDivisionError, RuntimeWarning): | ||||
|             this_step_reward = 0 | ||||
|  | ||||
|             penalty = 0 | ||||
|         inforcements = 0 | ||||
|         for agent_state in agent_states: | ||||
|             collisions = agent_state.collisions | ||||
|             print(f't = {self.steps}\tAgent {agent_state.i} has collisions with ' | ||||
|                   f'{[self.slice_strings[entity] for entity in collisions if entity != self.string_slices["dirt"]]}') | ||||
|             if self._is_clean_up_action(agent_state.action) and agent_state.action_valid: | ||||
|                 this_step_reward += 1 | ||||
|             cols = agent_state.collisions | ||||
|             self.print(f't = {self.steps}\tAgent {agent_state.i} has collisions with ' | ||||
|                        f'{[self.slice_strings[entity] for entity in cols if entity != self.string_slices["dirt"]]}') | ||||
|             if self._is_clean_up_action(agent_state.action): | ||||
|                 if agent_state.action_valid: | ||||
|                     inforcements += 10 | ||||
|                     self.print(f'Agent {agent_state.i} did just clean up some dirt at {agent_state.pos}.') | ||||
|                     self.monitor.add('dirt_cleaned', self._dirt_properties.clean_amount) | ||||
|                 else: | ||||
|                     self.print(f'Agent {agent_state.i} just tried to clean up some dirt ' | ||||
|                                f'at {agent_state.pos}, but was unsucsessfull.') | ||||
|                     self.monitor.add('failed_cleanup_attempt', 1) | ||||
|             elif self._is_moving_action(agent_state.action): | ||||
|                 if not agent_state.action_valid: | ||||
|                     penalty += 10 | ||||
|                 else: | ||||
|                     inforcements += 1 | ||||
|  | ||||
|             for entity in collisions: | ||||
|             for entity in cols: | ||||
|                 if entity != self.string_slices["dirt"]: | ||||
|                     self.monitor.add(f'agent_{agent_state.i}_vs_{self.slice_strings[entity]}', 1) | ||||
|  | ||||
|         this_step_reward = max(0, inforcements-penalty) | ||||
|         self.monitor.set('dirt_amount', current_dirt_amount) | ||||
|         self.monitor.set('dirty_tiles', dirty_tiles) | ||||
|         print(f"reward is {this_step_reward}") | ||||
|         self.print(f"reward is {this_step_reward}") | ||||
|         return this_step_reward, {} | ||||
|  | ||||
|     def print(self, string): | ||||
|         if self.verbose: | ||||
|             print(string) | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     render = True | ||||
|   | ||||
| @@ -56,12 +56,10 @@ class FactoryMonitor: | ||||
|  | ||||
| class MonitorCallback(BaseCallback): | ||||
|  | ||||
|     def __init__(self, env, outpath='debug_out', filename='monitor'): | ||||
|     def __init__(self, env, filepath=Path('debug_out/monitor.pick')): | ||||
|         super(MonitorCallback, self).__init__() | ||||
|         self._outpath = Path(outpath) | ||||
|         self._filename = filename | ||||
|         self.filepath = Path(filepath) | ||||
|         self._monitor_list = list() | ||||
|         self.out_file = self._outpath / f'{self._filename.split(".")[0]}.pick' | ||||
|         self.env = env | ||||
|         self.started = False | ||||
|         self.closed = False | ||||
| @@ -84,7 +82,7 @@ class MonitorCallback(BaseCallback): | ||||
|         if self.started: | ||||
|             pass | ||||
|         else: | ||||
|             self.out_file.parent.mkdir(exist_ok=True, parents=True) | ||||
|             self.filepath.parent.mkdir(exist_ok=True, parents=True) | ||||
|             self.started = True | ||||
|         pass | ||||
|  | ||||
| @@ -93,7 +91,7 @@ class MonitorCallback(BaseCallback): | ||||
|             pass | ||||
|         else: | ||||
|             # self.out_file.unlink(missing_ok=True) | ||||
|             with self.out_file.open('wb') as f: | ||||
|             with self.filepath.open('wb') as f: | ||||
|                 pickle.dump(self.monitor_as_df_list, f, protocol=pickle.HIGHEST_PROTOCOL) | ||||
|             self.closed = True | ||||
|  | ||||
|   | ||||
							
								
								
									
										35
									
								
								environments/logging/training.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										35
									
								
								environments/logging/training.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,35 @@ | ||||
| from pathlib import Path | ||||
|  | ||||
| import pandas as pd | ||||
| from stable_baselines3.common.callbacks import BaseCallback | ||||
|  | ||||
|  | ||||
| class TraningMonitor(BaseCallback): | ||||
|  | ||||
|     def __init__(self, filepath, flush_interval=None): | ||||
|         super(TraningMonitor, self).__init__() | ||||
|         self.values = dict() | ||||
|         self.filepath = Path(filepath) | ||||
|         self.flush_interval = flush_interval | ||||
|         pass | ||||
|  | ||||
|     def _on_training_start(self) -> None: | ||||
|         self.flush_interval = self.flush_interval or (self.locals['total_timesteps'] * 0.1) | ||||
|  | ||||
|     def _flush(self): | ||||
|         df = pd.DataFrame.from_dict(self.values) | ||||
|         if not self.filepath.exists(): | ||||
|             df.to_csv(self.filepath, mode='wb', header=True) | ||||
|         else: | ||||
|             df.to_csv(self.filepath, mode='a', header=False) | ||||
|         self.values = dict() | ||||
|  | ||||
|     def _on_step(self) -> bool: | ||||
|         self.values[self.num_timesteps] = dict(reward=self.locals['rewards'].item()) | ||||
|         if self.num_timesteps % self.flush_interval == 0: | ||||
|             self._flush() | ||||
|         return True | ||||
|  | ||||
|     def on_training_end(self) -> None: | ||||
|         self._flush() | ||||
|  | ||||
		Reference in New Issue
	
	Block a user