Lightning integration basic ae, dataloaders and dataset
This commit is contained in:
@ -90,13 +90,15 @@ class Repeater(Module):
|
||||
|
||||
class RNNOutputFilter(Module):
|
||||
|
||||
def __init__(self, return_output=True):
|
||||
def __init__(self, return_output=True, only_last=False):
|
||||
super(RNNOutputFilter, self).__init__()
|
||||
self.only_last = only_last
|
||||
self.return_output = return_output
|
||||
|
||||
def forward(self, x: tuple):
|
||||
outputs, hidden = x
|
||||
return outputs if self.return_output else hidden
|
||||
out = outputs if self.return_output else hidden
|
||||
return out if not self.only_last else out[:, -1, :]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Reference in New Issue
Block a user