diff --git a/algorithms/utils.py b/algorithms/utils.py
index cd3fbd3..d72046a 100644
--- a/algorithms/utils.py
+++ b/algorithms/utils.py
@@ -5,7 +5,12 @@ import yaml
 from pathlib import Path
 from salina import instantiate_class
 from salina import TAgent
-from salina.agents.gyma import AutoResetGymAgent, _torch_type, _format_frame
+from salina.agents.gyma import (
+    AutoResetGymAgent,
+    _torch_type,
+    _format_frame,
+    _torch_cat_dict
+)
 
 
 def load_yaml_file(path: Path):
@@ -20,42 +25,47 @@ def add_env_props(cfg):
                              n_actions=env.action_space.n))
 
 
-class CombineActionsAgent(TAgent):
-    def __init__(self, pattern=r'^agent\d_action$'):
-        super().__init__()
-        self.pattern = pattern
 
-    def forward(self, t, **kwargs):
-        keys = list(self.workspace.keys())
-        action_keys = sorted([k for k in keys if bool(re.match(self.pattern, k))])
-        actions = torch.cat([self.get((k, t)) for k in action_keys], 0)
-        actions = actions if len(action_keys) <= 1 else actions.unsqueeze(0)
-        self.set((f'action', t), actions)
+
+AGENT_PREFIX = 'agent#'
+REWARD       =  'reward'
+CUMU_REWARD  = 'cumulated_reward'
+OBS          = 'env_obs'
+SEP          = '_'
+ACTION       = 'action'
+
+
+def access_str(agent_i, name, prefix=''):
+    return f'{prefix}{AGENT_PREFIX}{agent_i}{SEP}{name}'
 
 
 class AutoResetGymMultiAgent(AutoResetGymAgent):
-    AGENT_PREFIX = 'agent#'
-    REWARD       =  'reward'
-    CUMU_REWARD  = 'cumulated_reward'
-    SEP          = '_'
-
-    def __init__(self, *args, n_agents, **kwargs):
+    def __init__(self, *args, **kwargs):
         super(AutoResetGymMultiAgent, self).__init__(*args, **kwargs)
-        self.n_agents = n_agents
 
-    def prefix(self, agent_id, name):
-        return f'{self.AGENT_PREFIX}{agent_id}{self.SEP}{name}'
+    def per_agent_values(self, name, values):
+        return {access_str(agent_i, name): value
+                for agent_i, value in zip(range(self.n_agents), values)}
+
+    def _initialize_envs(self, n):
+        super()._initialize_envs(n)
+        n_agents_list = [self.envs[i].unwrapped.n_agents for i in range(n)]
+        assert all(n_agents == n_agents_list[0] for n_agents in n_agents_list), \
+            'All envs must have the same number of agents.'
+        self.n_agents = n_agents_list[0]
 
     def _reset(self, k, save_render):
         ret = super()._reset(k, save_render)
+        obs = ret['env_obs'].squeeze()
         self.cumulated_reward[k] = [0.0]*self.n_agents
-        del ret['cumulated_reward']
-        cumu_rew = {self.prefix(agent_i, self.CUMU_REWARD): torch.zeros(1).float()
-                    for agent_i in range(self.n_agents)}
-        rewards  = {self.prefix(agent_i, self.REWARD)     : torch.zeros(1).float()
-                    for agent_i in range(self.n_agents)}
+        obs      = self.per_agent_values(OBS,  [_format_frame(obs[i]) for i in range(self.n_agents)])
+        cumu_rew = self.per_agent_values(CUMU_REWARD, torch.zeros(self.n_agents, 1).float().unbind())
+        rewards  = self.per_agent_values(REWARD,      torch.zeros(self.n_agents, 1).float().unbind())
         ret.update(cumu_rew)
         ret.update(rewards)
+        ret.update(obs)
+        for remove in ['env_obs', 'cumulated_reward', 'reward']:
+            del ret[remove]
         return ret
 
     def _step(self, k, action, save_render):
@@ -68,28 +78,33 @@ class AutoResetGymMultiAgent(AutoResetGymAgent):
             action = np.array(action.tolist())
         o, r, d, _ = env.step(action)
         self.cumulated_reward[k] = [x+y for x, y in zip(r, self.cumulated_reward[k])]
-        print(o.shape)
-        observation = _format_frame(o)
-        if isinstance(observation, torch.Tensor):
-            print(observation.shape)
-            observation = {self.prefix(agent_i, 'env_obs'): observation[agent_i]
-                           for agent_i in range(self.n_agents)}
-            print(observation)
-        else:
-            assert isinstance(observation, dict)
+        observation = self.per_agent_values(OBS, [_format_frame(o[i]) for i in range(self.n_agents)])
         if d:
             self.is_running[k] = False
-
         if save_render:
             image = env.render(mode="image").unsqueeze(0)
             observation["rendering"] = image
+        rewards           = self.per_agent_values(REWARD, torch.tensor(r).float().view(-1, 1).unbind())
+        cumulated_rewards = self.per_agent_values(CUMU_REWARD, torch.tensor(self.cumulated_reward[k]).float().view(-1, 1).unbind())
         ret = {
             **observation,
+            **rewards,
+            **cumulated_rewards,
             "done": torch.tensor([d]),
             "initial_state": torch.tensor([False]),
-            "reward": torch.tensor(r).float(),
-            "timestep": torch.tensor([self.timestep[k]]),
-            "cumulated_reward": torch.tensor(self.cumulated_reward[k]).float(),
+            "timestep": torch.tensor([self.timestep[k]])
         }
         return _torch_type(ret)
 
+
+class CombineActionsAgent(TAgent):
+    def __init__(self):
+        super().__init__()
+        self.pattern = fr'^{AGENT_PREFIX}\d{SEP}{ACTION}$'
+
+    def forward(self, t, **kwargs):
+        keys = list(self.workspace.keys())
+        action_keys = sorted([k for k in keys if bool(re.match(self.pattern, k))])
+        actions = torch.cat([self.get((k, t)) for k in action_keys], 0)
+        actions = actions if len(action_keys) <= 1 else actions.unsqueeze(0)
+        self.set((f'action', t), actions)
diff --git a/studies/sat_mad.py b/studies/sat_mad.py
index 4380c7d..6bf3e22 100644
--- a/studies/sat_mad.py
+++ b/studies/sat_mad.py
@@ -13,16 +13,18 @@ from algorithms.utils import (
     add_env_props,
     load_yaml_file,
     CombineActionsAgent,
-    AutoResetGymMultiAgent
+    AutoResetGymMultiAgent,
+    access_str,
+    AGENT_PREFIX, REWARD, CUMU_REWARD, OBS, SEP
 )
 
 
 class A2CAgent(TAgent):
-    def __init__(self, observation_size, hidden_size, n_actions, agent_id=-1, marl=False):
+    def __init__(self, observation_size, hidden_size, n_actions, agent_id):
         super().__init__()
         observation_size = np.prod(observation_size)
+        print(observation_size)
         self.agent_id = agent_id
-        self.marl = marl
         self.model = nn.Sequential(
             nn.Flatten(),
             nn.Linear(observation_size, hidden_size),
@@ -36,10 +38,7 @@ class A2CAgent(TAgent):
         self.critic_head = nn.Linear(hidden_size, 1)
 
     def get_obs(self, t):
-        observation = self.get(("env/env_obs", t))
-        print(observation.shape)
-        if self.marl:
-            observation = observation[self.agent_id]
+        observation = self.get((f'env/{access_str(self.agent_id, OBS)}', t))
         return observation
 
     def forward(self, t, stochastic, **kwargs):
@@ -52,10 +51,9 @@ class A2CAgent(TAgent):
             action = torch.distributions.Categorical(probs).sample()
         else:
             action = probs.argmax(1)
-        agent_str = f'agent{self.agent_id}_'
-        self.set((f'{agent_str}action', t), action)
-        self.set((f'{agent_str}action_probs', t), probs)
-        self.set((f'{agent_str}critic', t), critic)
+        self.set((f'{access_str(self.agent_id, "action")}', t), action)
+        self.set((f'{access_str(self.agent_id, "action_probs")}', t), probs)
+        self.set((f'{access_str(self.agent_id, "critic")}', t), critic)
 
 
 if __name__ == '__main__':
@@ -73,13 +71,11 @@ if __name__ == '__main__':
     env_agent = AutoResetGymMultiAgent(
         get_class(cfg['env']),
         get_arguments(cfg['env']),
-        n_envs=1,
-        n_agents=n_agents
+        n_envs=1
     )
 
     a2c_agents = [instantiate_class({**cfg['agent'],
-                                     'agent_id': agent_id,
-                                     'marl':     n_agents > 1})
+                                     'agent_id': agent_id})
                   for agent_id in range(n_agents)]
 
     # combine agents
@@ -105,11 +101,12 @@ if __name__ == '__main__':
 
             for agent_id in range(n_agents):
                 critic, done, action_probs, reward, action = workspace[
-                    f"agent{agent_id}_critic", "env/done",
-                    f'agent{agent_id}_action_probs', "env/reward",
-                    f"agent{agent_id}_action"
+                    access_str(agent_id, 'critic'),
+                    "env/done",
+                    access_str(agent_id, 'action_probs'),
+                    access_str(agent_id, 'reward', 'env/'),
+                    access_str(agent_id, 'action')
                 ]
-                reward = reward[agent_id]
                 td = gae(critic, reward, done, 0.98, 0.25)
                 td_error = td ** 2
                 critic_loss = td_error.mean()
@@ -129,13 +126,14 @@ if __name__ == '__main__':
                 optimizer.step()
 
                 # Compute the cumulated reward on final_state
-                creward = workspace["env/cumulated_reward"]#[agent_id].unsqueeze(-1)
-                print(creward.shape, done.shape)
-                creward = creward[done]
-                if creward.size()[0] > 0:
-                    cum_r = creward.mean().item()
-                    if cum_r > best:
-                    #    torch.save(a2c_agent.state_dict(), Path(__file__).parent / f'agent_{uid}.pt')
-                        best = cum_r
-                    pbar.set_description(f"Cum. r: {cum_r:.2f}, Best r. so far: {best:.2f}", refresh=True)
+                rews = ''
+                for agent_i in range(n_agents):
+                    creward = workspace['env/'+access_str(agent_i, CUMU_REWARD)]
+                    creward = creward[done]
+                    if creward.size()[0] > 0:
+                        rews += f'{AGENT_PREFIX}{agent_i}: {creward.mean().item():.2f}  |  '
+                        """if cum_r > best:
+                            torch.save(a2c_agent.state_dict(), Path(__file__).parent / f'agent_{uid}.pt')
+                            best = cum_r"""
+                        pbar.set_description(rews, refresh=True)