New Model, Many Changes

This commit is contained in:
Si11ium
2020-11-21 09:28:26 +01:00
parent 7bac9e984b
commit be097a111a
12 changed files with 349 additions and 125 deletions

View File

@@ -4,7 +4,7 @@ from torch import nn
from torch.nn import ModuleList
from ml_lib.modules.blocks import ConvModule, LinearModule
from ml_lib.modules.util import (LightningBaseModule, HorizontalSplitter, HorizontalMerger)
from ml_lib.modules.util import (LightningBaseModule, Splitter, Merger)
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
BaseDataloadersMixin)
@@ -33,7 +33,7 @@ class BandwiseConvClassifier(BinaryMaskDatasetMixin,
# Modules
# =============================================================================
self.split = HorizontalSplitter(self.in_shape, self.n_band_sections)
self.split = Splitter(self.in_shape, self.n_band_sections)
k = 3
self.band_list = ModuleList()
@@ -48,7 +48,7 @@ class BandwiseConvClassifier(BinaryMaskDatasetMixin,
# last_shape = self.conv_list[-1].shape
self.band_list.append(conv_list)
self.merge = HorizontalMerger(self.band_list[-1][-1].shape, self.n_band_sections)
self.merge = Merger(self.band_list[-1][-1].shape, self.n_band_sections)
self.full_1 = LinearModule(self.merge.shape, self.params.lat_dim, **self.params.module_kwargs)
self.full_2 = LinearModule(self.full_1.shape, self.full_1.shape * 2, **self.params.module_kwargs)