ResidualModule and New Parameters, Speed Manipulation

This commit is contained in:
Si11ium
2020-05-12 12:37:26 +02:00
parent 3fbc98dfa3
commit 28bfcfdce3
8 changed files with 181 additions and 78 deletions

View File

@@ -22,24 +22,32 @@ class BandwiseConvMultiheadClassifier(BinaryMaskDatasetFunction,
batch_x, batch_y = batch_xy
y = self(batch_x)
y, bands_y = y.main_out, y.bands
bands_y_losses = [self.criterion(band_y, batch_y) for band_y in bands_y]
bands_y_losses = [self.bce_loss(band_y, batch_y) for band_y in bands_y]
return_dict = {f'band_{band_idx}_loss': band_y for band_idx, band_y in enumerate(bands_y_losses)}
overall_loss = self.criterion(y, batch_y)
combined_loss = overall_loss + torch.stack(bands_y_losses).sum()
return_dict.update(loss=combined_loss, overall_loss=overall_loss)
last_bce_loss = self.bce_loss(y, batch_y)
return_dict.update(last_bce_loss=last_bce_loss)
bands_y_losses.append(last_bce_loss)
combined_loss = torch.stack(bands_y_losses).mean()
return_dict.update(loss=combined_loss)
return return_dict
def validation_step(self, batch_xy, batch_idx, *args, **kwargs):
batch_x, batch_y = batch_xy
y = self(batch_x)
y, bands_y = y.main_out, y.bands
bands_y_losses = [self.criterion(band_y, batch_y) for band_y in bands_y]
bands_y_losses = [self.bce_loss(band_y, batch_y) for band_y in bands_y]
return_dict = {f'band_{band_idx}_val_loss': band_y for band_idx, band_y in enumerate(bands_y_losses)}
overall_loss = self.criterion(y, batch_y)
combined_loss = overall_loss + torch.stack(bands_y_losses).sum()
val_abs_loss = self.absolute_loss(y, batch_y)
return_dict.update(val_bce_loss=combined_loss, val_abs_loss=val_abs_loss,
last_bce_loss = self.bce_loss(y, batch_y)
return_dict.update(last_bce_loss=last_bce_loss)
bands_y_losses.append(last_bce_loss)
combined_loss = torch.stack(bands_y_losses).mean()
return_dict.update(val_bce_loss=combined_loss,
batch_idx=batch_idx, y=y, batch_y=batch_y
)
return return_dict
@@ -56,7 +64,6 @@ class BandwiseConvMultiheadClassifier(BinaryMaskDatasetFunction,
# Additional parameters
self.in_shape = self.dataset.train_dataset.sample_shape
self.conv_filters = self.params.filters
self.criterion = nn.BCELoss()
self.n_band_sections = 4
k = 3 # Base Kernel Value
@@ -69,7 +76,7 @@ class BandwiseConvMultiheadClassifier(BinaryMaskDatasetFunction,
last_shape = self.split.shape
conv_list = ModuleList()
for filters in self.conv_filters:
conv_list.append(ConvModule(last_shape, filters, (k, k*4), conv_stride=(1, 2),
conv_list.append(ConvModule(last_shape, filters, (k,k), conv_stride=(1, 1),
**self.params.module_kwargs))
last_shape = conv_list[-1].shape
# self.conv_list.append(ConvModule(last_shape, 1, 1, conv_stride=1, **self.params.module_kwargs))