Lightning integration basic ae, dataloaders and dataset

This commit is contained in:
Si11ium
2019-08-16 14:29:48 +02:00
parent fbe0600e24
commit 265c900f33
10 changed files with 406 additions and 49 deletions

View File

@ -90,13 +90,15 @@ class Repeater(Module):
class RNNOutputFilter(Module):
def __init__(self, return_output=True):
def __init__(self, return_output=True, only_last=False):
super(RNNOutputFilter, self).__init__()
self.only_last = only_last
self.return_output = return_output
def forward(self, x: tuple):
outputs, hidden = x
return outputs if self.return_output else hidden
out = outputs if self.return_output else hidden
return out if not self.only_last else out[:, -1, :]
if __name__ == '__main__':