added individual eps-greedy for VDN
This commit is contained in:
@ -71,6 +71,7 @@ def mlp_maker(dims, flatten=False, activation='elu', activation_last='identity')
|
||||
return nn.Sequential(OrderedDict(layers))
|
||||
|
||||
|
||||
|
||||
class BaseDQN(nn.Module):
|
||||
def __init__(self, dims=[3*5*5, 64, 64, 9]):
|
||||
super(BaseDQN, self).__init__()
|
||||
@ -100,3 +101,20 @@ class BaseDDQN(BaseDQN):
|
||||
advantages = self.advantage_head(features)
|
||||
values = self.value_head(features)
|
||||
return values + (advantages - advantages.mean())
|
||||
|
||||
|
||||
class QTRANtestNet(nn.Module):
|
||||
def __init__(self, backbone_dims=[3*5*5, 64, 64], q_head=[64, 9]):
|
||||
super(QTRANtestNet, self).__init__()
|
||||
self.backbone = mlp_maker(backbone_dims, flatten=True, activation_last='elu')
|
||||
self.q_head = mlp_maker(q_head)
|
||||
|
||||
def forward(self, x):
|
||||
features = self.backbone(x)
|
||||
qs = self.q_head(features)
|
||||
return qs, features
|
||||
|
||||
@torch.no_grad()
|
||||
def act(self, x) -> np.ndarray:
|
||||
action = self.forward(x)[0].max(-1)[1].numpy()
|
||||
return action
|
Reference in New Issue
Block a user