mirror of
				https://github.com/illiumst/marl-factory-grid.git
				synced 2025-11-01 21:17:26 +01:00 
			
		
		
		
	Stable Baseline Running
This commit is contained in:
		| @@ -1,4 +1,3 @@ | ||||
| import abc | ||||
| from typing import List, Union, Iterable | ||||
|  | ||||
| import gym | ||||
| @@ -43,10 +42,6 @@ class BaseFactory(gym.Env): | ||||
|     def observation_space(self): | ||||
|         return spaces.Box(low=-1, high=1, shape=self.state.shape, dtype=np.float32) | ||||
|  | ||||
|     @property | ||||
|     def monitor_as_df_list(self): | ||||
|         return [x.to_pd_dataframe() for x in self._monitor_list] | ||||
|  | ||||
|     @property | ||||
|     def movement_actions(self): | ||||
|         return (int(self.allow_vertical_movement) + int(self.allow_horizontal_movement)) * 4 | ||||
| @@ -61,7 +56,6 @@ class BaseFactory(gym.Env): | ||||
|         self.allow_vertical_movement = True | ||||
|         self.allow_horizontal_movement = True | ||||
|         self.allow_no_OP = True | ||||
|         self._monitor_list = list() | ||||
|         self._registered_actions = self.movement_actions + int(self.allow_no_OP) + self.register_additional_actions() | ||||
|         self.level = h.one_hot_level( | ||||
|             h.parse_level(Path(__file__).parent / h.LEVELS_DIR / f'{level}.txt') | ||||
| @@ -77,7 +71,6 @@ class BaseFactory(gym.Env): | ||||
|         self.steps = 0 | ||||
|         self.cumulative_reward = 0 | ||||
|         self.monitor = FactoryMonitor(self) | ||||
|         self._monitor_list.append(self.monitor) | ||||
|         self.agent_states = [] | ||||
|         # Agent placement ... | ||||
|         agents = np.zeros((self.n_agents, *self.level.shape), dtype=np.int8) | ||||
| @@ -92,7 +85,6 @@ class BaseFactory(gym.Env): | ||||
|         # state.shape = level, agent 1,..., agent n, | ||||
|         self.state = np.concatenate((np.expand_dims(self.level, axis=0), agents), axis=0) | ||||
|         # Returns State | ||||
|  | ||||
|         return self.state | ||||
|  | ||||
|     def additional_actions(self, agent_i: int, action: int) -> ((int, int), bool): | ||||
| @@ -122,11 +114,11 @@ class BaseFactory(gym.Env): | ||||
|  | ||||
|         self.agent_states = states | ||||
|         reward, info = self.calculate_reward(states) | ||||
|         self.cumulative_reward += reward | ||||
|  | ||||
|         if self.steps >= self.max_steps: | ||||
|             self.done = True | ||||
|         return self.state, self.cumulative_reward, self.done, info | ||||
|         self.monitor.add('step_reward', reward) | ||||
|         return self.state, reward, self.done, info | ||||
|  | ||||
|     def _is_moving_action(self, action): | ||||
|         return action < self.movement_actions | ||||
|   | ||||
| @@ -120,8 +120,8 @@ class SimpleFactory(BaseFactory): | ||||
|         dirty_tiles = len(np.nonzero(self.state[DIRT_INDEX])) | ||||
|  | ||||
|         try: | ||||
|             this_step_reward = -(dirty_tiles / current_dirt_amount) | ||||
|         except ZeroDivisionError: | ||||
|             this_step_reward = (dirty_tiles / current_dirt_amount) | ||||
|         except (ZeroDivisionError, RuntimeWarning): | ||||
|             this_step_reward = 0 | ||||
|  | ||||
|         for agent_state in agent_states: | ||||
| @@ -136,6 +136,7 @@ class SimpleFactory(BaseFactory): | ||||
|                     self.monitor.add(f'agent_{agent_state.i}_vs_{self.slice_strings[entity]}', 1) | ||||
|         self.monitor.set('dirt_amount', current_dirt_amount) | ||||
|         self.monitor.set('dirty_tiles', dirty_tiles) | ||||
|         print(f"reward is {this_step_reward}") | ||||
|         return this_step_reward, {} | ||||
|  | ||||
|  | ||||
|   | ||||
| @@ -60,17 +60,26 @@ class MonitorCallback(BaseCallback): | ||||
|         super(MonitorCallback, self).__init__() | ||||
|         self._outpath = Path(outpath) | ||||
|         self._filename = filename | ||||
|         self._monitor_list = list() | ||||
|         self.out_file = self._outpath / f'{self._filename.split(".")[0]}.pick' | ||||
|         self.env = env | ||||
|         self.started = False | ||||
|         self.closed = False | ||||
|  | ||||
|     @property | ||||
|     def monitor_as_df_list(self): | ||||
|         return [x.to_pd_dataframe() for x in self._monitor_list] | ||||
|  | ||||
|     def __enter__(self): | ||||
|         self._on_training_start() | ||||
|  | ||||
|     def __exit__(self, exc_type, exc_val, exc_tb): | ||||
|         self._on_training_end() | ||||
|  | ||||
|     def _on_rollout_end(self) -> None: | ||||
|         self._monitor_list.append(self.env.monitor) | ||||
|         pass | ||||
|  | ||||
|     def _on_training_start(self) -> None: | ||||
|         if self.started: | ||||
|             pass | ||||
| @@ -85,7 +94,7 @@ class MonitorCallback(BaseCallback): | ||||
|         else: | ||||
|             # self.out_file.unlink(missing_ok=True) | ||||
|             with self.out_file.open('wb') as f: | ||||
|                 pickle.dump(self.env.monitor_as_df_list, f, protocol=pickle.HIGHEST_PROTOCOL) | ||||
|                 pickle.dump(self.monitor_as_df_list, f, protocol=pickle.HIGHEST_PROTOCOL) | ||||
|             self.closed = True | ||||
|  | ||||
|     def _on_step(self) -> bool: | ||||
|   | ||||
		Reference in New Issue
	
	Block a user