Debugging
This commit is contained in:
3
.idea/deployment.xml
generated
3
.idea/deployment.xml
generated
@@ -1,6 +1,6 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="PublishConfigData" serverName="traj_gen-AiMachine">
|
||||
<component name="PublishConfigData" autoUpload="On explicit save action" serverName="traj_gen-AiMachine">
|
||||
<serverData>
|
||||
<paths name="ErLoWa-AiMachine">
|
||||
<serverdata>
|
||||
@@ -17,5 +17,6 @@
|
||||
</serverdata>
|
||||
</paths>
|
||||
</serverData>
|
||||
<option name="myAutoUpload" value="ON_EXPLICIT_SAVE" />
|
||||
</component>
|
||||
</project>
|
2
.idea/hom_traj_gen.iml
generated
2
.idea/hom_traj_gen.iml
generated
@@ -2,7 +2,7 @@
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$" />
|
||||
<orderEntry type="jdk" jdkName="Python 3.7 (traj_gen)" jdkType="Python SDK" />
|
||||
<orderEntry type="jdk" jdkName="traj_gen@AiMachine" jdkType="Python SDK" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
</module>
|
2
.idea/misc.xml
generated
2
.idea/misc.xml
generated
@@ -3,5 +3,5 @@
|
||||
<component name="JavaScriptSettings">
|
||||
<option name="languageLevel" value="ES6" />
|
||||
</component>
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.7 (traj_gen)" project-jdk-type="Python SDK" />
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="traj_gen@AiMachine" project-jdk-type="Python SDK" />
|
||||
</project>
|
@@ -1,6 +1,6 @@
|
||||
import shelve
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from typing import Union, List
|
||||
|
||||
import torch
|
||||
from random import choice
|
||||
@@ -17,7 +17,7 @@ class TrajDataset(Dataset):
|
||||
return self.map.as_array.shape
|
||||
|
||||
def __init__(self, *args, maps_root: Union[Path, str] = '', mapname='tate_sw',
|
||||
length=100.000, all_in_map=True, **kwargs):
|
||||
length=100000, all_in_map=True, **kwargs):
|
||||
super(TrajDataset, self).__init__()
|
||||
self.all_in_map = all_in_map
|
||||
self.mapname = mapname if mapname.endswith('.bmp') else f'{mapname}.bmp'
|
||||
@@ -34,11 +34,11 @@ class TrajDataset(Dataset):
|
||||
alternative = self.map.generate_alternative(trajectory)
|
||||
label = choice([0, 1])
|
||||
if self.all_in_map:
|
||||
blank_trajectory_space = torch.zeros_like(self.map.as_array)
|
||||
blank_trajectory_space = torch.zeros(self.map.shape)
|
||||
blank_trajectory_space[trajectory.vertices] = 1
|
||||
|
||||
blank_alternative_space = torch.zeros_like(self.map.as_array)
|
||||
blank_alternative_space[trajectory.vertices] = 1
|
||||
blank_alternative_space = torch.zeros(self.map.shape)
|
||||
blank_alternative_space[trajectory.np_vertices] = 1
|
||||
|
||||
map_array = torch.as_tensor(self.map.as_array)
|
||||
label = self.map.are_homotopic(trajectory, alternative)
|
||||
@@ -56,7 +56,7 @@ class TrajData(object):
|
||||
@property
|
||||
def map_shapes_max(self):
|
||||
shapes = self.map_shapes
|
||||
return map(max, zip(*shapes))
|
||||
return list(map(max, zip(*shapes)))
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
@@ -66,12 +66,12 @@ class TrajData(object):
|
||||
|
||||
self.all_in_map = all_in_map
|
||||
self.maps_root = Path(map_root) if map_root else Path() / 'res' / 'maps'
|
||||
self._dataset = self._load_datasets()
|
||||
self.length = length
|
||||
self._dataset = self._load_datasets()
|
||||
|
||||
def _load_datasets(self):
|
||||
map_files = list(self.maps_root.glob('*.bmp'))
|
||||
equal_split = self.length // len(map_files)
|
||||
equal_split = int(self.length // len(map_files))
|
||||
return ConcatDataset([TrajDataset(maps_root=self.maps_root, mapname=map_image.name, length=equal_split,
|
||||
all_in_map=self.all_in_map) for map_image in map_files])
|
||||
|
||||
|
0
lib/models/generators/recurrent.py
Normal file
0
lib/models/generators/recurrent.py
Normal file
@@ -1,3 +1,6 @@
|
||||
from functools import reduce
|
||||
from operator import mul
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
@@ -13,7 +16,7 @@ class ConvHomDetector(LightningBaseModule):
|
||||
name = 'CNNHomotopyClassifier'
|
||||
|
||||
def configure_optimizers(self):
|
||||
return Adam(self.parameters(), lr=self.lr)
|
||||
return Adam(self.parameters(), lr=self.hparams.lr)
|
||||
|
||||
def validation_step(self, *args, **kwargs):
|
||||
pass
|
||||
@@ -32,29 +35,36 @@ class ConvHomDetector(LightningBaseModule):
|
||||
|
||||
def __init__(self, *params):
|
||||
super(ConvHomDetector, self).__init__(*params)
|
||||
|
||||
# Dataset
|
||||
self.dataset = TrajData(self.hparams.data_param.data_root)
|
||||
self.dataset = TrajData(self.hparams.data_param.root)
|
||||
|
||||
# Additional Attributes
|
||||
self.map_shape = self.dataset.map_shapes_max
|
||||
|
||||
# Model Paramters
|
||||
self.in_shape = self.dataset.map_shapes_max
|
||||
assert len(self.in_shape) == 3, f'Image or map shape has to have 3 dims, but had: {len(self.in_shape)}'
|
||||
|
||||
# NN Nodes
|
||||
# ============================
|
||||
# Convolutional Map Processing
|
||||
#
|
||||
self.map_res_1 = ResidualModule(self.in_shape, ConvModule, 3,
|
||||
self.map_conv_0 = ConvModule(self.in_shape, conv_kernel=3, conv_stride=1,
|
||||
conv_padding=0, conv_filters=self.hparams.model_param.filters[0])
|
||||
self.map_res_1 = ResidualModule(self.map_conv_0.shape, ConvModule, 3,
|
||||
**dict(conv_kernel=3, conv_stride=1,
|
||||
conv_padding=0, conv_filters=self.hparams.model_param.filters[0]))
|
||||
conv_padding=1, conv_filters=self.hparams.model_param.filters[0]))
|
||||
self.map_conv_1 = ConvModule(self.map_res_1.shape, conv_kernel=5, conv_stride=1,
|
||||
conv_padding=0, conv_filters=self.hparams.model_param.filters[0])
|
||||
self.map_res_2 = ResidualModule(self.map_conv_1.shape, ConvModule, 3,
|
||||
**dict(conv_kernel=3, conv_stride=1,
|
||||
conv_padding=0, conv_filters=self.hparams.model_param.filters[0]))
|
||||
conv_padding=1, conv_filters=self.hparams.model_param.filters[0]))
|
||||
self.map_conv_2 = ConvModule(self.map_res_2.shape, conv_kernel=5, conv_stride=1,
|
||||
conv_padding=0, conv_filters=self.hparams.model_param.filters[0])
|
||||
self.map_res_3 = ResidualModule(self.map_conv_2.shape, ConvModule, 3,
|
||||
**dict(conv_kernel=3, conv_stride=1,
|
||||
conv_padding=0, conv_filters=self.hparams.model_param.filters[0]))
|
||||
conv_padding=1, conv_filters=self.hparams.model_param.filters[0]))
|
||||
self.map_conv_3 = ConvModule(self.map_res_3.shape, conv_kernel=5, conv_stride=1,
|
||||
conv_padding=0, conv_filters=self.hparams.model_param.filters[0])
|
||||
|
||||
@@ -64,12 +74,13 @@ class ConvHomDetector(LightningBaseModule):
|
||||
# Classifier
|
||||
#
|
||||
|
||||
self.linear = nn.Linear(self.flatten.shape.item(), self.hparams.model_param.classes * 10)
|
||||
self.classifier = nn.Linear(self.linear.shape, self.hparams.model_param.classes)
|
||||
self.linear = nn.Linear(reduce(mul, self.flatten.shape), self.hparams.model_param.classes * 10)
|
||||
self.classifier = nn.Linear(self.hparams.model_param.classes * 10, self.hparams.model_param.classes)
|
||||
self.softmax = nn.Softmax()
|
||||
|
||||
def forward(self, x):
|
||||
tensor = self.map_res_1(x)
|
||||
tensor = self.map_conv_0(x)
|
||||
tensor = self.map_res_1(tensor)
|
||||
tensor = self.map_conv_1(tensor)
|
||||
tensor = self.map_res_2(tensor)
|
||||
tensor = self.map_conv_2(tensor)
|
||||
|
@@ -36,9 +36,9 @@ class ConvModule(nn.Module):
|
||||
self.stride = conv_stride
|
||||
|
||||
# Modules
|
||||
self.dropout = nn.Dropout2d(dropout) if dropout else False
|
||||
self.pooling = nn.MaxPool2d(pooling_size) if pooling_size else False
|
||||
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04, affine=False) if use_norm else False
|
||||
self.dropout = nn.Dropout2d(dropout) if dropout else lambda x: x
|
||||
self.pooling = nn.MaxPool2d(pooling_size) if pooling_size else lambda x: x
|
||||
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04, affine=False) if use_norm else lambda x: x
|
||||
self.conv = conv_class(in_channels, conv_filters, conv_kernel, bias=use_bias,
|
||||
padding=self.padding, stride=self.stride
|
||||
)
|
||||
@@ -47,8 +47,8 @@ class ConvModule(nn.Module):
|
||||
x = self.norm(x) if self.norm else x
|
||||
|
||||
tensor = self.conv(x)
|
||||
tensor = self.dropout(tensor) if self.dropout else tensor
|
||||
tensor = self.pooling(tensor) if self.pooling else tensor
|
||||
tensor = self.dropout(tensor)
|
||||
tensor = self.pooling(tensor)
|
||||
tensor = self.activation(tensor)
|
||||
return tensor
|
||||
|
||||
@@ -72,23 +72,23 @@ class DeConvModule(nn.Module):
|
||||
self.in_shape = in_shape
|
||||
self.conv_filters = conv_filters
|
||||
|
||||
self.autopad = AutoPad() if autopad else False
|
||||
self.interpolation = Interpolate(scale_factor=interpolation_scale) if interpolation_scale else False
|
||||
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04, affine=False) if normalize else False
|
||||
self.dropout = nn.Dropout2d(dropout) if dropout else False
|
||||
self.autopad = AutoPad() if autopad else lambda x: x
|
||||
self.interpolation = Interpolate(scale_factor=interpolation_scale) if interpolation_scale else lambda x: x
|
||||
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04, affine=False) if normalize else lambda x: x
|
||||
self.dropout = nn.Dropout2d(dropout) if dropout else lambda x: x
|
||||
self.de_conv = nn.ConvTranspose2d(in_channels, self.conv_filters, conv_kernel, bias=use_bias,
|
||||
padding=self.padding, stride=self.stride)
|
||||
|
||||
self.activation = activation() if activation else None
|
||||
self.activation = activation() if activation else lambda x: x
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x) if self.norm else x
|
||||
x = self.dropout(x) if self.dropout else x
|
||||
x = self.autopad(x) if self.autopad else x
|
||||
x = self.interpolation(x) if self.interpolation else x
|
||||
x = self.norm(x)
|
||||
x = self.dropout(x)
|
||||
x = self.autopad(x)
|
||||
x = self.interpolation(x)
|
||||
|
||||
tensor = self.de_conv(x)
|
||||
tensor = self.activation(tensor) if self.activation else tensor
|
||||
tensor = self.activation(tensor)
|
||||
return tensor
|
||||
|
||||
|
||||
@@ -100,12 +100,13 @@ class ResidualModule(nn.Module):
|
||||
output = self(x)
|
||||
return output.shape[1:]
|
||||
|
||||
def __init__(self, in_shape, module_class, n, **module_paramters):
|
||||
def __init__(self, in_shape, module_class, n, activation=None, **module_paramters):
|
||||
assert n >= 1
|
||||
super(ResidualModule, self).__init__()
|
||||
self.in_shape = in_shape
|
||||
module_paramters.update(in_shape=in_shape)
|
||||
self.residual_block = [module_class(**module_paramters) for x in range(n)]
|
||||
self.activation = activation() if activation else lambda x: x
|
||||
self.residual_block = [module_class(**module_paramters) for _ in range(n)]
|
||||
assert self.in_shape == self.shape, f'The in_shape: {self.in_shape} - must match the out_shape: {self.shape}.'
|
||||
|
||||
def forward(self, x):
|
||||
@@ -114,6 +115,7 @@ class ResidualModule(nn.Module):
|
||||
|
||||
# noinspection PyUnboundLocalVariable
|
||||
tensor = tensor + x
|
||||
tensor = self.activation(tensor)
|
||||
return tensor
|
||||
|
||||
|
||||
|
@@ -123,6 +123,10 @@ class LightningBaseModule(pl.LightningModule, ABC):
|
||||
batch_size=self.hparams.data_param.batchsize,
|
||||
num_workers=self.hparams.data_param.worker)
|
||||
|
||||
@property
|
||||
def data_len(self):
|
||||
return len(self.dataset.train_dataset)
|
||||
|
||||
def configure_optimizers(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
@@ -1,7 +1,5 @@
|
||||
import shelve
|
||||
from pathlib import Path
|
||||
from collections import UserDict
|
||||
|
||||
|
||||
import copy
|
||||
from math import sqrt
|
||||
@@ -46,6 +44,10 @@ class Map(object):
|
||||
return self.map_array
|
||||
|
||||
def __init__(self, name='', array_like_map_representation=None):
|
||||
if array_like_map_representation is not None:
|
||||
if array_like_map_representation.ndim == 2:
|
||||
array_like_map_representation = np.expand_dims(array_like_map_representation, axis=0)
|
||||
assert array_like_map_representation.ndim == 3
|
||||
self.map_array: np.ndarray = array_like_map_representation
|
||||
self.name = name
|
||||
pass
|
||||
@@ -63,22 +65,19 @@ class Map(object):
|
||||
# Check pixels for their color (determine if walkable)
|
||||
for idx, value in np.ndenumerate(self.map_array):
|
||||
if value == self.white:
|
||||
try:
|
||||
y, x = idx
|
||||
except ValueError:
|
||||
y, x, channels = idx
|
||||
idx = (y, x)
|
||||
# IF walkable, add node
|
||||
graph.add_node((y, x), count=0)
|
||||
graph.add_node(idx, count=0)
|
||||
# Fully connect to all surrounding neighbors
|
||||
for n, (xdif, ydif, weight) in enumerate(neighbors):
|
||||
# Differentiate between 8 and 4 neighbors
|
||||
if not full_neighbors and n >= 2:
|
||||
break
|
||||
|
||||
query_node = (y + ydif, x + xdif)
|
||||
# ToDO: make this explicite and less ugly
|
||||
query_node = idx[:1] + (idx[1] + ydif,) + (idx[2] + xdif,)
|
||||
if graph.has_node(query_node):
|
||||
graph.add_edge(idx, query_node, weight=weight)
|
||||
|
||||
return graph
|
||||
|
||||
@classmethod
|
||||
@@ -87,7 +86,7 @@ class Map(object):
|
||||
# Turn the image to single Channel Greyscale
|
||||
if image.mode != 'L':
|
||||
image = image.convert('L')
|
||||
map_array = np.array(image)
|
||||
map_array = np.expand_dims(np.array(image), axis=0)
|
||||
return cls(name=imagepath.name, array_like_map_representation=map_array)
|
||||
|
||||
def simple_trajectory_between(self, start, dest):
|
||||
|
@@ -2,8 +2,9 @@ from math import atan2
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
from matplotlib import pyplot as plt
|
||||
from lib.objects import variables as V
|
||||
from lib import variables as V
|
||||
|
||||
import numpy as np
|
||||
|
||||
class Trajectory(object):
|
||||
|
||||
@@ -31,6 +32,10 @@ class Trajectory(object):
|
||||
def as_paired_list(self):
|
||||
return list(zip(self.vertices[:-1], self.vertices[1:]))
|
||||
|
||||
@property
|
||||
def np_vertices(self):
|
||||
return [np.array(vertice) for vertice in self.vertices]
|
||||
|
||||
def __init__(self, vertices: Union[List[Tuple[int]], None] = None):
|
||||
assert any((isinstance(vertices, list), vertices is None))
|
||||
if vertices is not None:
|
||||
|
@@ -1,9 +0,0 @@
|
||||
from pathlib import Path
|
||||
_ROOT = Path('..')
|
||||
|
||||
HOMOTOPIC = 0
|
||||
ALTERNATIVE = 1
|
||||
|
||||
_key_1 = 'eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vdWkubmVwdHVuZS5haSIsImFwaV91cmwiOiJodHRwczovL3VpLm'
|
||||
_key_2 = '5lcHR1bmUuYWkiLCJhcGlfa2V5IjoiZmI0OGMzNzUtOTg1NS00Yzg2LThjMzYtMWFiYjUwMDUyMjVlIn0='
|
||||
NEPTUNE_KEY = _key_1 + _key_2
|
@@ -5,6 +5,7 @@ from collections import defaultdict
|
||||
from configparser import ConfigParser
|
||||
from pathlib import Path
|
||||
|
||||
from lib.models.homotopy_classification.cnn_based import ConvHomDetector
|
||||
from lib.utils.model_io import ModelParameters
|
||||
|
||||
|
||||
@@ -24,6 +25,15 @@ class Config(ConfigParser):
|
||||
# for section in self.default_sections:
|
||||
# self.__setattr__(section, property(lambda x :x._get_namespace_for_section(section))
|
||||
|
||||
@property
|
||||
def model_class(self):
|
||||
model_dict = dict(classifier_cnn=ConvHomDetector)
|
||||
try:
|
||||
return model_dict[self.get('model', 'type')]
|
||||
except KeyError as e:
|
||||
raise KeyError(rf'The model alias you provided ("{self.get("model", "type")}") does not exist! \n'
|
||||
f'Try one of these:\n{list(model_dict.keys())}')
|
||||
|
||||
@property
|
||||
def main(self):
|
||||
return self._get_namespace_for_section('main')
|
||||
|
@@ -16,6 +16,12 @@ class Logger(LightningLoggerBase):
|
||||
else:
|
||||
return self.neptunelogger.experiment
|
||||
|
||||
@property
|
||||
def log_dir(self):
|
||||
if self.debug:
|
||||
return Path(self.outpath)
|
||||
return Path(self.experiment.log_dir).parent
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self.config.model.type
|
||||
@@ -53,7 +59,6 @@ class Logger(LightningLoggerBase):
|
||||
self._neptune_kwargs = dict(offline_mode=self.debug,
|
||||
api_key=self.config.project.neptune_key,
|
||||
project_name=self.project_name,
|
||||
name=self.name,
|
||||
upload_source_files=list())
|
||||
self.neptunelogger = NeptuneLogger(**self._neptune_kwargs)
|
||||
self.testtubelogger = TestTubeLogger(**self._testtube_kwargs)
|
||||
@@ -67,3 +72,6 @@ class Logger(LightningLoggerBase):
|
||||
self.neptunelogger.log_metrics(metrics, step_num)
|
||||
self.testtubelogger.log_metrics(metrics, step_num)
|
||||
pass
|
||||
|
||||
def log_config_as_ini(self):
|
||||
self.config.write(self.log_dir)
|
||||
|
5
lib/variables.py
Normal file
5
lib/variables.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from pathlib import Path
|
||||
_ROOT = Path('..')
|
||||
|
||||
HOMOTOPIC = 0
|
||||
ALTERNATIVE = 1
|
58
main.py
58
main.py
@@ -7,10 +7,11 @@ from argparse import ArgumentParser
|
||||
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
from pytorch_lightning import Trainer
|
||||
from torch.utils.data import DataLoader
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
|
||||
from dataset.dataset import TrajDataset
|
||||
from lib.modules.utils import LightningBaseModule
|
||||
from lib.utils.config import Config
|
||||
from lib.utils.logging import Logger
|
||||
|
||||
@@ -33,7 +34,7 @@ main_arg_parser.add_argument("--main_seed", type=int, default=69, help="")
|
||||
main_arg_parser.add_argument("--data_worker", type=int, default=10, help="")
|
||||
main_arg_parser.add_argument("--data_batchsize", type=int, default=100, help="")
|
||||
main_arg_parser.add_argument("--data_root", type=str, default='/data/', help="")
|
||||
main_arg_parser.add_argument("--map_root", type=str, default='/res/maps', help="")
|
||||
main_arg_parser.add_argument("--data_map_root", type=str, default='/res/maps', help="")
|
||||
|
||||
# Transformations
|
||||
main_arg_parser.add_argument("--transformations_to_tensor", type=strtobool, default=False, help="")
|
||||
@@ -46,7 +47,7 @@ main_arg_parser.add_argument("--train_batch_size", type=int, default=512, help="
|
||||
main_arg_parser.add_argument("--train_lr", type=float, default=0.002, help="")
|
||||
|
||||
# Model
|
||||
main_arg_parser.add_argument("--model_type", type=str, default="LeNetAE", help="")
|
||||
main_arg_parser.add_argument("--model_type", type=str, default="classifier_cnn", help="")
|
||||
main_arg_parser.add_argument("--model_activation", type=str, default="relu", help="")
|
||||
main_arg_parser.add_argument("--model_filters", type=str, default="[32, 16, 4]", help="")
|
||||
main_arg_parser.add_argument("--model_classes", type=int, default=2, help="")
|
||||
@@ -63,28 +64,47 @@ main_arg_parser.add_argument("--project_neptune_key", type=str, default=os.geten
|
||||
args = main_arg_parser.parse_args()
|
||||
config = Config.read_namespace(args)
|
||||
|
||||
################
|
||||
# TESTING ONLY #
|
||||
# =============================================================================
|
||||
hparams = config.model_paramters
|
||||
dataset = TrajDataset('data', mapname='tate', alternatives=1000, trajectories=10000)
|
||||
dataloader = DataLoader(dataset=dataset.train_dataset, shuffle=True,
|
||||
batch_size=hparams.data_param.batchsize,
|
||||
num_workers=hparams.data_param.worker)
|
||||
|
||||
# Logger
|
||||
# =============================================================================
|
||||
logger = Logger(config, debug=True)
|
||||
|
||||
# Trainer
|
||||
# Checkpoint Callback
|
||||
# =============================================================================
|
||||
trainer = Trainer(logger=logger)
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
filepath=str(logger.log_dir / 'ckpt_weights'),
|
||||
verbose=True,
|
||||
period=1
|
||||
)
|
||||
|
||||
# Model
|
||||
# =============================================================================
|
||||
model = None
|
||||
# Init
|
||||
model: LightningBaseModule = config.model_class(config.model_paramters)
|
||||
model.init_weights()
|
||||
|
||||
# Trainer
|
||||
# =============================================================================
|
||||
trainer = Trainer(max_nb_epochs=config.train.epochs,
|
||||
show_progress_bar=True,
|
||||
weights_save_path=logger.log_dir,
|
||||
gpus=[0] if torch.cuda.is_available() else None,
|
||||
row_log_interval=model.data_len // 40, # TODO: Better Value / Setting
|
||||
log_save_interval=model.data_len // 10, # TODO: Better Value / Setting
|
||||
checkpoint_callback=checkpoint_callback,
|
||||
logger=logger,
|
||||
fast_dev_run=config.get('main', 'debug'),
|
||||
early_stop_callback=None
|
||||
)
|
||||
|
||||
if __name__ == '__main__':
|
||||
next(iter(dataloader))
|
||||
pass
|
||||
if __name__ == "__main__":
|
||||
# Check Cuda availability
|
||||
print(f'Cuda is {"" if torch.cuda.is_available() else "not"} available!!!')
|
||||
# Train it
|
||||
trainer.fit(model)
|
||||
|
||||
# Save the last state & all parameters
|
||||
config.exp_path.mkdir(parents=True, exist_ok=True) # Todo: do i need this?
|
||||
trainer.save_checkpoint(logger.log_dir / 'weights.ckpt')
|
||||
model.save_to_disk(logger.log_dir)
|
||||
|
||||
# TODO: Eval here!
|
||||
|
Reference in New Issue
Block a user