my update

This commit is contained in:
romue
2021-11-11 10:59:13 +01:00
parent 6287380f60
commit ea4582a59e
10 changed files with 88 additions and 308 deletions

View File

@ -1,3 +1,4 @@
import numpy as np
import torch
import torch.nn.functional as F
from algorithms.q_learner import QLearner
@ -53,19 +54,24 @@ class MQLearner(QLearner):
self._backprop_loss(loss)
from tqdm import trange
from collections import deque
class MQICMLearner(MQLearner):
def __init__(self, *args, icm, **kwargs):
super(MQICMLearner, self).__init__(*args, **kwargs)
self.icm = icm
self.icm_optimizer = torch.optim.Adam(self.icm.parameters())
self.icm_optimizer = torch.optim.AdamW(self.icm.parameters())
self.normalize_reward = deque(maxlen=1000)
def on_all_done(self):
for b in trange(50000):
from collections import deque
losses = deque(maxlen=100)
for b in trange(10000):
batch = self.buffer.sample(128, 0)
s0, s1, a = batch.observation, batch.next_observation, batch.action
loss = self.icm(s0, s1, a.squeeze())['loss']
self.icm_optimizer.zero_grad()
loss.backward()
self.icm_optimizer.step()
losses.append(loss.item())
if b%100 == 0:
print(loss.item())
print(np.mean(losses))