Transformer running

This commit is contained in:
Steffen Illium
2021-03-04 12:01:08 +01:00
parent b5e3e5aec1
commit f89f0f8528
14 changed files with 349 additions and 80 deletions

View File

@ -25,5 +25,12 @@ class _BaseDataModule(LightningDataModule):
self.datasets = dict()
def transfer_batch_to_device(self, batch, device):
return batch.to(device)
if isinstance(batch, list):
for idx, item in enumerate(batch):
try:
batch[idx] = item.to(device)
except (AttributeError, RuntimeError):
continue
return batch
else:
return batch.to(device)