Code Comments, Getting Dirty Env, Naming

This commit is contained in:
Steffen Illium
2021-05-11 10:31:34 +02:00
parent faa27c3cf9
commit ab01006eae
7 changed files with 51 additions and 16 deletions

View File

@@ -8,8 +8,6 @@ import torch
from torch import nn
from torch.nn import functional as F
from einops import rearrange, repeat
import sys
sys.path.append(str(Path(__file__).parent))
@@ -40,7 +38,7 @@ class LinearModule(ShapeMixin, nn.Module):
tensor = self.flat(x)
tensor = self.dropout(tensor)
tensor = self.norm(tensor)
tensor = self.linear(tensor)
tensor = self.linear(tensor.float())
tensor = self.activation(tensor)
return tensor
@@ -249,6 +247,7 @@ class Attention(nn.Module):
) if project_out else nn.Identity()
def forward(self, x, mask=None, return_attn_weights=False):
from einops import rearrange, repeat
# noinspection PyTupleAssignmentBalance
b, n, _, h = *x.shape, self.heads

View File

@@ -129,8 +129,10 @@ try:
self._weight_init = weight_init
self.params = ModelParameters(model_parameters)
self.metrics = PLMetrics(self.params.n_classes, tag='PL')
pass
if hasattr(self.params, 'n_classes'):
self.metrics = PLMetrics(self.params.n_classes, tag='PL')
else:
pass
def size(self):
return self.shape