Parameter Adjustmens and Ensemble Model Implementation
This commit is contained in:
@ -24,17 +24,6 @@ class F_x(object):
|
||||
|
||||
class ShapeMixin:
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
x = torch.randn(self.in_shape).unsqueeze(0)
|
||||
output: torch.Tensor = self(x)
|
||||
return output.shape[1:]
|
||||
|
||||
|
||||
# Utility - Modules
|
||||
###################
|
||||
class Flatten(nn.Module):
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
try:
|
||||
@ -45,6 +34,11 @@ class Flatten(nn.Module):
|
||||
print(e)
|
||||
return -1
|
||||
|
||||
|
||||
# Utility - Modules
|
||||
###################
|
||||
class Flatten(ShapeMixin, nn.Module):
|
||||
|
||||
def __init__(self, in_shape, to=-1):
|
||||
assert isinstance(to, int) or isinstance(to, tuple)
|
||||
super(Flatten, self).__init__()
|
||||
@ -172,29 +166,6 @@ class LightningBaseModule(pl.LightningModule, ABC):
|
||||
self.apply(weight_initializer)
|
||||
|
||||
|
||||
class BaseModuleMixin_Dataloaders(ABC):
|
||||
|
||||
# Dataloaders
|
||||
# ================================================================================
|
||||
# Train Dataloader
|
||||
def train_dataloader(self):
|
||||
return DataLoader(dataset=self.dataset.train_dataset, shuffle=True,
|
||||
batch_size=self.params.batch_size,
|
||||
num_workers=self.params.worker)
|
||||
|
||||
# Test Dataloader
|
||||
def test_dataloader(self):
|
||||
return DataLoader(dataset=self.dataset.test_dataset, shuffle=True,
|
||||
batch_size=self.params.batch_size,
|
||||
num_workers=self.params.worker)
|
||||
|
||||
# Validation Dataloader
|
||||
def val_dataloader(self):
|
||||
return DataLoader(dataset=self.dataset.val_dataset, shuffle=True,
|
||||
batch_size=self.params.batch_size,
|
||||
num_workers=self.params.worker)
|
||||
|
||||
|
||||
class FilterLayer(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
@ -253,7 +224,7 @@ class HorizontalSplitter(nn.Module):
|
||||
self.in_shape = in_shape
|
||||
|
||||
self.channel, self.height, self.width = self.in_shape
|
||||
self.new_height = (self.height // self.n) + 1 if self.height % self.n != 0 else 0
|
||||
self.new_height = (self.height // self.n) + (1 if self.height % self.n != 0 else 0)
|
||||
|
||||
self.shape = (self.channel, self.new_height, self.width)
|
||||
self.autopad = AutoPadToShape(self.shape)
|
||||
|
Reference in New Issue
Block a user