added first working MAPPO implementation

This commit is contained in:
Robert Müller 2022-01-28 11:07:25 +01:00
parent ffc47752a7
commit b09c461754
11 changed files with 194 additions and 61 deletions

View File

@ -1,4 +1,6 @@
from algorithms.marl.base_ac import BaseActorCritic
from algorithms.marl.iac import LoopIAC
from algorithms.marl.snac import LoopSNAC
from algorithms.marl.seac import LoopSEAC
from algorithms.marl.iac import LoopIAC
from algorithms.marl.snac import LoopSNAC
from algorithms.marl.seac import LoopSEAC
from algorithms.marl.mappo import LoopMAPPO
from algorithms.marl.memory import MARLActorCriticMemory

View File

@ -1,5 +1,6 @@
import torch
from typing import Union, List
import copy
import numpy as np
from torch.distributions import Categorical
from algorithms.marl.memory import MARLActorCriticMemory
@ -59,7 +60,7 @@ class BaseActorCritic:
actions: ListOrTensor,
hidden_actor: ListOrTensor,
hidden_critic: ListOrTensor
):
) -> dict[ListOrTensor]:
pass
@ -67,8 +68,9 @@ class BaseActorCritic:
def train_loop(self, checkpointer=None):
env = instantiate_class(self.cfg['env'])
n_steps, max_steps = [self.cfg['algorithm'][k] for k in ['n_steps', 'max_steps']]
global_steps = 0
global_steps, episode, df_results = 0, 0, []
reward_queue = deque(maxlen=2000)
memory_queue = deque(maxlen=self.cfg['algorithm'].get('keep_n_segments', 1))
while global_steps < max_steps:
tm = MARLActorCriticMemory(self.n_agents)
obs = env.reset()
@ -85,7 +87,8 @@ class BaseActorCritic:
next_obs = next_obs
if isinstance(done, bool): done = [done] * self.n_agents
tm.add(observation=obs, action=action, reward=reward, done=done)
tm.add(observation=obs, action=action, reward=reward, done=done,
logits=out.get('logits', None), values=out.get('critic', None))
obs = next_obs
last_action = action
last_hiddens = dict(hidden_actor=out.get('hidden_actor', None),
@ -94,9 +97,11 @@ class BaseActorCritic:
if len(tm) >= n_steps or all(done):
tm.add(observation=next_obs)
memory_queue.append(copy.deepcopy(tm))
if self.__training:
with torch.inference_mode(False):
self.learn(tm)
tm_ = tm if memory_queue.maxlen <= 1 else list(memory_queue)
self.learn(tm_)
tm.reset()
tm.add(action=last_action, **last_hiddens)
global_steps += 1
@ -110,7 +115,13 @@ class BaseActorCritic:
])
if global_steps >= max_steps: break
print(f'reward at step: {global_steps} = {rew_log}')
print(f'reward at step: {episode} = {rew_log}')
episode += 1
df_results.append([global_steps, rew_log])
df_results = pd.DataFrame(df_results, columns=['steps', 'reward'])
if checkpointer is not None:
df_results.to_csv(checkpointer.path / 'results.csv', index=False)
return df_results
@torch.inference_mode(True)
def eval_loop(self, n_episodes, render=False):
@ -143,10 +154,21 @@ class BaseActorCritic:
return results
@staticmethod
def compute_advantages(critic, reward, done, gamma):
return (reward + gamma * (1.0 - done) * critic[:, 1:].detach()) - critic[:, :-1]
def compute_advantages(critic, reward, done, gamma, gae_coef=0.0):
tds = (reward + gamma * (1.0 - done) * critic[:, 1:].detach()) - critic[:, :-1]
def actor_critic(self, tm, network, gamma, entropy_coef, vf_coef, **kwargs):
if gae_coef <= 0:
return tds
gae = torch.zeros_like(tds[:, -1])
gaes = []
for t in range(tds.shape[1]-1, -1, -1):
gae = tds[:, t] + gamma * gae_coef * (1.0 - done[:, t]) * gae
gaes.insert(0, gae)
gaes = torch.stack(gaes, dim=1)
return gaes
def actor_critic(self, tm, network, gamma, entropy_coef, vf_coef, gae_coef=0.0, **kwargs):
obs, actions, done, reward = tm.observation, tm.action, tm.done, tm.reward
out = network(obs, actions, tm.hidden_actor, tm.hidden_critic)
@ -154,7 +176,7 @@ class BaseActorCritic:
critic = out['critic']
entropy_loss = Categorical(logits=logits).entropy().mean(-1)
advantages = self.compute_advantages(critic, reward, done, gamma)
advantages = self.compute_advantages(critic, reward, done, gamma, gae_coef)
value_loss = advantages.pow(2).mean(-1) # n_agent
# policy loss
@ -163,7 +185,6 @@ class BaseActorCritic:
a2c_loss = -(advantages.detach() * log_ap).mean(-1)
# weighted loss
loss = a2c_loss + vf_coef*value_loss - entropy_coef * entropy_loss
return loss.mean()
def learn(self, tm: MARLActorCriticMemory, **kwargs):

View File

@ -21,7 +21,6 @@ class LoopIAC(BaseActorCritic):
def load_state_dict(self, path: Path):
paths = natsorted(list(path.glob('*.pt')))
print(list(paths))
for path, net in zip(paths, self.net):
net.load_state_dict(torch.load(path))

78
algorithms/marl/mappo.py Normal file
View File

@ -0,0 +1,78 @@
from algorithms.marl import LoopSNAC
from algorithms.marl.memory import MARLActorCriticMemory
from typing import List
import random
import torch
from torch.distributions import Categorical
class LoopMAPPO(LoopSNAC):
def __init__(self, *args, **kwargs):
super(LoopMAPPO, self).__init__(*args, **kwargs)
def build_batch(self, tm: List[MARLActorCriticMemory]):
sample = random.choices(tm, k=self.cfg['algorithm']['batch_size']-1)
sample.append(tm[-1]) # always use latest segment in batch
obs = torch.cat([s.observation for s in sample], 0)
actions = torch.cat([s.action for s in sample], 0)
hidden_actor = torch.cat([s.hidden_actor for s in sample], 0)
hidden_critic = torch.cat([s.hidden_critic for s in sample], 0)
logits = torch.cat([s.logits for s in sample], 0)
values = torch.cat([s.values for s in sample], 0)
reward = torch.cat([s.reward for s in sample], 0)
done = torch.cat([s.done for s in sample], 0)
log_props = torch.log_softmax(logits, -1)
log_props = torch.gather(log_props, index=actions[:, 1:].unsqueeze(-1), dim=-1).squeeze()
return obs, actions, hidden_actor, hidden_critic, log_props, values, reward, done
def learn(self, tm: List[MARLActorCriticMemory], **kwargs):
if len(tm) >= self.cfg['algorithm']['keep_n_segments']:
# only learn when buffer is full
for batch_i in range(self.cfg['algorithm']['n_updates']):
loss = self.actor_critic(tm, self.net, **self.cfg['algorithm'], **kwargs)
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.net.parameters(), 0.5)
self.optimizer.step()
def monte_carlo_returns(self, rewards, done, gamma):
rewards_ = []
discounted_reward = torch.zeros_like(rewards[:, -1])
for t in range(rewards.shape[1]-1, -1, -1):
discounted_reward = rewards[:, t] + (gamma * (1.0 - done[:, t]) * discounted_reward)
rewards_.insert(0, discounted_reward)
rewards_ = torch.stack(rewards_, dim=1)
return rewards_
def actor_critic(self, tm, network, gamma, entropy_coef, vf_coef, clip_range, gae_coef=0.0, **kwargs):
obs, actions, hidden_actor, hidden_critic, old_log_probs, old_critic, reward, done = self.build_batch(tm)
out = network(obs, actions, hidden_actor, hidden_critic)
logits = out['logits'][:, :-1] # last one only needed for v_{t+1}
critic = out['critic']
# monte carlo returns
mc_returns = self.monte_carlo_returns(reward, done, gamma)
# monte_carlo_returns = (mc_returns - mc_returns.mean()) / (mc_returns.std() + 1e-7) todo: norm across agents?
advantages = mc_returns - critic[:, :-1]
# policy loss
log_ap = torch.log_softmax(logits, -1)
log_ap = torch.gather(log_ap, dim=-1, index=actions[:, 1:].unsqueeze(-1)).squeeze()
ratio = (log_ap - old_log_probs).exp()
surr1 = ratio * advantages.detach()
surr2 = torch.clamp(ratio, 1 - clip_range, 1 + clip_range) * advantages.detach()
policy_loss = -torch.min(surr1, surr2).mean(-1)
# entropy & value loss
entropy_loss = Categorical(logits=logits).entropy().mean(-1)
value_loss = advantages.pow(2).mean(-1) # n_agent
# weighted loss
loss = policy_loss + vf_coef*value_loss - entropy_coef * entropy_loss
return loss.mean()

View File

@ -13,14 +13,16 @@ class ActorCriticMemory(object):
self.__actions = []
self.__rewards = []
self.__dones = []
self.__hiddens_actor = []
self.__hiddens_actor = []
self.__hiddens_critic = []
self.__logits = []
self.__values = []
def __len__(self):
return len(self.__states)
@property
def observation(self):
def observation(self): # add time dimension through stacking
return torch.stack(self.__states, 0).unsqueeze(0) # 1 x timesteps x hidden dim
@property
@ -47,6 +49,14 @@ class ActorCriticMemory(object):
def done(self):
return torch.tensor(self.__dones).float().unsqueeze(0) # 1 x timesteps
@property
def logits(self): # assumes a trailing 1 for time dimension - common when using output from NN
return torch.cat(self.__logits, 0).unsqueeze(0) # 1 x timesteps x actions
@property
def values(self):
return torch.cat(self.__values, 0).unsqueeze(0) # 1 x timesteps x actions
def add_observation(self, state: Union[Tensor, np.ndarray]):
self.__states.append(state if isinstance(state, Tensor) else torch.from_numpy(state))
@ -69,6 +79,12 @@ class ActorCriticMemory(object):
def add_done(self, done: bool):
self.__dones.append(done)
def add_logits(self, logits: Tensor):
self.__logits.append(logits)
def add_values(self, logits: Tensor):
self.__values.append(logits)
def add(self, **kwargs):
for k, v in kwargs.items():
func = getattr(ActorCriticMemory, f'add_{k}')
@ -129,3 +145,14 @@ class MARLActorCriticMemory(object):
all_hc = [mem.hidden_critic for mem in self.memories]
return torch.cat(all_hc, 0) # agents x layers x timesteps x hidden dim
@property
def logits(self):
all_lgts = [mem.logits for mem in self.memories]
return torch.cat(all_lgts, 0) # agents x layers x timesteps x hidden dim
@property
def values(self):
all_vals = [mem.values for mem in self.memories]
return torch.cat(all_vals, 0) # agents x layers x timesteps x hidden dim

View File

@ -12,6 +12,7 @@ class RecurrentAC(nn.Module):
super(RecurrentAC, self).__init__()
observation_size = np.prod(observation_size)
self.n_layers = 1
self.n_actions = n_actions
self.use_agent_embedding = use_agent_embedding
self.hidden_size_actor = hidden_size_actor
self.hidden_size_critic = hidden_size_critic
@ -25,13 +26,14 @@ class RecurrentAC(nn.Module):
nn.Tanh(),
nn.Linear(obs_emb_size, obs_emb_size)
)
self.gru_actor = nn.GRU(obs_emb_size, hidden_size_actor, batch_first=True, num_layers=self.n_layers)
self.gru_actor = nn.GRU(obs_emb_size, hidden_size_actor, batch_first=True, num_layers=self.n_layers)
self.gru_critic = nn.GRU(obs_emb_size, hidden_size_critic, batch_first=True, num_layers=self.n_layers)
self.action_head = nn.Sequential(
spectral_norm(nn.Linear(hidden_size_actor, hidden_size_actor)),
nn.Linear(hidden_size_actor, hidden_size_actor),
nn.Tanh(),
nn.Linear(hidden_size_actor, n_actions)
)
# spectral_norm(nn.Linear(hidden_size_actor, hidden_size_actor)),
self.critic_head = nn.Sequential(
nn.Linear(hidden_size_critic, hidden_size_critic),
nn.Tanh(),
@ -50,12 +52,14 @@ class RecurrentAC(nn.Module):
n_agents, t, *_ = observations.shape
obs_emb = self.obs_proj(observations.view(n_agents, t, -1).float())
action_emb = self.action_emb(actions+1) # shift by one due to padding idx
agent_emb = self.agent_emb(
torch.cat([torch.arange(0, n_agents, 1).view(-1, 1)]*t, 1)
)
x_t = torch.cat((obs_emb, action_emb), -1) \
if not self.use_agent_embedding else torch.cat((obs_emb, agent_emb, action_emb), -1)
if not self.use_agent_embedding:
x_t = torch.cat((obs_emb, action_emb), -1)
else:
agent_emb = self.agent_emb(
torch.cat([torch.arange(0, n_agents, 1).view(-1, 1)] * t, 1)
)
x_t = torch.cat((obs_emb, agent_emb, action_emb), -1)
mixed_x_t = self.mix(x_t)
output_p, _ = self.gru_actor(input=mixed_x_t, hx=hidden_actor.swapaxes(1, 0))
@ -66,6 +70,15 @@ class RecurrentAC(nn.Module):
return dict(logits=logits, critic=critic, hidden_actor=output_p, hidden_critic=output_c)
class RecurrentACL2(RecurrentAC):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.action_head = nn.Sequential(
nn.Linear(self.hidden_size_actor, self.hidden_size_actor),
nn.Tanh(),
NormalizedLinear(self.hidden_size_actor, self.n_actions, trainable_magnitude=True)
)
class NormalizedLinear(nn.Linear):
def __init__(self, in_features: int, out_features: int,

View File

@ -8,7 +8,7 @@ class LoopSEAC(LoopIAC):
def __init__(self, cfg):
super(LoopSEAC, self).__init__(cfg)
def actor_critic(self, tm, networks, gamma, entropy_coef, vf_coef, **kwargs):
def actor_critic(self, tm, networks, gamma, entropy_coef, vf_coef, gae_coef=0.0, **kwargs):
obs, actions, done, reward = tm.observation, tm.action, tm.done, tm.reward
outputs = [net(obs, actions, tm.hidden_actor, tm.hidden_critic) for net in networks]
@ -26,7 +26,7 @@ class LoopSEAC(LoopIAC):
critic = out['critic']
entropy_loss = Categorical(logits=logits[ag_i]).entropy().mean()
advantages = self.compute_advantages(critic, reward, done, gamma)
advantages = self.compute_advantages(critic, reward, done, gamma, gae_coef)
# policy loss
log_ap = torch.log_softmax(logits, -1)

View File

@ -1,8 +1,9 @@
numpy
scipy
tqdm
pandas
seaborn>=0.11.1
matplotlib>=3.4.1
matplotlib>=3.3.4
stable-baselines3>=1.0
pygame>=2.1.0
gym>=0.18.0
@ -10,3 +11,4 @@ networkx>=2.6.3
simplejson>=3.17.5
PyYAML>=6.0
einops
natsort

View File

@ -1,14 +1,13 @@
from algorithms.utils import Checkpointer
from pathlib import Path
from algorithms.utils import load_yaml_file, add_env_props, instantiate_class, load_class
from algorithms.marl import LoopSNAC, LoopIAC, LoopSEAC
#from algorithms.marl import LoopSNAC, LoopIAC, LoopSEAC
#study_root = Path(__file__).parent / 'curious_study'
study_root = Path('/Users/romue/PycharmProjects/EDYS/algorithms/marl')
for i in range(0, 5):
for name in ['example_config']:
for name in ['mappo']:#['seac', 'iac', 'snac']:
study_root = Path(__file__).parent / name
cfg = load_yaml_file(study_root / f'{name}.yaml')
add_env_props(cfg)
@ -17,7 +16,7 @@ for i in range(0, 5):
max_steps = cfg['algorithm']['max_steps']
n_steps = cfg['algorithm']['n_steps']
checkpointer = Checkpointer(f'{name}#{i}', study_root, cfg, max_steps, 250)
checkpointer = Checkpointer(f'{name}#{i}', study_root, cfg, max_steps, 50)
loop = load_class(cfg['method'])(cfg)
df = loop.train_loop(checkpointer)

View File

@ -1,32 +1,22 @@
import numpy as np
import pandas as pd
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
study_root = Path(__file__).parent / 'entropy_study'
names_all = ['basic_gru', 'layernorm_gru', 'spectralnorm_gru', 'nonorm_gru']
names_only_1 = ['L2OnlyAh_gru', 'L2OnlyChAh_gru', 'L2OnlyMix_gru', 'basic_gru']
names_only_2 = ['L2NoCh_gru', 'L2NoAh_gru', 'nomix_gru', 'basic_gru']
names = names_only_2
#names = ['nonorm_gru']
# /Users/romue/PycharmProjects/EDYS/studies/normalization_study/basic_gru#3
csvs = []
for name in ['basic_gru', 'nonorm_gru', 'spectralnorm_gru']:
for run in range(0, 1):
dfs = []
for name in ['l2snac', 'iac', 'snac', 'seac']:
for c in range(5):
try:
df = pd.read_csv(study_root / f'{name}#{run}' / 'results.csv')
df = df[df.agent == 'sum']
df = df.groupby(['checkpoint', 'run']).mean().reset_index()
df['method'] = name
df['run_'] = run
df.reward = df.reward.rolling(15).mean()
csvs.append(df)
study_root = Path(__file__).parent / name / f'{name}#{c}'
df = pd.read_csv(study_root / 'results.csv', index_col=False)
df.reward = df.reward.rolling(100).mean()
df['method'] = name.upper()
dfs.append(df)
except Exception as e:
print(f'skipped {run}\t {name}')
pass
csvs = pd.concat(csvs).rename(columns={"checkpoint": "steps*2e3", "B": "c"})
sns.lineplot(data=csvs, x='steps*2e3', y='reward', hue='method', palette='husl', ci='sd', linewidth=1.8)
plt.savefig('entropy.png')
df = pd.concat(dfs).reset_index()
sns.lineplot(data=df, x='episode', y='reward', hue='method', palette='husl', ci='sd', linewidth=1.5)
plt.savefig('study.png')
print('saved image')

View File

@ -3,19 +3,21 @@ from algorithms.marl import LoopSNAC, LoopIAC, LoopSEAC
from pathlib import Path
from algorithms.utils import load_yaml_file
from tqdm import trange
study = 'curious_study'
study_root = Path(__file__).parent / study
study = 'example_config#0'
#study_root = Path(__file__).parent / study
study_root = Path('/Users/romue/PycharmProjects/EDYS/algorithms/marl/')
#['L2NoAh_gru', 'L2NoCh_gru', 'nomix_gru']:
render = True
eval_eps = 3
for run in range(0, 5):
for name in ['basic_gru']:#['L2OnlyAh_gru', 'L2OnlyChAh_gru', 'L2OnlyMix_gru']: #['layernorm_gru', 'basic_gru', 'nonorm_gru', 'spectralnorm_gru']:
cfg = load_yaml_file(Path(__file__).parent / study / f'{name}.yaml')
p_root = Path(study_root / f'{name}#{run}')
for name in ['example_config']:#['L2OnlyAh_gru', 'L2OnlyChAh_gru', 'L2OnlyMix_gru']: #['layernorm_gru', 'basic_gru', 'nonorm_gru', 'spectralnorm_gru']:
cfg = load_yaml_file(study_root / study / 'config.yaml')
#p_root = Path(study_root / study / f'{name}#{run}')
dfs = []
for i in trange(500):
path = p_root / f'checkpoint_{i}'
path = study_root / study / f'checkpoint_{161}'
print(path)
snac = LoopSEAC(cfg)
snac.load_state_dict(path)