ResidualModule and New Parameters, Speed Manipulation
This commit is contained in:
@@ -22,24 +22,32 @@ class BandwiseConvMultiheadClassifier(BinaryMaskDatasetFunction,
|
||||
batch_x, batch_y = batch_xy
|
||||
y = self(batch_x)
|
||||
y, bands_y = y.main_out, y.bands
|
||||
bands_y_losses = [self.criterion(band_y, batch_y) for band_y in bands_y]
|
||||
bands_y_losses = [self.bce_loss(band_y, batch_y) for band_y in bands_y]
|
||||
return_dict = {f'band_{band_idx}_loss': band_y for band_idx, band_y in enumerate(bands_y_losses)}
|
||||
overall_loss = self.criterion(y, batch_y)
|
||||
combined_loss = overall_loss + torch.stack(bands_y_losses).sum()
|
||||
return_dict.update(loss=combined_loss, overall_loss=overall_loss)
|
||||
|
||||
last_bce_loss = self.bce_loss(y, batch_y)
|
||||
return_dict.update(last_bce_loss=last_bce_loss)
|
||||
|
||||
bands_y_losses.append(last_bce_loss)
|
||||
combined_loss = torch.stack(bands_y_losses).mean()
|
||||
|
||||
return_dict.update(loss=combined_loss)
|
||||
return return_dict
|
||||
|
||||
def validation_step(self, batch_xy, batch_idx, *args, **kwargs):
|
||||
batch_x, batch_y = batch_xy
|
||||
y = self(batch_x)
|
||||
y, bands_y = y.main_out, y.bands
|
||||
bands_y_losses = [self.criterion(band_y, batch_y) for band_y in bands_y]
|
||||
bands_y_losses = [self.bce_loss(band_y, batch_y) for band_y in bands_y]
|
||||
return_dict = {f'band_{band_idx}_val_loss': band_y for band_idx, band_y in enumerate(bands_y_losses)}
|
||||
overall_loss = self.criterion(y, batch_y)
|
||||
combined_loss = overall_loss + torch.stack(bands_y_losses).sum()
|
||||
|
||||
val_abs_loss = self.absolute_loss(y, batch_y)
|
||||
return_dict.update(val_bce_loss=combined_loss, val_abs_loss=val_abs_loss,
|
||||
last_bce_loss = self.bce_loss(y, batch_y)
|
||||
return_dict.update(last_bce_loss=last_bce_loss)
|
||||
|
||||
bands_y_losses.append(last_bce_loss)
|
||||
combined_loss = torch.stack(bands_y_losses).mean()
|
||||
|
||||
return_dict.update(val_bce_loss=combined_loss,
|
||||
batch_idx=batch_idx, y=y, batch_y=batch_y
|
||||
)
|
||||
return return_dict
|
||||
@@ -56,7 +64,6 @@ class BandwiseConvMultiheadClassifier(BinaryMaskDatasetFunction,
|
||||
# Additional parameters
|
||||
self.in_shape = self.dataset.train_dataset.sample_shape
|
||||
self.conv_filters = self.params.filters
|
||||
self.criterion = nn.BCELoss()
|
||||
self.n_band_sections = 4
|
||||
k = 3 # Base Kernel Value
|
||||
|
||||
@@ -69,7 +76,7 @@ class BandwiseConvMultiheadClassifier(BinaryMaskDatasetFunction,
|
||||
last_shape = self.split.shape
|
||||
conv_list = ModuleList()
|
||||
for filters in self.conv_filters:
|
||||
conv_list.append(ConvModule(last_shape, filters, (k, k*4), conv_stride=(1, 2),
|
||||
conv_list.append(ConvModule(last_shape, filters, (k,k), conv_stride=(1, 1),
|
||||
**self.params.module_kwargs))
|
||||
last_shape = conv_list[-1].shape
|
||||
# self.conv_list.append(ConvModule(last_shape, 1, 1, conv_stride=1, **self.params.module_kwargs))
|
||||
|
||||
Reference in New Issue
Block a user