added individual eps-greedy for VDN

This commit is contained in:
romue
2021-06-25 15:42:55 +02:00
parent 42f0dde056
commit 456e48f2e0
4 changed files with 85 additions and 2 deletions

View File

@ -71,6 +71,7 @@ def mlp_maker(dims, flatten=False, activation='elu', activation_last='identity')
return nn.Sequential(OrderedDict(layers))
class BaseDQN(nn.Module):
def __init__(self, dims=[3*5*5, 64, 64, 9]):
super(BaseDQN, self).__init__()
@ -100,3 +101,20 @@ class BaseDDQN(BaseDQN):
advantages = self.advantage_head(features)
values = self.value_head(features)
return values + (advantages - advantages.mean())
class QTRANtestNet(nn.Module):
def __init__(self, backbone_dims=[3*5*5, 64, 64], q_head=[64, 9]):
super(QTRANtestNet, self).__init__()
self.backbone = mlp_maker(backbone_dims, flatten=True, activation_last='elu')
self.q_head = mlp_maker(q_head)
def forward(self, x):
features = self.backbone(x)
qs = self.q_head(features)
return qs, features
@torch.no_grad()
def act(self, x) -> np.ndarray:
action = self.forward(x)[0].max(-1)[1].numpy()
return action