added first working MAPPO implementation

This commit is contained in:
Robert Müller 2022-01-28 11:07:25 +01:00
parent ffc47752a7
commit b09c461754
11 changed files with 194 additions and 61 deletions

View File

@ -2,3 +2,5 @@ from algorithms.marl.base_ac import BaseActorCritic
from algorithms.marl.iac import LoopIAC from algorithms.marl.iac import LoopIAC
from algorithms.marl.snac import LoopSNAC from algorithms.marl.snac import LoopSNAC
from algorithms.marl.seac import LoopSEAC from algorithms.marl.seac import LoopSEAC
from algorithms.marl.mappo import LoopMAPPO
from algorithms.marl.memory import MARLActorCriticMemory

View File

@ -1,5 +1,6 @@
import torch import torch
from typing import Union, List from typing import Union, List
import copy
import numpy as np import numpy as np
from torch.distributions import Categorical from torch.distributions import Categorical
from algorithms.marl.memory import MARLActorCriticMemory from algorithms.marl.memory import MARLActorCriticMemory
@ -59,7 +60,7 @@ class BaseActorCritic:
actions: ListOrTensor, actions: ListOrTensor,
hidden_actor: ListOrTensor, hidden_actor: ListOrTensor,
hidden_critic: ListOrTensor hidden_critic: ListOrTensor
): ) -> dict[ListOrTensor]:
pass pass
@ -67,8 +68,9 @@ class BaseActorCritic:
def train_loop(self, checkpointer=None): def train_loop(self, checkpointer=None):
env = instantiate_class(self.cfg['env']) env = instantiate_class(self.cfg['env'])
n_steps, max_steps = [self.cfg['algorithm'][k] for k in ['n_steps', 'max_steps']] n_steps, max_steps = [self.cfg['algorithm'][k] for k in ['n_steps', 'max_steps']]
global_steps = 0 global_steps, episode, df_results = 0, 0, []
reward_queue = deque(maxlen=2000) reward_queue = deque(maxlen=2000)
memory_queue = deque(maxlen=self.cfg['algorithm'].get('keep_n_segments', 1))
while global_steps < max_steps: while global_steps < max_steps:
tm = MARLActorCriticMemory(self.n_agents) tm = MARLActorCriticMemory(self.n_agents)
obs = env.reset() obs = env.reset()
@ -85,7 +87,8 @@ class BaseActorCritic:
next_obs = next_obs next_obs = next_obs
if isinstance(done, bool): done = [done] * self.n_agents if isinstance(done, bool): done = [done] * self.n_agents
tm.add(observation=obs, action=action, reward=reward, done=done) tm.add(observation=obs, action=action, reward=reward, done=done,
logits=out.get('logits', None), values=out.get('critic', None))
obs = next_obs obs = next_obs
last_action = action last_action = action
last_hiddens = dict(hidden_actor=out.get('hidden_actor', None), last_hiddens = dict(hidden_actor=out.get('hidden_actor', None),
@ -94,9 +97,11 @@ class BaseActorCritic:
if len(tm) >= n_steps or all(done): if len(tm) >= n_steps or all(done):
tm.add(observation=next_obs) tm.add(observation=next_obs)
memory_queue.append(copy.deepcopy(tm))
if self.__training: if self.__training:
with torch.inference_mode(False): with torch.inference_mode(False):
self.learn(tm) tm_ = tm if memory_queue.maxlen <= 1 else list(memory_queue)
self.learn(tm_)
tm.reset() tm.reset()
tm.add(action=last_action, **last_hiddens) tm.add(action=last_action, **last_hiddens)
global_steps += 1 global_steps += 1
@ -110,7 +115,13 @@ class BaseActorCritic:
]) ])
if global_steps >= max_steps: break if global_steps >= max_steps: break
print(f'reward at step: {global_steps} = {rew_log}') print(f'reward at step: {episode} = {rew_log}')
episode += 1
df_results.append([global_steps, rew_log])
df_results = pd.DataFrame(df_results, columns=['steps', 'reward'])
if checkpointer is not None:
df_results.to_csv(checkpointer.path / 'results.csv', index=False)
return df_results
@torch.inference_mode(True) @torch.inference_mode(True)
def eval_loop(self, n_episodes, render=False): def eval_loop(self, n_episodes, render=False):
@ -143,10 +154,21 @@ class BaseActorCritic:
return results return results
@staticmethod @staticmethod
def compute_advantages(critic, reward, done, gamma): def compute_advantages(critic, reward, done, gamma, gae_coef=0.0):
return (reward + gamma * (1.0 - done) * critic[:, 1:].detach()) - critic[:, :-1] tds = (reward + gamma * (1.0 - done) * critic[:, 1:].detach()) - critic[:, :-1]
def actor_critic(self, tm, network, gamma, entropy_coef, vf_coef, **kwargs): if gae_coef <= 0:
return tds
gae = torch.zeros_like(tds[:, -1])
gaes = []
for t in range(tds.shape[1]-1, -1, -1):
gae = tds[:, t] + gamma * gae_coef * (1.0 - done[:, t]) * gae
gaes.insert(0, gae)
gaes = torch.stack(gaes, dim=1)
return gaes
def actor_critic(self, tm, network, gamma, entropy_coef, vf_coef, gae_coef=0.0, **kwargs):
obs, actions, done, reward = tm.observation, tm.action, tm.done, tm.reward obs, actions, done, reward = tm.observation, tm.action, tm.done, tm.reward
out = network(obs, actions, tm.hidden_actor, tm.hidden_critic) out = network(obs, actions, tm.hidden_actor, tm.hidden_critic)
@ -154,7 +176,7 @@ class BaseActorCritic:
critic = out['critic'] critic = out['critic']
entropy_loss = Categorical(logits=logits).entropy().mean(-1) entropy_loss = Categorical(logits=logits).entropy().mean(-1)
advantages = self.compute_advantages(critic, reward, done, gamma) advantages = self.compute_advantages(critic, reward, done, gamma, gae_coef)
value_loss = advantages.pow(2).mean(-1) # n_agent value_loss = advantages.pow(2).mean(-1) # n_agent
# policy loss # policy loss
@ -163,7 +185,6 @@ class BaseActorCritic:
a2c_loss = -(advantages.detach() * log_ap).mean(-1) a2c_loss = -(advantages.detach() * log_ap).mean(-1)
# weighted loss # weighted loss
loss = a2c_loss + vf_coef*value_loss - entropy_coef * entropy_loss loss = a2c_loss + vf_coef*value_loss - entropy_coef * entropy_loss
return loss.mean() return loss.mean()
def learn(self, tm: MARLActorCriticMemory, **kwargs): def learn(self, tm: MARLActorCriticMemory, **kwargs):

View File

@ -21,7 +21,6 @@ class LoopIAC(BaseActorCritic):
def load_state_dict(self, path: Path): def load_state_dict(self, path: Path):
paths = natsorted(list(path.glob('*.pt'))) paths = natsorted(list(path.glob('*.pt')))
print(list(paths))
for path, net in zip(paths, self.net): for path, net in zip(paths, self.net):
net.load_state_dict(torch.load(path)) net.load_state_dict(torch.load(path))

78
algorithms/marl/mappo.py Normal file
View File

@ -0,0 +1,78 @@
from algorithms.marl import LoopSNAC
from algorithms.marl.memory import MARLActorCriticMemory
from typing import List
import random
import torch
from torch.distributions import Categorical
class LoopMAPPO(LoopSNAC):
def __init__(self, *args, **kwargs):
super(LoopMAPPO, self).__init__(*args, **kwargs)
def build_batch(self, tm: List[MARLActorCriticMemory]):
sample = random.choices(tm, k=self.cfg['algorithm']['batch_size']-1)
sample.append(tm[-1]) # always use latest segment in batch
obs = torch.cat([s.observation for s in sample], 0)
actions = torch.cat([s.action for s in sample], 0)
hidden_actor = torch.cat([s.hidden_actor for s in sample], 0)
hidden_critic = torch.cat([s.hidden_critic for s in sample], 0)
logits = torch.cat([s.logits for s in sample], 0)
values = torch.cat([s.values for s in sample], 0)
reward = torch.cat([s.reward for s in sample], 0)
done = torch.cat([s.done for s in sample], 0)
log_props = torch.log_softmax(logits, -1)
log_props = torch.gather(log_props, index=actions[:, 1:].unsqueeze(-1), dim=-1).squeeze()
return obs, actions, hidden_actor, hidden_critic, log_props, values, reward, done
def learn(self, tm: List[MARLActorCriticMemory], **kwargs):
if len(tm) >= self.cfg['algorithm']['keep_n_segments']:
# only learn when buffer is full
for batch_i in range(self.cfg['algorithm']['n_updates']):
loss = self.actor_critic(tm, self.net, **self.cfg['algorithm'], **kwargs)
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.net.parameters(), 0.5)
self.optimizer.step()
def monte_carlo_returns(self, rewards, done, gamma):
rewards_ = []
discounted_reward = torch.zeros_like(rewards[:, -1])
for t in range(rewards.shape[1]-1, -1, -1):
discounted_reward = rewards[:, t] + (gamma * (1.0 - done[:, t]) * discounted_reward)
rewards_.insert(0, discounted_reward)
rewards_ = torch.stack(rewards_, dim=1)
return rewards_
def actor_critic(self, tm, network, gamma, entropy_coef, vf_coef, clip_range, gae_coef=0.0, **kwargs):
obs, actions, hidden_actor, hidden_critic, old_log_probs, old_critic, reward, done = self.build_batch(tm)
out = network(obs, actions, hidden_actor, hidden_critic)
logits = out['logits'][:, :-1] # last one only needed for v_{t+1}
critic = out['critic']
# monte carlo returns
mc_returns = self.monte_carlo_returns(reward, done, gamma)
# monte_carlo_returns = (mc_returns - mc_returns.mean()) / (mc_returns.std() + 1e-7) todo: norm across agents?
advantages = mc_returns - critic[:, :-1]
# policy loss
log_ap = torch.log_softmax(logits, -1)
log_ap = torch.gather(log_ap, dim=-1, index=actions[:, 1:].unsqueeze(-1)).squeeze()
ratio = (log_ap - old_log_probs).exp()
surr1 = ratio * advantages.detach()
surr2 = torch.clamp(ratio, 1 - clip_range, 1 + clip_range) * advantages.detach()
policy_loss = -torch.min(surr1, surr2).mean(-1)
# entropy & value loss
entropy_loss = Categorical(logits=logits).entropy().mean(-1)
value_loss = advantages.pow(2).mean(-1) # n_agent
# weighted loss
loss = policy_loss + vf_coef*value_loss - entropy_coef * entropy_loss
return loss.mean()

View File

@ -15,12 +15,14 @@ class ActorCriticMemory(object):
self.__dones = [] self.__dones = []
self.__hiddens_actor = [] self.__hiddens_actor = []
self.__hiddens_critic = [] self.__hiddens_critic = []
self.__logits = []
self.__values = []
def __len__(self): def __len__(self):
return len(self.__states) return len(self.__states)
@property @property
def observation(self): def observation(self): # add time dimension through stacking
return torch.stack(self.__states, 0).unsqueeze(0) # 1 x timesteps x hidden dim return torch.stack(self.__states, 0).unsqueeze(0) # 1 x timesteps x hidden dim
@property @property
@ -47,6 +49,14 @@ class ActorCriticMemory(object):
def done(self): def done(self):
return torch.tensor(self.__dones).float().unsqueeze(0) # 1 x timesteps return torch.tensor(self.__dones).float().unsqueeze(0) # 1 x timesteps
@property
def logits(self): # assumes a trailing 1 for time dimension - common when using output from NN
return torch.cat(self.__logits, 0).unsqueeze(0) # 1 x timesteps x actions
@property
def values(self):
return torch.cat(self.__values, 0).unsqueeze(0) # 1 x timesteps x actions
def add_observation(self, state: Union[Tensor, np.ndarray]): def add_observation(self, state: Union[Tensor, np.ndarray]):
self.__states.append(state if isinstance(state, Tensor) else torch.from_numpy(state)) self.__states.append(state if isinstance(state, Tensor) else torch.from_numpy(state))
@ -69,6 +79,12 @@ class ActorCriticMemory(object):
def add_done(self, done: bool): def add_done(self, done: bool):
self.__dones.append(done) self.__dones.append(done)
def add_logits(self, logits: Tensor):
self.__logits.append(logits)
def add_values(self, logits: Tensor):
self.__values.append(logits)
def add(self, **kwargs): def add(self, **kwargs):
for k, v in kwargs.items(): for k, v in kwargs.items():
func = getattr(ActorCriticMemory, f'add_{k}') func = getattr(ActorCriticMemory, f'add_{k}')
@ -129,3 +145,14 @@ class MARLActorCriticMemory(object):
all_hc = [mem.hidden_critic for mem in self.memories] all_hc = [mem.hidden_critic for mem in self.memories]
return torch.cat(all_hc, 0) # agents x layers x timesteps x hidden dim return torch.cat(all_hc, 0) # agents x layers x timesteps x hidden dim
@property
def logits(self):
all_lgts = [mem.logits for mem in self.memories]
return torch.cat(all_lgts, 0) # agents x layers x timesteps x hidden dim
@property
def values(self):
all_vals = [mem.values for mem in self.memories]
return torch.cat(all_vals, 0) # agents x layers x timesteps x hidden dim

View File

@ -12,6 +12,7 @@ class RecurrentAC(nn.Module):
super(RecurrentAC, self).__init__() super(RecurrentAC, self).__init__()
observation_size = np.prod(observation_size) observation_size = np.prod(observation_size)
self.n_layers = 1 self.n_layers = 1
self.n_actions = n_actions
self.use_agent_embedding = use_agent_embedding self.use_agent_embedding = use_agent_embedding
self.hidden_size_actor = hidden_size_actor self.hidden_size_actor = hidden_size_actor
self.hidden_size_critic = hidden_size_critic self.hidden_size_critic = hidden_size_critic
@ -28,10 +29,11 @@ class RecurrentAC(nn.Module):
self.gru_actor = nn.GRU(obs_emb_size, hidden_size_actor, batch_first=True, num_layers=self.n_layers) self.gru_actor = nn.GRU(obs_emb_size, hidden_size_actor, batch_first=True, num_layers=self.n_layers)
self.gru_critic = nn.GRU(obs_emb_size, hidden_size_critic, batch_first=True, num_layers=self.n_layers) self.gru_critic = nn.GRU(obs_emb_size, hidden_size_critic, batch_first=True, num_layers=self.n_layers)
self.action_head = nn.Sequential( self.action_head = nn.Sequential(
spectral_norm(nn.Linear(hidden_size_actor, hidden_size_actor)), nn.Linear(hidden_size_actor, hidden_size_actor),
nn.Tanh(), nn.Tanh(),
nn.Linear(hidden_size_actor, n_actions) nn.Linear(hidden_size_actor, n_actions)
) )
# spectral_norm(nn.Linear(hidden_size_actor, hidden_size_actor)),
self.critic_head = nn.Sequential( self.critic_head = nn.Sequential(
nn.Linear(hidden_size_critic, hidden_size_critic), nn.Linear(hidden_size_critic, hidden_size_critic),
nn.Tanh(), nn.Tanh(),
@ -50,12 +52,14 @@ class RecurrentAC(nn.Module):
n_agents, t, *_ = observations.shape n_agents, t, *_ = observations.shape
obs_emb = self.obs_proj(observations.view(n_agents, t, -1).float()) obs_emb = self.obs_proj(observations.view(n_agents, t, -1).float())
action_emb = self.action_emb(actions+1) # shift by one due to padding idx action_emb = self.action_emb(actions+1) # shift by one due to padding idx
agent_emb = self.agent_emb(
torch.cat([torch.arange(0, n_agents, 1).view(-1, 1)]*t, 1)
)
x_t = torch.cat((obs_emb, action_emb), -1) \
if not self.use_agent_embedding else torch.cat((obs_emb, agent_emb, action_emb), -1)
if not self.use_agent_embedding:
x_t = torch.cat((obs_emb, action_emb), -1)
else:
agent_emb = self.agent_emb(
torch.cat([torch.arange(0, n_agents, 1).view(-1, 1)] * t, 1)
)
x_t = torch.cat((obs_emb, agent_emb, action_emb), -1)
mixed_x_t = self.mix(x_t) mixed_x_t = self.mix(x_t)
output_p, _ = self.gru_actor(input=mixed_x_t, hx=hidden_actor.swapaxes(1, 0)) output_p, _ = self.gru_actor(input=mixed_x_t, hx=hidden_actor.swapaxes(1, 0))
@ -66,6 +70,15 @@ class RecurrentAC(nn.Module):
return dict(logits=logits, critic=critic, hidden_actor=output_p, hidden_critic=output_c) return dict(logits=logits, critic=critic, hidden_actor=output_p, hidden_critic=output_c)
class RecurrentACL2(RecurrentAC):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.action_head = nn.Sequential(
nn.Linear(self.hidden_size_actor, self.hidden_size_actor),
nn.Tanh(),
NormalizedLinear(self.hidden_size_actor, self.n_actions, trainable_magnitude=True)
)
class NormalizedLinear(nn.Linear): class NormalizedLinear(nn.Linear):
def __init__(self, in_features: int, out_features: int, def __init__(self, in_features: int, out_features: int,

View File

@ -8,7 +8,7 @@ class LoopSEAC(LoopIAC):
def __init__(self, cfg): def __init__(self, cfg):
super(LoopSEAC, self).__init__(cfg) super(LoopSEAC, self).__init__(cfg)
def actor_critic(self, tm, networks, gamma, entropy_coef, vf_coef, **kwargs): def actor_critic(self, tm, networks, gamma, entropy_coef, vf_coef, gae_coef=0.0, **kwargs):
obs, actions, done, reward = tm.observation, tm.action, tm.done, tm.reward obs, actions, done, reward = tm.observation, tm.action, tm.done, tm.reward
outputs = [net(obs, actions, tm.hidden_actor, tm.hidden_critic) for net in networks] outputs = [net(obs, actions, tm.hidden_actor, tm.hidden_critic) for net in networks]
@ -26,7 +26,7 @@ class LoopSEAC(LoopIAC):
critic = out['critic'] critic = out['critic']
entropy_loss = Categorical(logits=logits[ag_i]).entropy().mean() entropy_loss = Categorical(logits=logits[ag_i]).entropy().mean()
advantages = self.compute_advantages(critic, reward, done, gamma) advantages = self.compute_advantages(critic, reward, done, gamma, gae_coef)
# policy loss # policy loss
log_ap = torch.log_softmax(logits, -1) log_ap = torch.log_softmax(logits, -1)

View File

@ -1,8 +1,9 @@
numpy numpy
scipy scipy
tqdm tqdm
pandas
seaborn>=0.11.1 seaborn>=0.11.1
matplotlib>=3.4.1 matplotlib>=3.3.4
stable-baselines3>=1.0 stable-baselines3>=1.0
pygame>=2.1.0 pygame>=2.1.0
gym>=0.18.0 gym>=0.18.0
@ -10,3 +11,4 @@ networkx>=2.6.3
simplejson>=3.17.5 simplejson>=3.17.5
PyYAML>=6.0 PyYAML>=6.0
einops einops
natsort

View File

@ -1,14 +1,13 @@
from algorithms.utils import Checkpointer from algorithms.utils import Checkpointer
from pathlib import Path from pathlib import Path
from algorithms.utils import load_yaml_file, add_env_props, instantiate_class, load_class from algorithms.utils import load_yaml_file, add_env_props, instantiate_class, load_class
from algorithms.marl import LoopSNAC, LoopIAC, LoopSEAC #from algorithms.marl import LoopSNAC, LoopIAC, LoopSEAC
#study_root = Path(__file__).parent / 'curious_study'
study_root = Path('/Users/romue/PycharmProjects/EDYS/algorithms/marl')
for i in range(0, 5): for i in range(0, 5):
for name in ['example_config']: for name in ['mappo']:#['seac', 'iac', 'snac']:
study_root = Path(__file__).parent / name
cfg = load_yaml_file(study_root / f'{name}.yaml') cfg = load_yaml_file(study_root / f'{name}.yaml')
add_env_props(cfg) add_env_props(cfg)
@ -17,7 +16,7 @@ for i in range(0, 5):
max_steps = cfg['algorithm']['max_steps'] max_steps = cfg['algorithm']['max_steps']
n_steps = cfg['algorithm']['n_steps'] n_steps = cfg['algorithm']['n_steps']
checkpointer = Checkpointer(f'{name}#{i}', study_root, cfg, max_steps, 250) checkpointer = Checkpointer(f'{name}#{i}', study_root, cfg, max_steps, 50)
loop = load_class(cfg['method'])(cfg) loop = load_class(cfg['method'])(cfg)
df = loop.train_loop(checkpointer) df = loop.train_loop(checkpointer)

View File

@ -1,32 +1,22 @@
import numpy as np
import pandas as pd import pandas as pd
from pathlib import Path from pathlib import Path
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import seaborn as sns import seaborn as sns
study_root = Path(__file__).parent / 'entropy_study'
names_all = ['basic_gru', 'layernorm_gru', 'spectralnorm_gru', 'nonorm_gru']
names_only_1 = ['L2OnlyAh_gru', 'L2OnlyChAh_gru', 'L2OnlyMix_gru', 'basic_gru']
names_only_2 = ['L2NoCh_gru', 'L2NoAh_gru', 'nomix_gru', 'basic_gru']
names = names_only_2 dfs = []
#names = ['nonorm_gru'] for name in ['l2snac', 'iac', 'snac', 'seac']:
# /Users/romue/PycharmProjects/EDYS/studies/normalization_study/basic_gru#3 for c in range(5):
csvs = []
for name in ['basic_gru', 'nonorm_gru', 'spectralnorm_gru']:
for run in range(0, 1):
try: try:
df = pd.read_csv(study_root / f'{name}#{run}' / 'results.csv') study_root = Path(__file__).parent / name / f'{name}#{c}'
df = df[df.agent == 'sum'] df = pd.read_csv(study_root / 'results.csv', index_col=False)
df = df.groupby(['checkpoint', 'run']).mean().reset_index() df.reward = df.reward.rolling(100).mean()
df['method'] = name df['method'] = name.upper()
df['run_'] = run dfs.append(df)
df.reward = df.reward.rolling(15).mean()
csvs.append(df)
except Exception as e: except Exception as e:
print(f'skipped {run}\t {name}') pass
csvs = pd.concat(csvs).rename(columns={"checkpoint": "steps*2e3", "B": "c"}) df = pd.concat(dfs).reset_index()
sns.lineplot(data=csvs, x='steps*2e3', y='reward', hue='method', palette='husl', ci='sd', linewidth=1.8) sns.lineplot(data=df, x='episode', y='reward', hue='method', palette='husl', ci='sd', linewidth=1.5)
plt.savefig('entropy.png') plt.savefig('study.png')
print('saved image')

View File

@ -3,19 +3,21 @@ from algorithms.marl import LoopSNAC, LoopIAC, LoopSEAC
from pathlib import Path from pathlib import Path
from algorithms.utils import load_yaml_file from algorithms.utils import load_yaml_file
from tqdm import trange from tqdm import trange
study = 'curious_study' study = 'example_config#0'
study_root = Path(__file__).parent / study #study_root = Path(__file__).parent / study
study_root = Path('/Users/romue/PycharmProjects/EDYS/algorithms/marl/')
#['L2NoAh_gru', 'L2NoCh_gru', 'nomix_gru']: #['L2NoAh_gru', 'L2NoCh_gru', 'nomix_gru']:
render = True render = True
eval_eps = 3 eval_eps = 3
for run in range(0, 5): for run in range(0, 5):
for name in ['basic_gru']:#['L2OnlyAh_gru', 'L2OnlyChAh_gru', 'L2OnlyMix_gru']: #['layernorm_gru', 'basic_gru', 'nonorm_gru', 'spectralnorm_gru']: for name in ['example_config']:#['L2OnlyAh_gru', 'L2OnlyChAh_gru', 'L2OnlyMix_gru']: #['layernorm_gru', 'basic_gru', 'nonorm_gru', 'spectralnorm_gru']:
cfg = load_yaml_file(Path(__file__).parent / study / f'{name}.yaml') cfg = load_yaml_file(study_root / study / 'config.yaml')
p_root = Path(study_root / f'{name}#{run}') #p_root = Path(study_root / study / f'{name}#{run}')
dfs = [] dfs = []
for i in trange(500): for i in trange(500):
path = p_root / f'checkpoint_{i}' path = study_root / study / f'checkpoint_{161}'
print(path)
snac = LoopSEAC(cfg) snac = LoopSEAC(cfg)
snac.load_state_dict(path) snac.load_state_dict(path)