Data Loaders and Stuff
This commit is contained in:
@ -31,8 +31,8 @@ class LightningModuleOverrides:
|
||||
return self.network.forward(x)
|
||||
|
||||
@data_loader
|
||||
def tng_dataloader(self):
|
||||
num_workers = 0 # os.cpu_count() // 2
|
||||
def train_dataloader(self):
|
||||
num_workers = 0 # os.cpu_count() // 2
|
||||
return DataLoader(DataContainer(os.path.join('data', 'training'), self.size, self.step),
|
||||
shuffle=True, batch_size=10000, num_workers=num_workers)
|
||||
"""
|
||||
@ -73,6 +73,17 @@ class LightningModule(pl.LightningModule, ABC):
|
||||
# REQUIRED
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def configure_optimizers(self):
|
||||
# REQUIRED
|
||||
raise NotImplementedError
|
||||
|
||||
@pl.data_loader
|
||||
def train_dataloader(self):
|
||||
# REQUIRED
|
||||
raise NotImplementedError
|
||||
|
||||
"""
|
||||
def validation_step(self, batch, batch_nb):
|
||||
# OPTIONAL
|
||||
pass
|
||||
@ -81,19 +92,6 @@ class LightningModule(pl.LightningModule, ABC):
|
||||
# OPTIONAL
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def configure_optimizers(self):
|
||||
# REQUIRED
|
||||
raise NotImplementedError
|
||||
|
||||
@pl.data_loader
|
||||
def tng_dataloader(self):
|
||||
# REQUIRED
|
||||
raise NotImplementedError
|
||||
# return DataLoader(MNIST(os.getcwd(), train=True, download=True,
|
||||
# transform=transforms.ToTensor()), batch_size=32)
|
||||
|
||||
"""
|
||||
@pl.data_loader
|
||||
def val_dataloader(self):
|
||||
# OPTIONAL
|
||||
|
Reference in New Issue
Block a user