add more efficient (lazy) experience queue implementation based on tensor, adjusted marl algorithms

This commit is contained in:
Robert Müller
2022-02-03 13:14:48 +01:00
parent b09c461754
commit a9a4274370
8 changed files with 243 additions and 165 deletions

View File

@ -1,6 +1,5 @@
import torch import torch
from typing import Union, List from typing import Union, List
import copy
import numpy as np import numpy as np
from torch.distributions import Categorical from torch.distributions import Categorical
from algorithms.marl.memory import MARLActorCriticMemory from algorithms.marl.memory import MARLActorCriticMemory
@ -8,6 +7,28 @@ from algorithms.utils import add_env_props, instantiate_class
from pathlib import Path from pathlib import Path
import pandas as pd import pandas as pd
from collections import deque from collections import deque
class Names:
REWARD = 'reward'
DONE = 'done'
ACTION = 'action'
OBSERVATION = 'observation'
LOGITS = 'logits'
HIDDEN_ACTOR = 'hidden_actor'
HIDDEN_CRITIC = 'hidden_critic'
AGENT = 'agent'
ENV = 'env'
N_AGENTS = 'n_agents'
ALGORITHM = 'algorithm'
MAX_STEPS = 'max_steps'
N_STEPS = 'n_steps'
BUFFER_SIZE = 'buffer_size'
CRITIC = 'critic'
BATCH_SIZE = 'bnatch_size'
N_ACTIONS = 'n_actions'
nms = Names
ListOrTensor = Union[List, torch.Tensor] ListOrTensor = Union[List, torch.Tensor]
@ -16,11 +37,12 @@ class BaseActorCritic:
add_env_props(cfg) add_env_props(cfg)
self.__training = True self.__training = True
self.cfg = cfg self.cfg = cfg
self.n_agents = cfg['env']['n_agents'] self.n_agents = cfg[nms.ENV][nms.N_AGENTS]
self.reset_memory_after_epoch = True
self.setup() self.setup()
def setup(self): def setup(self):
self.net = instantiate_class(self.cfg['agent']) self.net = instantiate_class(self.cfg[nms.AGENT])
self.optimizer = torch.optim.RMSprop(self.net.parameters(), lr=3e-4, eps=1e-5) self.optimizer = torch.optim.RMSprop(self.net.parameters(), lr=3e-4, eps=1e-5)
@classmethod @classmethod
@ -49,7 +71,7 @@ class BaseActorCritic:
pass pass
def get_actions(self, out) -> ListOrTensor: def get_actions(self, out) -> ListOrTensor:
actions = [Categorical(logits=logits).sample().item() for logits in out['logits']] actions = [Categorical(logits=logits).sample().item() for logits in out[nms.LOGITS]]
return actions return actions
def init_hidden(self) -> dict[ListOrTensor]: def init_hidden(self) -> dict[ListOrTensor]:
@ -63,47 +85,48 @@ class BaseActorCritic:
) -> dict[ListOrTensor]: ) -> dict[ListOrTensor]:
pass pass
@torch.no_grad() @torch.no_grad()
def train_loop(self, checkpointer=None): def train_loop(self, checkpointer=None):
env = instantiate_class(self.cfg['env']) env = instantiate_class(self.cfg[nms.ENV])
n_steps, max_steps = [self.cfg['algorithm'][k] for k in ['n_steps', 'max_steps']] n_steps, max_steps = [self.cfg[nms.ALGORITHM][k] for k in [nms.N_STEPS, nms.MAX_STEPS]]
tm = MARLActorCriticMemory(self.n_agents, self.cfg[nms.ALGORITHM].get(nms.BUFFER_SIZE, n_steps))
global_steps, episode, df_results = 0, 0, [] global_steps, episode, df_results = 0, 0, []
reward_queue = deque(maxlen=2000) reward_queue = deque(maxlen=2000)
memory_queue = deque(maxlen=self.cfg['algorithm'].get('keep_n_segments', 1))
while global_steps < max_steps: while global_steps < max_steps:
tm = MARLActorCriticMemory(self.n_agents)
obs = env.reset() obs = env.reset()
last_hiddens = self.init_hidden() last_hiddens = self.init_hidden()
last_action, reward = [-1] * self.n_agents, [0.] * self.n_agents last_action, reward = [-1] * self.n_agents, [0.] * self.n_agents
done, rew_log = [False] * self.n_agents, 0 done, rew_log = [False] * self.n_agents, 0
tm.add(action=last_action, **last_hiddens)
if self.reset_memory_after_epoch:
tm.reset()
tm.add(observation=obs, action=last_action,
logits=torch.zeros(self.n_agents, 1, self.cfg[nms.AGENT][nms.N_ACTIONS]),
values=torch.zeros(self.n_agents, 1), reward=reward, done=done, **last_hiddens)
while not all(done): while not all(done):
out = self.forward(obs, last_action, **last_hiddens) out = self.forward(obs, last_action, **last_hiddens)
action = self.get_actions(out) action = self.get_actions(out)
next_obs, reward, done, info = env.step(action) next_obs, reward, done, info = env.step(action)
next_obs = next_obs done = [done] * self.n_agents if isinstance(done, bool) else done
if isinstance(done, bool): done = [done] * self.n_agents
last_hiddens = dict(hidden_actor =out[nms.HIDDEN_ACTOR],
hidden_critic=out[nms.HIDDEN_CRITIC])
tm.add(observation=obs, action=action, reward=reward, done=done, tm.add(observation=obs, action=action, reward=reward, done=done,
logits=out.get('logits', None), values=out.get('critic', None)) logits=out.get(nms.LOGITS, None), values=out.get(nms.CRITIC, None),
**last_hiddens)
obs = next_obs obs = next_obs
last_action = action last_action = action
last_hiddens = dict(hidden_actor=out.get('hidden_actor', None),
hidden_critic=out.get('hidden_critic', None)
)
if len(tm) >= n_steps or all(done): if (global_steps+1) % n_steps == 0 or all(done):
tm.add(observation=next_obs)
memory_queue.append(copy.deepcopy(tm))
if self.__training:
with torch.inference_mode(False): with torch.inference_mode(False):
tm_ = tm if memory_queue.maxlen <= 1 else list(memory_queue) self.learn(tm)
self.learn(tm_)
tm.reset()
tm.add(action=last_action, **last_hiddens)
global_steps += 1 global_steps += 1
rew_log += sum(reward) rew_log += sum(reward)
reward_queue.extend(reward) reward_queue.extend(reward)
@ -114,18 +137,19 @@ class BaseActorCritic:
for i, agent in enumerate([self.net] if not isinstance(self.net, List) else self.net) for i, agent in enumerate([self.net] if not isinstance(self.net, List) else self.net)
]) ])
if global_steps >= max_steps: break if global_steps >= max_steps:
print(f'reward at step: {episode} = {rew_log}') break
print(f'reward at episode: {episode} = {rew_log}')
episode += 1 episode += 1
df_results.append([global_steps, rew_log]) df_results.append([episode, rew_log, *reward])
df_results = pd.DataFrame(df_results, columns=['steps', 'reward']) df_results = pd.DataFrame(df_results, columns=['steps', 'reward', *[f'agent#{i}' for i in range(self.n_agents)]])
if checkpointer is not None: if checkpointer is not None:
df_results.to_csv(checkpointer.path / 'results.csv', index=False) df_results.to_csv(checkpointer.path / 'results.csv', index=False)
return df_results return df_results
@torch.inference_mode(True) @torch.inference_mode(True)
def eval_loop(self, n_episodes, render=False): def eval_loop(self, n_episodes, render=False):
env = instantiate_class(self.cfg['env']) env = instantiate_class(self.cfg[nms.ENV])
episode, results = 0, [] episode, results = 0, []
while episode < n_episodes: while episode < n_episodes:
obs = env.reset() obs = env.reset()
@ -142,8 +166,8 @@ class BaseActorCritic:
if isinstance(done, bool): done = [done] * obs.shape[0] if isinstance(done, bool): done = [done] * obs.shape[0]
obs = next_obs obs = next_obs
last_action = action last_action = action
last_hiddens = dict(hidden_actor=out.get('hidden_actor', None), last_hiddens = dict(hidden_actor=out.get(nms.HIDDEN_ACTOR, None),
hidden_critic=out.get('hidden_critic', None) hidden_critic=out.get(nms.HIDDEN_CRITIC, None)
) )
eps_rew += torch.tensor(reward) eps_rew += torch.tensor(reward)
results.append(eps_rew.tolist() + [sum(eps_rew).item()] + [episode]) results.append(eps_rew.tolist() + [sum(eps_rew).item()] + [episode])
@ -169,11 +193,11 @@ class BaseActorCritic:
return gaes return gaes
def actor_critic(self, tm, network, gamma, entropy_coef, vf_coef, gae_coef=0.0, **kwargs): def actor_critic(self, tm, network, gamma, entropy_coef, vf_coef, gae_coef=0.0, **kwargs):
obs, actions, done, reward = tm.observation, tm.action, tm.done, tm.reward obs, actions, done, reward = tm.observation, tm.action, tm.done[:, 1:], tm.reward[:, 1:]
out = network(obs, actions, tm.hidden_actor, tm.hidden_critic) out = network(obs, actions, tm.hidden_actor[:, 0], tm.hidden_critic[:, 0])
logits = out['logits'][:, :-1] # last one only needed for v_{t+1} logits = out[nms.LOGITS][:, :-1] # last one only needed for v_{t+1}
critic = out['critic'] critic = out[nms.CRITIC]
entropy_loss = Categorical(logits=logits).entropy().mean(-1) entropy_loss = Categorical(logits=logits).entropy().mean(-1)
advantages = self.compute_advantages(critic, reward, done, gamma, gae_coef) advantages = self.compute_advantages(critic, reward, done, gamma, gae_coef)
@ -188,7 +212,7 @@ class BaseActorCritic:
return loss.mean() return loss.mean()
def learn(self, tm: MARLActorCriticMemory, **kwargs): def learn(self, tm: MARLActorCriticMemory, **kwargs):
loss = self.actor_critic(tm, self.net, **self.cfg['algorithm'], **kwargs) loss = self.actor_critic(tm, self.net, **self.cfg[nms.ALGORITHM], **kwargs)
# remove next_obs, will be added in next iter # remove next_obs, will be added in next iter
self.optimizer.zero_grad() self.optimizer.zero_grad()
loss.backward() loss.backward()

View File

@ -1,5 +1,5 @@
import torch import torch
from algorithms.marl.base_ac import BaseActorCritic from algorithms.marl.base_ac import BaseActorCritic, nms
from algorithms.utils import instantiate_class from algorithms.utils import instantiate_class
from pathlib import Path from pathlib import Path
from natsort import natsorted from natsort import natsorted
@ -13,7 +13,7 @@ class LoopIAC(BaseActorCritic):
def setup(self): def setup(self):
self.net = [ self.net = [
instantiate_class(self.cfg['agent']) for _ in range(self.n_agents) instantiate_class(self.cfg[nms.AGENT]) for _ in range(self.n_agents)
] ]
self.optimizer = [ self.optimizer = [
torch.optim.RMSprop(self.net[ag_i].parameters(), lr=3e-4, eps=1e-5) for ag_i in range(self.n_agents) torch.optim.RMSprop(self.net[ag_i].parameters(), lr=3e-4, eps=1e-5) for ag_i in range(self.n_agents)
@ -50,7 +50,7 @@ class LoopIAC(BaseActorCritic):
def learn(self, tms: MARLActorCriticMemory, **kwargs): def learn(self, tms: MARLActorCriticMemory, **kwargs):
for ag_i in range(self.n_agents): for ag_i in range(self.n_agents):
tm, net = tms(ag_i), self.net[ag_i] tm, net = tms(ag_i), self.net[ag_i]
loss = self.actor_critic(tm, net, **self.cfg['algorithm'], **kwargs) loss = self.actor_critic(tm, net, **self.cfg[nms.ALGORITHM], **kwargs)
self.optimizer[ag_i].zero_grad() self.optimizer[ag_i].zero_grad()
loss.backward() loss.backward()
torch.nn.utils.clip_grad_norm_(net.parameters(), 0.5) torch.nn.utils.clip_grad_norm_(net.parameters(), 0.5)

View File

@ -1,39 +1,28 @@
from algorithms.marl.base_ac import Names as nms
from algorithms.marl import LoopSNAC from algorithms.marl import LoopSNAC
from algorithms.marl.memory import MARLActorCriticMemory from algorithms.marl.memory import MARLActorCriticMemory
from typing import List
import random import random
import torch import torch
from torch.distributions import Categorical from torch.distributions import Categorical
from algorithms.utils import instantiate_class
class LoopMAPPO(LoopSNAC): class LoopMAPPO(LoopSNAC):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(LoopMAPPO, self).__init__(*args, **kwargs) super(LoopMAPPO, self).__init__(*args, **kwargs)
self.reset_memory_after_epoch = False
def build_batch(self, tm: List[MARLActorCriticMemory]): def setup(self):
sample = random.choices(tm, k=self.cfg['algorithm']['batch_size']-1) self.net = instantiate_class(self.cfg[nms.AGENT])
sample.append(tm[-1]) # always use latest segment in batch self.optimizer = torch.optim.Adam(self.net.parameters(), lr=3e-4, eps=1e-5)
obs = torch.cat([s.observation for s in sample], 0) def learn(self, tm: MARLActorCriticMemory, **kwargs):
actions = torch.cat([s.action for s in sample], 0) if len(tm) >= self.cfg['algorithm']['buffer_size']:
hidden_actor = torch.cat([s.hidden_actor for s in sample], 0)
hidden_critic = torch.cat([s.hidden_critic for s in sample], 0)
logits = torch.cat([s.logits for s in sample], 0)
values = torch.cat([s.values for s in sample], 0)
reward = torch.cat([s.reward for s in sample], 0)
done = torch.cat([s.done for s in sample], 0)
log_props = torch.log_softmax(logits, -1)
log_props = torch.gather(log_props, index=actions[:, 1:].unsqueeze(-1), dim=-1).squeeze()
return obs, actions, hidden_actor, hidden_critic, log_props, values, reward, done
def learn(self, tm: List[MARLActorCriticMemory], **kwargs):
if len(tm) >= self.cfg['algorithm']['keep_n_segments']:
# only learn when buffer is full # only learn when buffer is full
for batch_i in range(self.cfg['algorithm']['n_updates']): for batch_i in range(self.cfg['algorithm']['n_updates']):
loss = self.actor_critic(tm, self.net, **self.cfg['algorithm'], **kwargs) batch = tm.chunk_dataloader(chunk_len=self.cfg['algorithm']['n_steps'],
k=self.cfg['algorithm']['batch_size'])
loss = self.mappo(batch, self.net, **self.cfg[nms.ALGORITHM], **kwargs)
self.optimizer.zero_grad() self.optimizer.zero_grad()
loss.backward() loss.backward()
torch.nn.utils.clip_grad_norm_(self.net.parameters(), 0.5) torch.nn.utils.clip_grad_norm_(self.net.parameters(), 0.5)
@ -48,21 +37,21 @@ class LoopMAPPO(LoopSNAC):
rewards_ = torch.stack(rewards_, dim=1) rewards_ = torch.stack(rewards_, dim=1)
return rewards_ return rewards_
def actor_critic(self, tm, network, gamma, entropy_coef, vf_coef, clip_range, gae_coef=0.0, **kwargs): def mappo(self, batch, network, gamma, entropy_coef, vf_coef, clip_range, **kwargs):
obs, actions, hidden_actor, hidden_critic, old_log_probs, old_critic, reward, done = self.build_batch(tm) out = network(batch[nms.OBSERVATION], batch[nms.ACTION], batch[nms.HIDDEN_ACTOR], batch[nms.HIDDEN_CRITIC])
logits = out[nms.LOGITS][:, :-1] # last one only needed for v_{t+1}
out = network(obs, actions, hidden_actor, hidden_critic) old_log_probs = torch.log_softmax(batch[nms.LOGITS], -1)
logits = out['logits'][:, :-1] # last one only needed for v_{t+1} old_log_probs = torch.gather(old_log_probs, index=batch[nms.ACTION][:, 1:].unsqueeze(-1), dim=-1).squeeze()
critic = out['critic']
# monte carlo returns # monte carlo returns
mc_returns = self.monte_carlo_returns(reward, done, gamma) mc_returns = self.monte_carlo_returns(batch[nms.REWARD], batch[nms.DONE], gamma)
# monte_carlo_returns = (mc_returns - mc_returns.mean()) / (mc_returns.std() + 1e-7) todo: norm across agents? mc_returns = (mc_returns - mc_returns.mean()) / (mc_returns.std() + 1e-8) #todo: norm across agents ok?
advantages = mc_returns - critic[:, :-1] advantages = mc_returns - out[nms.CRITIC][:, :-1]
# policy loss # policy loss
log_ap = torch.log_softmax(logits, -1) log_ap = torch.log_softmax(logits, -1)
log_ap = torch.gather(log_ap, dim=-1, index=actions[:, 1:].unsqueeze(-1)).squeeze() log_ap = torch.gather(log_ap, dim=-1, index=batch[nms.ACTION][:, 1:].unsqueeze(-1)).squeeze()
ratio = (log_ap - old_log_probs).exp() ratio = (log_ap - old_log_probs).exp()
surr1 = ratio * advantages.detach() surr1 = ratio * advantages.detach()
surr2 = torch.clamp(ratio, 1 - clip_range, 1 + clip_range) * advantages.detach() surr2 = torch.clamp(ratio, 1 - clip_range, 1 + clip_range) * advantages.detach()

View File

@ -1,89 +1,93 @@
import torch
from typing import Union, List
from torch import Tensor
import numpy as np import numpy as np
from collections import deque
import torch
from typing import Union
from torch import Tensor
from torch.utils.data import Dataset, ConcatDataset
import random
class ActorCriticMemory(object): class ActorCriticMemory(object):
def __init__(self): def __init__(self, capacity=10):
self.capacity = capacity
self.reset() self.reset()
def reset(self): def reset(self):
self.__states = [] self.__actions = LazyTensorFiFoQueue(maxlen=self.capacity+1)
self.__actions = [] self.__hidden_actor = LazyTensorFiFoQueue(maxlen=self.capacity+1)
self.__rewards = [] self.__hidden_critic = LazyTensorFiFoQueue(maxlen=self.capacity+1)
self.__dones = [] self.__states = LazyTensorFiFoQueue(maxlen=self.capacity+1)
self.__hiddens_actor = [] self.__rewards = LazyTensorFiFoQueue(maxlen=self.capacity+1)
self.__hiddens_critic = [] self.__dones = LazyTensorFiFoQueue(maxlen=self.capacity+1)
self.__logits = [] self.__logits = LazyTensorFiFoQueue(maxlen=self.capacity+1)
self.__values = [] self.__values = LazyTensorFiFoQueue(maxlen=self.capacity+1)
def __len__(self): def __len__(self):
return len(self.__states) return len(self.__rewards) - 1
@property @property
def observation(self): # add time dimension through stacking def observation(self, sls=slice(0, None)): # add time dimension through stacking
return torch.stack(self.__states, 0).unsqueeze(0) # 1 x timesteps x hidden dim return self.__states[sls].unsqueeze(0) # 1 x time x hidden dim
@property @property
def hidden_actor(self): def hidden_actor(self, sls=slice(0, None)): # 1 x n_layers x dim
if len(self.__hiddens_actor) == 1: return self.__hidden_actor[sls].unsqueeze(0) # 1 x time x n_layers x dim
return self.__hiddens_actor[0]
return torch.stack(self.__hiddens_actor, 0) # layers x timesteps x hidden dim
@property @property
def hidden_critic(self): def hidden_critic(self, sls=slice(0, None)): # 1 x n_layers x dim
if len(self.__hiddens_critic) == 1: return self.__hidden_critic[sls].unsqueeze(0) # 1 x time x n_layers x dim
return self.__hiddens_critic[0]
return torch.stack(self.__hiddens_critic, 0) # layers x timesteps x hidden dim
@property @property
def reward(self): def reward(self, sls=slice(0, None)):
return torch.tensor(self.__rewards).float().unsqueeze(0) # 1 x timesteps return self.__rewards[sls].squeeze().unsqueeze(0) # 1 x time
@property @property
def action(self): def action(self, sls=slice(0, None)):
return torch.tensor(self.__actions).long().unsqueeze(0) # 1 x timesteps+1 return self.__actions[sls].long().squeeze().unsqueeze(0) # 1 x time
@property @property
def done(self): def done(self, sls=slice(0, None)):
return torch.tensor(self.__dones).float().unsqueeze(0) # 1 x timesteps return self.__dones[sls].float().squeeze().unsqueeze(0) # 1 x time
@property @property
def logits(self): # assumes a trailing 1 for time dimension - common when using output from NN def logits(self, sls=slice(0, None)): # assumes a trailing 1 for time dimension - common when using output from NN
return torch.cat(self.__logits, 0).unsqueeze(0) # 1 x timesteps x actions return self.__logits[sls].squeeze().unsqueeze(0) # 1 x time x actions
@property @property
def values(self): def values(self, sls=slice(0, None)):
return torch.cat(self.__values, 0).unsqueeze(0) # 1 x timesteps x actions return self.__values[sls].squeeze().unsqueeze(0) # 1 x time x actions
def add_observation(self, state: Union[Tensor, np.ndarray]): def add_observation(self, state: Union[Tensor, np.ndarray]):
self.__states.append(state if isinstance(state, Tensor) else torch.from_numpy(state)) self.__states.append(state if isinstance(state, Tensor) else torch.from_numpy(state))
def add_hidden_actor(self, hidden: Tensor): def add_hidden_actor(self, hidden: Tensor):
# 1x layers x hidden dim # layers x hidden dim
if len(hidden.shape) < 3: hidden = hidden.unsqueeze(0) self.__hidden_actor.append(hidden)
self.__hiddens_actor.append(hidden)
def add_hidden_critic(self, hidden: Tensor): def add_hidden_critic(self, hidden: Tensor):
# 1x layers x hidden dim # layers x hidden dim
if len(hidden.shape) < 3: hidden = hidden.unsqueeze(0) self.__hidden_critic.append(hidden)
self.__hiddens_critic.append(hidden)
def add_action(self, action: int): def add_action(self, action: Union[int, Tensor]):
if not isinstance(action, Tensor):
action = torch.tensor(action)
self.__actions.append(action) self.__actions.append(action)
def add_reward(self, reward: float): def add_reward(self, reward: Union[float, Tensor]):
if not isinstance(reward, Tensor):
reward = torch.tensor(reward)
self.__rewards.append(reward) self.__rewards.append(reward)
def add_done(self, done: bool): def add_done(self, done: bool):
if not isinstance(done, Tensor):
done = torch.tensor(done)
self.__dones.append(done) self.__dones.append(done)
def add_logits(self, logits: Tensor): def add_logits(self, logits: Tensor):
self.__logits.append(logits) self.__logits.append(logits)
def add_values(self, logits: Tensor): def add_values(self, values: Tensor):
self.__values.append(logits) self.__values.append(values)
def add(self, **kwargs): def add(self, **kwargs):
for k, v in kwargs.items(): for k, v in kwargs.items():
@ -92,10 +96,10 @@ class ActorCriticMemory(object):
class MARLActorCriticMemory(object): class MARLActorCriticMemory(object):
def __init__(self, n_agents): def __init__(self, n_agents, capacity):
self.n_agents = n_agents self.n_agents = n_agents
self.memories = [ self.memories = [
ActorCriticMemory() for _ in range(n_agents) ActorCriticMemory(capacity) for _ in range(n_agents)
] ]
def __call__(self, agent_i): def __call__(self, agent_i):
@ -109,50 +113,109 @@ class MARLActorCriticMemory(object):
mem.reset() mem.reset()
def add(self, **kwargs): def add(self, **kwargs):
# todo try catch - print all possible functions
for agent_i in range(self.n_agents): for agent_i in range(self.n_agents):
for k, v in kwargs.items(): for k, v in kwargs.items():
func = getattr(ActorCriticMemory, f'add_{k}') func = getattr(ActorCriticMemory, f'add_{k}')
func(self.memories[agent_i], v[agent_i]) func(self.memories[agent_i], v[agent_i])
@property def __getattr__(self, attr):
def observation(self): all_attrs = [getattr(mem, attr) for mem in self.memories]
all_obs = [mem.observation for mem in self.memories] return torch.cat(all_attrs, 0) # agents x time ...
return torch.cat(all_obs, 0) # agents x timesteps+1 x ...
def chunk_dataloader(self, chunk_len, k):
datasets = [ExperienceChunks(mem, chunk_len, k) for mem in self.memories]
dataset = ConcatDataset(datasets)
data = [dataset[i] for i in range(len(dataset))]
data = custom_collate_fn(data)
return data
def custom_collate_fn(batch):
elem = batch[0]
return {key: torch.cat([d[key] for d in batch], dim=0) for key in elem}
class ExperienceChunks(Dataset):
def __init__(self, memory, chunk_len, k):
assert chunk_len <= len(memory), 'chunk_len cannot be longer than the size of the memory'
self.memory = memory
self.chunk_len = chunk_len
self.k = k
@property @property
def action(self): def whitelist(self):
all_actions = [mem.action for mem in self.memories] whitelist = torch.ones(len(self.memory) - self.chunk_len)
return torch.cat(all_actions, 0) # agents x timesteps+1 x ... for d in self.memory.done.squeeze().nonzero().flatten():
whitelist[max((0, d-self.chunk_len-1)):d+2] = 0
whitelist[0] = 0
return whitelist.tolist()
@property def sample(self, start=1):
def done(self): cl = self.chunk_len
all_dones = [mem.done for mem in self.memories] sample = dict(observation=self.memory.observation[:, start:start+cl+1],
return torch.cat(all_dones, 0).float() # agents x timesteps x ... action=self.memory.action[:, start-1:start+cl],
hidden_actor=self.memory.hidden_actor[:, start-1],
hidden_critic=self.memory.hidden_critic[:, start-1],
reward=self.memory.reward[:, start:start + cl],
done=self.memory.done[:, start:start + cl],
logits=self.memory.logits[:, start:start + cl],
values=self.memory.values[:, start:start + cl])
return sample
def __len__(self):
return self.k
def __getitem__(self, i):
idx = random.choices(range(0, len(self.memory) - self.chunk_len), weights=self.whitelist, k=1)
return self.sample(idx[0])
class LazyTensorFiFoQueue:
def __init__(self, maxlen):
self.maxlen = maxlen
self.reset()
def reset(self):
self.__lazy_queue = deque(maxlen=self.maxlen)
self.shape = None
self.queue = None
def shape_init(self, tensor: Tensor):
self.shape = torch.Size([self.maxlen, *tensor.shape])
def build_tensor_queue(self):
if len(self.__lazy_queue) > 0:
block = torch.stack(list(self.__lazy_queue), dim=0)
l = block.shape[0]
if self.queue is None:
self.queue = block
elif self.true_len() <= self.maxlen:
self.queue = torch.cat((self.queue, block), dim=0)
else:
self.queue = torch.cat((self.queue[l:], block), dim=0)
self.__lazy_queue.clear()
def append(self, data):
if self.shape is None:
self.shape_init(data)
self.__lazy_queue.append(data)
if len(self.__lazy_queue) >= self.maxlen:
self.build_tensor_queue()
def true_len(self):
return len(self.__lazy_queue) + (0 if self.queue is None else self.queue.shape[0])
def __len__(self):
return min((self.true_len(), self.maxlen))
def __str__(self):
return f'LazyTensorFiFoQueue\tmaxlen: {self.maxlen}, shape: {self.shape}, ' \
f'len: {len(self)}, true_len: {self.true_len()}, elements in lazy queue: {len(self.__lazy_queue)}'
def __getitem__(self, item_or_slice):
self.build_tensor_queue()
return self.queue[item_or_slice]
@property
def reward(self):
all_rewards = [mem.reward for mem in self.memories]
return torch.cat(all_rewards, 0).float() # agents x timesteps x ...
@property
def hidden_actor(self):
all_ha = [mem.hidden_actor for mem in self.memories]
return torch.cat(all_ha, 0) # agents x layers x x timesteps x hidden dim
@property
def hidden_critic(self):
all_hc = [mem.hidden_critic for mem in self.memories]
return torch.cat(all_hc, 0) # agents x layers x timesteps x hidden dim
@property
def logits(self):
all_lgts = [mem.logits for mem in self.memories]
return torch.cat(all_lgts, 0) # agents x layers x timesteps x hidden dim
@property
def values(self):
all_vals = [mem.values for mem in self.memories]
return torch.cat(all_vals, 0) # agents x layers x timesteps x hidden dim

View File

@ -1,6 +1,7 @@
import torch import torch
from torch.distributions import Categorical from torch.distributions import Categorical
from algorithms.marl.iac import LoopIAC from algorithms.marl.iac import LoopIAC
from algorithms.marl.base_ac import nms
from algorithms.marl.memory import MARLActorCriticMemory from algorithms.marl.memory import MARLActorCriticMemory
@ -9,12 +10,12 @@ class LoopSEAC(LoopIAC):
super(LoopSEAC, self).__init__(cfg) super(LoopSEAC, self).__init__(cfg)
def actor_critic(self, tm, networks, gamma, entropy_coef, vf_coef, gae_coef=0.0, **kwargs): def actor_critic(self, tm, networks, gamma, entropy_coef, vf_coef, gae_coef=0.0, **kwargs):
obs, actions, done, reward = tm.observation, tm.action, tm.done, tm.reward obs, actions, done, reward = tm.observation, tm.action, tm.done[:, 1:], tm.reward[:, 1:]
outputs = [net(obs, actions, tm.hidden_actor, tm.hidden_critic) for net in networks] outputs = [net(obs, actions, tm.hidden_actor[:, 0], tm.hidden_critic[:, 0]) for net in networks]
with torch.inference_mode(True): with torch.inference_mode(True):
true_action_logp = torch.stack([ true_action_logp = torch.stack([
torch.log_softmax(out['logits'][ag_i, :-1], -1) torch.log_softmax(out[nms.LOGITS][ag_i, :-1], -1)
.gather(index=actions[ag_i, 1:, None], dim=-1) .gather(index=actions[ag_i, 1:, None], dim=-1)
for ag_i, out in enumerate(outputs) for ag_i, out in enumerate(outputs)
], 0).squeeze() ], 0).squeeze()
@ -22,8 +23,8 @@ class LoopSEAC(LoopIAC):
losses = [] losses = []
for ag_i, out in enumerate(outputs): for ag_i, out in enumerate(outputs):
logits = out['logits'][:, :-1] # last one only needed for v_{t+1} logits = out[nms.LOGITS][:, :-1] # last one only needed for v_{t+1}
critic = out['critic'] critic = out[nms.CRITIC]
entropy_loss = Categorical(logits=logits[ag_i]).entropy().mean() entropy_loss = Categorical(logits=logits[ag_i]).entropy().mean()
advantages = self.compute_advantages(critic, reward, done, gamma, gae_coef) advantages = self.compute_advantages(critic, reward, done, gamma, gae_coef)
@ -47,7 +48,7 @@ class LoopSEAC(LoopIAC):
return losses return losses
def learn(self, tms: MARLActorCriticMemory, **kwargs): def learn(self, tms: MARLActorCriticMemory, **kwargs):
losses = self.actor_critic(tms, self.net, **self.cfg['algorithm'], **kwargs) losses = self.actor_critic(tms, self.net, **self.cfg[nms.ALGORITHM], **kwargs)
for ag_i, loss in enumerate(losses): for ag_i, loss in enumerate(losses):
self.optimizer[ag_i].zero_grad() self.optimizer[ag_i].zero_grad()
loss.backward() loss.backward()

View File

@ -1,4 +1,5 @@
from algorithms.marl.base_ac import BaseActorCritic from algorithms.marl.base_ac import BaseActorCritic
from algorithms.marl.base_ac import nms
import torch import torch
from torch.distributions import Categorical from torch.distributions import Categorical
from pathlib import Path from pathlib import Path
@ -21,7 +22,7 @@ class LoopSNAC(BaseActorCritic):
) )
def get_actions(self, out): def get_actions(self, out):
actions = Categorical(logits=out['logits']).sample().squeeze() actions = Categorical(logits=out[nms.LOGITS]).sample().squeeze()
return actions return actions
def forward(self, observations, actions, hidden_actor, hidden_critic): def forward(self, observations, actions, hidden_actor, hidden_critic):

View File

@ -6,7 +6,7 @@ from algorithms.utils import load_yaml_file, add_env_props, instantiate_class, l
for i in range(0, 5): for i in range(0, 5):
for name in ['mappo']:#['seac', 'iac', 'snac']: for name in ['snac', 'mappo', 'iac', 'seac']:
study_root = Path(__file__).parent / name study_root = Path(__file__).parent / name
cfg = load_yaml_file(study_root / f'{name}.yaml') cfg = load_yaml_file(study_root / f'{name}.yaml')
add_env_props(cfg) add_env_props(cfg)

View File

@ -3,12 +3,12 @@ from pathlib import Path
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import seaborn as sns import seaborn as sns
dfs = [] dfs = []
for name in ['l2snac', 'iac', 'snac', 'seac']: for name in ['mappo']:
for c in range(5): for c in range(5):
try: try:
study_root = Path(__file__).parent / name / f'{name}#{c}' study_root = Path(__file__).parent / name / f'{name}#{c}'
print(study_root)
df = pd.read_csv(study_root / 'results.csv', index_col=False) df = pd.read_csv(study_root / 'results.csv', index_col=False)
df.reward = df.reward.rolling(100).mean() df.reward = df.reward.rolling(100).mean()
df['method'] = name.upper() df['method'] = name.upper()
@ -17,6 +17,6 @@ for name in ['l2snac', 'iac', 'snac', 'seac']:
pass pass
df = pd.concat(dfs).reset_index() df = pd.concat(dfs).reset_index()
sns.lineplot(data=df, x='episode', y='reward', hue='method', palette='husl', ci='sd', linewidth=1.5) sns.lineplot(data=df, x='steps', y='reward', hue='method', palette='husl', ci='sd', linewidth=1.5, err_style='bars')
plt.savefig('study.png') plt.savefig('study.png')
print('saved image') print('saved image')