111 lines
		
	
	
		
			3.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			111 lines
		
	
	
		
			3.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import re
 | |
| import torch
 | |
| import numpy as np
 | |
| import yaml
 | |
| from pathlib import Path
 | |
| from salina import instantiate_class
 | |
| from salina import TAgent
 | |
| from salina.agents.gyma import (
 | |
|     AutoResetGymAgent,
 | |
|     _torch_type,
 | |
|     _format_frame,
 | |
|     _torch_cat_dict
 | |
| )
 | |
| 
 | |
| 
 | |
| def load_yaml_file(path: Path):
 | |
|     with path.open() as stream:
 | |
|         cfg = yaml.load(stream, Loader=yaml.FullLoader)
 | |
|     return cfg
 | |
| 
 | |
| 
 | |
| def add_env_props(cfg):
 | |
|     env = instantiate_class(cfg['env'].copy())
 | |
|     cfg['agent'].update(dict(observation_size=env.observation_space.shape,
 | |
|                              n_actions=env.action_space.n))
 | |
| 
 | |
| 
 | |
| 
 | |
| 
 | |
| AGENT_PREFIX = 'agent#'
 | |
| REWARD       =  'reward'
 | |
| CUMU_REWARD  = 'cumulated_reward'
 | |
| OBS          = 'env_obs'
 | |
| SEP          = '_'
 | |
| ACTION       = 'action'
 | |
| 
 | |
| 
 | |
| def access_str(agent_i, name, prefix=''):
 | |
|     return f'{prefix}{AGENT_PREFIX}{agent_i}{SEP}{name}'
 | |
| 
 | |
| 
 | |
| class AutoResetGymMultiAgent(AutoResetGymAgent):
 | |
|     def __init__(self, *args, **kwargs):
 | |
|         super(AutoResetGymMultiAgent, self).__init__(*args, **kwargs)
 | |
| 
 | |
|     def per_agent_values(self, name, values):
 | |
|         return {access_str(agent_i, name): value
 | |
|                 for agent_i, value in zip(range(self.n_agents), values)}
 | |
| 
 | |
|     def _initialize_envs(self, n):
 | |
|         super()._initialize_envs(n)
 | |
|         n_agents_list = [self.envs[i].unwrapped.n_agents for i in range(n)]
 | |
|         assert all(n_agents == n_agents_list[0] for n_agents in n_agents_list), \
 | |
|             'All envs must have the same number of agents.'
 | |
|         self.n_agents = n_agents_list[0]
 | |
| 
 | |
|     def _reset(self, k, save_render):
 | |
|         ret = super()._reset(k, save_render)
 | |
|         obs = ret['env_obs'].squeeze()
 | |
|         self.cumulated_reward[k] = [0.0]*self.n_agents
 | |
|         obs      = self.per_agent_values(OBS,  [_format_frame(obs[i]) for i in range(self.n_agents)])
 | |
|         cumu_rew = self.per_agent_values(CUMU_REWARD, torch.zeros(self.n_agents, 1).float().unbind())
 | |
|         rewards  = self.per_agent_values(REWARD,      torch.zeros(self.n_agents, 1).float().unbind())
 | |
|         ret.update(cumu_rew)
 | |
|         ret.update(rewards)
 | |
|         ret.update(obs)
 | |
|         for remove in ['env_obs', 'cumulated_reward', 'reward']:
 | |
|             del ret[remove]
 | |
|         return ret
 | |
| 
 | |
|     def _step(self, k, action, save_render):
 | |
|         self.timestep[k] += 1
 | |
|         env = self.envs[k]
 | |
|         if len(action.size()) == 0:
 | |
|             action = action.item()
 | |
|             assert isinstance(action, int)
 | |
|         else:
 | |
|             action = np.array(action.tolist())
 | |
|         o, r, d, _ = env.step(action)
 | |
|         self.cumulated_reward[k] = [x+y for x, y in zip(r, self.cumulated_reward[k])]
 | |
|         observation = self.per_agent_values(OBS, [_format_frame(o[i]) for i in range(self.n_agents)])
 | |
|         if d:
 | |
|             self.is_running[k] = False
 | |
|         if save_render:
 | |
|             image = env.render(mode="image").unsqueeze(0)
 | |
|             observation["rendering"] = image
 | |
|         rewards           = self.per_agent_values(REWARD, torch.tensor(r).float().view(-1, 1).unbind())
 | |
|         cumulated_rewards = self.per_agent_values(CUMU_REWARD, torch.tensor(self.cumulated_reward[k]).float().view(-1, 1).unbind())
 | |
|         ret = {
 | |
|             **observation,
 | |
|             **rewards,
 | |
|             **cumulated_rewards,
 | |
|             "done": torch.tensor([d]),
 | |
|             "initial_state": torch.tensor([False]),
 | |
|             "timestep": torch.tensor([self.timestep[k]])
 | |
|         }
 | |
|         return _torch_type(ret)
 | |
| 
 | |
| 
 | |
| class CombineActionsAgent(TAgent):
 | |
|     def __init__(self):
 | |
|         super().__init__()
 | |
|         self.pattern = fr'^{AGENT_PREFIX}\d{SEP}{ACTION}$'
 | |
| 
 | |
|     def forward(self, t, **kwargs):
 | |
|         keys = list(self.workspace.keys())
 | |
|         action_keys = sorted([k for k in keys if bool(re.match(self.pattern, k))])
 | |
|         actions = torch.cat([self.get((k, t)) for k in action_keys], 0)
 | |
|         actions = actions if len(action_keys) <= 1 else actions.unsqueeze(0)
 | |
|         self.set((f'action', t), actions)
 | 
