Merge remote-tracking branch 'origin/master'
# Conflicts: # res/shapes/shapes_1.bmp # res/shapes/shapes_2.bmp # res/shapes/shapes_3.bmp # res/shapes/shapes_4.bmp # res/shapes/shapes_5.bmp # res/shapes/shapes_6.bmp
41
.gitignore
vendored
@@ -2,43 +2,11 @@
|
||||
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
|
||||
|
||||
# User-specific stuff
|
||||
.idea/
|
||||
|
||||
# Generated files
|
||||
.idea/**/contentModel.xml
|
||||
|
||||
# Sensitive or high-churn files
|
||||
.idea/**/dataSources/
|
||||
.idea/**/dataSources.ids
|
||||
.idea/**/dataSources.local.xml
|
||||
.idea/**/sqlDataSources.xml
|
||||
.idea/**/dynamic.xml
|
||||
.idea/**/uiDesigner.xml
|
||||
.idea/**/dbnavigator.xml
|
||||
|
||||
# Gradle
|
||||
.idea/**/gradle.xml
|
||||
.idea/**/libraries
|
||||
|
||||
# Gradle and Maven with auto-import
|
||||
# When using Gradle or Maven with auto-import, you should exclude module files,
|
||||
# since they will be recreated, and may cause churn. Uncomment if using
|
||||
# auto-import.
|
||||
# .idea/artifacts
|
||||
# .idea/compiler.xml
|
||||
# .idea/jarRepositories.xml
|
||||
# .idea/modules.xml
|
||||
# .idea/*.iml
|
||||
# .idea/modules
|
||||
# *.iml
|
||||
# *.ipr
|
||||
.idea/**
|
||||
|
||||
# CMake
|
||||
cmake-build-*/
|
||||
|
||||
# Mongo Explorer plugin
|
||||
.idea/**/mongoSettings.xml
|
||||
|
||||
# File-based project format
|
||||
*.iws
|
||||
|
||||
@@ -59,10 +27,3 @@ com_crashlytics_export_strings.xml
|
||||
crashlytics.properties
|
||||
crashlytics-build.properties
|
||||
fabric.properties
|
||||
|
||||
# Editor-based Rest Client
|
||||
.idea/httpRequests
|
||||
|
||||
# Android studio 3.1+ serialized cache file
|
||||
.idea/caches/build_file_checksums.ser
|
||||
/.idea/inspectionProfiles/
|
||||
|
2
.idea/.gitignore
generated
vendored
@@ -1,2 +0,0 @@
|
||||
# Default ignored files
|
||||
/workspace.xml
|
22
.idea/deployment.xml
generated
@@ -1,22 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="PublishConfigData" autoUpload="On explicit save action" serverName="steffen@aimachine:22" showAutoUploadSettingsWarning="false">
|
||||
<serverData>
|
||||
<paths name="erlowa@aimachine">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping deploy="/" local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="steffen@aimachine:22">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping deploy="/" local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
</serverData>
|
||||
<option name="myAutoUpload" value="ON_EXPLICIT_SAVE" />
|
||||
</component>
|
||||
</project>
|
23
.idea/dictionaries/steffen.xml
generated
@@ -1,23 +0,0 @@
|
||||
<component name="ProjectDictionaryState">
|
||||
<dictionary name="steffen">
|
||||
<words>
|
||||
<w>autopad</w>
|
||||
<w>conv</w>
|
||||
<w>convolutional</w>
|
||||
<w>dataloader</w>
|
||||
<w>dataloaders</w>
|
||||
<w>datasets</w>
|
||||
<w>homotopic</w>
|
||||
<w>hparams</w>
|
||||
<w>hyperparamter</w>
|
||||
<w>kingma</w>
|
||||
<w>logvar</w>
|
||||
<w>mapname</w>
|
||||
<w>mapnames</w>
|
||||
<w>numlayers</w>
|
||||
<w>reparameterize</w>
|
||||
<w>softmax</w>
|
||||
<w>traj</w>
|
||||
</words>
|
||||
</dictionary>
|
||||
</component>
|
8
.idea/hom_traj_gen.iml
generated
@@ -1,8 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$" />
|
||||
<orderEntry type="jdk" jdkName="Remote Python 3.7.6 (sftp://steffen@aimachine:22/home/steffen/envs/traj_gen/bin/python)" jdkType="Python SDK" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
</module>
|
7
.idea/inspectionProfiles/profiles_settings.xml
generated
@@ -1,7 +0,0 @@
|
||||
<component name="InspectionProjectProfileManager">
|
||||
<settings>
|
||||
<option name="PROJECT_PROFILE" value="Default" />
|
||||
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||
<version value="1.0" />
|
||||
</settings>
|
||||
</component>
|
10
.idea/misc.xml
generated
@@ -1,10 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="JavaScriptSettings">
|
||||
<option name="languageLevel" value="ES6" />
|
||||
</component>
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="traj_gen@ai-machine" project-jdk-type="Python SDK" />
|
||||
<component name="PyPackaging">
|
||||
<option name="earlyReleasesAsUpgrades" value="true" />
|
||||
</component>
|
||||
</project>
|
8
.idea/modules.xml
generated
@@ -1,8 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectModuleManager">
|
||||
<modules>
|
||||
<module fileurl="file://$PROJECT_DIR$/.idea/hom_traj_gen.iml" filepath="$PROJECT_DIR$/.idea/hom_traj_gen.iml" />
|
||||
</modules>
|
||||
</component>
|
||||
</project>
|
6
.idea/vcs.xml
generated
@@ -1,6 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="VcsDirectoryMappings">
|
||||
<mapping directory="$PROJECT_DIR$" vcs="Git" />
|
||||
</component>
|
||||
</project>
|
15
.idea/webResources.xml
generated
@@ -1,15 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="WebResourcesPaths">
|
||||
<contentEntries>
|
||||
<entry url="file://$PROJECT_DIR$">
|
||||
<entryData>
|
||||
<resourceRoots>
|
||||
<path value="file://$PROJECT_DIR$/res" />
|
||||
<path value="file://$PROJECT_DIR$/data" />
|
||||
</resourceRoots>
|
||||
</entryData>
|
||||
</entry>
|
||||
</contentEntries>
|
||||
</component>
|
||||
</project>
|
@@ -102,7 +102,7 @@ class TrajData(object):
|
||||
|
||||
def _load_datasets(self):
|
||||
map_files = list(self.maps_root.glob('*.bmp'))
|
||||
equal_split = int(self.length // len(map_files))
|
||||
equal_split = int(self.length // len(map_files)) or 1
|
||||
|
||||
# find max image size among available maps:
|
||||
max_map_size = (1, ) + tuple(reversed(tuple(map(max, *[Image.open(map_file).size for map_file in map_files]))))
|
||||
|
@@ -1,3 +1,7 @@
|
||||
from statistics import mean
|
||||
|
||||
from random import choice
|
||||
|
||||
import torch
|
||||
from functools import reduce
|
||||
from operator import mul
|
||||
@@ -6,9 +10,12 @@ from torch import nn
|
||||
from torch.optim import Adam
|
||||
|
||||
from datasets.trajectory_dataset import TrajData
|
||||
from lib.evaluation.classification import ROCEvaluation
|
||||
from lib.modules.blocks import ConvModule, ResidualModule, DeConvModule
|
||||
from lib.modules.utils import LightningBaseModule, Flatten
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
|
||||
@@ -33,14 +40,53 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
# https://arxiv.org/abs/1312.6114
|
||||
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
|
||||
kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
|
||||
# Dimensional Resizing
|
||||
kld_loss /= self.in_shape
|
||||
|
||||
loss = (kld_loss + discriminated_bce_loss) / 2
|
||||
return dict(loss=loss, log=dict(loss=loss,
|
||||
discriminated_bce_loss=discriminated_bce_loss,
|
||||
kld_loss=kld_loss)
|
||||
)
|
||||
|
||||
def test_step(self, *args, **kwargs):
|
||||
pass
|
||||
def _test_val_step(self, batch_xy, batch_nb, *args):
|
||||
batch_x, label = batch_xy
|
||||
|
||||
generated_alternative, z, mu, logvar = self(batch_x + [label, ])
|
||||
map_array, trajectory = batch_x
|
||||
|
||||
map_stack = torch.cat((map_array, trajectory, generated_alternative), dim=1)
|
||||
pred_label = self.discriminator(map_stack)
|
||||
|
||||
discriminated_bce_loss = self.criterion(pred_label, label.float().unsqueeze(-1))
|
||||
return dict(discriminated_bce_loss=discriminated_bce_loss, batch_nb=batch_nb,
|
||||
pred_label=pred_label, label=label, generated_alternative=generated_alternative)
|
||||
|
||||
def validation_step(self, *args):
|
||||
return self._test_val_step(*args)
|
||||
|
||||
def validation_epoch_end(self, outputs):
|
||||
evaluation = ROCEvaluation(plot_roc=True)
|
||||
pred_label = torch.cat([x['pred_label'] for x in outputs])
|
||||
labels = torch.cat([x['label'] for x in outputs]).unsqueeze(1)
|
||||
mean_losses = torch.stack([x['discriminated_bce_loss'] for x in outputs]).mean()
|
||||
|
||||
# Sci-py call ROC eval call is eval(true_label, prediction)
|
||||
roc_auc, tpr, fpr = evaluation(labels.cpu().numpy(), pred_label.cpu().numpy(), )
|
||||
# self.logger.log_metrics(score_dict)
|
||||
self.logger.log_image(f'{self.name}_ROC-Curve_E{self.current_epoch}', plt.gcf())
|
||||
plt.clf()
|
||||
maps, trajectories, labels, val_restul_dict = self.generate_random()
|
||||
|
||||
from lib.visualization.generator_eval import GeneratorVisualizer
|
||||
g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict)
|
||||
fig = g.draw()
|
||||
self.logger.log_image(f'{self.name}_Output_E{self.current_epoch}', fig)
|
||||
|
||||
return dict(mean_losses=mean_losses, roc_auc=roc_auc, epoch=self.current_epoch)
|
||||
|
||||
def test_step(self, *args):
|
||||
return self._test_val_step(*args)
|
||||
|
||||
@property
|
||||
def discriminator(self):
|
||||
@@ -57,12 +103,14 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
super(CNNRouteGeneratorModel, self).__init__(*params)
|
||||
|
||||
# Dataset
|
||||
self.dataset = TrajData(self.hparams.data_param.map_root, mode='just_route')
|
||||
self.dataset = TrajData(self.hparams.data_param.map_root, mode='just_route',
|
||||
length=self.hparams.data_param.dataset_length)
|
||||
|
||||
# Additional Attributes
|
||||
self.in_shape = self.dataset.map_shapes_max
|
||||
# Todo: Better naming and size in Parameters
|
||||
self.feature_dim = 10
|
||||
self.lat_dim = self.feature_dim + self.feature_dim + 1
|
||||
self._disc = None
|
||||
|
||||
# NN Nodes
|
||||
@@ -70,6 +118,7 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
#
|
||||
# Utils
|
||||
self.relu = nn.ReLU()
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
self.criterion = nn.MSELoss()
|
||||
|
||||
#
|
||||
@@ -111,10 +160,14 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
|
||||
self.traj_lin = nn.Linear(reduce(mul, self.traj_conv_3.shape), self.feature_dim)
|
||||
|
||||
#
|
||||
# Mixed Encoder
|
||||
self.mixed_lin = nn.Linear(self.lat_dim, self.lat_dim)
|
||||
|
||||
#
|
||||
# Variational Bottleneck
|
||||
self.mu = nn.Linear(self.feature_dim + self.feature_dim + 1, self.hparams.model_param.lat_dim)
|
||||
self.logvar = nn.Linear(self.feature_dim + self.feature_dim + 1, self.hparams.model_param.lat_dim)
|
||||
self.mu = nn.Linear(self.lat_dim, self.hparams.model_param.lat_dim)
|
||||
self.logvar = nn.Linear(self.lat_dim, self.hparams.model_param.lat_dim)
|
||||
|
||||
#
|
||||
# Alternative Generator
|
||||
@@ -139,6 +192,32 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
|
||||
#
|
||||
# Encode
|
||||
z, mu, logvar = self.encode(map_array, trajectory, label)
|
||||
|
||||
#
|
||||
# Generate
|
||||
alt_tensor = self.generate(z)
|
||||
|
||||
return alt_tensor, z, mu, logvar
|
||||
|
||||
@staticmethod
|
||||
def reparameterize(mu, logvar):
|
||||
std = torch.exp(0.5 * logvar)
|
||||
eps = torch.randn_like(std)
|
||||
return mu + eps * std
|
||||
|
||||
def generate(self, z):
|
||||
alt_tensor = self.alt_lin_1(z)
|
||||
alt_tensor = self.alt_lin_2(alt_tensor)
|
||||
alt_tensor = self.reshape_to_map(alt_tensor)
|
||||
alt_tensor = self.alt_deconv_1(alt_tensor)
|
||||
alt_tensor = self.alt_deconv_2(alt_tensor)
|
||||
alt_tensor = self.alt_deconv_3(alt_tensor)
|
||||
alt_tensor = self.alt_deconv_out(alt_tensor)
|
||||
alt_tensor = self.sigmoid(alt_tensor)
|
||||
return alt_tensor
|
||||
|
||||
def encode(self, map_array, trajectory, label):
|
||||
map_tensor = self.map_conv_0(map_array)
|
||||
map_tensor = self.map_res_1(map_tensor)
|
||||
map_tensor = self.map_conv_1(map_tensor)
|
||||
@@ -157,27 +236,19 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
|
||||
mixed_tensor = torch.cat((map_tensor, traj_tensor, label.float().unsqueeze(-1)), dim=1)
|
||||
mixed_tensor = self.relu(mixed_tensor)
|
||||
mixed_tensor = self.mixed_lin(mixed_tensor)
|
||||
mixed_tensor = self.relu(mixed_tensor)
|
||||
|
||||
#
|
||||
# Parameter and Sampling
|
||||
mu = self.mu(mixed_tensor)
|
||||
logvar = self.logvar(mixed_tensor)
|
||||
z = self.reparameterize(mu, logvar)
|
||||
return z, mu, logvar
|
||||
|
||||
#
|
||||
# Generate
|
||||
alt_tensor = self.alt_lin_1(z)
|
||||
alt_tensor = self.alt_lin_2(alt_tensor)
|
||||
alt_tensor = self.reshape_to_map(alt_tensor)
|
||||
alt_tensor = self.alt_deconv_1(alt_tensor)
|
||||
alt_tensor = self.alt_deconv_2(alt_tensor)
|
||||
alt_tensor = self.alt_deconv_3(alt_tensor)
|
||||
alt_tensor = self.alt_deconv_out(alt_tensor)
|
||||
|
||||
return alt_tensor, z, mu, logvar
|
||||
|
||||
@staticmethod
|
||||
def reparameterize(mu, logvar):
|
||||
std = torch.exp(0.5 * logvar)
|
||||
eps = torch.randn_like(std)
|
||||
return mu + eps * std
|
||||
def generate_random(self, n=6):
|
||||
maps = [self.map_storage[choice(self.map_storage.keys)] for _ in range(n)]
|
||||
trajectories = torch.stack([x.get_random_trajectory() for x in maps] * 2)
|
||||
maps = torch.stack([x.as_2d_array for x in maps] * 2)
|
||||
labels = torch.as_tensor([0] * n + [1] * n)
|
||||
return maps, trajectories, labels, self._test_val_step(maps, trajectories, labels)
|
||||
|
@@ -57,7 +57,8 @@ class ConvHomDetector(LightningBaseModule):
|
||||
# Model Parameters
|
||||
self.in_shape = self.dataset.map_shapes_max
|
||||
assert len(self.in_shape) == 3, f'Image or map shape has to have 3 dims, but had: {len(self.in_shape)}'
|
||||
self.criterion = nn.BCEWithLogitsLoss()
|
||||
self.criterion = nn.BCELoss()
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
# NN Nodes
|
||||
# ============================
|
||||
@@ -100,4 +101,5 @@ class ConvHomDetector(LightningBaseModule):
|
||||
tensor = self.flatten(tensor)
|
||||
tensor = self.linear(tensor)
|
||||
tensor = self.classifier(tensor)
|
||||
tensor = self.sigmoid(tensor)
|
||||
return tensor
|
||||
|
@@ -90,7 +90,7 @@ class LightningBaseModule(pl.LightningModule, ABC):
|
||||
# Data loading
|
||||
# =============================================================================
|
||||
# Map Object
|
||||
self.map_storage = MapStorage(self.hparams.data_param.map_root)
|
||||
self.map_storage = MapStorage(self.hparams.data_param.map_root, load_all=True)
|
||||
|
||||
def size(self):
|
||||
return self.shape
|
||||
@@ -143,19 +143,19 @@ class LightningBaseModule(pl.LightningModule, ABC):
|
||||
# Train Dataloader
|
||||
def train_dataloader(self):
|
||||
return DataLoader(dataset=self.dataset.train_dataset, shuffle=True,
|
||||
batch_size=self.hparams.data_param.batchsize,
|
||||
batch_size=self.hparams.train_param.batch_size,
|
||||
num_workers=self.hparams.data_param.worker)
|
||||
|
||||
# Test Dataloader
|
||||
def test_dataloader(self):
|
||||
return DataLoader(dataset=self.dataset.test_dataset, shuffle=True,
|
||||
batch_size=self.hparams.data_param.batchsize,
|
||||
batch_size=self.hparams.train_param.batch_size,
|
||||
num_workers=self.hparams.data_param.worker)
|
||||
|
||||
# Validation Dataloader
|
||||
def val_dataloader(self):
|
||||
return DataLoader(dataset=self.dataset.val_dataset, shuffle=False,
|
||||
batch_size=self.hparams.data_param.batchsize,
|
||||
batch_size=self.hparams.train_param.batch_size,
|
||||
num_workers=self.hparams.data_param.worker)
|
||||
|
||||
|
||||
|
@@ -146,6 +146,7 @@ class Map(object):
|
||||
|
||||
img = Image.new('L', (self.height, self.width), 0)
|
||||
draw = ImageDraw.Draw(img)
|
||||
|
||||
draw.polygon(polyline, outline=self.white, fill=self.white)
|
||||
|
||||
a = (np.where(np.asarray(img) == self.white, 1, 0) * np.where(self.as_2d_array == self.black, 1, 0)).sum()
|
||||
@@ -166,6 +167,10 @@ class Map(object):
|
||||
|
||||
class MapStorage(object):
|
||||
|
||||
@property
|
||||
def keys(self):
|
||||
return list(self.data.keys())
|
||||
|
||||
def __init__(self, map_root, load_all=False):
|
||||
self.data = dict()
|
||||
self.map_root = Path(map_root)
|
||||
@@ -174,11 +179,11 @@ class MapStorage(object):
|
||||
_ = self[map_file.name]
|
||||
|
||||
def __getitem__(self, item):
|
||||
if item in hasattr(self, item):
|
||||
return self.__getattribute__(item)
|
||||
if item in self.data.keys():
|
||||
return self.data.get(item)
|
||||
else:
|
||||
with shelve.open(self.map_root / f'{item}.pik', flag='r') as d:
|
||||
self.__setattr__(item, d['map']['map'])
|
||||
current_map = Map().from_image(self.map_root / item)
|
||||
self.data.__setitem__(item, np.asarray(current_map))
|
||||
return self[item]
|
||||
|
||||
|
||||
|
43
lib/visualization/generator_eval.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import torch
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from mpl_toolkits.axisartist.axes_grid import ImageGrid
|
||||
from tqdm import tqdm
|
||||
from typing import List
|
||||
|
||||
|
||||
class GeneratorVisualizer(object):
|
||||
|
||||
def __init__(self, maps, trajectories, labels, val_result_dict):
|
||||
# val_results = dict(discriminated_bce_loss, batch_nb, pred_label, label, generated_alternative)
|
||||
self.generated_alternatives = val_result_dict['generated_alternative']
|
||||
self.pred_labels = val_result_dict['pred_label']
|
||||
self.labels = labels
|
||||
self.trajectories = trajectories
|
||||
self.maps = maps
|
||||
self.column_dict_list = self._build_column_dict_list()
|
||||
|
||||
def _build_column_dict_list(self):
|
||||
dict_list = []
|
||||
for idx in range(self.maps):
|
||||
image = self.maps[idx] + self.trajectories[idx] + self.generated_alternatives
|
||||
label = self.labels[idx]
|
||||
dict_list.append(dict(image=image, label=label))
|
||||
half_size = int(len(dict_list) // 2)
|
||||
return dict_list[:half_size], dict_list[half_size:]
|
||||
|
||||
def draw(self):
|
||||
fig = plt.figure()
|
||||
grid = ImageGrid(fig, 111, # similar to subplot(111)
|
||||
nrows_ncols=(len(self.column_dict_list[0]), len(self.column_dict_list)),
|
||||
axes_pad=0.2, # pad between axes in inch.
|
||||
)
|
||||
|
||||
for idx in grid.axes_all:
|
||||
row, col = divmod(idx, len(self.column_dict_list))
|
||||
current_image = self.column_dict_list[col]['image'][row]
|
||||
current_label = self.column_dict_list[col]['label'][row]
|
||||
grid[idx].imshow(current_image)
|
||||
grid[idx].title.set_text(current_label)
|
||||
fig.cbar_mode = 'single'
|
||||
return fig
|
3
main.py
@@ -33,7 +33,7 @@ main_arg_parser.add_argument("--main_seed", type=int, default=69, help="")
|
||||
|
||||
# Data Parameters
|
||||
main_arg_parser.add_argument("--data_worker", type=int, default=10, help="")
|
||||
main_arg_parser.add_argument("--data_batchsize", type=int, default=100, help="")
|
||||
main_arg_parser.add_argument("--data_dataset_length", type=int, default=10000, help="")
|
||||
main_arg_parser.add_argument("--data_root", type=str, default='data', help="")
|
||||
main_arg_parser.add_argument("--data_map_root", type=str, default='res/shapes', help="")
|
||||
|
||||
@@ -106,6 +106,7 @@ def run_lightning_loop(config_obj):
|
||||
show_progress_bar=True,
|
||||
weights_save_path=logger.log_dir,
|
||||
gpus=[0] if torch.cuda.is_available() else None,
|
||||
check_val_every_n_epoch=1,
|
||||
# row_log_interval=(model.n_train_batches * 0.1), # TODO: Better Value / Setting
|
||||
# log_save_interval=(model.n_train_batches * 0.2), # TODO: Better Value / Setting
|
||||
checkpoint_callback=checkpoint_callback,
|
||||
|
Before Width: | Height: | Size: 1.7 KiB After Width: | Height: | Size: 1.6 KiB |
Before Width: | Height: | Size: 1.7 KiB After Width: | Height: | Size: 1.6 KiB |
Before Width: | Height: | Size: 1.7 KiB After Width: | Height: | Size: 1.6 KiB |
BIN
res/shapes/shapes_3.png
Normal file
After Width: | Height: | Size: 831 B |
Before Width: | Height: | Size: 1.7 KiB After Width: | Height: | Size: 1.6 KiB |
Before Width: | Height: | Size: 1.7 KiB After Width: | Height: | Size: 1.6 KiB |
Before Width: | Height: | Size: 1.7 KiB After Width: | Height: | Size: 1.6 KiB |