From c1c7909925e1aa7850039a97a4c800c7a3ea87d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Robert=20M=C3=BCller?= Date: Tue, 23 Nov 2021 17:02:35 +0100 Subject: [PATCH] added running marl a2c --- algorithms/utils.py | 91 ++++++++++++++++++++++++++------------------- studies/sat_mad.py | 54 +++++++++++++-------------- 2 files changed, 79 insertions(+), 66 deletions(-) diff --git a/algorithms/utils.py b/algorithms/utils.py index cd3fbd3..d72046a 100644 --- a/algorithms/utils.py +++ b/algorithms/utils.py @@ -5,7 +5,12 @@ import yaml from pathlib import Path from salina import instantiate_class from salina import TAgent -from salina.agents.gyma import AutoResetGymAgent, _torch_type, _format_frame +from salina.agents.gyma import ( + AutoResetGymAgent, + _torch_type, + _format_frame, + _torch_cat_dict +) def load_yaml_file(path: Path): @@ -20,42 +25,47 @@ def add_env_props(cfg): n_actions=env.action_space.n)) -class CombineActionsAgent(TAgent): - def __init__(self, pattern=r'^agent\d_action$'): - super().__init__() - self.pattern = pattern - def forward(self, t, **kwargs): - keys = list(self.workspace.keys()) - action_keys = sorted([k for k in keys if bool(re.match(self.pattern, k))]) - actions = torch.cat([self.get((k, t)) for k in action_keys], 0) - actions = actions if len(action_keys) <= 1 else actions.unsqueeze(0) - self.set((f'action', t), actions) + +AGENT_PREFIX = 'agent#' +REWARD = 'reward' +CUMU_REWARD = 'cumulated_reward' +OBS = 'env_obs' +SEP = '_' +ACTION = 'action' + + +def access_str(agent_i, name, prefix=''): + return f'{prefix}{AGENT_PREFIX}{agent_i}{SEP}{name}' class AutoResetGymMultiAgent(AutoResetGymAgent): - AGENT_PREFIX = 'agent#' - REWARD = 'reward' - CUMU_REWARD = 'cumulated_reward' - SEP = '_' - - def __init__(self, *args, n_agents, **kwargs): + def __init__(self, *args, **kwargs): super(AutoResetGymMultiAgent, self).__init__(*args, **kwargs) - self.n_agents = n_agents - def prefix(self, agent_id, name): - return f'{self.AGENT_PREFIX}{agent_id}{self.SEP}{name}' + def per_agent_values(self, name, values): + return {access_str(agent_i, name): value + for agent_i, value in zip(range(self.n_agents), values)} + + def _initialize_envs(self, n): + super()._initialize_envs(n) + n_agents_list = [self.envs[i].unwrapped.n_agents for i in range(n)] + assert all(n_agents == n_agents_list[0] for n_agents in n_agents_list), \ + 'All envs must have the same number of agents.' + self.n_agents = n_agents_list[0] def _reset(self, k, save_render): ret = super()._reset(k, save_render) + obs = ret['env_obs'].squeeze() self.cumulated_reward[k] = [0.0]*self.n_agents - del ret['cumulated_reward'] - cumu_rew = {self.prefix(agent_i, self.CUMU_REWARD): torch.zeros(1).float() - for agent_i in range(self.n_agents)} - rewards = {self.prefix(agent_i, self.REWARD) : torch.zeros(1).float() - for agent_i in range(self.n_agents)} + obs = self.per_agent_values(OBS, [_format_frame(obs[i]) for i in range(self.n_agents)]) + cumu_rew = self.per_agent_values(CUMU_REWARD, torch.zeros(self.n_agents, 1).float().unbind()) + rewards = self.per_agent_values(REWARD, torch.zeros(self.n_agents, 1).float().unbind()) ret.update(cumu_rew) ret.update(rewards) + ret.update(obs) + for remove in ['env_obs', 'cumulated_reward', 'reward']: + del ret[remove] return ret def _step(self, k, action, save_render): @@ -68,28 +78,33 @@ class AutoResetGymMultiAgent(AutoResetGymAgent): action = np.array(action.tolist()) o, r, d, _ = env.step(action) self.cumulated_reward[k] = [x+y for x, y in zip(r, self.cumulated_reward[k])] - print(o.shape) - observation = _format_frame(o) - if isinstance(observation, torch.Tensor): - print(observation.shape) - observation = {self.prefix(agent_i, 'env_obs'): observation[agent_i] - for agent_i in range(self.n_agents)} - print(observation) - else: - assert isinstance(observation, dict) + observation = self.per_agent_values(OBS, [_format_frame(o[i]) for i in range(self.n_agents)]) if d: self.is_running[k] = False - if save_render: image = env.render(mode="image").unsqueeze(0) observation["rendering"] = image + rewards = self.per_agent_values(REWARD, torch.tensor(r).float().view(-1, 1).unbind()) + cumulated_rewards = self.per_agent_values(CUMU_REWARD, torch.tensor(self.cumulated_reward[k]).float().view(-1, 1).unbind()) ret = { **observation, + **rewards, + **cumulated_rewards, "done": torch.tensor([d]), "initial_state": torch.tensor([False]), - "reward": torch.tensor(r).float(), - "timestep": torch.tensor([self.timestep[k]]), - "cumulated_reward": torch.tensor(self.cumulated_reward[k]).float(), + "timestep": torch.tensor([self.timestep[k]]) } return _torch_type(ret) + +class CombineActionsAgent(TAgent): + def __init__(self): + super().__init__() + self.pattern = fr'^{AGENT_PREFIX}\d{SEP}{ACTION}$' + + def forward(self, t, **kwargs): + keys = list(self.workspace.keys()) + action_keys = sorted([k for k in keys if bool(re.match(self.pattern, k))]) + actions = torch.cat([self.get((k, t)) for k in action_keys], 0) + actions = actions if len(action_keys) <= 1 else actions.unsqueeze(0) + self.set((f'action', t), actions) diff --git a/studies/sat_mad.py b/studies/sat_mad.py index 4380c7d..6bf3e22 100644 --- a/studies/sat_mad.py +++ b/studies/sat_mad.py @@ -13,16 +13,18 @@ from algorithms.utils import ( add_env_props, load_yaml_file, CombineActionsAgent, - AutoResetGymMultiAgent + AutoResetGymMultiAgent, + access_str, + AGENT_PREFIX, REWARD, CUMU_REWARD, OBS, SEP ) class A2CAgent(TAgent): - def __init__(self, observation_size, hidden_size, n_actions, agent_id=-1, marl=False): + def __init__(self, observation_size, hidden_size, n_actions, agent_id): super().__init__() observation_size = np.prod(observation_size) + print(observation_size) self.agent_id = agent_id - self.marl = marl self.model = nn.Sequential( nn.Flatten(), nn.Linear(observation_size, hidden_size), @@ -36,10 +38,7 @@ class A2CAgent(TAgent): self.critic_head = nn.Linear(hidden_size, 1) def get_obs(self, t): - observation = self.get(("env/env_obs", t)) - print(observation.shape) - if self.marl: - observation = observation[self.agent_id] + observation = self.get((f'env/{access_str(self.agent_id, OBS)}', t)) return observation def forward(self, t, stochastic, **kwargs): @@ -52,10 +51,9 @@ class A2CAgent(TAgent): action = torch.distributions.Categorical(probs).sample() else: action = probs.argmax(1) - agent_str = f'agent{self.agent_id}_' - self.set((f'{agent_str}action', t), action) - self.set((f'{agent_str}action_probs', t), probs) - self.set((f'{agent_str}critic', t), critic) + self.set((f'{access_str(self.agent_id, "action")}', t), action) + self.set((f'{access_str(self.agent_id, "action_probs")}', t), probs) + self.set((f'{access_str(self.agent_id, "critic")}', t), critic) if __name__ == '__main__': @@ -73,13 +71,11 @@ if __name__ == '__main__': env_agent = AutoResetGymMultiAgent( get_class(cfg['env']), get_arguments(cfg['env']), - n_envs=1, - n_agents=n_agents + n_envs=1 ) a2c_agents = [instantiate_class({**cfg['agent'], - 'agent_id': agent_id, - 'marl': n_agents > 1}) + 'agent_id': agent_id}) for agent_id in range(n_agents)] # combine agents @@ -105,11 +101,12 @@ if __name__ == '__main__': for agent_id in range(n_agents): critic, done, action_probs, reward, action = workspace[ - f"agent{agent_id}_critic", "env/done", - f'agent{agent_id}_action_probs', "env/reward", - f"agent{agent_id}_action" + access_str(agent_id, 'critic'), + "env/done", + access_str(agent_id, 'action_probs'), + access_str(agent_id, 'reward', 'env/'), + access_str(agent_id, 'action') ] - reward = reward[agent_id] td = gae(critic, reward, done, 0.98, 0.25) td_error = td ** 2 critic_loss = td_error.mean() @@ -129,13 +126,14 @@ if __name__ == '__main__': optimizer.step() # Compute the cumulated reward on final_state - creward = workspace["env/cumulated_reward"]#[agent_id].unsqueeze(-1) - print(creward.shape, done.shape) - creward = creward[done] - if creward.size()[0] > 0: - cum_r = creward.mean().item() - if cum_r > best: - # torch.save(a2c_agent.state_dict(), Path(__file__).parent / f'agent_{uid}.pt') - best = cum_r - pbar.set_description(f"Cum. r: {cum_r:.2f}, Best r. so far: {best:.2f}", refresh=True) + rews = '' + for agent_i in range(n_agents): + creward = workspace['env/'+access_str(agent_i, CUMU_REWARD)] + creward = creward[done] + if creward.size()[0] > 0: + rews += f'{AGENT_PREFIX}{agent_i}: {creward.mean().item():.2f} | ' + """if cum_r > best: + torch.save(a2c_agent.state_dict(), Path(__file__).parent / f'agent_{uid}.pt') + best = cum_r""" + pbar.set_description(rews, refresh=True)