Data Loaders and Stuff

This commit is contained in:
illiumst
2019-09-29 19:59:09 +02:00
parent aa802cb2be
commit 221565f4ec
26 changed files with 267 additions and 18 deletions

View File

@ -31,8 +31,8 @@ class LightningModuleOverrides:
return self.network.forward(x)
@data_loader
def tng_dataloader(self):
num_workers = 0 # os.cpu_count() // 2
def train_dataloader(self):
num_workers = 0 # os.cpu_count() // 2
return DataLoader(DataContainer(os.path.join('data', 'training'), self.size, self.step),
shuffle=True, batch_size=10000, num_workers=num_workers)
"""
@ -73,6 +73,17 @@ class LightningModule(pl.LightningModule, ABC):
# REQUIRED
raise NotImplementedError
@abstractmethod
def configure_optimizers(self):
# REQUIRED
raise NotImplementedError
@pl.data_loader
def train_dataloader(self):
# REQUIRED
raise NotImplementedError
"""
def validation_step(self, batch, batch_nb):
# OPTIONAL
pass
@ -81,19 +92,6 @@ class LightningModule(pl.LightningModule, ABC):
# OPTIONAL
pass
@abstractmethod
def configure_optimizers(self):
# REQUIRED
raise NotImplementedError
@pl.data_loader
def tng_dataloader(self):
# REQUIRED
raise NotImplementedError
# return DataLoader(MNIST(os.getcwd(), train=True, download=True,
# transform=transforms.ToTensor()), batch_size=32)
"""
@pl.data_loader
def val_dataloader(self):
# OPTIONAL