added first working MAPPO implementation
This commit is contained in:
@ -1,5 +1,6 @@
|
||||
import torch
|
||||
from typing import Union, List
|
||||
import copy
|
||||
import numpy as np
|
||||
from torch.distributions import Categorical
|
||||
from algorithms.marl.memory import MARLActorCriticMemory
|
||||
@ -59,7 +60,7 @@ class BaseActorCritic:
|
||||
actions: ListOrTensor,
|
||||
hidden_actor: ListOrTensor,
|
||||
hidden_critic: ListOrTensor
|
||||
):
|
||||
) -> dict[ListOrTensor]:
|
||||
pass
|
||||
|
||||
|
||||
@ -67,8 +68,9 @@ class BaseActorCritic:
|
||||
def train_loop(self, checkpointer=None):
|
||||
env = instantiate_class(self.cfg['env'])
|
||||
n_steps, max_steps = [self.cfg['algorithm'][k] for k in ['n_steps', 'max_steps']]
|
||||
global_steps = 0
|
||||
global_steps, episode, df_results = 0, 0, []
|
||||
reward_queue = deque(maxlen=2000)
|
||||
memory_queue = deque(maxlen=self.cfg['algorithm'].get('keep_n_segments', 1))
|
||||
while global_steps < max_steps:
|
||||
tm = MARLActorCriticMemory(self.n_agents)
|
||||
obs = env.reset()
|
||||
@ -85,7 +87,8 @@ class BaseActorCritic:
|
||||
next_obs = next_obs
|
||||
if isinstance(done, bool): done = [done] * self.n_agents
|
||||
|
||||
tm.add(observation=obs, action=action, reward=reward, done=done)
|
||||
tm.add(observation=obs, action=action, reward=reward, done=done,
|
||||
logits=out.get('logits', None), values=out.get('critic', None))
|
||||
obs = next_obs
|
||||
last_action = action
|
||||
last_hiddens = dict(hidden_actor=out.get('hidden_actor', None),
|
||||
@ -94,9 +97,11 @@ class BaseActorCritic:
|
||||
|
||||
if len(tm) >= n_steps or all(done):
|
||||
tm.add(observation=next_obs)
|
||||
memory_queue.append(copy.deepcopy(tm))
|
||||
if self.__training:
|
||||
with torch.inference_mode(False):
|
||||
self.learn(tm)
|
||||
tm_ = tm if memory_queue.maxlen <= 1 else list(memory_queue)
|
||||
self.learn(tm_)
|
||||
tm.reset()
|
||||
tm.add(action=last_action, **last_hiddens)
|
||||
global_steps += 1
|
||||
@ -110,7 +115,13 @@ class BaseActorCritic:
|
||||
])
|
||||
|
||||
if global_steps >= max_steps: break
|
||||
print(f'reward at step: {global_steps} = {rew_log}')
|
||||
print(f'reward at step: {episode} = {rew_log}')
|
||||
episode += 1
|
||||
df_results.append([global_steps, rew_log])
|
||||
df_results = pd.DataFrame(df_results, columns=['steps', 'reward'])
|
||||
if checkpointer is not None:
|
||||
df_results.to_csv(checkpointer.path / 'results.csv', index=False)
|
||||
return df_results
|
||||
|
||||
@torch.inference_mode(True)
|
||||
def eval_loop(self, n_episodes, render=False):
|
||||
@ -143,10 +154,21 @@ class BaseActorCritic:
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def compute_advantages(critic, reward, done, gamma):
|
||||
return (reward + gamma * (1.0 - done) * critic[:, 1:].detach()) - critic[:, :-1]
|
||||
def compute_advantages(critic, reward, done, gamma, gae_coef=0.0):
|
||||
tds = (reward + gamma * (1.0 - done) * critic[:, 1:].detach()) - critic[:, :-1]
|
||||
|
||||
def actor_critic(self, tm, network, gamma, entropy_coef, vf_coef, **kwargs):
|
||||
if gae_coef <= 0:
|
||||
return tds
|
||||
|
||||
gae = torch.zeros_like(tds[:, -1])
|
||||
gaes = []
|
||||
for t in range(tds.shape[1]-1, -1, -1):
|
||||
gae = tds[:, t] + gamma * gae_coef * (1.0 - done[:, t]) * gae
|
||||
gaes.insert(0, gae)
|
||||
gaes = torch.stack(gaes, dim=1)
|
||||
return gaes
|
||||
|
||||
def actor_critic(self, tm, network, gamma, entropy_coef, vf_coef, gae_coef=0.0, **kwargs):
|
||||
obs, actions, done, reward = tm.observation, tm.action, tm.done, tm.reward
|
||||
|
||||
out = network(obs, actions, tm.hidden_actor, tm.hidden_critic)
|
||||
@ -154,7 +176,7 @@ class BaseActorCritic:
|
||||
critic = out['critic']
|
||||
|
||||
entropy_loss = Categorical(logits=logits).entropy().mean(-1)
|
||||
advantages = self.compute_advantages(critic, reward, done, gamma)
|
||||
advantages = self.compute_advantages(critic, reward, done, gamma, gae_coef)
|
||||
value_loss = advantages.pow(2).mean(-1) # n_agent
|
||||
|
||||
# policy loss
|
||||
@ -163,7 +185,6 @@ class BaseActorCritic:
|
||||
a2c_loss = -(advantages.detach() * log_ap).mean(-1)
|
||||
# weighted loss
|
||||
loss = a2c_loss + vf_coef*value_loss - entropy_coef * entropy_loss
|
||||
|
||||
return loss.mean()
|
||||
|
||||
def learn(self, tm: MARLActorCriticMemory, **kwargs):
|
||||
|
Reference in New Issue
Block a user