first icm steps
This commit is contained in:
@ -24,7 +24,7 @@ class BaseLearner:
|
||||
self.n_agents = n_agents
|
||||
self.n_grad_steps = n_grad_steps
|
||||
self.train_every = train_every
|
||||
self.stack_n_frames = deque(stack_n_frames)
|
||||
self.stack_n_frames = deque(maxlen=stack_n_frames)
|
||||
self.device = 'cpu'
|
||||
self.n_updates = 0
|
||||
self.step = 0
|
||||
@ -51,6 +51,9 @@ class BaseLearner:
|
||||
def on_episode_end(self, n_steps):
|
||||
pass
|
||||
|
||||
def on_all_done(self):
|
||||
pass
|
||||
|
||||
def train(self):
|
||||
pass
|
||||
|
||||
@ -93,6 +96,7 @@ class BaseLearner:
|
||||
f' eps: {self.eps:.4f}\tRunning loss: {sum(list(self.running_loss)) / len(self.running_loss):.4f}\tUpdates:{self.n_updates}')
|
||||
except Exception as e:
|
||||
pass
|
||||
self.on_all_done()
|
||||
|
||||
|
||||
class BaseBuffer:
|
||||
@ -180,18 +184,17 @@ class BaseDDQN(BaseDQN):
|
||||
return values + (advantages - advantages.mean())
|
||||
|
||||
|
||||
class QTRANtestNet(nn.Module):
|
||||
def __init__(self, backbone_dims=[3*5*5, 64, 64], q_head=[64, 9]):
|
||||
super(QTRANtestNet, self).__init__()
|
||||
self.backbone = mlp_maker(backbone_dims, flatten=True, activation_last='elu')
|
||||
self.q_head = mlp_maker(q_head)
|
||||
class BaseICM(nn.Module):
|
||||
def __init__(self, backbone_dims=[2*3*5*5, 64, 64], head_dims=[2*64, 64, 9]):
|
||||
super(BaseICM, self).__init__()
|
||||
self.backbone = mlp_maker(backbone_dims, flatten=True)
|
||||
self.icm = mlp_maker(head_dims)
|
||||
self.ce = nn.CrossEntropyLoss()
|
||||
|
||||
def forward(self, x):
|
||||
features = self.backbone(x)
|
||||
qs = self.q_head(features)
|
||||
return qs, features
|
||||
|
||||
@torch.no_grad()
|
||||
def act(self, x) -> np.ndarray:
|
||||
action = self.forward(x)[0].max(-1)[1].numpy()
|
||||
return action
|
||||
def forward(self, s0, s1, a):
|
||||
phi_s0 = self.backbone(s0)
|
||||
phi_s1 = self.backbone(s1)
|
||||
cat = torch.cat((phi_s0, phi_s1), dim=1)
|
||||
a_prime = torch.softmax(self.icm(cat), dim=-1)
|
||||
ce = self.ce(a_prime, a)
|
||||
return dict(prediction=a_prime, loss=ce)
|
Reference in New Issue
Block a user