ml_lib/utils/_basedatamodule.py
2021-02-15 11:39:54 +01:00

30 lines
739 B
Python

from pytorch_lightning import LightningDataModule
# Dataset Options
from ml_lib.utils.tools import add_argparse_args
DATA_OPTION_test = 'test'
DATA_OPTION_devel = 'devel'
DATA_OPTION_train = 'train'
DATA_OPTIONS = [DATA_OPTION_train, DATA_OPTION_devel, DATA_OPTION_test]
class _BaseDataModule(LightningDataModule):
@property
def shape(self):
return self.datasets[DATA_OPTION_train].sample_shape
@classmethod
def add_argparse_args(cls, parent_parser):
return add_argparse_args(cls, parent_parser)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.datasets = dict()
def transfer_batch_to_device(self, batch, device):
return batch.to(device)