add more efficient (lazy) experience queue implementation based on tensor, adjusted marl algorithms
This commit is contained in:
@ -1,6 +1,5 @@
|
||||
import torch
|
||||
from typing import Union, List
|
||||
import copy
|
||||
import numpy as np
|
||||
from torch.distributions import Categorical
|
||||
from algorithms.marl.memory import MARLActorCriticMemory
|
||||
@ -8,6 +7,28 @@ from algorithms.utils import add_env_props, instantiate_class
|
||||
from pathlib import Path
|
||||
import pandas as pd
|
||||
from collections import deque
|
||||
|
||||
|
||||
class Names:
|
||||
REWARD = 'reward'
|
||||
DONE = 'done'
|
||||
ACTION = 'action'
|
||||
OBSERVATION = 'observation'
|
||||
LOGITS = 'logits'
|
||||
HIDDEN_ACTOR = 'hidden_actor'
|
||||
HIDDEN_CRITIC = 'hidden_critic'
|
||||
AGENT = 'agent'
|
||||
ENV = 'env'
|
||||
N_AGENTS = 'n_agents'
|
||||
ALGORITHM = 'algorithm'
|
||||
MAX_STEPS = 'max_steps'
|
||||
N_STEPS = 'n_steps'
|
||||
BUFFER_SIZE = 'buffer_size'
|
||||
CRITIC = 'critic'
|
||||
BATCH_SIZE = 'bnatch_size'
|
||||
N_ACTIONS = 'n_actions'
|
||||
|
||||
nms = Names
|
||||
ListOrTensor = Union[List, torch.Tensor]
|
||||
|
||||
|
||||
@ -16,11 +37,12 @@ class BaseActorCritic:
|
||||
add_env_props(cfg)
|
||||
self.__training = True
|
||||
self.cfg = cfg
|
||||
self.n_agents = cfg['env']['n_agents']
|
||||
self.n_agents = cfg[nms.ENV][nms.N_AGENTS]
|
||||
self.reset_memory_after_epoch = True
|
||||
self.setup()
|
||||
|
||||
def setup(self):
|
||||
self.net = instantiate_class(self.cfg['agent'])
|
||||
self.net = instantiate_class(self.cfg[nms.AGENT])
|
||||
self.optimizer = torch.optim.RMSprop(self.net.parameters(), lr=3e-4, eps=1e-5)
|
||||
|
||||
@classmethod
|
||||
@ -49,7 +71,7 @@ class BaseActorCritic:
|
||||
pass
|
||||
|
||||
def get_actions(self, out) -> ListOrTensor:
|
||||
actions = [Categorical(logits=logits).sample().item() for logits in out['logits']]
|
||||
actions = [Categorical(logits=logits).sample().item() for logits in out[nms.LOGITS]]
|
||||
return actions
|
||||
|
||||
def init_hidden(self) -> dict[ListOrTensor]:
|
||||
@ -63,47 +85,48 @@ class BaseActorCritic:
|
||||
) -> dict[ListOrTensor]:
|
||||
pass
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def train_loop(self, checkpointer=None):
|
||||
env = instantiate_class(self.cfg['env'])
|
||||
n_steps, max_steps = [self.cfg['algorithm'][k] for k in ['n_steps', 'max_steps']]
|
||||
global_steps, episode, df_results = 0, 0, []
|
||||
env = instantiate_class(self.cfg[nms.ENV])
|
||||
n_steps, max_steps = [self.cfg[nms.ALGORITHM][k] for k in [nms.N_STEPS, nms.MAX_STEPS]]
|
||||
tm = MARLActorCriticMemory(self.n_agents, self.cfg[nms.ALGORITHM].get(nms.BUFFER_SIZE, n_steps))
|
||||
global_steps, episode, df_results = 0, 0, []
|
||||
reward_queue = deque(maxlen=2000)
|
||||
memory_queue = deque(maxlen=self.cfg['algorithm'].get('keep_n_segments', 1))
|
||||
|
||||
while global_steps < max_steps:
|
||||
tm = MARLActorCriticMemory(self.n_agents)
|
||||
obs = env.reset()
|
||||
last_hiddens = self.init_hidden()
|
||||
last_action, reward = [-1] * self.n_agents, [0.] * self.n_agents
|
||||
done, rew_log = [False] * self.n_agents, 0
|
||||
tm.add(action=last_action, **last_hiddens)
|
||||
done, rew_log = [False] * self.n_agents, 0
|
||||
|
||||
if self.reset_memory_after_epoch:
|
||||
tm.reset()
|
||||
|
||||
tm.add(observation=obs, action=last_action,
|
||||
logits=torch.zeros(self.n_agents, 1, self.cfg[nms.AGENT][nms.N_ACTIONS]),
|
||||
values=torch.zeros(self.n_agents, 1), reward=reward, done=done, **last_hiddens)
|
||||
|
||||
while not all(done):
|
||||
|
||||
out = self.forward(obs, last_action, **last_hiddens)
|
||||
action = self.get_actions(out)
|
||||
next_obs, reward, done, info = env.step(action)
|
||||
next_obs = next_obs
|
||||
if isinstance(done, bool): done = [done] * self.n_agents
|
||||
done = [done] * self.n_agents if isinstance(done, bool) else done
|
||||
|
||||
last_hiddens = dict(hidden_actor =out[nms.HIDDEN_ACTOR],
|
||||
hidden_critic=out[nms.HIDDEN_CRITIC])
|
||||
|
||||
|
||||
tm.add(observation=obs, action=action, reward=reward, done=done,
|
||||
logits=out.get('logits', None), values=out.get('critic', None))
|
||||
logits=out.get(nms.LOGITS, None), values=out.get(nms.CRITIC, None),
|
||||
**last_hiddens)
|
||||
|
||||
obs = next_obs
|
||||
last_action = action
|
||||
last_hiddens = dict(hidden_actor=out.get('hidden_actor', None),
|
||||
hidden_critic=out.get('hidden_critic', None)
|
||||
)
|
||||
|
||||
if len(tm) >= n_steps or all(done):
|
||||
tm.add(observation=next_obs)
|
||||
memory_queue.append(copy.deepcopy(tm))
|
||||
if self.__training:
|
||||
with torch.inference_mode(False):
|
||||
tm_ = tm if memory_queue.maxlen <= 1 else list(memory_queue)
|
||||
self.learn(tm_)
|
||||
tm.reset()
|
||||
tm.add(action=last_action, **last_hiddens)
|
||||
if (global_steps+1) % n_steps == 0 or all(done):
|
||||
with torch.inference_mode(False):
|
||||
self.learn(tm)
|
||||
|
||||
global_steps += 1
|
||||
rew_log += sum(reward)
|
||||
reward_queue.extend(reward)
|
||||
@ -114,18 +137,19 @@ class BaseActorCritic:
|
||||
for i, agent in enumerate([self.net] if not isinstance(self.net, List) else self.net)
|
||||
])
|
||||
|
||||
if global_steps >= max_steps: break
|
||||
print(f'reward at step: {episode} = {rew_log}')
|
||||
if global_steps >= max_steps:
|
||||
break
|
||||
print(f'reward at episode: {episode} = {rew_log}')
|
||||
episode += 1
|
||||
df_results.append([global_steps, rew_log])
|
||||
df_results = pd.DataFrame(df_results, columns=['steps', 'reward'])
|
||||
df_results.append([episode, rew_log, *reward])
|
||||
df_results = pd.DataFrame(df_results, columns=['steps', 'reward', *[f'agent#{i}' for i in range(self.n_agents)]])
|
||||
if checkpointer is not None:
|
||||
df_results.to_csv(checkpointer.path / 'results.csv', index=False)
|
||||
return df_results
|
||||
|
||||
@torch.inference_mode(True)
|
||||
def eval_loop(self, n_episodes, render=False):
|
||||
env = instantiate_class(self.cfg['env'])
|
||||
env = instantiate_class(self.cfg[nms.ENV])
|
||||
episode, results = 0, []
|
||||
while episode < n_episodes:
|
||||
obs = env.reset()
|
||||
@ -142,8 +166,8 @@ class BaseActorCritic:
|
||||
if isinstance(done, bool): done = [done] * obs.shape[0]
|
||||
obs = next_obs
|
||||
last_action = action
|
||||
last_hiddens = dict(hidden_actor=out.get('hidden_actor', None),
|
||||
hidden_critic=out.get('hidden_critic', None)
|
||||
last_hiddens = dict(hidden_actor=out.get(nms.HIDDEN_ACTOR, None),
|
||||
hidden_critic=out.get(nms.HIDDEN_CRITIC, None)
|
||||
)
|
||||
eps_rew += torch.tensor(reward)
|
||||
results.append(eps_rew.tolist() + [sum(eps_rew).item()] + [episode])
|
||||
@ -169,11 +193,11 @@ class BaseActorCritic:
|
||||
return gaes
|
||||
|
||||
def actor_critic(self, tm, network, gamma, entropy_coef, vf_coef, gae_coef=0.0, **kwargs):
|
||||
obs, actions, done, reward = tm.observation, tm.action, tm.done, tm.reward
|
||||
obs, actions, done, reward = tm.observation, tm.action, tm.done[:, 1:], tm.reward[:, 1:]
|
||||
|
||||
out = network(obs, actions, tm.hidden_actor, tm.hidden_critic)
|
||||
logits = out['logits'][:, :-1] # last one only needed for v_{t+1}
|
||||
critic = out['critic']
|
||||
out = network(obs, actions, tm.hidden_actor[:, 0], tm.hidden_critic[:, 0])
|
||||
logits = out[nms.LOGITS][:, :-1] # last one only needed for v_{t+1}
|
||||
critic = out[nms.CRITIC]
|
||||
|
||||
entropy_loss = Categorical(logits=logits).entropy().mean(-1)
|
||||
advantages = self.compute_advantages(critic, reward, done, gamma, gae_coef)
|
||||
@ -188,7 +212,7 @@ class BaseActorCritic:
|
||||
return loss.mean()
|
||||
|
||||
def learn(self, tm: MARLActorCriticMemory, **kwargs):
|
||||
loss = self.actor_critic(tm, self.net, **self.cfg['algorithm'], **kwargs)
|
||||
loss = self.actor_critic(tm, self.net, **self.cfg[nms.ALGORITHM], **kwargs)
|
||||
# remove next_obs, will be added in next iter
|
||||
self.optimizer.zero_grad()
|
||||
loss.backward()
|
||||
|
Reference in New Issue
Block a user