Code Comments, Getting Dirty Env, Naming
This commit is contained in:
@@ -8,8 +8,6 @@ import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from einops import rearrange, repeat
|
||||
|
||||
import sys
|
||||
sys.path.append(str(Path(__file__).parent))
|
||||
|
||||
@@ -40,7 +38,7 @@ class LinearModule(ShapeMixin, nn.Module):
|
||||
tensor = self.flat(x)
|
||||
tensor = self.dropout(tensor)
|
||||
tensor = self.norm(tensor)
|
||||
tensor = self.linear(tensor)
|
||||
tensor = self.linear(tensor.float())
|
||||
tensor = self.activation(tensor)
|
||||
return tensor
|
||||
|
||||
@@ -249,6 +247,7 @@ class Attention(nn.Module):
|
||||
) if project_out else nn.Identity()
|
||||
|
||||
def forward(self, x, mask=None, return_attn_weights=False):
|
||||
from einops import rearrange, repeat
|
||||
# noinspection PyTupleAssignmentBalance
|
||||
b, n, _, h = *x.shape, self.heads
|
||||
|
||||
|
@@ -129,8 +129,10 @@ try:
|
||||
self._weight_init = weight_init
|
||||
self.params = ModelParameters(model_parameters)
|
||||
|
||||
self.metrics = PLMetrics(self.params.n_classes, tag='PL')
|
||||
pass
|
||||
if hasattr(self.params, 'n_classes'):
|
||||
self.metrics = PLMetrics(self.params.n_classes, tag='PL')
|
||||
else:
|
||||
pass
|
||||
|
||||
def size(self):
|
||||
return self.shape
|
||||
|
Reference in New Issue
Block a user