my update

This commit is contained in:
romue
2021-11-11 10:59:13 +01:00
parent 6287380f60
commit ea4582a59e
10 changed files with 88 additions and 308 deletions

View File

@ -1,6 +1,7 @@
from typing import Union
import torch
import numpy as np
import pandas as pd
from algorithms.q_learner import QLearner
@ -37,4 +38,18 @@ class VDNLearner(QLearner):
target_q_raw += next_q_values_raw
target_q = experience.reward + (1 - experience.done) * self.gamma * target_q_raw
loss = torch.mean(self.reg_weight * pred_q + torch.pow(pred_q - target_q, 2))
self._backprop_loss(loss)
self._backprop_loss(loss)
def evaluate(self, n_episodes=100, render=False):
with torch.no_grad():
data = []
for eval_i in range(n_episodes):
obs, done = self.env.reset(), False
while not done:
action = self.get_action(obs)
next_obs, reward, done, info = self.env.step(action)
if render: self.env.render()
obs = next_obs # srsly i'm so stupid
info.update({'reward': reward, 'eval_episode': eval_i})
data.append(info)
return pd.DataFrame(data).fillna(0)