added first working MAPPO implementation
This commit is contained in:
parent
ffc47752a7
commit
b09c461754
@ -1,4 +1,6 @@
|
||||
from algorithms.marl.base_ac import BaseActorCritic
|
||||
from algorithms.marl.iac import LoopIAC
|
||||
from algorithms.marl.snac import LoopSNAC
|
||||
from algorithms.marl.seac import LoopSEAC
|
||||
from algorithms.marl.iac import LoopIAC
|
||||
from algorithms.marl.snac import LoopSNAC
|
||||
from algorithms.marl.seac import LoopSEAC
|
||||
from algorithms.marl.mappo import LoopMAPPO
|
||||
from algorithms.marl.memory import MARLActorCriticMemory
|
@ -1,5 +1,6 @@
|
||||
import torch
|
||||
from typing import Union, List
|
||||
import copy
|
||||
import numpy as np
|
||||
from torch.distributions import Categorical
|
||||
from algorithms.marl.memory import MARLActorCriticMemory
|
||||
@ -59,7 +60,7 @@ class BaseActorCritic:
|
||||
actions: ListOrTensor,
|
||||
hidden_actor: ListOrTensor,
|
||||
hidden_critic: ListOrTensor
|
||||
):
|
||||
) -> dict[ListOrTensor]:
|
||||
pass
|
||||
|
||||
|
||||
@ -67,8 +68,9 @@ class BaseActorCritic:
|
||||
def train_loop(self, checkpointer=None):
|
||||
env = instantiate_class(self.cfg['env'])
|
||||
n_steps, max_steps = [self.cfg['algorithm'][k] for k in ['n_steps', 'max_steps']]
|
||||
global_steps = 0
|
||||
global_steps, episode, df_results = 0, 0, []
|
||||
reward_queue = deque(maxlen=2000)
|
||||
memory_queue = deque(maxlen=self.cfg['algorithm'].get('keep_n_segments', 1))
|
||||
while global_steps < max_steps:
|
||||
tm = MARLActorCriticMemory(self.n_agents)
|
||||
obs = env.reset()
|
||||
@ -85,7 +87,8 @@ class BaseActorCritic:
|
||||
next_obs = next_obs
|
||||
if isinstance(done, bool): done = [done] * self.n_agents
|
||||
|
||||
tm.add(observation=obs, action=action, reward=reward, done=done)
|
||||
tm.add(observation=obs, action=action, reward=reward, done=done,
|
||||
logits=out.get('logits', None), values=out.get('critic', None))
|
||||
obs = next_obs
|
||||
last_action = action
|
||||
last_hiddens = dict(hidden_actor=out.get('hidden_actor', None),
|
||||
@ -94,9 +97,11 @@ class BaseActorCritic:
|
||||
|
||||
if len(tm) >= n_steps or all(done):
|
||||
tm.add(observation=next_obs)
|
||||
memory_queue.append(copy.deepcopy(tm))
|
||||
if self.__training:
|
||||
with torch.inference_mode(False):
|
||||
self.learn(tm)
|
||||
tm_ = tm if memory_queue.maxlen <= 1 else list(memory_queue)
|
||||
self.learn(tm_)
|
||||
tm.reset()
|
||||
tm.add(action=last_action, **last_hiddens)
|
||||
global_steps += 1
|
||||
@ -110,7 +115,13 @@ class BaseActorCritic:
|
||||
])
|
||||
|
||||
if global_steps >= max_steps: break
|
||||
print(f'reward at step: {global_steps} = {rew_log}')
|
||||
print(f'reward at step: {episode} = {rew_log}')
|
||||
episode += 1
|
||||
df_results.append([global_steps, rew_log])
|
||||
df_results = pd.DataFrame(df_results, columns=['steps', 'reward'])
|
||||
if checkpointer is not None:
|
||||
df_results.to_csv(checkpointer.path / 'results.csv', index=False)
|
||||
return df_results
|
||||
|
||||
@torch.inference_mode(True)
|
||||
def eval_loop(self, n_episodes, render=False):
|
||||
@ -143,10 +154,21 @@ class BaseActorCritic:
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def compute_advantages(critic, reward, done, gamma):
|
||||
return (reward + gamma * (1.0 - done) * critic[:, 1:].detach()) - critic[:, :-1]
|
||||
def compute_advantages(critic, reward, done, gamma, gae_coef=0.0):
|
||||
tds = (reward + gamma * (1.0 - done) * critic[:, 1:].detach()) - critic[:, :-1]
|
||||
|
||||
def actor_critic(self, tm, network, gamma, entropy_coef, vf_coef, **kwargs):
|
||||
if gae_coef <= 0:
|
||||
return tds
|
||||
|
||||
gae = torch.zeros_like(tds[:, -1])
|
||||
gaes = []
|
||||
for t in range(tds.shape[1]-1, -1, -1):
|
||||
gae = tds[:, t] + gamma * gae_coef * (1.0 - done[:, t]) * gae
|
||||
gaes.insert(0, gae)
|
||||
gaes = torch.stack(gaes, dim=1)
|
||||
return gaes
|
||||
|
||||
def actor_critic(self, tm, network, gamma, entropy_coef, vf_coef, gae_coef=0.0, **kwargs):
|
||||
obs, actions, done, reward = tm.observation, tm.action, tm.done, tm.reward
|
||||
|
||||
out = network(obs, actions, tm.hidden_actor, tm.hidden_critic)
|
||||
@ -154,7 +176,7 @@ class BaseActorCritic:
|
||||
critic = out['critic']
|
||||
|
||||
entropy_loss = Categorical(logits=logits).entropy().mean(-1)
|
||||
advantages = self.compute_advantages(critic, reward, done, gamma)
|
||||
advantages = self.compute_advantages(critic, reward, done, gamma, gae_coef)
|
||||
value_loss = advantages.pow(2).mean(-1) # n_agent
|
||||
|
||||
# policy loss
|
||||
@ -163,7 +185,6 @@ class BaseActorCritic:
|
||||
a2c_loss = -(advantages.detach() * log_ap).mean(-1)
|
||||
# weighted loss
|
||||
loss = a2c_loss + vf_coef*value_loss - entropy_coef * entropy_loss
|
||||
|
||||
return loss.mean()
|
||||
|
||||
def learn(self, tm: MARLActorCriticMemory, **kwargs):
|
||||
|
@ -21,7 +21,6 @@ class LoopIAC(BaseActorCritic):
|
||||
|
||||
def load_state_dict(self, path: Path):
|
||||
paths = natsorted(list(path.glob('*.pt')))
|
||||
print(list(paths))
|
||||
for path, net in zip(paths, self.net):
|
||||
net.load_state_dict(torch.load(path))
|
||||
|
||||
|
78
algorithms/marl/mappo.py
Normal file
78
algorithms/marl/mappo.py
Normal file
@ -0,0 +1,78 @@
|
||||
from algorithms.marl import LoopSNAC
|
||||
from algorithms.marl.memory import MARLActorCriticMemory
|
||||
from typing import List
|
||||
import random
|
||||
import torch
|
||||
from torch.distributions import Categorical
|
||||
|
||||
|
||||
class LoopMAPPO(LoopSNAC):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(LoopMAPPO, self).__init__(*args, **kwargs)
|
||||
|
||||
def build_batch(self, tm: List[MARLActorCriticMemory]):
|
||||
sample = random.choices(tm, k=self.cfg['algorithm']['batch_size']-1)
|
||||
sample.append(tm[-1]) # always use latest segment in batch
|
||||
|
||||
obs = torch.cat([s.observation for s in sample], 0)
|
||||
actions = torch.cat([s.action for s in sample], 0)
|
||||
hidden_actor = torch.cat([s.hidden_actor for s in sample], 0)
|
||||
hidden_critic = torch.cat([s.hidden_critic for s in sample], 0)
|
||||
logits = torch.cat([s.logits for s in sample], 0)
|
||||
values = torch.cat([s.values for s in sample], 0)
|
||||
reward = torch.cat([s.reward for s in sample], 0)
|
||||
done = torch.cat([s.done for s in sample], 0)
|
||||
|
||||
|
||||
log_props = torch.log_softmax(logits, -1)
|
||||
log_props = torch.gather(log_props, index=actions[:, 1:].unsqueeze(-1), dim=-1).squeeze()
|
||||
|
||||
return obs, actions, hidden_actor, hidden_critic, log_props, values, reward, done
|
||||
|
||||
def learn(self, tm: List[MARLActorCriticMemory], **kwargs):
|
||||
if len(tm) >= self.cfg['algorithm']['keep_n_segments']:
|
||||
# only learn when buffer is full
|
||||
for batch_i in range(self.cfg['algorithm']['n_updates']):
|
||||
loss = self.actor_critic(tm, self.net, **self.cfg['algorithm'], **kwargs)
|
||||
self.optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(self.net.parameters(), 0.5)
|
||||
self.optimizer.step()
|
||||
|
||||
def monte_carlo_returns(self, rewards, done, gamma):
|
||||
rewards_ = []
|
||||
discounted_reward = torch.zeros_like(rewards[:, -1])
|
||||
for t in range(rewards.shape[1]-1, -1, -1):
|
||||
discounted_reward = rewards[:, t] + (gamma * (1.0 - done[:, t]) * discounted_reward)
|
||||
rewards_.insert(0, discounted_reward)
|
||||
rewards_ = torch.stack(rewards_, dim=1)
|
||||
return rewards_
|
||||
|
||||
def actor_critic(self, tm, network, gamma, entropy_coef, vf_coef, clip_range, gae_coef=0.0, **kwargs):
|
||||
obs, actions, hidden_actor, hidden_critic, old_log_probs, old_critic, reward, done = self.build_batch(tm)
|
||||
|
||||
out = network(obs, actions, hidden_actor, hidden_critic)
|
||||
logits = out['logits'][:, :-1] # last one only needed for v_{t+1}
|
||||
critic = out['critic']
|
||||
|
||||
# monte carlo returns
|
||||
mc_returns = self.monte_carlo_returns(reward, done, gamma)
|
||||
# monte_carlo_returns = (mc_returns - mc_returns.mean()) / (mc_returns.std() + 1e-7) todo: norm across agents?
|
||||
advantages = mc_returns - critic[:, :-1]
|
||||
|
||||
# policy loss
|
||||
log_ap = torch.log_softmax(logits, -1)
|
||||
log_ap = torch.gather(log_ap, dim=-1, index=actions[:, 1:].unsqueeze(-1)).squeeze()
|
||||
ratio = (log_ap - old_log_probs).exp()
|
||||
surr1 = ratio * advantages.detach()
|
||||
surr2 = torch.clamp(ratio, 1 - clip_range, 1 + clip_range) * advantages.detach()
|
||||
policy_loss = -torch.min(surr1, surr2).mean(-1)
|
||||
|
||||
# entropy & value loss
|
||||
entropy_loss = Categorical(logits=logits).entropy().mean(-1)
|
||||
value_loss = advantages.pow(2).mean(-1) # n_agent
|
||||
|
||||
# weighted loss
|
||||
loss = policy_loss + vf_coef*value_loss - entropy_coef * entropy_loss
|
||||
|
||||
return loss.mean()
|
@ -13,14 +13,16 @@ class ActorCriticMemory(object):
|
||||
self.__actions = []
|
||||
self.__rewards = []
|
||||
self.__dones = []
|
||||
self.__hiddens_actor = []
|
||||
self.__hiddens_actor = []
|
||||
self.__hiddens_critic = []
|
||||
self.__logits = []
|
||||
self.__values = []
|
||||
|
||||
def __len__(self):
|
||||
return len(self.__states)
|
||||
|
||||
@property
|
||||
def observation(self):
|
||||
def observation(self): # add time dimension through stacking
|
||||
return torch.stack(self.__states, 0).unsqueeze(0) # 1 x timesteps x hidden dim
|
||||
|
||||
@property
|
||||
@ -47,6 +49,14 @@ class ActorCriticMemory(object):
|
||||
def done(self):
|
||||
return torch.tensor(self.__dones).float().unsqueeze(0) # 1 x timesteps
|
||||
|
||||
@property
|
||||
def logits(self): # assumes a trailing 1 for time dimension - common when using output from NN
|
||||
return torch.cat(self.__logits, 0).unsqueeze(0) # 1 x timesteps x actions
|
||||
|
||||
@property
|
||||
def values(self):
|
||||
return torch.cat(self.__values, 0).unsqueeze(0) # 1 x timesteps x actions
|
||||
|
||||
def add_observation(self, state: Union[Tensor, np.ndarray]):
|
||||
self.__states.append(state if isinstance(state, Tensor) else torch.from_numpy(state))
|
||||
|
||||
@ -69,6 +79,12 @@ class ActorCriticMemory(object):
|
||||
def add_done(self, done: bool):
|
||||
self.__dones.append(done)
|
||||
|
||||
def add_logits(self, logits: Tensor):
|
||||
self.__logits.append(logits)
|
||||
|
||||
def add_values(self, logits: Tensor):
|
||||
self.__values.append(logits)
|
||||
|
||||
def add(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
func = getattr(ActorCriticMemory, f'add_{k}')
|
||||
@ -129,3 +145,14 @@ class MARLActorCriticMemory(object):
|
||||
all_hc = [mem.hidden_critic for mem in self.memories]
|
||||
return torch.cat(all_hc, 0) # agents x layers x timesteps x hidden dim
|
||||
|
||||
@property
|
||||
def logits(self):
|
||||
all_lgts = [mem.logits for mem in self.memories]
|
||||
return torch.cat(all_lgts, 0) # agents x layers x timesteps x hidden dim
|
||||
|
||||
@property
|
||||
def values(self):
|
||||
all_vals = [mem.values for mem in self.memories]
|
||||
return torch.cat(all_vals, 0) # agents x layers x timesteps x hidden dim
|
||||
|
||||
|
||||
|
@ -12,6 +12,7 @@ class RecurrentAC(nn.Module):
|
||||
super(RecurrentAC, self).__init__()
|
||||
observation_size = np.prod(observation_size)
|
||||
self.n_layers = 1
|
||||
self.n_actions = n_actions
|
||||
self.use_agent_embedding = use_agent_embedding
|
||||
self.hidden_size_actor = hidden_size_actor
|
||||
self.hidden_size_critic = hidden_size_critic
|
||||
@ -25,13 +26,14 @@ class RecurrentAC(nn.Module):
|
||||
nn.Tanh(),
|
||||
nn.Linear(obs_emb_size, obs_emb_size)
|
||||
)
|
||||
self.gru_actor = nn.GRU(obs_emb_size, hidden_size_actor, batch_first=True, num_layers=self.n_layers)
|
||||
self.gru_actor = nn.GRU(obs_emb_size, hidden_size_actor, batch_first=True, num_layers=self.n_layers)
|
||||
self.gru_critic = nn.GRU(obs_emb_size, hidden_size_critic, batch_first=True, num_layers=self.n_layers)
|
||||
self.action_head = nn.Sequential(
|
||||
spectral_norm(nn.Linear(hidden_size_actor, hidden_size_actor)),
|
||||
nn.Linear(hidden_size_actor, hidden_size_actor),
|
||||
nn.Tanh(),
|
||||
nn.Linear(hidden_size_actor, n_actions)
|
||||
)
|
||||
# spectral_norm(nn.Linear(hidden_size_actor, hidden_size_actor)),
|
||||
self.critic_head = nn.Sequential(
|
||||
nn.Linear(hidden_size_critic, hidden_size_critic),
|
||||
nn.Tanh(),
|
||||
@ -50,12 +52,14 @@ class RecurrentAC(nn.Module):
|
||||
n_agents, t, *_ = observations.shape
|
||||
obs_emb = self.obs_proj(observations.view(n_agents, t, -1).float())
|
||||
action_emb = self.action_emb(actions+1) # shift by one due to padding idx
|
||||
agent_emb = self.agent_emb(
|
||||
torch.cat([torch.arange(0, n_agents, 1).view(-1, 1)]*t, 1)
|
||||
)
|
||||
x_t = torch.cat((obs_emb, action_emb), -1) \
|
||||
if not self.use_agent_embedding else torch.cat((obs_emb, agent_emb, action_emb), -1)
|
||||
|
||||
if not self.use_agent_embedding:
|
||||
x_t = torch.cat((obs_emb, action_emb), -1)
|
||||
else:
|
||||
agent_emb = self.agent_emb(
|
||||
torch.cat([torch.arange(0, n_agents, 1).view(-1, 1)] * t, 1)
|
||||
)
|
||||
x_t = torch.cat((obs_emb, agent_emb, action_emb), -1)
|
||||
|
||||
mixed_x_t = self.mix(x_t)
|
||||
output_p, _ = self.gru_actor(input=mixed_x_t, hx=hidden_actor.swapaxes(1, 0))
|
||||
@ -66,6 +70,15 @@ class RecurrentAC(nn.Module):
|
||||
return dict(logits=logits, critic=critic, hidden_actor=output_p, hidden_critic=output_c)
|
||||
|
||||
|
||||
class RecurrentACL2(RecurrentAC):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.action_head = nn.Sequential(
|
||||
nn.Linear(self.hidden_size_actor, self.hidden_size_actor),
|
||||
nn.Tanh(),
|
||||
NormalizedLinear(self.hidden_size_actor, self.n_actions, trainable_magnitude=True)
|
||||
)
|
||||
|
||||
|
||||
class NormalizedLinear(nn.Linear):
|
||||
def __init__(self, in_features: int, out_features: int,
|
||||
|
@ -8,7 +8,7 @@ class LoopSEAC(LoopIAC):
|
||||
def __init__(self, cfg):
|
||||
super(LoopSEAC, self).__init__(cfg)
|
||||
|
||||
def actor_critic(self, tm, networks, gamma, entropy_coef, vf_coef, **kwargs):
|
||||
def actor_critic(self, tm, networks, gamma, entropy_coef, vf_coef, gae_coef=0.0, **kwargs):
|
||||
obs, actions, done, reward = tm.observation, tm.action, tm.done, tm.reward
|
||||
outputs = [net(obs, actions, tm.hidden_actor, tm.hidden_critic) for net in networks]
|
||||
|
||||
@ -26,7 +26,7 @@ class LoopSEAC(LoopIAC):
|
||||
critic = out['critic']
|
||||
|
||||
entropy_loss = Categorical(logits=logits[ag_i]).entropy().mean()
|
||||
advantages = self.compute_advantages(critic, reward, done, gamma)
|
||||
advantages = self.compute_advantages(critic, reward, done, gamma, gae_coef)
|
||||
|
||||
# policy loss
|
||||
log_ap = torch.log_softmax(logits, -1)
|
||||
|
@ -1,12 +1,14 @@
|
||||
numpy
|
||||
scipy
|
||||
tqdm
|
||||
pandas
|
||||
seaborn>=0.11.1
|
||||
matplotlib>=3.4.1
|
||||
matplotlib>=3.3.4
|
||||
stable-baselines3>=1.0
|
||||
pygame>=2.1.0
|
||||
gym>=0.18.0
|
||||
networkx>=2.6.3
|
||||
simplejson>=3.17.5
|
||||
PyYAML>=6.0
|
||||
einops
|
||||
einops
|
||||
natsort
|
@ -1,14 +1,13 @@
|
||||
from algorithms.utils import Checkpointer
|
||||
from pathlib import Path
|
||||
from algorithms.utils import load_yaml_file, add_env_props, instantiate_class, load_class
|
||||
from algorithms.marl import LoopSNAC, LoopIAC, LoopSEAC
|
||||
#from algorithms.marl import LoopSNAC, LoopIAC, LoopSEAC
|
||||
|
||||
|
||||
#study_root = Path(__file__).parent / 'curious_study'
|
||||
study_root = Path('/Users/romue/PycharmProjects/EDYS/algorithms/marl')
|
||||
|
||||
for i in range(0, 5):
|
||||
for name in ['example_config']:
|
||||
for name in ['mappo']:#['seac', 'iac', 'snac']:
|
||||
study_root = Path(__file__).parent / name
|
||||
cfg = load_yaml_file(study_root / f'{name}.yaml')
|
||||
add_env_props(cfg)
|
||||
|
||||
@ -17,7 +16,7 @@ for i in range(0, 5):
|
||||
max_steps = cfg['algorithm']['max_steps']
|
||||
n_steps = cfg['algorithm']['n_steps']
|
||||
|
||||
checkpointer = Checkpointer(f'{name}#{i}', study_root, cfg, max_steps, 250)
|
||||
checkpointer = Checkpointer(f'{name}#{i}', study_root, cfg, max_steps, 50)
|
||||
|
||||
loop = load_class(cfg['method'])(cfg)
|
||||
df = loop.train_loop(checkpointer)
|
||||
|
@ -1,32 +1,22 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
|
||||
study_root = Path(__file__).parent / 'entropy_study'
|
||||
names_all = ['basic_gru', 'layernorm_gru', 'spectralnorm_gru', 'nonorm_gru']
|
||||
names_only_1 = ['L2OnlyAh_gru', 'L2OnlyChAh_gru', 'L2OnlyMix_gru', 'basic_gru']
|
||||
names_only_2 = ['L2NoCh_gru', 'L2NoAh_gru', 'nomix_gru', 'basic_gru']
|
||||
|
||||
names = names_only_2
|
||||
#names = ['nonorm_gru']
|
||||
# /Users/romue/PycharmProjects/EDYS/studies/normalization_study/basic_gru#3
|
||||
csvs = []
|
||||
for name in ['basic_gru', 'nonorm_gru', 'spectralnorm_gru']:
|
||||
for run in range(0, 1):
|
||||
dfs = []
|
||||
for name in ['l2snac', 'iac', 'snac', 'seac']:
|
||||
for c in range(5):
|
||||
try:
|
||||
df = pd.read_csv(study_root / f'{name}#{run}' / 'results.csv')
|
||||
df = df[df.agent == 'sum']
|
||||
df = df.groupby(['checkpoint', 'run']).mean().reset_index()
|
||||
df['method'] = name
|
||||
df['run_'] = run
|
||||
|
||||
df.reward = df.reward.rolling(15).mean()
|
||||
csvs.append(df)
|
||||
study_root = Path(__file__).parent / name / f'{name}#{c}'
|
||||
df = pd.read_csv(study_root / 'results.csv', index_col=False)
|
||||
df.reward = df.reward.rolling(100).mean()
|
||||
df['method'] = name.upper()
|
||||
dfs.append(df)
|
||||
except Exception as e:
|
||||
print(f'skipped {run}\t {name}')
|
||||
pass
|
||||
|
||||
csvs = pd.concat(csvs).rename(columns={"checkpoint": "steps*2e3", "B": "c"})
|
||||
sns.lineplot(data=csvs, x='steps*2e3', y='reward', hue='method', palette='husl', ci='sd', linewidth=1.8)
|
||||
plt.savefig('entropy.png')
|
||||
df = pd.concat(dfs).reset_index()
|
||||
sns.lineplot(data=df, x='episode', y='reward', hue='method', palette='husl', ci='sd', linewidth=1.5)
|
||||
plt.savefig('study.png')
|
||||
print('saved image')
|
@ -3,19 +3,21 @@ from algorithms.marl import LoopSNAC, LoopIAC, LoopSEAC
|
||||
from pathlib import Path
|
||||
from algorithms.utils import load_yaml_file
|
||||
from tqdm import trange
|
||||
study = 'curious_study'
|
||||
study_root = Path(__file__).parent / study
|
||||
study = 'example_config#0'
|
||||
#study_root = Path(__file__).parent / study
|
||||
study_root = Path('/Users/romue/PycharmProjects/EDYS/algorithms/marl/')
|
||||
|
||||
#['L2NoAh_gru', 'L2NoCh_gru', 'nomix_gru']:
|
||||
render = True
|
||||
eval_eps = 3
|
||||
for run in range(0, 5):
|
||||
for name in ['basic_gru']:#['L2OnlyAh_gru', 'L2OnlyChAh_gru', 'L2OnlyMix_gru']: #['layernorm_gru', 'basic_gru', 'nonorm_gru', 'spectralnorm_gru']:
|
||||
cfg = load_yaml_file(Path(__file__).parent / study / f'{name}.yaml')
|
||||
p_root = Path(study_root / f'{name}#{run}')
|
||||
for name in ['example_config']:#['L2OnlyAh_gru', 'L2OnlyChAh_gru', 'L2OnlyMix_gru']: #['layernorm_gru', 'basic_gru', 'nonorm_gru', 'spectralnorm_gru']:
|
||||
cfg = load_yaml_file(study_root / study / 'config.yaml')
|
||||
#p_root = Path(study_root / study / f'{name}#{run}')
|
||||
dfs = []
|
||||
for i in trange(500):
|
||||
path = p_root / f'checkpoint_{i}'
|
||||
path = study_root / study / f'checkpoint_{161}'
|
||||
print(path)
|
||||
|
||||
snac = LoopSEAC(cfg)
|
||||
snac.load_state_dict(path)
|
||||
|
Loading…
x
Reference in New Issue
Block a user