Kurz vorm durchdrehen
This commit is contained in:
parent
1b5a7dc69e
commit
1f4edae95c
@ -18,10 +18,12 @@ class TrajDataset(Dataset):
|
|||||||
def map_shape(self):
|
def map_shape(self):
|
||||||
return self.map.as_array.shape
|
return self.map.as_array.shape
|
||||||
|
|
||||||
def __init__(self, *args, maps_root: Union[Path, str] = '', mapname='tate_sw',
|
def __init__(self, *args, maps_root: Union[Path, str] = '', mapname='tate_sw', normalized=True,
|
||||||
length=100000, mode='separated_arrays', embedding_size=None, preserve_equal_samples=False, **kwargs):
|
length=100000, mode='separated_arrays', embedding_size=None, preserve_equal_samples=False,
|
||||||
|
**kwargs):
|
||||||
super(TrajDataset, self).__init__()
|
super(TrajDataset, self).__init__()
|
||||||
assert mode.lower() in ['vectors', 'all_in_map', 'separated_arrays', 'just_route']
|
assert mode.lower() in ['vectors', 'all_in_map', 'separated_arrays', 'just_route']
|
||||||
|
self.normalized = normalized
|
||||||
self.preserve_equal_samples = preserve_equal_samples
|
self.preserve_equal_samples = preserve_equal_samples
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
self.mapname = mapname if mapname.endswith('.bmp') else f'{mapname}.bmp'
|
self.mapname = mapname if mapname.endswith('.bmp') else f'{mapname}.bmp'
|
||||||
@ -58,6 +60,10 @@ class TrajDataset(Dataset):
|
|||||||
trajectory = trajectory.draw_in_array(self.map_shape)
|
trajectory = trajectory.draw_in_array(self.map_shape)
|
||||||
alternative = alternative.draw_in_array(self.map_shape)
|
alternative = alternative.draw_in_array(self.map_shape)
|
||||||
if self.mode == 'separated_arrays':
|
if self.mode == 'separated_arrays':
|
||||||
|
if self.normalized:
|
||||||
|
map_array = map_array / V.WHITE
|
||||||
|
trajectory = trajectory / V.WHITE
|
||||||
|
alternative = alternative / V.WHITE
|
||||||
return (map_array, trajectory, label), alternative
|
return (map_array, trajectory, label), alternative
|
||||||
else:
|
else:
|
||||||
return np.concatenate((map_array, trajectory, alternative)), label
|
return np.concatenate((map_array, trajectory, alternative)), label
|
||||||
@ -86,8 +92,9 @@ class TrajData(object):
|
|||||||
def name(self):
|
def name(self):
|
||||||
return self.__class__.__name__
|
return self.__class__.__name__
|
||||||
|
|
||||||
def __init__(self, map_root, length=100000, mode='separated_arrays', **_):
|
def __init__(self, map_root, length=100000, mode='separated_arrays', normalized=True, **_):
|
||||||
|
|
||||||
|
self.normalized = normalized
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
self.maps_root = Path(map_root)
|
self.maps_root = Path(map_root)
|
||||||
self.length = length
|
self.length = length
|
||||||
@ -100,7 +107,7 @@ class TrajData(object):
|
|||||||
# find max image size among available maps:
|
# find max image size among available maps:
|
||||||
max_map_size = (1, ) + tuple(reversed(tuple(map(max, *[Image.open(map_file).size for map_file in map_files]))))
|
max_map_size = (1, ) + tuple(reversed(tuple(map(max, *[Image.open(map_file).size for map_file in map_files]))))
|
||||||
return ConcatDataset([TrajDataset(maps_root=self.maps_root, mapname=map_file.name, length=equal_split,
|
return ConcatDataset([TrajDataset(maps_root=self.maps_root, mapname=map_file.name, length=equal_split,
|
||||||
mode=self.mode, embedding_size=max_map_size,
|
mode=self.mode, embedding_size=max_map_size, normalized=self.normalized,
|
||||||
preserve_equal_samples=True)
|
preserve_equal_samples=True)
|
||||||
for map_file in map_files])
|
for map_file in map_files])
|
||||||
|
|
||||||
|
@ -27,20 +27,21 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
|||||||
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
|
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
|
||||||
batch_x, alternative = batch_xy
|
batch_x, alternative = batch_xy
|
||||||
generated_alternative, z, mu, logvar = self(batch_x)
|
generated_alternative, z, mu, logvar = self(batch_x)
|
||||||
mse_loss = self.criterion(generated_alternative, alternative)
|
element_wise_loss = self.criterion(generated_alternative, alternative)
|
||||||
# see Appendix B from VAE paper:
|
# see Appendix B from VAE paper:
|
||||||
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
|
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
|
||||||
# https://arxiv.org/abs/1312.6114
|
# https://arxiv.org/abs/1312.6114
|
||||||
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
|
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
|
||||||
|
|
||||||
kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
|
kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
|
||||||
# Dimensional Resizing TODO: Does This make sense? Sanity Check it!
|
# Dimensional Resizing TODO: Does This make sense? Sanity Check it!
|
||||||
kld_loss /= reduce(mul, self.in_shape)
|
# kld_loss /= reduce(mul, self.in_shape)
|
||||||
|
|
||||||
loss = (kld_loss + mse_loss) / 2
|
loss = (kld_loss + element_wise_loss) / 2
|
||||||
return dict(loss=loss, log=dict(loss=loss, mse_loss=mse_loss, kld_loss=kld_loss))
|
return dict(loss=loss, log=dict(element_wise_loss=element_wise_loss, loss=loss, kld_loss=kld_loss))
|
||||||
|
|
||||||
def _test_val_step(self, batch_xy, batch_nb, *args):
|
def _test_val_step(self, batch_xy, batch_nb, *args):
|
||||||
batch_x, alternative = batch_xy
|
batch_x, _ = batch_xy
|
||||||
map_array, trajectory, label = batch_x
|
map_array, trajectory, label = batch_x
|
||||||
|
|
||||||
generated_alternative, z, mu, logvar = self(batch_x)
|
generated_alternative, z, mu, logvar = self(batch_x)
|
||||||
@ -48,18 +49,12 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
|||||||
return dict(batch_nb=batch_nb, label=label, generated_alternative=generated_alternative, pred_label=-1)
|
return dict(batch_nb=batch_nb, label=label, generated_alternative=generated_alternative, pred_label=-1)
|
||||||
|
|
||||||
def _test_val_epoch_end(self, outputs, test=False):
|
def _test_val_epoch_end(self, outputs, test=False):
|
||||||
|
|
||||||
labels = torch.cat([x['label'] for x in outputs]).unsqueeze(1)
|
|
||||||
|
|
||||||
if test:
|
|
||||||
self.logger.log_image(f'{self.name}_ROC-Curve', plt.gcf())
|
|
||||||
plt.clf()
|
|
||||||
maps, trajectories, labels, val_restul_dict = self.generate_random()
|
maps, trajectories, labels, val_restul_dict = self.generate_random()
|
||||||
|
|
||||||
from lib.visualization.generator_eval import GeneratorVisualizer
|
from lib.visualization.generator_eval import GeneratorVisualizer
|
||||||
g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict)
|
g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict)
|
||||||
fig = g.draw()
|
fig = g.draw()
|
||||||
self.logger.log_image(f'{self.name}_Output', fig)
|
self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step)
|
||||||
|
|
||||||
return dict(epoch=self.current_epoch)
|
return dict(epoch=self.current_epoch)
|
||||||
|
|
||||||
@ -81,69 +76,88 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
|||||||
if not issubclassed:
|
if not issubclassed:
|
||||||
# Dataset
|
# Dataset
|
||||||
self.dataset = TrajData(self.hparams.data_param.map_root, mode='separated_arrays',
|
self.dataset = TrajData(self.hparams.data_param.map_root, mode='separated_arrays',
|
||||||
length=self.hparams.data_param.dataset_length)
|
length=self.hparams.data_param.dataset_length, normalized=True)
|
||||||
self.criterion = nn.MSELoss()
|
self.criterion = nn.MSELoss()
|
||||||
|
|
||||||
# Additional Attributes
|
# Additional Attributes
|
||||||
self.in_shape = self.dataset.map_shapes_max
|
self.in_shape = self.dataset.map_shapes_max
|
||||||
# Todo: Better naming and size in Parameters
|
# Todo: Better naming and size in Parameters
|
||||||
self.feature_dim = 10
|
self.feature_dim = self.hparams.model_param.lat_dim * 10
|
||||||
self.lat_dim = self.feature_dim + self.feature_dim + 1
|
self.feature_mixed_dim = self.feature_dim + self.feature_dim + 1
|
||||||
|
|
||||||
# NN Nodes
|
# NN Nodes
|
||||||
###################################################
|
###################################################
|
||||||
#
|
#
|
||||||
# Utils
|
# Utils
|
||||||
self.relu = nn.ReLU()
|
self.activation = nn.ReLU()
|
||||||
self.sigmoid = nn.Sigmoid()
|
self.sigmoid = nn.Sigmoid()
|
||||||
|
|
||||||
#
|
#
|
||||||
# Map Encoder
|
# Map Encoder
|
||||||
self.map_conv_0 = ConvModule(self.in_shape, conv_kernel=3, conv_stride=1, conv_padding=1,
|
self.map_conv_0 = ConvModule(self.in_shape, conv_kernel=3, conv_stride=1, conv_padding=1,
|
||||||
conv_filters=self.hparams.model_param.filters[0])
|
conv_filters=self.hparams.model_param.filters[0],
|
||||||
|
use_norm=self.hparams.model_param.use_norm,
|
||||||
|
use_bias=self.hparams.model_param.use_bias)
|
||||||
|
|
||||||
self.map_res_1 = ResidualModule(self.map_conv_0.shape, ConvModule, 2, conv_kernel=3, conv_stride=1,
|
self.map_res_1 = ResidualModule(self.map_conv_0.shape, ConvModule, 2, conv_kernel=3, conv_stride=1,
|
||||||
conv_padding=1, conv_filters=self.hparams.model_param.filters[0])
|
conv_padding=1, conv_filters=self.hparams.model_param.filters[0],
|
||||||
|
use_norm=self.hparams.model_param.use_norm,
|
||||||
|
use_bias=self.hparams.model_param.use_bias)
|
||||||
self.map_conv_1 = ConvModule(self.map_res_1.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
|
self.map_conv_1 = ConvModule(self.map_res_1.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
|
||||||
conv_filters=self.hparams.model_param.filters[1])
|
conv_filters=self.hparams.model_param.filters[1],
|
||||||
|
use_norm=self.hparams.model_param.use_norm,
|
||||||
|
use_bias=self.hparams.model_param.use_bias)
|
||||||
|
|
||||||
self.map_res_2 = ResidualModule(self.map_conv_1.shape, ConvModule, 2, conv_kernel=3, conv_stride=1,
|
self.map_res_2 = ResidualModule(self.map_conv_1.shape, ConvModule, 2, conv_kernel=3, conv_stride=1,
|
||||||
conv_padding=1, conv_filters=self.hparams.model_param.filters[1])
|
conv_padding=1, conv_filters=self.hparams.model_param.filters[1],
|
||||||
|
use_norm=self.hparams.model_param.use_norm,
|
||||||
|
use_bias=self.hparams.model_param.use_bias)
|
||||||
self.map_conv_2 = ConvModule(self.map_res_2.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
|
self.map_conv_2 = ConvModule(self.map_res_2.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
|
||||||
conv_filters=self.hparams.model_param.filters[2])
|
conv_filters=self.hparams.model_param.filters[2],
|
||||||
|
use_norm=self.hparams.model_param.use_norm,
|
||||||
|
use_bias=self.hparams.model_param.use_bias)
|
||||||
|
|
||||||
self.map_res_3 = ResidualModule(self.map_conv_2.shape, ConvModule, 2, conv_kernel=3, conv_stride=1,
|
self.map_res_3 = ResidualModule(self.map_conv_2.shape, ConvModule, 2, conv_kernel=3, conv_stride=1,
|
||||||
conv_padding=1, conv_filters=self.hparams.model_param.filters[2])
|
conv_padding=1, conv_filters=self.hparams.model_param.filters[2],
|
||||||
|
use_norm=self.hparams.model_param.use_norm,
|
||||||
|
use_bias=self.hparams.model_param.use_bias)
|
||||||
self.map_conv_3 = ConvModule(self.map_res_3.shape, conv_kernel=5, conv_stride=1, conv_padding=0,
|
self.map_conv_3 = ConvModule(self.map_res_3.shape, conv_kernel=5, conv_stride=1, conv_padding=0,
|
||||||
conv_filters=self.hparams.model_param.filters[2]*2)
|
conv_filters=self.hparams.model_param.filters[2]*2,
|
||||||
|
use_norm=self.hparams.model_param.use_norm,
|
||||||
|
use_bias=self.hparams.model_param.use_bias)
|
||||||
|
|
||||||
self.map_flat = Flatten(self.map_conv_3.shape)
|
self.map_flat = Flatten(self.map_conv_3.shape)
|
||||||
|
|
||||||
self.map_lin = nn.Linear(reduce(mul, self.map_conv_3.shape), self.feature_dim)
|
self.map_lin = nn.Linear(reduce(mul, self.map_conv_3.shape), self.feature_dim)
|
||||||
|
|
||||||
#
|
#
|
||||||
# Trajectory Encoder
|
# Trajectory Encoder
|
||||||
self.traj_conv_1 = ConvModule(self.in_shape, conv_kernel=3, conv_stride=1, conv_padding=0,
|
self.traj_conv_1 = ConvModule(self.in_shape, conv_kernel=3, conv_stride=1, conv_padding=0,
|
||||||
conv_filters=self.hparams.model_param.filters[0])
|
conv_filters=self.hparams.model_param.filters[0],
|
||||||
|
use_norm=self.hparams.model_param.use_norm,
|
||||||
|
use_bias=self.hparams.model_param.use_bias)
|
||||||
|
|
||||||
self.traj_conv_2 = ConvModule(self.traj_conv_1.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
|
self.traj_conv_2 = ConvModule(self.traj_conv_1.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
|
||||||
conv_filters=self.hparams.model_param.filters[0])
|
conv_filters=self.hparams.model_param.filters[0],
|
||||||
|
use_norm=self.hparams.model_param.use_norm,
|
||||||
|
use_bias=self.hparams.model_param.use_bias)
|
||||||
|
|
||||||
self.traj_conv_3 = ConvModule(self.traj_conv_2.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
|
self.traj_conv_3 = ConvModule(self.traj_conv_2.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
|
||||||
conv_filters=self.hparams.model_param.filters[0])
|
conv_filters=self.hparams.model_param.filters[0],
|
||||||
|
use_norm=self.hparams.model_param.use_norm,
|
||||||
|
use_bias=self.hparams.model_param.use_bias)
|
||||||
|
|
||||||
self.traj_flat = Flatten(self.traj_conv_3.shape)
|
self.traj_flat = Flatten(self.traj_conv_3.shape)
|
||||||
|
|
||||||
self.traj_lin = nn.Linear(reduce(mul, self.traj_conv_3.shape), self.feature_dim)
|
self.traj_lin = nn.Linear(reduce(mul, self.traj_conv_3.shape), self.feature_dim)
|
||||||
|
|
||||||
#
|
#
|
||||||
# Mixed Encoder
|
# Mixed Encoder
|
||||||
self.mixed_lin = nn.Linear(self.lat_dim, self.lat_dim)
|
self.mixed_lin = nn.Linear(self.feature_mixed_dim, self.feature_mixed_dim)
|
||||||
|
self.mixed_norm = nn.BatchNorm1d(self.feature_mixed_dim) if self.hparams.model_param.use_norm else lambda x: x
|
||||||
|
|
||||||
#
|
#
|
||||||
# Variational Bottleneck
|
# Variational Bottleneck
|
||||||
self.mu = nn.Linear(self.lat_dim, self.hparams.model_param.lat_dim)
|
self.mu = nn.Linear(self.feature_mixed_dim, self.hparams.model_param.lat_dim)
|
||||||
self.logvar = nn.Linear(self.lat_dim, self.hparams.model_param.lat_dim)
|
self.logvar = nn.Linear(self.feature_mixed_dim, self.hparams.model_param.lat_dim)
|
||||||
|
|
||||||
#
|
#
|
||||||
# Alternative Generator
|
# Alternative Generator
|
||||||
@ -153,13 +167,17 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
|||||||
self.reshape_to_map = Flatten(reduce(mul, self.traj_conv_3.shape), self.traj_conv_3.shape)
|
self.reshape_to_map = Flatten(reduce(mul, self.traj_conv_3.shape), self.traj_conv_3.shape)
|
||||||
|
|
||||||
self.alt_deconv_1 = DeConvModule(self.traj_conv_3.shape, self.hparams.model_param.filters[2],
|
self.alt_deconv_1 = DeConvModule(self.traj_conv_3.shape, self.hparams.model_param.filters[2],
|
||||||
conv_padding=0, conv_kernel=5, conv_stride=1)
|
conv_padding=0, conv_kernel=5, conv_stride=1,
|
||||||
|
use_norm=self.hparams.model_param.use_norm)
|
||||||
self.alt_deconv_2 = DeConvModule(self.alt_deconv_1.shape, self.hparams.model_param.filters[1],
|
self.alt_deconv_2 = DeConvModule(self.alt_deconv_1.shape, self.hparams.model_param.filters[1],
|
||||||
conv_padding=0, conv_kernel=3, conv_stride=1)
|
conv_padding=0, conv_kernel=3, conv_stride=1,
|
||||||
|
use_norm=self.hparams.model_param.use_norm)
|
||||||
self.alt_deconv_3 = DeConvModule(self.alt_deconv_2.shape, self.hparams.model_param.filters[0],
|
self.alt_deconv_3 = DeConvModule(self.alt_deconv_2.shape, self.hparams.model_param.filters[0],
|
||||||
conv_padding=1, conv_kernel=3, conv_stride=1)
|
conv_padding=1, conv_kernel=3, conv_stride=1,
|
||||||
|
use_norm=self.hparams.model_param.use_norm)
|
||||||
self.alt_deconv_out = DeConvModule(self.alt_deconv_3.shape, 1, activation=None,
|
self.alt_deconv_out = DeConvModule(self.alt_deconv_3.shape, 1, activation=None,
|
||||||
conv_padding=1, conv_kernel=3, conv_stride=1)
|
conv_padding=1, conv_kernel=3, conv_stride=1,
|
||||||
|
use_norm=self.hparams.model_param.use_norm)
|
||||||
|
|
||||||
def forward(self, batch_x):
|
def forward(self, batch_x):
|
||||||
#
|
#
|
||||||
@ -173,7 +191,6 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
|||||||
#
|
#
|
||||||
# Generate
|
# Generate
|
||||||
alt_tensor = self.generate(z)
|
alt_tensor = self.generate(z)
|
||||||
|
|
||||||
return alt_tensor, z, mu, logvar
|
return alt_tensor, z, mu, logvar
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -184,13 +201,16 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
|||||||
|
|
||||||
def generate(self, z):
|
def generate(self, z):
|
||||||
alt_tensor = self.alt_lin_1(z)
|
alt_tensor = self.alt_lin_1(z)
|
||||||
|
alt_tensor = self.activation(alt_tensor)
|
||||||
alt_tensor = self.alt_lin_2(alt_tensor)
|
alt_tensor = self.alt_lin_2(alt_tensor)
|
||||||
|
alt_tensor = self.activation(alt_tensor)
|
||||||
alt_tensor = self.reshape_to_map(alt_tensor)
|
alt_tensor = self.reshape_to_map(alt_tensor)
|
||||||
alt_tensor = self.alt_deconv_1(alt_tensor)
|
alt_tensor = self.alt_deconv_1(alt_tensor)
|
||||||
alt_tensor = self.alt_deconv_2(alt_tensor)
|
alt_tensor = self.alt_deconv_2(alt_tensor)
|
||||||
alt_tensor = self.alt_deconv_3(alt_tensor)
|
alt_tensor = self.alt_deconv_3(alt_tensor)
|
||||||
alt_tensor = self.alt_deconv_out(alt_tensor)
|
alt_tensor = self.alt_deconv_out(alt_tensor)
|
||||||
# alt_tensor = self.sigmoid(alt_tensor)
|
# alt_tensor = self.activation(alt_tensor)
|
||||||
|
alt_tensor = self.sigmoid(alt_tensor)
|
||||||
return alt_tensor
|
return alt_tensor
|
||||||
|
|
||||||
def encode(self, map_array, trajectory, label):
|
def encode(self, map_array, trajectory, label):
|
||||||
@ -211,23 +231,26 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
|||||||
traj_tensor = self.traj_lin(traj_tensor)
|
traj_tensor = self.traj_lin(traj_tensor)
|
||||||
|
|
||||||
mixed_tensor = torch.cat((map_tensor, traj_tensor, label.float().unsqueeze(-1)), dim=1)
|
mixed_tensor = torch.cat((map_tensor, traj_tensor, label.float().unsqueeze(-1)), dim=1)
|
||||||
mixed_tensor = self.relu(mixed_tensor)
|
mixed_tensor = self.mixed_norm(mixed_tensor)
|
||||||
|
mixed_tensor = self.activation(mixed_tensor)
|
||||||
mixed_tensor = self.mixed_lin(mixed_tensor)
|
mixed_tensor = self.mixed_lin(mixed_tensor)
|
||||||
mixed_tensor = self.relu(mixed_tensor)
|
mixed_tensor = self.mixed_norm(mixed_tensor)
|
||||||
|
mixed_tensor = self.activation(mixed_tensor)
|
||||||
|
|
||||||
#
|
#
|
||||||
# Parameter and Sampling
|
# Parameter and Sampling
|
||||||
mu = self.mu(mixed_tensor)
|
mu = self.mu(mixed_tensor)
|
||||||
logvar = self.logvar(mixed_tensor)
|
logvar = self.logvar(mixed_tensor)
|
||||||
|
# logvar = torch.clamp(logvar, min=0, max=10)
|
||||||
z = self.reparameterize(mu, logvar)
|
z = self.reparameterize(mu, logvar)
|
||||||
return z, mu, logvar
|
return z, mu, logvar
|
||||||
|
|
||||||
def generate_random(self, n=6):
|
def generate_random(self, n=6):
|
||||||
maps = [self.map_storage[choice(self.map_storage.keys_list)] for _ in range(n)]
|
maps = [self.map_storage[choice(self.map_storage.keys_list)] for _ in range(n)]
|
||||||
|
|
||||||
trajectories = [x.get_random_trajectory() for x in maps] * 2
|
trajectories = [x.get_random_trajectory() for x in maps]
|
||||||
trajectories = [x.draw_in_array(self.map_storage.max_map_size) for x in trajectories]
|
trajectories = [x.draw_in_array(self.map_storage.max_map_size) for x in trajectories]
|
||||||
trajectories = [torch.as_tensor(x, dtype=torch.float32) for x in trajectories]
|
trajectories = [torch.as_tensor(x, dtype=torch.float32) for x in trajectories] * 2
|
||||||
trajectories = self._move_to_model_device(torch.stack(trajectories))
|
trajectories = self._move_to_model_device(torch.stack(trajectories))
|
||||||
|
|
||||||
maps = [torch.as_tensor(x.as_array, dtype=torch.float32) for x in maps] * 2
|
maps = [torch.as_tensor(x.as_array, dtype=torch.float32) for x in maps] * 2
|
||||||
@ -294,14 +317,14 @@ class CNNRouteGeneratorDiscriminated(CNNRouteGeneratorModel):
|
|||||||
roc_auc, tpr, fpr = evaluation(labels.cpu().numpy(), pred_label.cpu().numpy(), )
|
roc_auc, tpr, fpr = evaluation(labels.cpu().numpy(), pred_label.cpu().numpy(), )
|
||||||
if test:
|
if test:
|
||||||
# self.logger.log_metrics(score_dict)
|
# self.logger.log_metrics(score_dict)
|
||||||
self.logger.log_image(f'{self.name}_ROC-Curve', plt.gcf())
|
self.logger.log_image(f'{self.name}_ROC-Curve', plt.gcf(), step=self.global_step)
|
||||||
plt.clf()
|
plt.clf()
|
||||||
maps, trajectories, labels, val_restul_dict = self.generate_random()
|
maps, trajectories, labels, val_restul_dict = self.generate_random()
|
||||||
|
|
||||||
from lib.visualization.generator_eval import GeneratorVisualizer
|
from lib.visualization.generator_eval import GeneratorVisualizer
|
||||||
g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict)
|
g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict)
|
||||||
fig = g.draw()
|
fig = g.draw()
|
||||||
self.logger.log_image(f'{self.name}_Output', fig)
|
self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step)
|
||||||
|
|
||||||
return dict(mean_losses=mean_losses, roc_auc=roc_auc, epoch=self.current_epoch)
|
return dict(mean_losses=mean_losses, roc_auc=roc_auc, epoch=self.current_epoch)
|
||||||
|
|
||||||
@ -330,4 +353,4 @@ class CNNRouteGeneratorDiscriminated(CNNRouteGeneratorModel):
|
|||||||
self.criterion = nn.BCELoss()
|
self.criterion = nn.BCELoss()
|
||||||
|
|
||||||
self.dataset = TrajData(self.hparams.data_param.map_root, mode='just_route',
|
self.dataset = TrajData(self.hparams.data_param.map_root, mode='just_route',
|
||||||
length=self.hparams.data_param.dataset_length)
|
length=self.hparams.data_param.dataset_length, normalized=True)
|
||||||
|
@ -4,11 +4,11 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from lib.modules.utils import AutoPad, Interpolate
|
from lib.modules.utils import AutoPad, Interpolate
|
||||||
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# Sub - Modules
|
# Sub - Modules
|
||||||
###################
|
###################
|
||||||
|
|
||||||
|
|
||||||
class ConvModule(nn.Module):
|
class ConvModule(nn.Module):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -60,7 +60,7 @@ class DeConvModule(nn.Module):
|
|||||||
def __init__(self, in_shape, conv_filters=3, conv_kernel=5, conv_stride=1, conv_padding=0,
|
def __init__(self, in_shape, conv_filters=3, conv_kernel=5, conv_stride=1, conv_padding=0,
|
||||||
dropout: Union[int, float] = 0, autopad=False,
|
dropout: Union[int, float] = 0, autopad=False,
|
||||||
activation: Union[None, nn.Module] = nn.ReLU, interpolation_scale=None,
|
activation: Union[None, nn.Module] = nn.ReLU, interpolation_scale=None,
|
||||||
use_bias=True, normalize=False):
|
use_bias=True, use_norm=False):
|
||||||
super(DeConvModule, self).__init__()
|
super(DeConvModule, self).__init__()
|
||||||
in_channels, height, width = in_shape[0], in_shape[1], in_shape[2]
|
in_channels, height, width = in_shape[0], in_shape[1], in_shape[2]
|
||||||
self.padding = conv_padding
|
self.padding = conv_padding
|
||||||
@ -70,7 +70,7 @@ class DeConvModule(nn.Module):
|
|||||||
|
|
||||||
self.autopad = AutoPad() if autopad else lambda x: x
|
self.autopad = AutoPad() if autopad else lambda x: x
|
||||||
self.interpolation = Interpolate(scale_factor=interpolation_scale) if interpolation_scale else lambda x: x
|
self.interpolation = Interpolate(scale_factor=interpolation_scale) if interpolation_scale else lambda x: x
|
||||||
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04, affine=False) if normalize else lambda x: x
|
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04, affine=False) if use_norm else lambda x: x
|
||||||
self.dropout = nn.Dropout2d(dropout) if dropout else lambda x: x
|
self.dropout = nn.Dropout2d(dropout) if dropout else lambda x: x
|
||||||
self.de_conv = nn.ConvTranspose2d(in_channels, self.conv_filters, conv_kernel, bias=use_bias,
|
self.de_conv = nn.ConvTranspose2d(in_channels, self.conv_filters, conv_kernel, bias=use_bias,
|
||||||
padding=self.padding, stride=self.stride)
|
padding=self.padding, stride=self.stride)
|
||||||
|
@ -24,7 +24,7 @@ class Generator(nn.Module):
|
|||||||
self.lat_dim = lat_dim
|
self.lat_dim = lat_dim
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
self.l1 = nn.Linear(self.lat_dim, reduce(mul, re_shape), bias=use_bias)
|
self.l1 = nn.Linear(self.lat_dim, reduce(mul, re_shape), bias=use_bias)
|
||||||
# re_shape = (self.lat_dim // reduce(mul, re_shape[1:]), ) + tuple(re_shape[1:])
|
# re_shape = (self.feature_mixed_dim // reduce(mul, re_shape[1:]), ) + tuple(re_shape[1:])
|
||||||
|
|
||||||
self.flat = Flatten(to=re_shape)
|
self.flat = Flatten(to=re_shape)
|
||||||
|
|
||||||
|
@ -67,6 +67,23 @@ class AutoPad(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class WeightInit:
|
||||||
|
|
||||||
|
def __init__(self, in_place_init_function):
|
||||||
|
self.in_place_init_function = in_place_init_function
|
||||||
|
|
||||||
|
def __call__(self, m):
|
||||||
|
if hasattr(m, 'weight'):
|
||||||
|
if isinstance(m.weight, torch.Tensor):
|
||||||
|
if m.weight.ndim < 2:
|
||||||
|
m.weight.data.fill_(0.01)
|
||||||
|
else:
|
||||||
|
self.in_place_init_function(m.weight)
|
||||||
|
if hasattr(m, 'bias'):
|
||||||
|
if isinstance(m.bias, torch.Tensor):
|
||||||
|
m.bias.data.fill_(0.01)
|
||||||
|
|
||||||
|
|
||||||
class LightningBaseModule(pl.LightningModule, ABC):
|
class LightningBaseModule(pl.LightningModule, ABC):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -128,15 +145,9 @@ class LightningBaseModule(pl.LightningModule, ABC):
|
|||||||
def test_epoch_end(self, outputs):
|
def test_epoch_end(self, outputs):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def init_weights(self):
|
def init_weights(self, in_place_init_func_=nn.init.xavier_uniform_):
|
||||||
def _weight_init(m):
|
weight_initializer = WeightInit(in_place_init_function=in_place_init_func_)
|
||||||
if hasattr(m, 'weight'):
|
self.apply(weight_initializer)
|
||||||
if isinstance(m.weight, torch.Tensor):
|
|
||||||
torch.nn.init.xavier_uniform_(m.weight)
|
|
||||||
if hasattr(m, 'bias'):
|
|
||||||
if isinstance(m.bias, torch.Tensor):
|
|
||||||
m.bias.data.fill_(0.01)
|
|
||||||
self.apply(_weight_init)
|
|
||||||
|
|
||||||
# Dataloaders
|
# Dataloaders
|
||||||
# ================================================================================
|
# ================================================================================
|
||||||
|
@ -8,6 +8,7 @@ from pathlib import Path
|
|||||||
from lib.models.generators.cnn import CNNRouteGeneratorModel, CNNRouteGeneratorDiscriminated
|
from lib.models.generators.cnn import CNNRouteGeneratorModel, CNNRouteGeneratorDiscriminated
|
||||||
from lib.models.homotopy_classification.cnn_based import ConvHomDetector
|
from lib.models.homotopy_classification.cnn_based import ConvHomDetector
|
||||||
from lib.utils.model_io import ModelParameters
|
from lib.utils.model_io import ModelParameters
|
||||||
|
from lib.utils.transforms import AsArray
|
||||||
|
|
||||||
|
|
||||||
def is_jsonable(x):
|
def is_jsonable(x):
|
||||||
|
@ -5,10 +5,12 @@ from pytorch_lightning.loggers.neptune import NeptuneLogger
|
|||||||
from pytorch_lightning.loggers.test_tube import TestTubeLogger
|
from pytorch_lightning.loggers.test_tube import TestTubeLogger
|
||||||
|
|
||||||
from lib.utils.config import Config
|
from lib.utils.config import Config
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
class Logger(LightningLoggerBase):
|
class Logger(LightningLoggerBase):
|
||||||
|
|
||||||
|
media_dir = 'media'
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def experiment(self):
|
def experiment(self):
|
||||||
if self.debug:
|
if self.debug:
|
||||||
@ -84,7 +86,9 @@ class Logger(LightningLoggerBase):
|
|||||||
|
|
||||||
def log_image(self, name, image, **kwargs):
|
def log_image(self, name, image, **kwargs):
|
||||||
self.neptunelogger.log_image(name, image, **kwargs)
|
self.neptunelogger.log_image(name, image, **kwargs)
|
||||||
image.savefig(self.log_dir / name)
|
step = kwargs.get('step', None)
|
||||||
|
name = f'{step}_{name}' if step is not None else name
|
||||||
|
image.savefig(self.log_dir / self.media_dir / name)
|
||||||
|
|
||||||
def save(self):
|
def save(self):
|
||||||
self.testtubelogger.save()
|
self.testtubelogger.save()
|
||||||
|
@ -8,5 +8,4 @@ class AsArray(object):
|
|||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
array = np.zeros((self.width, self.height))
|
array = np.zeros((self.width, self.height))
|
||||||
|
|
||||||
return array
|
return array
|
||||||
|
@ -5,3 +5,5 @@ HOMOTOPIC = 1
|
|||||||
ALTERNATIVE = 0
|
ALTERNATIVE = 0
|
||||||
WHITE = 255
|
WHITE = 255
|
||||||
BLACK = 0
|
BLACK = 0
|
||||||
|
|
||||||
|
DPI = 100
|
||||||
|
@ -1,36 +1,49 @@
|
|||||||
import torch
|
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
from mpl_toolkits.axisartist.axes_grid import ImageGrid
|
from mpl_toolkits.axisartist.axes_grid import ImageGrid
|
||||||
from tqdm import tqdm
|
import lib.variables as V
|
||||||
from typing import List
|
|
||||||
|
|
||||||
|
|
||||||
class GeneratorVisualizer(object):
|
class GeneratorVisualizer(object):
|
||||||
|
|
||||||
def __init__(self, maps, trajectories, labels, val_result_dict):
|
def __init__(self, maps, trajectories, labels, val_result_dict):
|
||||||
# val_results = dict(discriminated_bce_loss, batch_nb, pred_label, label, generated_alternative)
|
# val_results = dict(discriminated_bce_loss, batch_nb, pred_label, label, generated_alternative)
|
||||||
self.generated_alternatives = val_result_dict['generated_alternative']
|
self.alternatives = val_result_dict['generated_alternative']
|
||||||
self.pred_labels = val_result_dict['pred_label']
|
|
||||||
self.labels = labels
|
self.labels = labels
|
||||||
self.trajectories = trajectories
|
self.trajectories = trajectories
|
||||||
self.maps = maps
|
self.maps = maps
|
||||||
|
self._map_width, self._map_height = self.maps[0].squeeze().shape
|
||||||
self.column_dict_list = self._build_column_dict_list()
|
self.column_dict_list = self._build_column_dict_list()
|
||||||
|
self._cols = len(self.column_dict_list)
|
||||||
|
self._rows = len(self.column_dict_list[0])
|
||||||
|
|
||||||
def _build_column_dict_list(self):
|
def _build_column_dict_list(self):
|
||||||
dict_list = []
|
trajectories = []
|
||||||
for idx in range(self.maps.shape[0]):
|
non_hom_alternatives = []
|
||||||
image = (self.maps[idx] + self.trajectories[idx] + self.generated_alternatives[idx]).cpu().numpy().squeeze()
|
hom_alternatives = []
|
||||||
label = int(self.labels[idx])
|
|
||||||
dict_list.append(dict(image=image, label=label))
|
for idx in range(self.alternatives.shape[0]):
|
||||||
half_size = int(len(dict_list) // 2)
|
image = (self.alternatives[idx]).cpu().numpy().squeeze()
|
||||||
return dict_list[:half_size], dict_list[half_size:]
|
label = self.labels[idx].item()
|
||||||
|
if label == V.HOMOTOPIC:
|
||||||
|
hom_alternatives.append(dict(image=image, label='Homotopic'))
|
||||||
|
else:
|
||||||
|
non_hom_alternatives.append(dict(image=image, label='NonHomotopic'))
|
||||||
|
for idx in range(max(len(hom_alternatives), len(non_hom_alternatives))):
|
||||||
|
image = (self.maps[idx] + self.trajectories[idx]).cpu().numpy().squeeze()
|
||||||
|
label = 'original'
|
||||||
|
trajectories.append(dict(image=image, label=label))
|
||||||
|
|
||||||
|
return trajectories, hom_alternatives, non_hom_alternatives
|
||||||
|
|
||||||
def draw(self):
|
def draw(self):
|
||||||
fig = plt.figure()
|
padding = 0.25
|
||||||
|
additional_size = self._cols * padding + 3 * padding
|
||||||
|
width = (self._map_width * self._cols) / V.DPI + additional_size
|
||||||
|
height = (self._map_height * self._rows) / V.DPI + additional_size
|
||||||
|
fig = plt.figure(figsize=(width, height), dpi=V.DPI)
|
||||||
grid = ImageGrid(fig, 111, # similar to subplot(111)
|
grid = ImageGrid(fig, 111, # similar to subplot(111)
|
||||||
nrows_ncols=(len(self.column_dict_list[0]), len(self.column_dict_list)),
|
nrows_ncols=(self._rows, self._cols),
|
||||||
axes_pad=0.2, # pad between axes in inch.
|
axes_pad=padding, # pad between axes in inch.
|
||||||
)
|
)
|
||||||
|
|
||||||
for idx in range(len(grid.axes_all)):
|
for idx in range(len(grid.axes_all)):
|
||||||
@ -40,4 +53,5 @@ class GeneratorVisualizer(object):
|
|||||||
grid[idx].imshow(current_image)
|
grid[idx].imshow(current_image)
|
||||||
grid[idx].title.set_text(current_label)
|
grid[idx].title.set_text(current_label)
|
||||||
fig.cbar_mode = 'single'
|
fig.cbar_mode = 'single'
|
||||||
|
fig.tight_layout()
|
||||||
return fig
|
return fig
|
||||||
|
27
main.py
27
main.py
@ -28,14 +28,16 @@ main_arg_parser = ArgumentParser(description="parser for fast-neural-style")
|
|||||||
|
|
||||||
# Main Parameters
|
# Main Parameters
|
||||||
main_arg_parser.add_argument("--main_debug", type=strtobool, default=False, help="")
|
main_arg_parser.add_argument("--main_debug", type=strtobool, default=False, help="")
|
||||||
main_arg_parser.add_argument("--main_eval", type=strtobool, default=False, help="")
|
main_arg_parser.add_argument("--main_eval", type=strtobool, default=True, help="")
|
||||||
main_arg_parser.add_argument("--main_seed", type=int, default=69, help="")
|
main_arg_parser.add_argument("--main_seed", type=int, default=69, help="")
|
||||||
|
|
||||||
# Data Parameters
|
# Data Parameters
|
||||||
main_arg_parser.add_argument("--data_worker", type=int, default=10, help="")
|
main_arg_parser.add_argument("--data_worker", type=int, default=10, help="")
|
||||||
main_arg_parser.add_argument("--data_dataset_length", type=int, default=10000, help="")
|
main_arg_parser.add_argument("--data_dataset_length", type=int, default=100000, help="")
|
||||||
main_arg_parser.add_argument("--data_root", type=str, default='data', help="")
|
main_arg_parser.add_argument("--data_root", type=str, default='data', help="")
|
||||||
main_arg_parser.add_argument("--data_map_root", type=str, default='res/shapes', help="")
|
main_arg_parser.add_argument("--data_map_root", type=str, default='res/shapes', help="")
|
||||||
|
main_arg_parser.add_argument("--data_normalized", type=strtobool, default=True, help="")
|
||||||
|
|
||||||
|
|
||||||
# Transformations
|
# Transformations
|
||||||
main_arg_parser.add_argument("--transformations_to_tensor", type=strtobool, default=False, help="")
|
main_arg_parser.add_argument("--transformations_to_tensor", type=strtobool, default=False, help="")
|
||||||
@ -43,16 +45,16 @@ main_arg_parser.add_argument("--transformations_to_tensor", type=strtobool, defa
|
|||||||
# Transformations
|
# Transformations
|
||||||
main_arg_parser.add_argument("--train_outpath", type=str, default="output", help="")
|
main_arg_parser.add_argument("--train_outpath", type=str, default="output", help="")
|
||||||
main_arg_parser.add_argument("--train_version", type=strtobool, required=False, help="")
|
main_arg_parser.add_argument("--train_version", type=strtobool, required=False, help="")
|
||||||
main_arg_parser.add_argument("--train_epochs", type=int, default=10, help="")
|
main_arg_parser.add_argument("--train_epochs", type=int, default=20, help="")
|
||||||
main_arg_parser.add_argument("--train_batch_size", type=int, default=256, help="")
|
main_arg_parser.add_argument("--train_batch_size", type=int, default=164, help="")
|
||||||
main_arg_parser.add_argument("--train_lr", type=float, default=0.002, help="")
|
main_arg_parser.add_argument("--train_lr", type=float, default=0.002, help="")
|
||||||
|
|
||||||
# Model
|
# Model
|
||||||
main_arg_parser.add_argument("--model_type", type=str, default="CNNRouteGenerator", help="")
|
main_arg_parser.add_argument("--model_type", type=str, default="CNNRouteGenerator", help="")
|
||||||
main_arg_parser.add_argument("--model_activation", type=str, default="relu", help="")
|
main_arg_parser.add_argument("--model_activation", type=str, default="elu", help="")
|
||||||
main_arg_parser.add_argument("--model_filters", type=str, default="[32, 16, 4]", help="")
|
main_arg_parser.add_argument("--model_filters", type=str, default="[32, 16, 4]", help="")
|
||||||
main_arg_parser.add_argument("--model_classes", type=int, default=2, help="")
|
main_arg_parser.add_argument("--model_classes", type=int, default=2, help="")
|
||||||
main_arg_parser.add_argument("--model_lat_dim", type=int, default=2, help="")
|
main_arg_parser.add_argument("--model_lat_dim", type=int, default=4, help="")
|
||||||
main_arg_parser.add_argument("--model_use_bias", type=strtobool, default=True, help="")
|
main_arg_parser.add_argument("--model_use_bias", type=strtobool, default=True, help="")
|
||||||
main_arg_parser.add_argument("--model_use_norm", type=strtobool, default=True, help="")
|
main_arg_parser.add_argument("--model_use_norm", type=strtobool, default=True, help="")
|
||||||
main_arg_parser.add_argument("--model_dropout", type=float, default=0.00, help="")
|
main_arg_parser.add_argument("--model_dropout", type=float, default=0.00, help="")
|
||||||
@ -93,10 +95,10 @@ def run_lightning_loop(config_obj):
|
|||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Init
|
# Init
|
||||||
model: LightningBaseModule = config_obj.model_class(config_obj.model_paramters)
|
model: LightningBaseModule = config_obj.model_class(config_obj.model_paramters)
|
||||||
model.init_weights()
|
model.init_weights(torch.nn.init.xavier_normal_)
|
||||||
if model.name == 'CNNRouteGeneratorDiscriminated':
|
if model.name == 'CNNRouteGeneratorDiscriminated':
|
||||||
# ToDo: Make this dependent on the used seed
|
# ToDo: Make this dependent on the used seed
|
||||||
path = Path(Path(config_obj.train.outpath) / 'classifier_cnn' / 'trained')
|
path = Path(Path(config_obj.train.outpath) / 'classifier_cnn' / 'version_0')
|
||||||
disc_model = SavedLightningModels.load_checkpoint(path).restore()
|
disc_model = SavedLightningModels.load_checkpoint(path).restore()
|
||||||
model.set_discriminator(disc_model)
|
model.set_discriminator(disc_model)
|
||||||
|
|
||||||
@ -107,14 +109,14 @@ def run_lightning_loop(config_obj):
|
|||||||
weights_save_path=logger.log_dir,
|
weights_save_path=logger.log_dir,
|
||||||
gpus=[0] if torch.cuda.is_available() else None,
|
gpus=[0] if torch.cuda.is_available() else None,
|
||||||
check_val_every_n_epoch=1,
|
check_val_every_n_epoch=1,
|
||||||
|
num_sanity_val_steps=0,
|
||||||
# row_log_interval=(model.n_train_batches * 0.1), # TODO: Better Value / Setting
|
# row_log_interval=(model.n_train_batches * 0.1), # TODO: Better Value / Setting
|
||||||
# log_save_interval=(model.n_train_batches * 0.2), # TODO: Better Value / Setting
|
# log_save_interval=(model.n_train_batches * 0.2), # TODO: Better Value / Setting
|
||||||
checkpoint_callback=checkpoint_callback,
|
checkpoint_callback=checkpoint_callback,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
|
val_percent_check=0.05,
|
||||||
fast_dev_run=config_obj.main.debug,
|
fast_dev_run=config_obj.main.debug,
|
||||||
early_stop_callback=None,
|
early_stop_callback=None
|
||||||
val_percent_check=0.10,
|
|
||||||
num_sanity_val_steps=1,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Train It
|
# Train It
|
||||||
@ -125,7 +127,8 @@ def run_lightning_loop(config_obj):
|
|||||||
model.save_to_disk(logger.log_dir)
|
model.save_to_disk(logger.log_dir)
|
||||||
|
|
||||||
# Evaluate It
|
# Evaluate It
|
||||||
trainer.test()
|
if config_obj.main.eval:
|
||||||
|
trainer.test()
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
@ -18,7 +18,7 @@ if __name__ == '__main__':
|
|||||||
# use_bias, activation, model, use_norm, max_epochs, filters
|
# use_bias, activation, model, use_norm, max_epochs, filters
|
||||||
cnn_classifier = dict(train_epochs=10, model_use_bias=True, model_use_norm=True, model_activation='leaky_relu',
|
cnn_classifier = dict(train_epochs=10, model_use_bias=True, model_use_norm=True, model_activation='leaky_relu',
|
||||||
model_type='classifier_cnn', model_filters=[16, 32, 64], data_batchsize=512)
|
model_type='classifier_cnn', model_filters=[16, 32, 64], data_batchsize=512)
|
||||||
# use_bias, activation, model, use_norm, max_epochs, sr, lat_dim, filters
|
# use_bias, activation, model, use_norm, max_epochs, sr, feature_mixed_dim, filters
|
||||||
|
|
||||||
for arg_dict in [cnn_classifier]:
|
for arg_dict in [cnn_classifier]:
|
||||||
for seed in range(5):
|
for seed in range(5):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user