from networks.basic_vae import BasicVAE, VAELightningOverrides from networks.modules import LightningModule import pytorch_lightning as pl from torch.nn.functional import mse_loss from torch.optim import Adam import torch from torch.nn import BatchNorm1d from torch.utils.data import DataLoader from dataset import DataContainer from pytorch_lightning import Trainer class AEModel(VAELightningOverrides, LightningModule): def __init__(self, dataParams: dict): super(AEModel, self).__init__() self.dataParams = dataParams # noinspection PyUnresolvedReferences self.network = BasicVAE(self.dataParams) def forward(self, x): return self.network.forward(x) def configure_optimizers(self): # ToDo: Where do i get the Paramers from? return [Adam(self.parameters(), lr=0.02)] @pl.data_loader def tng_dataloader(self): return DataLoader(DataContainer('data', **self.dataParams), shuffle=True, batch_size=100) if __name__ == '__main__': features = 6 ae = AEModel( dataParams=dict(refresh=False, size=5, step=5, features=features, transforms=[BatchNorm1d(features)]) ) trainer = Trainer() trainer.fit(ae)