mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-06-18 18:52:52 +02:00
added individual eps-greedy for VDN
This commit is contained in:
@ -1,4 +1,6 @@
|
||||
from typing import Union
|
||||
import torch
|
||||
import numpy as np
|
||||
from algorithms.q_learner import QLearner
|
||||
|
||||
|
||||
@ -7,6 +9,21 @@ class VDNLearner(QLearner):
|
||||
super(VDNLearner, self).__init__(*args, **kwargs)
|
||||
assert self.n_agents >= 2, 'VDN requires more than one agent, use QLearner instead'
|
||||
|
||||
def get_action(self, obs) -> Union[int, np.ndarray]:
|
||||
o = torch.from_numpy(obs).unsqueeze(0) if self.n_agents <= 1 else torch.from_numpy(obs)
|
||||
eps = np.random.rand(self.n_agents)
|
||||
greedy = eps > self.eps
|
||||
agent_actions = None
|
||||
actions = []
|
||||
for i in range(self.n_agents):
|
||||
if greedy[i]:
|
||||
if agent_actions is None: agent_actions = self.q_net.act(o.float())
|
||||
action = agent_actions[i]
|
||||
else:
|
||||
action = self.env.action_space.sample()
|
||||
actions.append(action)
|
||||
return np.array(actions)
|
||||
|
||||
def train(self):
|
||||
if len(self.buffer) < self.batch_size: return
|
||||
for _ in range(self.n_grad_steps):
|
||||
|
Reference in New Issue
Block a user