mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-06-18 18:52:52 +02:00
my update
This commit is contained in:
@ -1,6 +1,7 @@
|
||||
from typing import Union
|
||||
import torch
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from algorithms.q_learner import QLearner
|
||||
|
||||
|
||||
@ -37,4 +38,18 @@ class VDNLearner(QLearner):
|
||||
target_q_raw += next_q_values_raw
|
||||
target_q = experience.reward + (1 - experience.done) * self.gamma * target_q_raw
|
||||
loss = torch.mean(self.reg_weight * pred_q + torch.pow(pred_q - target_q, 2))
|
||||
self._backprop_loss(loss)
|
||||
self._backprop_loss(loss)
|
||||
|
||||
def evaluate(self, n_episodes=100, render=False):
|
||||
with torch.no_grad():
|
||||
data = []
|
||||
for eval_i in range(n_episodes):
|
||||
obs, done = self.env.reset(), False
|
||||
while not done:
|
||||
action = self.get_action(obs)
|
||||
next_obs, reward, done, info = self.env.step(action)
|
||||
if render: self.env.render()
|
||||
obs = next_obs # srsly i'm so stupid
|
||||
info.update({'reward': reward, 'eval_episode': eval_i})
|
||||
data.append(info)
|
||||
return pd.DataFrame(data).fillna(0)
|
||||
|
Reference in New Issue
Block a user