from networks.basic_ae import BasicAE, AELightningOverrides from networks.modules import LightningModule from torch.optim import Adam from torch.utils.data import DataLoader from pytorch_lightning import data_loader from dataset import DataContainer from torch.nn import BatchNorm1d from pytorch_lightning import Trainer class AEModel(AELightningOverrides, LightningModule): def __init__(self, dataParams: dict): super(AEModel, self).__init__() self.dataParams = dataParams # noinspection PyUnresolvedReferences self.network = BasicAE(self.dataParams) def configure_optimizers(self): return [Adam(self.parameters(), lr=0.02)] @data_loader def tng_dataloader(self): return DataLoader(DataContainer('data', **self.dataParams), shuffle=True, batch_size=100) def forward(self, x): return self.network.forward(x) if __name__ == '__main__': features = 6 ae = AEModel( dataParams=dict(refresh=False, size=5, step=5, features=features, transforms=[BatchNorm1d(features)]) ) trainer = Trainer() trainer.fit(ae)