mirror of
				https://github.com/illiumst/marl-factory-grid.git
				synced 2025-10-31 04:37:25 +01:00 
			
		
		
		
	added running marl a2c
This commit is contained in:
		| @@ -5,7 +5,12 @@ import yaml | ||||
| from pathlib import Path | ||||
| from salina import instantiate_class | ||||
| from salina import TAgent | ||||
| from salina.agents.gyma import AutoResetGymAgent, _torch_type, _format_frame | ||||
| from salina.agents.gyma import ( | ||||
|     AutoResetGymAgent, | ||||
|     _torch_type, | ||||
|     _format_frame, | ||||
|     _torch_cat_dict | ||||
| ) | ||||
|  | ||||
|  | ||||
| def load_yaml_file(path: Path): | ||||
| @@ -20,42 +25,47 @@ def add_env_props(cfg): | ||||
|                              n_actions=env.action_space.n)) | ||||
|  | ||||
|  | ||||
| class CombineActionsAgent(TAgent): | ||||
|     def __init__(self, pattern=r'^agent\d_action$'): | ||||
|         super().__init__() | ||||
|         self.pattern = pattern | ||||
|  | ||||
|     def forward(self, t, **kwargs): | ||||
|         keys = list(self.workspace.keys()) | ||||
|         action_keys = sorted([k for k in keys if bool(re.match(self.pattern, k))]) | ||||
|         actions = torch.cat([self.get((k, t)) for k in action_keys], 0) | ||||
|         actions = actions if len(action_keys) <= 1 else actions.unsqueeze(0) | ||||
|         self.set((f'action', t), actions) | ||||
|  | ||||
| AGENT_PREFIX = 'agent#' | ||||
| REWARD       =  'reward' | ||||
| CUMU_REWARD  = 'cumulated_reward' | ||||
| OBS          = 'env_obs' | ||||
| SEP          = '_' | ||||
| ACTION       = 'action' | ||||
|  | ||||
|  | ||||
| def access_str(agent_i, name, prefix=''): | ||||
|     return f'{prefix}{AGENT_PREFIX}{agent_i}{SEP}{name}' | ||||
|  | ||||
|  | ||||
| class AutoResetGymMultiAgent(AutoResetGymAgent): | ||||
|     AGENT_PREFIX = 'agent#' | ||||
|     REWARD       =  'reward' | ||||
|     CUMU_REWARD  = 'cumulated_reward' | ||||
|     SEP          = '_' | ||||
|  | ||||
|     def __init__(self, *args, n_agents, **kwargs): | ||||
|     def __init__(self, *args, **kwargs): | ||||
|         super(AutoResetGymMultiAgent, self).__init__(*args, **kwargs) | ||||
|         self.n_agents = n_agents | ||||
|  | ||||
|     def prefix(self, agent_id, name): | ||||
|         return f'{self.AGENT_PREFIX}{agent_id}{self.SEP}{name}' | ||||
|     def per_agent_values(self, name, values): | ||||
|         return {access_str(agent_i, name): value | ||||
|                 for agent_i, value in zip(range(self.n_agents), values)} | ||||
|  | ||||
|     def _initialize_envs(self, n): | ||||
|         super()._initialize_envs(n) | ||||
|         n_agents_list = [self.envs[i].unwrapped.n_agents for i in range(n)] | ||||
|         assert all(n_agents == n_agents_list[0] for n_agents in n_agents_list), \ | ||||
|             'All envs must have the same number of agents.' | ||||
|         self.n_agents = n_agents_list[0] | ||||
|  | ||||
|     def _reset(self, k, save_render): | ||||
|         ret = super()._reset(k, save_render) | ||||
|         obs = ret['env_obs'].squeeze() | ||||
|         self.cumulated_reward[k] = [0.0]*self.n_agents | ||||
|         del ret['cumulated_reward'] | ||||
|         cumu_rew = {self.prefix(agent_i, self.CUMU_REWARD): torch.zeros(1).float() | ||||
|                     for agent_i in range(self.n_agents)} | ||||
|         rewards  = {self.prefix(agent_i, self.REWARD)     : torch.zeros(1).float() | ||||
|                     for agent_i in range(self.n_agents)} | ||||
|         obs      = self.per_agent_values(OBS,  [_format_frame(obs[i]) for i in range(self.n_agents)]) | ||||
|         cumu_rew = self.per_agent_values(CUMU_REWARD, torch.zeros(self.n_agents, 1).float().unbind()) | ||||
|         rewards  = self.per_agent_values(REWARD,      torch.zeros(self.n_agents, 1).float().unbind()) | ||||
|         ret.update(cumu_rew) | ||||
|         ret.update(rewards) | ||||
|         ret.update(obs) | ||||
|         for remove in ['env_obs', 'cumulated_reward', 'reward']: | ||||
|             del ret[remove] | ||||
|         return ret | ||||
|  | ||||
|     def _step(self, k, action, save_render): | ||||
| @@ -68,28 +78,33 @@ class AutoResetGymMultiAgent(AutoResetGymAgent): | ||||
|             action = np.array(action.tolist()) | ||||
|         o, r, d, _ = env.step(action) | ||||
|         self.cumulated_reward[k] = [x+y for x, y in zip(r, self.cumulated_reward[k])] | ||||
|         print(o.shape) | ||||
|         observation = _format_frame(o) | ||||
|         if isinstance(observation, torch.Tensor): | ||||
|             print(observation.shape) | ||||
|             observation = {self.prefix(agent_i, 'env_obs'): observation[agent_i] | ||||
|                            for agent_i in range(self.n_agents)} | ||||
|             print(observation) | ||||
|         else: | ||||
|             assert isinstance(observation, dict) | ||||
|         observation = self.per_agent_values(OBS, [_format_frame(o[i]) for i in range(self.n_agents)]) | ||||
|         if d: | ||||
|             self.is_running[k] = False | ||||
|  | ||||
|         if save_render: | ||||
|             image = env.render(mode="image").unsqueeze(0) | ||||
|             observation["rendering"] = image | ||||
|         rewards           = self.per_agent_values(REWARD, torch.tensor(r).float().view(-1, 1).unbind()) | ||||
|         cumulated_rewards = self.per_agent_values(CUMU_REWARD, torch.tensor(self.cumulated_reward[k]).float().view(-1, 1).unbind()) | ||||
|         ret = { | ||||
|             **observation, | ||||
|             **rewards, | ||||
|             **cumulated_rewards, | ||||
|             "done": torch.tensor([d]), | ||||
|             "initial_state": torch.tensor([False]), | ||||
|             "reward": torch.tensor(r).float(), | ||||
|             "timestep": torch.tensor([self.timestep[k]]), | ||||
|             "cumulated_reward": torch.tensor(self.cumulated_reward[k]).float(), | ||||
|             "timestep": torch.tensor([self.timestep[k]]) | ||||
|         } | ||||
|         return _torch_type(ret) | ||||
|  | ||||
|  | ||||
| class CombineActionsAgent(TAgent): | ||||
|     def __init__(self): | ||||
|         super().__init__() | ||||
|         self.pattern = fr'^{AGENT_PREFIX}\d{SEP}{ACTION}$' | ||||
|  | ||||
|     def forward(self, t, **kwargs): | ||||
|         keys = list(self.workspace.keys()) | ||||
|         action_keys = sorted([k for k in keys if bool(re.match(self.pattern, k))]) | ||||
|         actions = torch.cat([self.get((k, t)) for k in action_keys], 0) | ||||
|         actions = actions if len(action_keys) <= 1 else actions.unsqueeze(0) | ||||
|         self.set((f'action', t), actions) | ||||
|   | ||||
| @@ -13,16 +13,18 @@ from algorithms.utils import ( | ||||
|     add_env_props, | ||||
|     load_yaml_file, | ||||
|     CombineActionsAgent, | ||||
|     AutoResetGymMultiAgent | ||||
|     AutoResetGymMultiAgent, | ||||
|     access_str, | ||||
|     AGENT_PREFIX, REWARD, CUMU_REWARD, OBS, SEP | ||||
| ) | ||||
|  | ||||
|  | ||||
| class A2CAgent(TAgent): | ||||
|     def __init__(self, observation_size, hidden_size, n_actions, agent_id=-1, marl=False): | ||||
|     def __init__(self, observation_size, hidden_size, n_actions, agent_id): | ||||
|         super().__init__() | ||||
|         observation_size = np.prod(observation_size) | ||||
|         print(observation_size) | ||||
|         self.agent_id = agent_id | ||||
|         self.marl = marl | ||||
|         self.model = nn.Sequential( | ||||
|             nn.Flatten(), | ||||
|             nn.Linear(observation_size, hidden_size), | ||||
| @@ -36,10 +38,7 @@ class A2CAgent(TAgent): | ||||
|         self.critic_head = nn.Linear(hidden_size, 1) | ||||
|  | ||||
|     def get_obs(self, t): | ||||
|         observation = self.get(("env/env_obs", t)) | ||||
|         print(observation.shape) | ||||
|         if self.marl: | ||||
|             observation = observation[self.agent_id] | ||||
|         observation = self.get((f'env/{access_str(self.agent_id, OBS)}', t)) | ||||
|         return observation | ||||
|  | ||||
|     def forward(self, t, stochastic, **kwargs): | ||||
| @@ -52,10 +51,9 @@ class A2CAgent(TAgent): | ||||
|             action = torch.distributions.Categorical(probs).sample() | ||||
|         else: | ||||
|             action = probs.argmax(1) | ||||
|         agent_str = f'agent{self.agent_id}_' | ||||
|         self.set((f'{agent_str}action', t), action) | ||||
|         self.set((f'{agent_str}action_probs', t), probs) | ||||
|         self.set((f'{agent_str}critic', t), critic) | ||||
|         self.set((f'{access_str(self.agent_id, "action")}', t), action) | ||||
|         self.set((f'{access_str(self.agent_id, "action_probs")}', t), probs) | ||||
|         self.set((f'{access_str(self.agent_id, "critic")}', t), critic) | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
| @@ -73,13 +71,11 @@ if __name__ == '__main__': | ||||
|     env_agent = AutoResetGymMultiAgent( | ||||
|         get_class(cfg['env']), | ||||
|         get_arguments(cfg['env']), | ||||
|         n_envs=1, | ||||
|         n_agents=n_agents | ||||
|         n_envs=1 | ||||
|     ) | ||||
|  | ||||
|     a2c_agents = [instantiate_class({**cfg['agent'], | ||||
|                                      'agent_id': agent_id, | ||||
|                                      'marl':     n_agents > 1}) | ||||
|                                      'agent_id': agent_id}) | ||||
|                   for agent_id in range(n_agents)] | ||||
|  | ||||
|     # combine agents | ||||
| @@ -105,11 +101,12 @@ if __name__ == '__main__': | ||||
|  | ||||
|             for agent_id in range(n_agents): | ||||
|                 critic, done, action_probs, reward, action = workspace[ | ||||
|                     f"agent{agent_id}_critic", "env/done", | ||||
|                     f'agent{agent_id}_action_probs', "env/reward", | ||||
|                     f"agent{agent_id}_action" | ||||
|                     access_str(agent_id, 'critic'), | ||||
|                     "env/done", | ||||
|                     access_str(agent_id, 'action_probs'), | ||||
|                     access_str(agent_id, 'reward', 'env/'), | ||||
|                     access_str(agent_id, 'action') | ||||
|                 ] | ||||
|                 reward = reward[agent_id] | ||||
|                 td = gae(critic, reward, done, 0.98, 0.25) | ||||
|                 td_error = td ** 2 | ||||
|                 critic_loss = td_error.mean() | ||||
| @@ -129,13 +126,14 @@ if __name__ == '__main__': | ||||
|                 optimizer.step() | ||||
|  | ||||
|                 # Compute the cumulated reward on final_state | ||||
|                 creward = workspace["env/cumulated_reward"]#[agent_id].unsqueeze(-1) | ||||
|                 print(creward.shape, done.shape) | ||||
|                 creward = creward[done] | ||||
|                 if creward.size()[0] > 0: | ||||
|                     cum_r = creward.mean().item() | ||||
|                     if cum_r > best: | ||||
|                     #    torch.save(a2c_agent.state_dict(), Path(__file__).parent / f'agent_{uid}.pt') | ||||
|                         best = cum_r | ||||
|                     pbar.set_description(f"Cum. r: {cum_r:.2f}, Best r. so far: {best:.2f}", refresh=True) | ||||
|                 rews = '' | ||||
|                 for agent_i in range(n_agents): | ||||
|                     creward = workspace['env/'+access_str(agent_i, CUMU_REWARD)] | ||||
|                     creward = creward[done] | ||||
|                     if creward.size()[0] > 0: | ||||
|                         rews += f'{AGENT_PREFIX}{agent_i}: {creward.mean().item():.2f}  |  ' | ||||
|                         """if cum_r > best: | ||||
|                             torch.save(a2c_agent.state_dict(), Path(__file__).parent / f'agent_{uid}.pt') | ||||
|                             best = cum_r""" | ||||
|                         pbar.set_description(rews, refresh=True) | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Robert Müller
					Robert Müller