Debugging

This commit is contained in:
Si11ium
2020-02-28 19:11:53 +01:00
parent 7b3f781d19
commit 44f6589259
18 changed files with 134 additions and 78 deletions

3
.idea/deployment.xml generated
View File

@@ -1,6 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?> <?xml version="1.0" encoding="UTF-8"?>
<project version="4"> <project version="4">
<component name="PublishConfigData" serverName="traj_gen-AiMachine"> <component name="PublishConfigData" autoUpload="On explicit save action" serverName="traj_gen-AiMachine">
<serverData> <serverData>
<paths name="ErLoWa-AiMachine"> <paths name="ErLoWa-AiMachine">
<serverdata> <serverdata>
@@ -17,5 +17,6 @@
</serverdata> </serverdata>
</paths> </paths>
</serverData> </serverData>
<option name="myAutoUpload" value="ON_EXPLICIT_SAVE" />
</component> </component>
</project> </project>

View File

@@ -2,7 +2,7 @@
<module type="PYTHON_MODULE" version="4"> <module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager"> <component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" /> <content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="Python 3.7 (traj_gen)" jdkType="Python SDK" /> <orderEntry type="jdk" jdkName="traj_gen@AiMachine" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" /> <orderEntry type="sourceFolder" forTests="false" />
</component> </component>
</module> </module>

2
.idea/misc.xml generated
View File

@@ -3,5 +3,5 @@
<component name="JavaScriptSettings"> <component name="JavaScriptSettings">
<option name="languageLevel" value="ES6" /> <option name="languageLevel" value="ES6" />
</component> </component>
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.7 (traj_gen)" project-jdk-type="Python SDK" /> <component name="ProjectRootManager" version="2" project-jdk-name="traj_gen@AiMachine" project-jdk-type="Python SDK" />
</project> </project>

View File

@@ -1,6 +1,6 @@
import shelve import shelve
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union, List
import torch import torch
from random import choice from random import choice
@@ -17,7 +17,7 @@ class TrajDataset(Dataset):
return self.map.as_array.shape return self.map.as_array.shape
def __init__(self, *args, maps_root: Union[Path, str] = '', mapname='tate_sw', def __init__(self, *args, maps_root: Union[Path, str] = '', mapname='tate_sw',
length=100.000, all_in_map=True, **kwargs): length=100000, all_in_map=True, **kwargs):
super(TrajDataset, self).__init__() super(TrajDataset, self).__init__()
self.all_in_map = all_in_map self.all_in_map = all_in_map
self.mapname = mapname if mapname.endswith('.bmp') else f'{mapname}.bmp' self.mapname = mapname if mapname.endswith('.bmp') else f'{mapname}.bmp'
@@ -34,11 +34,11 @@ class TrajDataset(Dataset):
alternative = self.map.generate_alternative(trajectory) alternative = self.map.generate_alternative(trajectory)
label = choice([0, 1]) label = choice([0, 1])
if self.all_in_map: if self.all_in_map:
blank_trajectory_space = torch.zeros_like(self.map.as_array) blank_trajectory_space = torch.zeros(self.map.shape)
blank_trajectory_space[trajectory.vertices] = 1 blank_trajectory_space[trajectory.vertices] = 1
blank_alternative_space = torch.zeros_like(self.map.as_array) blank_alternative_space = torch.zeros(self.map.shape)
blank_alternative_space[trajectory.vertices] = 1 blank_alternative_space[trajectory.np_vertices] = 1
map_array = torch.as_tensor(self.map.as_array) map_array = torch.as_tensor(self.map.as_array)
label = self.map.are_homotopic(trajectory, alternative) label = self.map.are_homotopic(trajectory, alternative)
@@ -56,7 +56,7 @@ class TrajData(object):
@property @property
def map_shapes_max(self): def map_shapes_max(self):
shapes = self.map_shapes shapes = self.map_shapes
return map(max, zip(*shapes)) return list(map(max, zip(*shapes)))
@property @property
def name(self): def name(self):
@@ -66,12 +66,12 @@ class TrajData(object):
self.all_in_map = all_in_map self.all_in_map = all_in_map
self.maps_root = Path(map_root) if map_root else Path() / 'res' / 'maps' self.maps_root = Path(map_root) if map_root else Path() / 'res' / 'maps'
self._dataset = self._load_datasets()
self.length = length self.length = length
self._dataset = self._load_datasets()
def _load_datasets(self): def _load_datasets(self):
map_files = list(self.maps_root.glob('*.bmp')) map_files = list(self.maps_root.glob('*.bmp'))
equal_split = self.length // len(map_files) equal_split = int(self.length // len(map_files))
return ConcatDataset([TrajDataset(maps_root=self.maps_root, mapname=map_image.name, length=equal_split, return ConcatDataset([TrajDataset(maps_root=self.maps_root, mapname=map_image.name, length=equal_split,
all_in_map=self.all_in_map) for map_image in map_files]) all_in_map=self.all_in_map) for map_image in map_files])

View File

View File

@@ -1,3 +1,6 @@
from functools import reduce
from operator import mul
import torch import torch
from torch import nn from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
@@ -13,7 +16,7 @@ class ConvHomDetector(LightningBaseModule):
name = 'CNNHomotopyClassifier' name = 'CNNHomotopyClassifier'
def configure_optimizers(self): def configure_optimizers(self):
return Adam(self.parameters(), lr=self.lr) return Adam(self.parameters(), lr=self.hparams.lr)
def validation_step(self, *args, **kwargs): def validation_step(self, *args, **kwargs):
pass pass
@@ -32,29 +35,36 @@ class ConvHomDetector(LightningBaseModule):
def __init__(self, *params): def __init__(self, *params):
super(ConvHomDetector, self).__init__(*params) super(ConvHomDetector, self).__init__(*params)
# Dataset # Dataset
self.dataset = TrajData(self.hparams.data_param.data_root) self.dataset = TrajData(self.hparams.data_param.root)
# Additional Attributes # Additional Attributes
self.map_shape = self.dataset.map_shapes_max self.map_shape = self.dataset.map_shapes_max
# Model Paramters
self.in_shape = self.dataset.map_shapes_max
assert len(self.in_shape) == 3, f'Image or map shape has to have 3 dims, but had: {len(self.in_shape)}'
# NN Nodes # NN Nodes
# ============================ # ============================
# Convolutional Map Processing # Convolutional Map Processing
# #
self.map_res_1 = ResidualModule(self.in_shape, ConvModule, 3, self.map_conv_0 = ConvModule(self.in_shape, conv_kernel=3, conv_stride=1,
conv_padding=0, conv_filters=self.hparams.model_param.filters[0])
self.map_res_1 = ResidualModule(self.map_conv_0.shape, ConvModule, 3,
**dict(conv_kernel=3, conv_stride=1, **dict(conv_kernel=3, conv_stride=1,
conv_padding=0, conv_filters=self.hparams.model_param.filters[0])) conv_padding=1, conv_filters=self.hparams.model_param.filters[0]))
self.map_conv_1 = ConvModule(self.map_res_1.shape, conv_kernel=5, conv_stride=1, self.map_conv_1 = ConvModule(self.map_res_1.shape, conv_kernel=5, conv_stride=1,
conv_padding=0, conv_filters=self.hparams.model_param.filters[0]) conv_padding=0, conv_filters=self.hparams.model_param.filters[0])
self.map_res_2 = ResidualModule(self.map_conv_1.shape, ConvModule, 3, self.map_res_2 = ResidualModule(self.map_conv_1.shape, ConvModule, 3,
**dict(conv_kernel=3, conv_stride=1, **dict(conv_kernel=3, conv_stride=1,
conv_padding=0, conv_filters=self.hparams.model_param.filters[0])) conv_padding=1, conv_filters=self.hparams.model_param.filters[0]))
self.map_conv_2 = ConvModule(self.map_res_2.shape, conv_kernel=5, conv_stride=1, self.map_conv_2 = ConvModule(self.map_res_2.shape, conv_kernel=5, conv_stride=1,
conv_padding=0, conv_filters=self.hparams.model_param.filters[0]) conv_padding=0, conv_filters=self.hparams.model_param.filters[0])
self.map_res_3 = ResidualModule(self.map_conv_2.shape, ConvModule, 3, self.map_res_3 = ResidualModule(self.map_conv_2.shape, ConvModule, 3,
**dict(conv_kernel=3, conv_stride=1, **dict(conv_kernel=3, conv_stride=1,
conv_padding=0, conv_filters=self.hparams.model_param.filters[0])) conv_padding=1, conv_filters=self.hparams.model_param.filters[0]))
self.map_conv_3 = ConvModule(self.map_res_3.shape, conv_kernel=5, conv_stride=1, self.map_conv_3 = ConvModule(self.map_res_3.shape, conv_kernel=5, conv_stride=1,
conv_padding=0, conv_filters=self.hparams.model_param.filters[0]) conv_padding=0, conv_filters=self.hparams.model_param.filters[0])
@@ -64,12 +74,13 @@ class ConvHomDetector(LightningBaseModule):
# Classifier # Classifier
# #
self.linear = nn.Linear(self.flatten.shape.item(), self.hparams.model_param.classes * 10) self.linear = nn.Linear(reduce(mul, self.flatten.shape), self.hparams.model_param.classes * 10)
self.classifier = nn.Linear(self.linear.shape, self.hparams.model_param.classes) self.classifier = nn.Linear(self.hparams.model_param.classes * 10, self.hparams.model_param.classes)
self.softmax = nn.Softmax() self.softmax = nn.Softmax()
def forward(self, x): def forward(self, x):
tensor = self.map_res_1(x) tensor = self.map_conv_0(x)
tensor = self.map_res_1(tensor)
tensor = self.map_conv_1(tensor) tensor = self.map_conv_1(tensor)
tensor = self.map_res_2(tensor) tensor = self.map_res_2(tensor)
tensor = self.map_conv_2(tensor) tensor = self.map_conv_2(tensor)

View File

@@ -36,9 +36,9 @@ class ConvModule(nn.Module):
self.stride = conv_stride self.stride = conv_stride
# Modules # Modules
self.dropout = nn.Dropout2d(dropout) if dropout else False self.dropout = nn.Dropout2d(dropout) if dropout else lambda x: x
self.pooling = nn.MaxPool2d(pooling_size) if pooling_size else False self.pooling = nn.MaxPool2d(pooling_size) if pooling_size else lambda x: x
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04, affine=False) if use_norm else False self.norm = nn.BatchNorm2d(in_channels, eps=1e-04, affine=False) if use_norm else lambda x: x
self.conv = conv_class(in_channels, conv_filters, conv_kernel, bias=use_bias, self.conv = conv_class(in_channels, conv_filters, conv_kernel, bias=use_bias,
padding=self.padding, stride=self.stride padding=self.padding, stride=self.stride
) )
@@ -47,8 +47,8 @@ class ConvModule(nn.Module):
x = self.norm(x) if self.norm else x x = self.norm(x) if self.norm else x
tensor = self.conv(x) tensor = self.conv(x)
tensor = self.dropout(tensor) if self.dropout else tensor tensor = self.dropout(tensor)
tensor = self.pooling(tensor) if self.pooling else tensor tensor = self.pooling(tensor)
tensor = self.activation(tensor) tensor = self.activation(tensor)
return tensor return tensor
@@ -72,23 +72,23 @@ class DeConvModule(nn.Module):
self.in_shape = in_shape self.in_shape = in_shape
self.conv_filters = conv_filters self.conv_filters = conv_filters
self.autopad = AutoPad() if autopad else False self.autopad = AutoPad() if autopad else lambda x: x
self.interpolation = Interpolate(scale_factor=interpolation_scale) if interpolation_scale else False self.interpolation = Interpolate(scale_factor=interpolation_scale) if interpolation_scale else lambda x: x
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04, affine=False) if normalize else False self.norm = nn.BatchNorm2d(in_channels, eps=1e-04, affine=False) if normalize else lambda x: x
self.dropout = nn.Dropout2d(dropout) if dropout else False self.dropout = nn.Dropout2d(dropout) if dropout else lambda x: x
self.de_conv = nn.ConvTranspose2d(in_channels, self.conv_filters, conv_kernel, bias=use_bias, self.de_conv = nn.ConvTranspose2d(in_channels, self.conv_filters, conv_kernel, bias=use_bias,
padding=self.padding, stride=self.stride) padding=self.padding, stride=self.stride)
self.activation = activation() if activation else None self.activation = activation() if activation else lambda x: x
def forward(self, x): def forward(self, x):
x = self.norm(x) if self.norm else x x = self.norm(x)
x = self.dropout(x) if self.dropout else x x = self.dropout(x)
x = self.autopad(x) if self.autopad else x x = self.autopad(x)
x = self.interpolation(x) if self.interpolation else x x = self.interpolation(x)
tensor = self.de_conv(x) tensor = self.de_conv(x)
tensor = self.activation(tensor) if self.activation else tensor tensor = self.activation(tensor)
return tensor return tensor
@@ -100,12 +100,13 @@ class ResidualModule(nn.Module):
output = self(x) output = self(x)
return output.shape[1:] return output.shape[1:]
def __init__(self, in_shape, module_class, n, **module_paramters): def __init__(self, in_shape, module_class, n, activation=None, **module_paramters):
assert n >= 1 assert n >= 1
super(ResidualModule, self).__init__() super(ResidualModule, self).__init__()
self.in_shape = in_shape self.in_shape = in_shape
module_paramters.update(in_shape=in_shape) module_paramters.update(in_shape=in_shape)
self.residual_block = [module_class(**module_paramters) for x in range(n)] self.activation = activation() if activation else lambda x: x
self.residual_block = [module_class(**module_paramters) for _ in range(n)]
assert self.in_shape == self.shape, f'The in_shape: {self.in_shape} - must match the out_shape: {self.shape}.' assert self.in_shape == self.shape, f'The in_shape: {self.in_shape} - must match the out_shape: {self.shape}.'
def forward(self, x): def forward(self, x):
@@ -114,6 +115,7 @@ class ResidualModule(nn.Module):
# noinspection PyUnboundLocalVariable # noinspection PyUnboundLocalVariable
tensor = tensor + x tensor = tensor + x
tensor = self.activation(tensor)
return tensor return tensor

View File

@@ -123,6 +123,10 @@ class LightningBaseModule(pl.LightningModule, ABC):
batch_size=self.hparams.data_param.batchsize, batch_size=self.hparams.data_param.batchsize,
num_workers=self.hparams.data_param.worker) num_workers=self.hparams.data_param.worker)
@property
def data_len(self):
return len(self.dataset.train_dataset)
def configure_optimizers(self): def configure_optimizers(self):
raise NotImplementedError raise NotImplementedError

View File

@@ -1,7 +1,5 @@
import shelve import shelve
from pathlib import Path from pathlib import Path
from collections import UserDict
import copy import copy
from math import sqrt from math import sqrt
@@ -46,6 +44,10 @@ class Map(object):
return self.map_array return self.map_array
def __init__(self, name='', array_like_map_representation=None): def __init__(self, name='', array_like_map_representation=None):
if array_like_map_representation is not None:
if array_like_map_representation.ndim == 2:
array_like_map_representation = np.expand_dims(array_like_map_representation, axis=0)
assert array_like_map_representation.ndim == 3
self.map_array: np.ndarray = array_like_map_representation self.map_array: np.ndarray = array_like_map_representation
self.name = name self.name = name
pass pass
@@ -63,22 +65,19 @@ class Map(object):
# Check pixels for their color (determine if walkable) # Check pixels for their color (determine if walkable)
for idx, value in np.ndenumerate(self.map_array): for idx, value in np.ndenumerate(self.map_array):
if value == self.white: if value == self.white:
try:
y, x = idx
except ValueError:
y, x, channels = idx
idx = (y, x)
# IF walkable, add node # IF walkable, add node
graph.add_node((y, x), count=0) graph.add_node(idx, count=0)
# Fully connect to all surrounding neighbors # Fully connect to all surrounding neighbors
for n, (xdif, ydif, weight) in enumerate(neighbors): for n, (xdif, ydif, weight) in enumerate(neighbors):
# Differentiate between 8 and 4 neighbors # Differentiate between 8 and 4 neighbors
if not full_neighbors and n >= 2: if not full_neighbors and n >= 2:
break break
query_node = (y + ydif, x + xdif) # ToDO: make this explicite and less ugly
query_node = idx[:1] + (idx[1] + ydif,) + (idx[2] + xdif,)
if graph.has_node(query_node): if graph.has_node(query_node):
graph.add_edge(idx, query_node, weight=weight) graph.add_edge(idx, query_node, weight=weight)
return graph return graph
@classmethod @classmethod
@@ -87,7 +86,7 @@ class Map(object):
# Turn the image to single Channel Greyscale # Turn the image to single Channel Greyscale
if image.mode != 'L': if image.mode != 'L':
image = image.convert('L') image = image.convert('L')
map_array = np.array(image) map_array = np.expand_dims(np.array(image), axis=0)
return cls(name=imagepath.name, array_like_map_representation=map_array) return cls(name=imagepath.name, array_like_map_representation=map_array)
def simple_trajectory_between(self, start, dest): def simple_trajectory_between(self, start, dest):

View File

@@ -2,8 +2,9 @@ from math import atan2
from typing import List, Tuple, Union from typing import List, Tuple, Union
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from lib.objects import variables as V from lib import variables as V
import numpy as np
class Trajectory(object): class Trajectory(object):
@@ -31,6 +32,10 @@ class Trajectory(object):
def as_paired_list(self): def as_paired_list(self):
return list(zip(self.vertices[:-1], self.vertices[1:])) return list(zip(self.vertices[:-1], self.vertices[1:]))
@property
def np_vertices(self):
return [np.array(vertice) for vertice in self.vertices]
def __init__(self, vertices: Union[List[Tuple[int]], None] = None): def __init__(self, vertices: Union[List[Tuple[int]], None] = None):
assert any((isinstance(vertices, list), vertices is None)) assert any((isinstance(vertices, list), vertices is None))
if vertices is not None: if vertices is not None:

View File

@@ -1,9 +0,0 @@
from pathlib import Path
_ROOT = Path('..')
HOMOTOPIC = 0
ALTERNATIVE = 1
_key_1 = 'eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vdWkubmVwdHVuZS5haSIsImFwaV91cmwiOiJodHRwczovL3VpLm'
_key_2 = '5lcHR1bmUuYWkiLCJhcGlfa2V5IjoiZmI0OGMzNzUtOTg1NS00Yzg2LThjMzYtMWFiYjUwMDUyMjVlIn0='
NEPTUNE_KEY = _key_1 + _key_2

View File

@@ -5,6 +5,7 @@ from collections import defaultdict
from configparser import ConfigParser from configparser import ConfigParser
from pathlib import Path from pathlib import Path
from lib.models.homotopy_classification.cnn_based import ConvHomDetector
from lib.utils.model_io import ModelParameters from lib.utils.model_io import ModelParameters
@@ -24,6 +25,15 @@ class Config(ConfigParser):
# for section in self.default_sections: # for section in self.default_sections:
# self.__setattr__(section, property(lambda x :x._get_namespace_for_section(section)) # self.__setattr__(section, property(lambda x :x._get_namespace_for_section(section))
@property
def model_class(self):
model_dict = dict(classifier_cnn=ConvHomDetector)
try:
return model_dict[self.get('model', 'type')]
except KeyError as e:
raise KeyError(rf'The model alias you provided ("{self.get("model", "type")}") does not exist! \n'
f'Try one of these:\n{list(model_dict.keys())}')
@property @property
def main(self): def main(self):
return self._get_namespace_for_section('main') return self._get_namespace_for_section('main')

View File

@@ -16,6 +16,12 @@ class Logger(LightningLoggerBase):
else: else:
return self.neptunelogger.experiment return self.neptunelogger.experiment
@property
def log_dir(self):
if self.debug:
return Path(self.outpath)
return Path(self.experiment.log_dir).parent
@property @property
def name(self): def name(self):
return self.config.model.type return self.config.model.type
@@ -50,10 +56,9 @@ class Logger(LightningLoggerBase):
self.debug = debug self.debug = debug
self.config = config self.config = config
self._testtube_kwargs = dict(save_dir=self.outpath, version=self.version, name=self.name) self._testtube_kwargs = dict(save_dir=self.outpath, version=self.version, name=self.name)
self._neptune_kwargs = dict(offline_mode= self.debug, self._neptune_kwargs = dict(offline_mode=self.debug,
api_key=self.config.project.neptune_key, api_key=self.config.project.neptune_key,
project_name=self.project_name, project_name=self.project_name,
name=self.name,
upload_source_files=list()) upload_source_files=list())
self.neptunelogger = NeptuneLogger(**self._neptune_kwargs) self.neptunelogger = NeptuneLogger(**self._neptune_kwargs)
self.testtubelogger = TestTubeLogger(**self._testtube_kwargs) self.testtubelogger = TestTubeLogger(**self._testtube_kwargs)
@@ -67,3 +72,6 @@ class Logger(LightningLoggerBase):
self.neptunelogger.log_metrics(metrics, step_num) self.neptunelogger.log_metrics(metrics, step_num)
self.testtubelogger.log_metrics(metrics, step_num) self.testtubelogger.log_metrics(metrics, step_num)
pass pass
def log_config_as_ini(self):
self.config.write(self.log_dir)

5
lib/variables.py Normal file
View File

@@ -0,0 +1,5 @@
from pathlib import Path
_ROOT = Path('..')
HOMOTOPIC = 0
ALTERNATIVE = 1

58
main.py
View File

@@ -7,10 +7,11 @@ from argparse import ArgumentParser
import warnings import warnings
import torch
from pytorch_lightning import Trainer from pytorch_lightning import Trainer
from torch.utils.data import DataLoader from pytorch_lightning.callbacks import ModelCheckpoint
from dataset.dataset import TrajDataset from lib.modules.utils import LightningBaseModule
from lib.utils.config import Config from lib.utils.config import Config
from lib.utils.logging import Logger from lib.utils.logging import Logger
@@ -33,7 +34,7 @@ main_arg_parser.add_argument("--main_seed", type=int, default=69, help="")
main_arg_parser.add_argument("--data_worker", type=int, default=10, help="") main_arg_parser.add_argument("--data_worker", type=int, default=10, help="")
main_arg_parser.add_argument("--data_batchsize", type=int, default=100, help="") main_arg_parser.add_argument("--data_batchsize", type=int, default=100, help="")
main_arg_parser.add_argument("--data_root", type=str, default='/data/', help="") main_arg_parser.add_argument("--data_root", type=str, default='/data/', help="")
main_arg_parser.add_argument("--map_root", type=str, default='/res/maps', help="") main_arg_parser.add_argument("--data_map_root", type=str, default='/res/maps', help="")
# Transformations # Transformations
main_arg_parser.add_argument("--transformations_to_tensor", type=strtobool, default=False, help="") main_arg_parser.add_argument("--transformations_to_tensor", type=strtobool, default=False, help="")
@@ -46,7 +47,7 @@ main_arg_parser.add_argument("--train_batch_size", type=int, default=512, help="
main_arg_parser.add_argument("--train_lr", type=float, default=0.002, help="") main_arg_parser.add_argument("--train_lr", type=float, default=0.002, help="")
# Model # Model
main_arg_parser.add_argument("--model_type", type=str, default="LeNetAE", help="") main_arg_parser.add_argument("--model_type", type=str, default="classifier_cnn", help="")
main_arg_parser.add_argument("--model_activation", type=str, default="relu", help="") main_arg_parser.add_argument("--model_activation", type=str, default="relu", help="")
main_arg_parser.add_argument("--model_filters", type=str, default="[32, 16, 4]", help="") main_arg_parser.add_argument("--model_filters", type=str, default="[32, 16, 4]", help="")
main_arg_parser.add_argument("--model_classes", type=int, default=2, help="") main_arg_parser.add_argument("--model_classes", type=int, default=2, help="")
@@ -63,28 +64,47 @@ main_arg_parser.add_argument("--project_neptune_key", type=str, default=os.geten
args = main_arg_parser.parse_args() args = main_arg_parser.parse_args()
config = Config.read_namespace(args) config = Config.read_namespace(args)
################
# TESTING ONLY #
# =============================================================================
hparams = config.model_paramters
dataset = TrajDataset('data', mapname='tate', alternatives=1000, trajectories=10000)
dataloader = DataLoader(dataset=dataset.train_dataset, shuffle=True,
batch_size=hparams.data_param.batchsize,
num_workers=hparams.data_param.worker)
# Logger # Logger
# ============================================================================= # =============================================================================
logger = Logger(config, debug=True) logger = Logger(config, debug=True)
# Trainer # Checkpoint Callback
# ============================================================================= # =============================================================================
trainer = Trainer(logger=logger) checkpoint_callback = ModelCheckpoint(
filepath=str(logger.log_dir / 'ckpt_weights'),
verbose=True,
period=1
)
# Model # Model
# ============================================================================= # =============================================================================
model = None # Init
model: LightningBaseModule = config.model_class(config.model_paramters)
model.init_weights()
# Trainer
# =============================================================================
trainer = Trainer(max_nb_epochs=config.train.epochs,
show_progress_bar=True,
weights_save_path=logger.log_dir,
gpus=[0] if torch.cuda.is_available() else None,
row_log_interval=model.data_len // 40, # TODO: Better Value / Setting
log_save_interval=model.data_len // 10, # TODO: Better Value / Setting
checkpoint_callback=checkpoint_callback,
logger=logger,
fast_dev_run=config.get('main', 'debug'),
early_stop_callback=None
)
if __name__ == '__main__': if __name__ == "__main__":
next(iter(dataloader)) # Check Cuda availability
pass print(f'Cuda is {"" if torch.cuda.is_available() else "not"} available!!!')
# Train it
trainer.fit(model)
# Save the last state & all parameters
config.exp_path.mkdir(parents=True, exist_ok=True) # Todo: do i need this?
trainer.save_checkpoint(logger.log_dir / 'weights.ckpt')
model.save_to_disk(logger.log_dir)
# TODO: Eval here!