diff --git a/lib/models/generators/cnn.py b/lib/models/generators/cnn.py index 3383029..6167b6e 100644 --- a/lib/models/generators/cnn.py +++ b/lib/models/generators/cnn.py @@ -148,10 +148,10 @@ class CNNRouteGeneratorModel(LightningBaseModule): self.reshape_to_map = Flatten(reduce(mul, reshape_shape), reshape_shape) self.alt_deconv_1 = DeConvModule(reshape_shape, self.hparams.model_param.filters[2], - conv_padding=0, conv_kernel=9, conv_stride=1, + conv_padding=0, conv_kernel=13, conv_stride=1, use_norm=self.hparams.model_param.use_norm) self.alt_deconv_2 = DeConvModule(self.alt_deconv_1.shape, self.hparams.model_param.filters[1], - conv_padding=0, conv_kernel=5, conv_stride=1, + conv_padding=0, conv_kernel=7, conv_stride=1, use_norm=self.hparams.model_param.use_norm) self.alt_deconv_3 = DeConvModule(self.alt_deconv_2.shape, self.hparams.model_param.filters[0], conv_padding=1, conv_kernel=5, conv_stride=1,