project Refactor, CNN Classifier Basics
This commit is contained in:
@ -83,9 +83,9 @@ class LightningBaseModule(pl.LightningModule, ABC):
|
||||
print(e)
|
||||
return -1
|
||||
|
||||
def __init__(self, params):
|
||||
def __init__(self, hparams):
|
||||
super(LightningBaseModule, self).__init__()
|
||||
self.hparams = params
|
||||
self.hparams = hparams
|
||||
|
||||
# Data loading
|
||||
# =============================================================================
|
||||
@ -109,6 +109,10 @@ class LightningBaseModule(pl.LightningModule, ABC):
|
||||
def data_len(self):
|
||||
return len(self.dataset.train_dataset)
|
||||
|
||||
@property
|
||||
def n_train_batches(self):
|
||||
return len(self.train_dataloader())
|
||||
|
||||
def configure_optimizers(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@ -121,7 +125,7 @@ class LightningBaseModule(pl.LightningModule, ABC):
|
||||
def test_step(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def test_end(self, outputs):
|
||||
def test_epoch_end(self, outputs):
|
||||
raise NotImplementedError
|
||||
|
||||
def init_weights(self):
|
||||
@ -134,6 +138,26 @@ class LightningBaseModule(pl.LightningModule, ABC):
|
||||
m.bias.data.fill_(0.01)
|
||||
self.apply(_weight_init)
|
||||
|
||||
# Dataloaders
|
||||
# ================================================================================
|
||||
# Train Dataloader
|
||||
def train_dataloader(self):
|
||||
return DataLoader(dataset=self.dataset.train_dataset, shuffle=True,
|
||||
batch_size=self.hparams.data_param.batchsize,
|
||||
num_workers=self.hparams.data_param.worker)
|
||||
|
||||
# Test Dataloader
|
||||
def test_dataloader(self):
|
||||
return DataLoader(dataset=self.dataset.test_dataset, shuffle=True,
|
||||
batch_size=self.hparams.data_param.batchsize,
|
||||
num_workers=self.hparams.data_param.worker)
|
||||
|
||||
# Validation Dataloader
|
||||
def val_dataloader(self):
|
||||
return DataLoader(dataset=self.dataset.val_dataset, shuffle=False,
|
||||
batch_size=self.hparams.data_param.batchsize,
|
||||
num_workers=self.hparams.data_param.worker)
|
||||
|
||||
|
||||
class FilterLayer(nn.Module):
|
||||
|
||||
|
Reference in New Issue
Block a user