From a9a4274370046e2e6cc80d4a2b19fb9a928a3693 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Robert=20M=C3=BCller?= Date: Thu, 3 Feb 2022 13:14:48 +0100 Subject: [PATCH] add more efficient (lazy) experience queue implementation based on tensor, adjusted marl algorithms --- algorithms/marl/base_ac.py | 104 +++++++++------ algorithms/marl/iac.py | 6 +- algorithms/marl/mappo.py | 51 +++----- algorithms/marl/memory.py | 223 +++++++++++++++++++++------------ algorithms/marl/seac.py | 13 +- algorithms/marl/snac.py | 3 +- studies/normalization_study.py | 2 +- studies/playground_file.py | 6 +- 8 files changed, 243 insertions(+), 165 deletions(-) diff --git a/algorithms/marl/base_ac.py b/algorithms/marl/base_ac.py index 99d6591..882cb4a 100644 --- a/algorithms/marl/base_ac.py +++ b/algorithms/marl/base_ac.py @@ -1,6 +1,5 @@ import torch from typing import Union, List -import copy import numpy as np from torch.distributions import Categorical from algorithms.marl.memory import MARLActorCriticMemory @@ -8,6 +7,28 @@ from algorithms.utils import add_env_props, instantiate_class from pathlib import Path import pandas as pd from collections import deque + + +class Names: + REWARD = 'reward' + DONE = 'done' + ACTION = 'action' + OBSERVATION = 'observation' + LOGITS = 'logits' + HIDDEN_ACTOR = 'hidden_actor' + HIDDEN_CRITIC = 'hidden_critic' + AGENT = 'agent' + ENV = 'env' + N_AGENTS = 'n_agents' + ALGORITHM = 'algorithm' + MAX_STEPS = 'max_steps' + N_STEPS = 'n_steps' + BUFFER_SIZE = 'buffer_size' + CRITIC = 'critic' + BATCH_SIZE = 'bnatch_size' + N_ACTIONS = 'n_actions' + +nms = Names ListOrTensor = Union[List, torch.Tensor] @@ -16,11 +37,12 @@ class BaseActorCritic: add_env_props(cfg) self.__training = True self.cfg = cfg - self.n_agents = cfg['env']['n_agents'] + self.n_agents = cfg[nms.ENV][nms.N_AGENTS] + self.reset_memory_after_epoch = True self.setup() def setup(self): - self.net = instantiate_class(self.cfg['agent']) + self.net = instantiate_class(self.cfg[nms.AGENT]) self.optimizer = torch.optim.RMSprop(self.net.parameters(), lr=3e-4, eps=1e-5) @classmethod @@ -49,7 +71,7 @@ class BaseActorCritic: pass def get_actions(self, out) -> ListOrTensor: - actions = [Categorical(logits=logits).sample().item() for logits in out['logits']] + actions = [Categorical(logits=logits).sample().item() for logits in out[nms.LOGITS]] return actions def init_hidden(self) -> dict[ListOrTensor]: @@ -63,47 +85,48 @@ class BaseActorCritic: ) -> dict[ListOrTensor]: pass - @torch.no_grad() def train_loop(self, checkpointer=None): - env = instantiate_class(self.cfg['env']) - n_steps, max_steps = [self.cfg['algorithm'][k] for k in ['n_steps', 'max_steps']] - global_steps, episode, df_results = 0, 0, [] + env = instantiate_class(self.cfg[nms.ENV]) + n_steps, max_steps = [self.cfg[nms.ALGORITHM][k] for k in [nms.N_STEPS, nms.MAX_STEPS]] + tm = MARLActorCriticMemory(self.n_agents, self.cfg[nms.ALGORITHM].get(nms.BUFFER_SIZE, n_steps)) + global_steps, episode, df_results = 0, 0, [] reward_queue = deque(maxlen=2000) - memory_queue = deque(maxlen=self.cfg['algorithm'].get('keep_n_segments', 1)) + while global_steps < max_steps: - tm = MARLActorCriticMemory(self.n_agents) obs = env.reset() last_hiddens = self.init_hidden() last_action, reward = [-1] * self.n_agents, [0.] * self.n_agents - done, rew_log = [False] * self.n_agents, 0 - tm.add(action=last_action, **last_hiddens) + done, rew_log = [False] * self.n_agents, 0 + + if self.reset_memory_after_epoch: + tm.reset() + + tm.add(observation=obs, action=last_action, + logits=torch.zeros(self.n_agents, 1, self.cfg[nms.AGENT][nms.N_ACTIONS]), + values=torch.zeros(self.n_agents, 1), reward=reward, done=done, **last_hiddens) while not all(done): - out = self.forward(obs, last_action, **last_hiddens) action = self.get_actions(out) next_obs, reward, done, info = env.step(action) - next_obs = next_obs - if isinstance(done, bool): done = [done] * self.n_agents + done = [done] * self.n_agents if isinstance(done, bool) else done + + last_hiddens = dict(hidden_actor =out[nms.HIDDEN_ACTOR], + hidden_critic=out[nms.HIDDEN_CRITIC]) + tm.add(observation=obs, action=action, reward=reward, done=done, - logits=out.get('logits', None), values=out.get('critic', None)) + logits=out.get(nms.LOGITS, None), values=out.get(nms.CRITIC, None), + **last_hiddens) + obs = next_obs last_action = action - last_hiddens = dict(hidden_actor=out.get('hidden_actor', None), - hidden_critic=out.get('hidden_critic', None) - ) - if len(tm) >= n_steps or all(done): - tm.add(observation=next_obs) - memory_queue.append(copy.deepcopy(tm)) - if self.__training: - with torch.inference_mode(False): - tm_ = tm if memory_queue.maxlen <= 1 else list(memory_queue) - self.learn(tm_) - tm.reset() - tm.add(action=last_action, **last_hiddens) + if (global_steps+1) % n_steps == 0 or all(done): + with torch.inference_mode(False): + self.learn(tm) + global_steps += 1 rew_log += sum(reward) reward_queue.extend(reward) @@ -114,18 +137,19 @@ class BaseActorCritic: for i, agent in enumerate([self.net] if not isinstance(self.net, List) else self.net) ]) - if global_steps >= max_steps: break - print(f'reward at step: {episode} = {rew_log}') + if global_steps >= max_steps: + break + print(f'reward at episode: {episode} = {rew_log}') episode += 1 - df_results.append([global_steps, rew_log]) - df_results = pd.DataFrame(df_results, columns=['steps', 'reward']) + df_results.append([episode, rew_log, *reward]) + df_results = pd.DataFrame(df_results, columns=['steps', 'reward', *[f'agent#{i}' for i in range(self.n_agents)]]) if checkpointer is not None: df_results.to_csv(checkpointer.path / 'results.csv', index=False) return df_results @torch.inference_mode(True) def eval_loop(self, n_episodes, render=False): - env = instantiate_class(self.cfg['env']) + env = instantiate_class(self.cfg[nms.ENV]) episode, results = 0, [] while episode < n_episodes: obs = env.reset() @@ -142,8 +166,8 @@ class BaseActorCritic: if isinstance(done, bool): done = [done] * obs.shape[0] obs = next_obs last_action = action - last_hiddens = dict(hidden_actor=out.get('hidden_actor', None), - hidden_critic=out.get('hidden_critic', None) + last_hiddens = dict(hidden_actor=out.get(nms.HIDDEN_ACTOR, None), + hidden_critic=out.get(nms.HIDDEN_CRITIC, None) ) eps_rew += torch.tensor(reward) results.append(eps_rew.tolist() + [sum(eps_rew).item()] + [episode]) @@ -169,11 +193,11 @@ class BaseActorCritic: return gaes def actor_critic(self, tm, network, gamma, entropy_coef, vf_coef, gae_coef=0.0, **kwargs): - obs, actions, done, reward = tm.observation, tm.action, tm.done, tm.reward + obs, actions, done, reward = tm.observation, tm.action, tm.done[:, 1:], tm.reward[:, 1:] - out = network(obs, actions, tm.hidden_actor, tm.hidden_critic) - logits = out['logits'][:, :-1] # last one only needed for v_{t+1} - critic = out['critic'] + out = network(obs, actions, tm.hidden_actor[:, 0], tm.hidden_critic[:, 0]) + logits = out[nms.LOGITS][:, :-1] # last one only needed for v_{t+1} + critic = out[nms.CRITIC] entropy_loss = Categorical(logits=logits).entropy().mean(-1) advantages = self.compute_advantages(critic, reward, done, gamma, gae_coef) @@ -188,7 +212,7 @@ class BaseActorCritic: return loss.mean() def learn(self, tm: MARLActorCriticMemory, **kwargs): - loss = self.actor_critic(tm, self.net, **self.cfg['algorithm'], **kwargs) + loss = self.actor_critic(tm, self.net, **self.cfg[nms.ALGORITHM], **kwargs) # remove next_obs, will be added in next iter self.optimizer.zero_grad() loss.backward() diff --git a/algorithms/marl/iac.py b/algorithms/marl/iac.py index 7d0c640..95a4923 100644 --- a/algorithms/marl/iac.py +++ b/algorithms/marl/iac.py @@ -1,5 +1,5 @@ import torch -from algorithms.marl.base_ac import BaseActorCritic +from algorithms.marl.base_ac import BaseActorCritic, nms from algorithms.utils import instantiate_class from pathlib import Path from natsort import natsorted @@ -13,7 +13,7 @@ class LoopIAC(BaseActorCritic): def setup(self): self.net = [ - instantiate_class(self.cfg['agent']) for _ in range(self.n_agents) + instantiate_class(self.cfg[nms.AGENT]) for _ in range(self.n_agents) ] self.optimizer = [ torch.optim.RMSprop(self.net[ag_i].parameters(), lr=3e-4, eps=1e-5) for ag_i in range(self.n_agents) @@ -50,7 +50,7 @@ class LoopIAC(BaseActorCritic): def learn(self, tms: MARLActorCriticMemory, **kwargs): for ag_i in range(self.n_agents): tm, net = tms(ag_i), self.net[ag_i] - loss = self.actor_critic(tm, net, **self.cfg['algorithm'], **kwargs) + loss = self.actor_critic(tm, net, **self.cfg[nms.ALGORITHM], **kwargs) self.optimizer[ag_i].zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(net.parameters(), 0.5) diff --git a/algorithms/marl/mappo.py b/algorithms/marl/mappo.py index 6719c47..9d339bc 100644 --- a/algorithms/marl/mappo.py +++ b/algorithms/marl/mappo.py @@ -1,39 +1,28 @@ +from algorithms.marl.base_ac import Names as nms from algorithms.marl import LoopSNAC from algorithms.marl.memory import MARLActorCriticMemory -from typing import List import random import torch from torch.distributions import Categorical +from algorithms.utils import instantiate_class class LoopMAPPO(LoopSNAC): def __init__(self, *args, **kwargs): super(LoopMAPPO, self).__init__(*args, **kwargs) + self.reset_memory_after_epoch = False - def build_batch(self, tm: List[MARLActorCriticMemory]): - sample = random.choices(tm, k=self.cfg['algorithm']['batch_size']-1) - sample.append(tm[-1]) # always use latest segment in batch + def setup(self): + self.net = instantiate_class(self.cfg[nms.AGENT]) + self.optimizer = torch.optim.Adam(self.net.parameters(), lr=3e-4, eps=1e-5) - obs = torch.cat([s.observation for s in sample], 0) - actions = torch.cat([s.action for s in sample], 0) - hidden_actor = torch.cat([s.hidden_actor for s in sample], 0) - hidden_critic = torch.cat([s.hidden_critic for s in sample], 0) - logits = torch.cat([s.logits for s in sample], 0) - values = torch.cat([s.values for s in sample], 0) - reward = torch.cat([s.reward for s in sample], 0) - done = torch.cat([s.done for s in sample], 0) - - - log_props = torch.log_softmax(logits, -1) - log_props = torch.gather(log_props, index=actions[:, 1:].unsqueeze(-1), dim=-1).squeeze() - - return obs, actions, hidden_actor, hidden_critic, log_props, values, reward, done - - def learn(self, tm: List[MARLActorCriticMemory], **kwargs): - if len(tm) >= self.cfg['algorithm']['keep_n_segments']: + def learn(self, tm: MARLActorCriticMemory, **kwargs): + if len(tm) >= self.cfg['algorithm']['buffer_size']: # only learn when buffer is full for batch_i in range(self.cfg['algorithm']['n_updates']): - loss = self.actor_critic(tm, self.net, **self.cfg['algorithm'], **kwargs) + batch = tm.chunk_dataloader(chunk_len=self.cfg['algorithm']['n_steps'], + k=self.cfg['algorithm']['batch_size']) + loss = self.mappo(batch, self.net, **self.cfg[nms.ALGORITHM], **kwargs) self.optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(self.net.parameters(), 0.5) @@ -48,21 +37,21 @@ class LoopMAPPO(LoopSNAC): rewards_ = torch.stack(rewards_, dim=1) return rewards_ - def actor_critic(self, tm, network, gamma, entropy_coef, vf_coef, clip_range, gae_coef=0.0, **kwargs): - obs, actions, hidden_actor, hidden_critic, old_log_probs, old_critic, reward, done = self.build_batch(tm) + def mappo(self, batch, network, gamma, entropy_coef, vf_coef, clip_range, **kwargs): + out = network(batch[nms.OBSERVATION], batch[nms.ACTION], batch[nms.HIDDEN_ACTOR], batch[nms.HIDDEN_CRITIC]) + logits = out[nms.LOGITS][:, :-1] # last one only needed for v_{t+1} - out = network(obs, actions, hidden_actor, hidden_critic) - logits = out['logits'][:, :-1] # last one only needed for v_{t+1} - critic = out['critic'] + old_log_probs = torch.log_softmax(batch[nms.LOGITS], -1) + old_log_probs = torch.gather(old_log_probs, index=batch[nms.ACTION][:, 1:].unsqueeze(-1), dim=-1).squeeze() # monte carlo returns - mc_returns = self.monte_carlo_returns(reward, done, gamma) - # monte_carlo_returns = (mc_returns - mc_returns.mean()) / (mc_returns.std() + 1e-7) todo: norm across agents? - advantages = mc_returns - critic[:, :-1] + mc_returns = self.monte_carlo_returns(batch[nms.REWARD], batch[nms.DONE], gamma) + mc_returns = (mc_returns - mc_returns.mean()) / (mc_returns.std() + 1e-8) #todo: norm across agents ok? + advantages = mc_returns - out[nms.CRITIC][:, :-1] # policy loss log_ap = torch.log_softmax(logits, -1) - log_ap = torch.gather(log_ap, dim=-1, index=actions[:, 1:].unsqueeze(-1)).squeeze() + log_ap = torch.gather(log_ap, dim=-1, index=batch[nms.ACTION][:, 1:].unsqueeze(-1)).squeeze() ratio = (log_ap - old_log_probs).exp() surr1 = ratio * advantages.detach() surr2 = torch.clamp(ratio, 1 - clip_range, 1 + clip_range) * advantages.detach() diff --git a/algorithms/marl/memory.py b/algorithms/marl/memory.py index 71dcad7..ab908c2 100644 --- a/algorithms/marl/memory.py +++ b/algorithms/marl/memory.py @@ -1,89 +1,93 @@ -import torch -from typing import Union, List -from torch import Tensor import numpy as np +from collections import deque +import torch +from typing import Union +from torch import Tensor +from torch.utils.data import Dataset, ConcatDataset +import random class ActorCriticMemory(object): - def __init__(self): + def __init__(self, capacity=10): + self.capacity = capacity self.reset() def reset(self): - self.__states = [] - self.__actions = [] - self.__rewards = [] - self.__dones = [] - self.__hiddens_actor = [] - self.__hiddens_critic = [] - self.__logits = [] - self.__values = [] + self.__actions = LazyTensorFiFoQueue(maxlen=self.capacity+1) + self.__hidden_actor = LazyTensorFiFoQueue(maxlen=self.capacity+1) + self.__hidden_critic = LazyTensorFiFoQueue(maxlen=self.capacity+1) + self.__states = LazyTensorFiFoQueue(maxlen=self.capacity+1) + self.__rewards = LazyTensorFiFoQueue(maxlen=self.capacity+1) + self.__dones = LazyTensorFiFoQueue(maxlen=self.capacity+1) + self.__logits = LazyTensorFiFoQueue(maxlen=self.capacity+1) + self.__values = LazyTensorFiFoQueue(maxlen=self.capacity+1) def __len__(self): - return len(self.__states) + return len(self.__rewards) - 1 @property - def observation(self): # add time dimension through stacking - return torch.stack(self.__states, 0).unsqueeze(0) # 1 x timesteps x hidden dim + def observation(self, sls=slice(0, None)): # add time dimension through stacking + return self.__states[sls].unsqueeze(0) # 1 x time x hidden dim @property - def hidden_actor(self): - if len(self.__hiddens_actor) == 1: - return self.__hiddens_actor[0] - return torch.stack(self.__hiddens_actor, 0) # layers x timesteps x hidden dim + def hidden_actor(self, sls=slice(0, None)): # 1 x n_layers x dim + return self.__hidden_actor[sls].unsqueeze(0) # 1 x time x n_layers x dim @property - def hidden_critic(self): - if len(self.__hiddens_critic) == 1: - return self.__hiddens_critic[0] - return torch.stack(self.__hiddens_critic, 0) # layers x timesteps x hidden dim + def hidden_critic(self, sls=slice(0, None)): # 1 x n_layers x dim + return self.__hidden_critic[sls].unsqueeze(0) # 1 x time x n_layers x dim @property - def reward(self): - return torch.tensor(self.__rewards).float().unsqueeze(0) # 1 x timesteps + def reward(self, sls=slice(0, None)): + return self.__rewards[sls].squeeze().unsqueeze(0) # 1 x time @property - def action(self): - return torch.tensor(self.__actions).long().unsqueeze(0) # 1 x timesteps+1 + def action(self, sls=slice(0, None)): + return self.__actions[sls].long().squeeze().unsqueeze(0) # 1 x time @property - def done(self): - return torch.tensor(self.__dones).float().unsqueeze(0) # 1 x timesteps + def done(self, sls=slice(0, None)): + return self.__dones[sls].float().squeeze().unsqueeze(0) # 1 x time @property - def logits(self): # assumes a trailing 1 for time dimension - common when using output from NN - return torch.cat(self.__logits, 0).unsqueeze(0) # 1 x timesteps x actions + def logits(self, sls=slice(0, None)): # assumes a trailing 1 for time dimension - common when using output from NN + return self.__logits[sls].squeeze().unsqueeze(0) # 1 x time x actions @property - def values(self): - return torch.cat(self.__values, 0).unsqueeze(0) # 1 x timesteps x actions + def values(self, sls=slice(0, None)): + return self.__values[sls].squeeze().unsqueeze(0) # 1 x time x actions def add_observation(self, state: Union[Tensor, np.ndarray]): self.__states.append(state if isinstance(state, Tensor) else torch.from_numpy(state)) def add_hidden_actor(self, hidden: Tensor): - # 1x layers x hidden dim - if len(hidden.shape) < 3: hidden = hidden.unsqueeze(0) - self.__hiddens_actor.append(hidden) + # layers x hidden dim + self.__hidden_actor.append(hidden) def add_hidden_critic(self, hidden: Tensor): - # 1x layers x hidden dim - if len(hidden.shape) < 3: hidden = hidden.unsqueeze(0) - self.__hiddens_critic.append(hidden) + # layers x hidden dim + self.__hidden_critic.append(hidden) - def add_action(self, action: int): + def add_action(self, action: Union[int, Tensor]): + if not isinstance(action, Tensor): + action = torch.tensor(action) self.__actions.append(action) - def add_reward(self, reward: float): + def add_reward(self, reward: Union[float, Tensor]): + if not isinstance(reward, Tensor): + reward = torch.tensor(reward) self.__rewards.append(reward) def add_done(self, done: bool): + if not isinstance(done, Tensor): + done = torch.tensor(done) self.__dones.append(done) def add_logits(self, logits: Tensor): self.__logits.append(logits) - def add_values(self, logits: Tensor): - self.__values.append(logits) + def add_values(self, values: Tensor): + self.__values.append(values) def add(self, **kwargs): for k, v in kwargs.items(): @@ -92,10 +96,10 @@ class ActorCriticMemory(object): class MARLActorCriticMemory(object): - def __init__(self, n_agents): + def __init__(self, n_agents, capacity): self.n_agents = n_agents self.memories = [ - ActorCriticMemory() for _ in range(n_agents) + ActorCriticMemory(capacity) for _ in range(n_agents) ] def __call__(self, agent_i): @@ -109,50 +113,109 @@ class MARLActorCriticMemory(object): mem.reset() def add(self, **kwargs): - # todo try catch - print all possible functions for agent_i in range(self.n_agents): for k, v in kwargs.items(): func = getattr(ActorCriticMemory, f'add_{k}') func(self.memories[agent_i], v[agent_i]) - @property - def observation(self): - all_obs = [mem.observation for mem in self.memories] - return torch.cat(all_obs, 0) # agents x timesteps+1 x ... + def __getattr__(self, attr): + all_attrs = [getattr(mem, attr) for mem in self.memories] + return torch.cat(all_attrs, 0) # agents x time ... + + def chunk_dataloader(self, chunk_len, k): + datasets = [ExperienceChunks(mem, chunk_len, k) for mem in self.memories] + dataset = ConcatDataset(datasets) + data = [dataset[i] for i in range(len(dataset))] + data = custom_collate_fn(data) + return data + + +def custom_collate_fn(batch): + elem = batch[0] + return {key: torch.cat([d[key] for d in batch], dim=0) for key in elem} + + +class ExperienceChunks(Dataset): + def __init__(self, memory, chunk_len, k): + assert chunk_len <= len(memory), 'chunk_len cannot be longer than the size of the memory' + self.memory = memory + self.chunk_len = chunk_len + self.k = k @property - def action(self): - all_actions = [mem.action for mem in self.memories] - return torch.cat(all_actions, 0) # agents x timesteps+1 x ... + def whitelist(self): + whitelist = torch.ones(len(self.memory) - self.chunk_len) + for d in self.memory.done.squeeze().nonzero().flatten(): + whitelist[max((0, d-self.chunk_len-1)):d+2] = 0 + whitelist[0] = 0 + return whitelist.tolist() - @property - def done(self): - all_dones = [mem.done for mem in self.memories] - return torch.cat(all_dones, 0).float() # agents x timesteps x ... + def sample(self, start=1): + cl = self.chunk_len + sample = dict(observation=self.memory.observation[:, start:start+cl+1], + action=self.memory.action[:, start-1:start+cl], + hidden_actor=self.memory.hidden_actor[:, start-1], + hidden_critic=self.memory.hidden_critic[:, start-1], + reward=self.memory.reward[:, start:start + cl], + done=self.memory.done[:, start:start + cl], + logits=self.memory.logits[:, start:start + cl], + values=self.memory.values[:, start:start + cl]) + return sample + + def __len__(self): + return self.k + + def __getitem__(self, i): + idx = random.choices(range(0, len(self.memory) - self.chunk_len), weights=self.whitelist, k=1) + return self.sample(idx[0]) + + +class LazyTensorFiFoQueue: + def __init__(self, maxlen): + self.maxlen = maxlen + self.reset() + + def reset(self): + self.__lazy_queue = deque(maxlen=self.maxlen) + self.shape = None + self.queue = None + + def shape_init(self, tensor: Tensor): + self.shape = torch.Size([self.maxlen, *tensor.shape]) + + def build_tensor_queue(self): + if len(self.__lazy_queue) > 0: + block = torch.stack(list(self.__lazy_queue), dim=0) + l = block.shape[0] + if self.queue is None: + self.queue = block + elif self.true_len() <= self.maxlen: + self.queue = torch.cat((self.queue, block), dim=0) + else: + self.queue = torch.cat((self.queue[l:], block), dim=0) + self.__lazy_queue.clear() + + def append(self, data): + if self.shape is None: + self.shape_init(data) + self.__lazy_queue.append(data) + if len(self.__lazy_queue) >= self.maxlen: + self.build_tensor_queue() + + def true_len(self): + return len(self.__lazy_queue) + (0 if self.queue is None else self.queue.shape[0]) + + def __len__(self): + return min((self.true_len(), self.maxlen)) + + def __str__(self): + return f'LazyTensorFiFoQueue\tmaxlen: {self.maxlen}, shape: {self.shape}, ' \ + f'len: {len(self)}, true_len: {self.true_len()}, elements in lazy queue: {len(self.__lazy_queue)}' + + def __getitem__(self, item_or_slice): + self.build_tensor_queue() + return self.queue[item_or_slice] - @property - def reward(self): - all_rewards = [mem.reward for mem in self.memories] - return torch.cat(all_rewards, 0).float() # agents x timesteps x ... - @property - def hidden_actor(self): - all_ha = [mem.hidden_actor for mem in self.memories] - return torch.cat(all_ha, 0) # agents x layers x x timesteps x hidden dim - - @property - def hidden_critic(self): - all_hc = [mem.hidden_critic for mem in self.memories] - return torch.cat(all_hc, 0) # agents x layers x timesteps x hidden dim - - @property - def logits(self): - all_lgts = [mem.logits for mem in self.memories] - return torch.cat(all_lgts, 0) # agents x layers x timesteps x hidden dim - - @property - def values(self): - all_vals = [mem.values for mem in self.memories] - return torch.cat(all_vals, 0) # agents x layers x timesteps x hidden dim diff --git a/algorithms/marl/seac.py b/algorithms/marl/seac.py index 5b33b0a..be572a5 100644 --- a/algorithms/marl/seac.py +++ b/algorithms/marl/seac.py @@ -1,6 +1,7 @@ import torch from torch.distributions import Categorical from algorithms.marl.iac import LoopIAC +from algorithms.marl.base_ac import nms from algorithms.marl.memory import MARLActorCriticMemory @@ -9,12 +10,12 @@ class LoopSEAC(LoopIAC): super(LoopSEAC, self).__init__(cfg) def actor_critic(self, tm, networks, gamma, entropy_coef, vf_coef, gae_coef=0.0, **kwargs): - obs, actions, done, reward = tm.observation, tm.action, tm.done, tm.reward - outputs = [net(obs, actions, tm.hidden_actor, tm.hidden_critic) for net in networks] + obs, actions, done, reward = tm.observation, tm.action, tm.done[:, 1:], tm.reward[:, 1:] + outputs = [net(obs, actions, tm.hidden_actor[:, 0], tm.hidden_critic[:, 0]) for net in networks] with torch.inference_mode(True): true_action_logp = torch.stack([ - torch.log_softmax(out['logits'][ag_i, :-1], -1) + torch.log_softmax(out[nms.LOGITS][ag_i, :-1], -1) .gather(index=actions[ag_i, 1:, None], dim=-1) for ag_i, out in enumerate(outputs) ], 0).squeeze() @@ -22,8 +23,8 @@ class LoopSEAC(LoopIAC): losses = [] for ag_i, out in enumerate(outputs): - logits = out['logits'][:, :-1] # last one only needed for v_{t+1} - critic = out['critic'] + logits = out[nms.LOGITS][:, :-1] # last one only needed for v_{t+1} + critic = out[nms.CRITIC] entropy_loss = Categorical(logits=logits[ag_i]).entropy().mean() advantages = self.compute_advantages(critic, reward, done, gamma, gae_coef) @@ -47,7 +48,7 @@ class LoopSEAC(LoopIAC): return losses def learn(self, tms: MARLActorCriticMemory, **kwargs): - losses = self.actor_critic(tms, self.net, **self.cfg['algorithm'], **kwargs) + losses = self.actor_critic(tms, self.net, **self.cfg[nms.ALGORITHM], **kwargs) for ag_i, loss in enumerate(losses): self.optimizer[ag_i].zero_grad() loss.backward() diff --git a/algorithms/marl/snac.py b/algorithms/marl/snac.py index 3d312a8..f0ff01d 100644 --- a/algorithms/marl/snac.py +++ b/algorithms/marl/snac.py @@ -1,4 +1,5 @@ from algorithms.marl.base_ac import BaseActorCritic +from algorithms.marl.base_ac import nms import torch from torch.distributions import Categorical from pathlib import Path @@ -21,7 +22,7 @@ class LoopSNAC(BaseActorCritic): ) def get_actions(self, out): - actions = Categorical(logits=out['logits']).sample().squeeze() + actions = Categorical(logits=out[nms.LOGITS]).sample().squeeze() return actions def forward(self, observations, actions, hidden_actor, hidden_critic): diff --git a/studies/normalization_study.py b/studies/normalization_study.py index e8e4d14..37e10c4 100644 --- a/studies/normalization_study.py +++ b/studies/normalization_study.py @@ -6,7 +6,7 @@ from algorithms.utils import load_yaml_file, add_env_props, instantiate_class, l for i in range(0, 5): - for name in ['mappo']:#['seac', 'iac', 'snac']: + for name in ['snac', 'mappo', 'iac', 'seac']: study_root = Path(__file__).parent / name cfg = load_yaml_file(study_root / f'{name}.yaml') add_env_props(cfg) diff --git a/studies/playground_file.py b/studies/playground_file.py index b58c82b..065e04e 100644 --- a/studies/playground_file.py +++ b/studies/playground_file.py @@ -3,12 +3,12 @@ from pathlib import Path import matplotlib.pyplot as plt import seaborn as sns - dfs = [] -for name in ['l2snac', 'iac', 'snac', 'seac']: +for name in ['mappo']: for c in range(5): try: study_root = Path(__file__).parent / name / f'{name}#{c}' + print(study_root) df = pd.read_csv(study_root / 'results.csv', index_col=False) df.reward = df.reward.rolling(100).mean() df['method'] = name.upper() @@ -17,6 +17,6 @@ for name in ['l2snac', 'iac', 'snac', 'seac']: pass df = pd.concat(dfs).reset_index() -sns.lineplot(data=df, x='episode', y='reward', hue='method', palette='husl', ci='sd', linewidth=1.5) +sns.lineplot(data=df, x='steps', y='reward', hue='method', palette='husl', ci='sd', linewidth=1.5, err_style='bars') plt.savefig('study.png') print('saved image') \ No newline at end of file