Debugging
This commit is contained in:
3
.idea/deployment.xml
generated
3
.idea/deployment.xml
generated
@@ -1,6 +1,6 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
<project version="4">
|
<project version="4">
|
||||||
<component name="PublishConfigData" serverName="traj_gen-AiMachine">
|
<component name="PublishConfigData" autoUpload="On explicit save action" serverName="traj_gen-AiMachine">
|
||||||
<serverData>
|
<serverData>
|
||||||
<paths name="ErLoWa-AiMachine">
|
<paths name="ErLoWa-AiMachine">
|
||||||
<serverdata>
|
<serverdata>
|
||||||
@@ -17,5 +17,6 @@
|
|||||||
</serverdata>
|
</serverdata>
|
||||||
</paths>
|
</paths>
|
||||||
</serverData>
|
</serverData>
|
||||||
|
<option name="myAutoUpload" value="ON_EXPLICIT_SAVE" />
|
||||||
</component>
|
</component>
|
||||||
</project>
|
</project>
|
2
.idea/hom_traj_gen.iml
generated
2
.idea/hom_traj_gen.iml
generated
@@ -2,7 +2,7 @@
|
|||||||
<module type="PYTHON_MODULE" version="4">
|
<module type="PYTHON_MODULE" version="4">
|
||||||
<component name="NewModuleRootManager">
|
<component name="NewModuleRootManager">
|
||||||
<content url="file://$MODULE_DIR$" />
|
<content url="file://$MODULE_DIR$" />
|
||||||
<orderEntry type="jdk" jdkName="Python 3.7 (traj_gen)" jdkType="Python SDK" />
|
<orderEntry type="jdk" jdkName="traj_gen@AiMachine" jdkType="Python SDK" />
|
||||||
<orderEntry type="sourceFolder" forTests="false" />
|
<orderEntry type="sourceFolder" forTests="false" />
|
||||||
</component>
|
</component>
|
||||||
</module>
|
</module>
|
2
.idea/misc.xml
generated
2
.idea/misc.xml
generated
@@ -3,5 +3,5 @@
|
|||||||
<component name="JavaScriptSettings">
|
<component name="JavaScriptSettings">
|
||||||
<option name="languageLevel" value="ES6" />
|
<option name="languageLevel" value="ES6" />
|
||||||
</component>
|
</component>
|
||||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.7 (traj_gen)" project-jdk-type="Python SDK" />
|
<component name="ProjectRootManager" version="2" project-jdk-name="traj_gen@AiMachine" project-jdk-type="Python SDK" />
|
||||||
</project>
|
</project>
|
@@ -1,6 +1,6 @@
|
|||||||
import shelve
|
import shelve
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Union, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from random import choice
|
from random import choice
|
||||||
@@ -17,7 +17,7 @@ class TrajDataset(Dataset):
|
|||||||
return self.map.as_array.shape
|
return self.map.as_array.shape
|
||||||
|
|
||||||
def __init__(self, *args, maps_root: Union[Path, str] = '', mapname='tate_sw',
|
def __init__(self, *args, maps_root: Union[Path, str] = '', mapname='tate_sw',
|
||||||
length=100.000, all_in_map=True, **kwargs):
|
length=100000, all_in_map=True, **kwargs):
|
||||||
super(TrajDataset, self).__init__()
|
super(TrajDataset, self).__init__()
|
||||||
self.all_in_map = all_in_map
|
self.all_in_map = all_in_map
|
||||||
self.mapname = mapname if mapname.endswith('.bmp') else f'{mapname}.bmp'
|
self.mapname = mapname if mapname.endswith('.bmp') else f'{mapname}.bmp'
|
||||||
@@ -34,11 +34,11 @@ class TrajDataset(Dataset):
|
|||||||
alternative = self.map.generate_alternative(trajectory)
|
alternative = self.map.generate_alternative(trajectory)
|
||||||
label = choice([0, 1])
|
label = choice([0, 1])
|
||||||
if self.all_in_map:
|
if self.all_in_map:
|
||||||
blank_trajectory_space = torch.zeros_like(self.map.as_array)
|
blank_trajectory_space = torch.zeros(self.map.shape)
|
||||||
blank_trajectory_space[trajectory.vertices] = 1
|
blank_trajectory_space[trajectory.vertices] = 1
|
||||||
|
|
||||||
blank_alternative_space = torch.zeros_like(self.map.as_array)
|
blank_alternative_space = torch.zeros(self.map.shape)
|
||||||
blank_alternative_space[trajectory.vertices] = 1
|
blank_alternative_space[trajectory.np_vertices] = 1
|
||||||
|
|
||||||
map_array = torch.as_tensor(self.map.as_array)
|
map_array = torch.as_tensor(self.map.as_array)
|
||||||
label = self.map.are_homotopic(trajectory, alternative)
|
label = self.map.are_homotopic(trajectory, alternative)
|
||||||
@@ -56,7 +56,7 @@ class TrajData(object):
|
|||||||
@property
|
@property
|
||||||
def map_shapes_max(self):
|
def map_shapes_max(self):
|
||||||
shapes = self.map_shapes
|
shapes = self.map_shapes
|
||||||
return map(max, zip(*shapes))
|
return list(map(max, zip(*shapes)))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self):
|
||||||
@@ -66,12 +66,12 @@ class TrajData(object):
|
|||||||
|
|
||||||
self.all_in_map = all_in_map
|
self.all_in_map = all_in_map
|
||||||
self.maps_root = Path(map_root) if map_root else Path() / 'res' / 'maps'
|
self.maps_root = Path(map_root) if map_root else Path() / 'res' / 'maps'
|
||||||
self._dataset = self._load_datasets()
|
|
||||||
self.length = length
|
self.length = length
|
||||||
|
self._dataset = self._load_datasets()
|
||||||
|
|
||||||
def _load_datasets(self):
|
def _load_datasets(self):
|
||||||
map_files = list(self.maps_root.glob('*.bmp'))
|
map_files = list(self.maps_root.glob('*.bmp'))
|
||||||
equal_split = self.length // len(map_files)
|
equal_split = int(self.length // len(map_files))
|
||||||
return ConcatDataset([TrajDataset(maps_root=self.maps_root, mapname=map_image.name, length=equal_split,
|
return ConcatDataset([TrajDataset(maps_root=self.maps_root, mapname=map_image.name, length=equal_split,
|
||||||
all_in_map=self.all_in_map) for map_image in map_files])
|
all_in_map=self.all_in_map) for map_image in map_files])
|
||||||
|
|
||||||
|
0
lib/models/generators/recurrent.py
Normal file
0
lib/models/generators/recurrent.py
Normal file
@@ -1,3 +1,6 @@
|
|||||||
|
from functools import reduce
|
||||||
|
from operator import mul
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@@ -13,7 +16,7 @@ class ConvHomDetector(LightningBaseModule):
|
|||||||
name = 'CNNHomotopyClassifier'
|
name = 'CNNHomotopyClassifier'
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
return Adam(self.parameters(), lr=self.lr)
|
return Adam(self.parameters(), lr=self.hparams.lr)
|
||||||
|
|
||||||
def validation_step(self, *args, **kwargs):
|
def validation_step(self, *args, **kwargs):
|
||||||
pass
|
pass
|
||||||
@@ -32,29 +35,36 @@ class ConvHomDetector(LightningBaseModule):
|
|||||||
|
|
||||||
def __init__(self, *params):
|
def __init__(self, *params):
|
||||||
super(ConvHomDetector, self).__init__(*params)
|
super(ConvHomDetector, self).__init__(*params)
|
||||||
|
|
||||||
# Dataset
|
# Dataset
|
||||||
self.dataset = TrajData(self.hparams.data_param.data_root)
|
self.dataset = TrajData(self.hparams.data_param.root)
|
||||||
|
|
||||||
# Additional Attributes
|
# Additional Attributes
|
||||||
self.map_shape = self.dataset.map_shapes_max
|
self.map_shape = self.dataset.map_shapes_max
|
||||||
|
|
||||||
|
# Model Paramters
|
||||||
|
self.in_shape = self.dataset.map_shapes_max
|
||||||
|
assert len(self.in_shape) == 3, f'Image or map shape has to have 3 dims, but had: {len(self.in_shape)}'
|
||||||
|
|
||||||
# NN Nodes
|
# NN Nodes
|
||||||
# ============================
|
# ============================
|
||||||
# Convolutional Map Processing
|
# Convolutional Map Processing
|
||||||
#
|
#
|
||||||
self.map_res_1 = ResidualModule(self.in_shape, ConvModule, 3,
|
self.map_conv_0 = ConvModule(self.in_shape, conv_kernel=3, conv_stride=1,
|
||||||
|
conv_padding=0, conv_filters=self.hparams.model_param.filters[0])
|
||||||
|
self.map_res_1 = ResidualModule(self.map_conv_0.shape, ConvModule, 3,
|
||||||
**dict(conv_kernel=3, conv_stride=1,
|
**dict(conv_kernel=3, conv_stride=1,
|
||||||
conv_padding=0, conv_filters=self.hparams.model_param.filters[0]))
|
conv_padding=1, conv_filters=self.hparams.model_param.filters[0]))
|
||||||
self.map_conv_1 = ConvModule(self.map_res_1.shape, conv_kernel=5, conv_stride=1,
|
self.map_conv_1 = ConvModule(self.map_res_1.shape, conv_kernel=5, conv_stride=1,
|
||||||
conv_padding=0, conv_filters=self.hparams.model_param.filters[0])
|
conv_padding=0, conv_filters=self.hparams.model_param.filters[0])
|
||||||
self.map_res_2 = ResidualModule(self.map_conv_1.shape, ConvModule, 3,
|
self.map_res_2 = ResidualModule(self.map_conv_1.shape, ConvModule, 3,
|
||||||
**dict(conv_kernel=3, conv_stride=1,
|
**dict(conv_kernel=3, conv_stride=1,
|
||||||
conv_padding=0, conv_filters=self.hparams.model_param.filters[0]))
|
conv_padding=1, conv_filters=self.hparams.model_param.filters[0]))
|
||||||
self.map_conv_2 = ConvModule(self.map_res_2.shape, conv_kernel=5, conv_stride=1,
|
self.map_conv_2 = ConvModule(self.map_res_2.shape, conv_kernel=5, conv_stride=1,
|
||||||
conv_padding=0, conv_filters=self.hparams.model_param.filters[0])
|
conv_padding=0, conv_filters=self.hparams.model_param.filters[0])
|
||||||
self.map_res_3 = ResidualModule(self.map_conv_2.shape, ConvModule, 3,
|
self.map_res_3 = ResidualModule(self.map_conv_2.shape, ConvModule, 3,
|
||||||
**dict(conv_kernel=3, conv_stride=1,
|
**dict(conv_kernel=3, conv_stride=1,
|
||||||
conv_padding=0, conv_filters=self.hparams.model_param.filters[0]))
|
conv_padding=1, conv_filters=self.hparams.model_param.filters[0]))
|
||||||
self.map_conv_3 = ConvModule(self.map_res_3.shape, conv_kernel=5, conv_stride=1,
|
self.map_conv_3 = ConvModule(self.map_res_3.shape, conv_kernel=5, conv_stride=1,
|
||||||
conv_padding=0, conv_filters=self.hparams.model_param.filters[0])
|
conv_padding=0, conv_filters=self.hparams.model_param.filters[0])
|
||||||
|
|
||||||
@@ -64,12 +74,13 @@ class ConvHomDetector(LightningBaseModule):
|
|||||||
# Classifier
|
# Classifier
|
||||||
#
|
#
|
||||||
|
|
||||||
self.linear = nn.Linear(self.flatten.shape.item(), self.hparams.model_param.classes * 10)
|
self.linear = nn.Linear(reduce(mul, self.flatten.shape), self.hparams.model_param.classes * 10)
|
||||||
self.classifier = nn.Linear(self.linear.shape, self.hparams.model_param.classes)
|
self.classifier = nn.Linear(self.hparams.model_param.classes * 10, self.hparams.model_param.classes)
|
||||||
self.softmax = nn.Softmax()
|
self.softmax = nn.Softmax()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
tensor = self.map_res_1(x)
|
tensor = self.map_conv_0(x)
|
||||||
|
tensor = self.map_res_1(tensor)
|
||||||
tensor = self.map_conv_1(tensor)
|
tensor = self.map_conv_1(tensor)
|
||||||
tensor = self.map_res_2(tensor)
|
tensor = self.map_res_2(tensor)
|
||||||
tensor = self.map_conv_2(tensor)
|
tensor = self.map_conv_2(tensor)
|
||||||
|
@@ -36,9 +36,9 @@ class ConvModule(nn.Module):
|
|||||||
self.stride = conv_stride
|
self.stride = conv_stride
|
||||||
|
|
||||||
# Modules
|
# Modules
|
||||||
self.dropout = nn.Dropout2d(dropout) if dropout else False
|
self.dropout = nn.Dropout2d(dropout) if dropout else lambda x: x
|
||||||
self.pooling = nn.MaxPool2d(pooling_size) if pooling_size else False
|
self.pooling = nn.MaxPool2d(pooling_size) if pooling_size else lambda x: x
|
||||||
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04, affine=False) if use_norm else False
|
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04, affine=False) if use_norm else lambda x: x
|
||||||
self.conv = conv_class(in_channels, conv_filters, conv_kernel, bias=use_bias,
|
self.conv = conv_class(in_channels, conv_filters, conv_kernel, bias=use_bias,
|
||||||
padding=self.padding, stride=self.stride
|
padding=self.padding, stride=self.stride
|
||||||
)
|
)
|
||||||
@@ -47,8 +47,8 @@ class ConvModule(nn.Module):
|
|||||||
x = self.norm(x) if self.norm else x
|
x = self.norm(x) if self.norm else x
|
||||||
|
|
||||||
tensor = self.conv(x)
|
tensor = self.conv(x)
|
||||||
tensor = self.dropout(tensor) if self.dropout else tensor
|
tensor = self.dropout(tensor)
|
||||||
tensor = self.pooling(tensor) if self.pooling else tensor
|
tensor = self.pooling(tensor)
|
||||||
tensor = self.activation(tensor)
|
tensor = self.activation(tensor)
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
@@ -72,23 +72,23 @@ class DeConvModule(nn.Module):
|
|||||||
self.in_shape = in_shape
|
self.in_shape = in_shape
|
||||||
self.conv_filters = conv_filters
|
self.conv_filters = conv_filters
|
||||||
|
|
||||||
self.autopad = AutoPad() if autopad else False
|
self.autopad = AutoPad() if autopad else lambda x: x
|
||||||
self.interpolation = Interpolate(scale_factor=interpolation_scale) if interpolation_scale else False
|
self.interpolation = Interpolate(scale_factor=interpolation_scale) if interpolation_scale else lambda x: x
|
||||||
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04, affine=False) if normalize else False
|
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04, affine=False) if normalize else lambda x: x
|
||||||
self.dropout = nn.Dropout2d(dropout) if dropout else False
|
self.dropout = nn.Dropout2d(dropout) if dropout else lambda x: x
|
||||||
self.de_conv = nn.ConvTranspose2d(in_channels, self.conv_filters, conv_kernel, bias=use_bias,
|
self.de_conv = nn.ConvTranspose2d(in_channels, self.conv_filters, conv_kernel, bias=use_bias,
|
||||||
padding=self.padding, stride=self.stride)
|
padding=self.padding, stride=self.stride)
|
||||||
|
|
||||||
self.activation = activation() if activation else None
|
self.activation = activation() if activation else lambda x: x
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.norm(x) if self.norm else x
|
x = self.norm(x)
|
||||||
x = self.dropout(x) if self.dropout else x
|
x = self.dropout(x)
|
||||||
x = self.autopad(x) if self.autopad else x
|
x = self.autopad(x)
|
||||||
x = self.interpolation(x) if self.interpolation else x
|
x = self.interpolation(x)
|
||||||
|
|
||||||
tensor = self.de_conv(x)
|
tensor = self.de_conv(x)
|
||||||
tensor = self.activation(tensor) if self.activation else tensor
|
tensor = self.activation(tensor)
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
@@ -100,12 +100,13 @@ class ResidualModule(nn.Module):
|
|||||||
output = self(x)
|
output = self(x)
|
||||||
return output.shape[1:]
|
return output.shape[1:]
|
||||||
|
|
||||||
def __init__(self, in_shape, module_class, n, **module_paramters):
|
def __init__(self, in_shape, module_class, n, activation=None, **module_paramters):
|
||||||
assert n >= 1
|
assert n >= 1
|
||||||
super(ResidualModule, self).__init__()
|
super(ResidualModule, self).__init__()
|
||||||
self.in_shape = in_shape
|
self.in_shape = in_shape
|
||||||
module_paramters.update(in_shape=in_shape)
|
module_paramters.update(in_shape=in_shape)
|
||||||
self.residual_block = [module_class(**module_paramters) for x in range(n)]
|
self.activation = activation() if activation else lambda x: x
|
||||||
|
self.residual_block = [module_class(**module_paramters) for _ in range(n)]
|
||||||
assert self.in_shape == self.shape, f'The in_shape: {self.in_shape} - must match the out_shape: {self.shape}.'
|
assert self.in_shape == self.shape, f'The in_shape: {self.in_shape} - must match the out_shape: {self.shape}.'
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
@@ -114,6 +115,7 @@ class ResidualModule(nn.Module):
|
|||||||
|
|
||||||
# noinspection PyUnboundLocalVariable
|
# noinspection PyUnboundLocalVariable
|
||||||
tensor = tensor + x
|
tensor = tensor + x
|
||||||
|
tensor = self.activation(tensor)
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
@@ -123,6 +123,10 @@ class LightningBaseModule(pl.LightningModule, ABC):
|
|||||||
batch_size=self.hparams.data_param.batchsize,
|
batch_size=self.hparams.data_param.batchsize,
|
||||||
num_workers=self.hparams.data_param.worker)
|
num_workers=self.hparams.data_param.worker)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def data_len(self):
|
||||||
|
return len(self.dataset.train_dataset)
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@@ -1,7 +1,5 @@
|
|||||||
import shelve
|
import shelve
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from collections import UserDict
|
|
||||||
|
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
from math import sqrt
|
from math import sqrt
|
||||||
@@ -46,6 +44,10 @@ class Map(object):
|
|||||||
return self.map_array
|
return self.map_array
|
||||||
|
|
||||||
def __init__(self, name='', array_like_map_representation=None):
|
def __init__(self, name='', array_like_map_representation=None):
|
||||||
|
if array_like_map_representation is not None:
|
||||||
|
if array_like_map_representation.ndim == 2:
|
||||||
|
array_like_map_representation = np.expand_dims(array_like_map_representation, axis=0)
|
||||||
|
assert array_like_map_representation.ndim == 3
|
||||||
self.map_array: np.ndarray = array_like_map_representation
|
self.map_array: np.ndarray = array_like_map_representation
|
||||||
self.name = name
|
self.name = name
|
||||||
pass
|
pass
|
||||||
@@ -63,22 +65,19 @@ class Map(object):
|
|||||||
# Check pixels for their color (determine if walkable)
|
# Check pixels for their color (determine if walkable)
|
||||||
for idx, value in np.ndenumerate(self.map_array):
|
for idx, value in np.ndenumerate(self.map_array):
|
||||||
if value == self.white:
|
if value == self.white:
|
||||||
try:
|
|
||||||
y, x = idx
|
|
||||||
except ValueError:
|
|
||||||
y, x, channels = idx
|
|
||||||
idx = (y, x)
|
|
||||||
# IF walkable, add node
|
# IF walkable, add node
|
||||||
graph.add_node((y, x), count=0)
|
graph.add_node(idx, count=0)
|
||||||
# Fully connect to all surrounding neighbors
|
# Fully connect to all surrounding neighbors
|
||||||
for n, (xdif, ydif, weight) in enumerate(neighbors):
|
for n, (xdif, ydif, weight) in enumerate(neighbors):
|
||||||
# Differentiate between 8 and 4 neighbors
|
# Differentiate between 8 and 4 neighbors
|
||||||
if not full_neighbors and n >= 2:
|
if not full_neighbors and n >= 2:
|
||||||
break
|
break
|
||||||
|
|
||||||
query_node = (y + ydif, x + xdif)
|
# ToDO: make this explicite and less ugly
|
||||||
|
query_node = idx[:1] + (idx[1] + ydif,) + (idx[2] + xdif,)
|
||||||
if graph.has_node(query_node):
|
if graph.has_node(query_node):
|
||||||
graph.add_edge(idx, query_node, weight=weight)
|
graph.add_edge(idx, query_node, weight=weight)
|
||||||
|
|
||||||
return graph
|
return graph
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -87,7 +86,7 @@ class Map(object):
|
|||||||
# Turn the image to single Channel Greyscale
|
# Turn the image to single Channel Greyscale
|
||||||
if image.mode != 'L':
|
if image.mode != 'L':
|
||||||
image = image.convert('L')
|
image = image.convert('L')
|
||||||
map_array = np.array(image)
|
map_array = np.expand_dims(np.array(image), axis=0)
|
||||||
return cls(name=imagepath.name, array_like_map_representation=map_array)
|
return cls(name=imagepath.name, array_like_map_representation=map_array)
|
||||||
|
|
||||||
def simple_trajectory_between(self, start, dest):
|
def simple_trajectory_between(self, start, dest):
|
||||||
|
@@ -2,8 +2,9 @@ from math import atan2
|
|||||||
from typing import List, Tuple, Union
|
from typing import List, Tuple, Union
|
||||||
|
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
from lib.objects import variables as V
|
from lib import variables as V
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
class Trajectory(object):
|
class Trajectory(object):
|
||||||
|
|
||||||
@@ -31,6 +32,10 @@ class Trajectory(object):
|
|||||||
def as_paired_list(self):
|
def as_paired_list(self):
|
||||||
return list(zip(self.vertices[:-1], self.vertices[1:]))
|
return list(zip(self.vertices[:-1], self.vertices[1:]))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def np_vertices(self):
|
||||||
|
return [np.array(vertice) for vertice in self.vertices]
|
||||||
|
|
||||||
def __init__(self, vertices: Union[List[Tuple[int]], None] = None):
|
def __init__(self, vertices: Union[List[Tuple[int]], None] = None):
|
||||||
assert any((isinstance(vertices, list), vertices is None))
|
assert any((isinstance(vertices, list), vertices is None))
|
||||||
if vertices is not None:
|
if vertices is not None:
|
||||||
|
@@ -1,9 +0,0 @@
|
|||||||
from pathlib import Path
|
|
||||||
_ROOT = Path('..')
|
|
||||||
|
|
||||||
HOMOTOPIC = 0
|
|
||||||
ALTERNATIVE = 1
|
|
||||||
|
|
||||||
_key_1 = 'eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vdWkubmVwdHVuZS5haSIsImFwaV91cmwiOiJodHRwczovL3VpLm'
|
|
||||||
_key_2 = '5lcHR1bmUuYWkiLCJhcGlfa2V5IjoiZmI0OGMzNzUtOTg1NS00Yzg2LThjMzYtMWFiYjUwMDUyMjVlIn0='
|
|
||||||
NEPTUNE_KEY = _key_1 + _key_2
|
|
@@ -5,6 +5,7 @@ from collections import defaultdict
|
|||||||
from configparser import ConfigParser
|
from configparser import ConfigParser
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
from lib.models.homotopy_classification.cnn_based import ConvHomDetector
|
||||||
from lib.utils.model_io import ModelParameters
|
from lib.utils.model_io import ModelParameters
|
||||||
|
|
||||||
|
|
||||||
@@ -24,6 +25,15 @@ class Config(ConfigParser):
|
|||||||
# for section in self.default_sections:
|
# for section in self.default_sections:
|
||||||
# self.__setattr__(section, property(lambda x :x._get_namespace_for_section(section))
|
# self.__setattr__(section, property(lambda x :x._get_namespace_for_section(section))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_class(self):
|
||||||
|
model_dict = dict(classifier_cnn=ConvHomDetector)
|
||||||
|
try:
|
||||||
|
return model_dict[self.get('model', 'type')]
|
||||||
|
except KeyError as e:
|
||||||
|
raise KeyError(rf'The model alias you provided ("{self.get("model", "type")}") does not exist! \n'
|
||||||
|
f'Try one of these:\n{list(model_dict.keys())}')
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def main(self):
|
def main(self):
|
||||||
return self._get_namespace_for_section('main')
|
return self._get_namespace_for_section('main')
|
||||||
|
@@ -16,6 +16,12 @@ class Logger(LightningLoggerBase):
|
|||||||
else:
|
else:
|
||||||
return self.neptunelogger.experiment
|
return self.neptunelogger.experiment
|
||||||
|
|
||||||
|
@property
|
||||||
|
def log_dir(self):
|
||||||
|
if self.debug:
|
||||||
|
return Path(self.outpath)
|
||||||
|
return Path(self.experiment.log_dir).parent
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self):
|
||||||
return self.config.model.type
|
return self.config.model.type
|
||||||
@@ -50,10 +56,9 @@ class Logger(LightningLoggerBase):
|
|||||||
self.debug = debug
|
self.debug = debug
|
||||||
self.config = config
|
self.config = config
|
||||||
self._testtube_kwargs = dict(save_dir=self.outpath, version=self.version, name=self.name)
|
self._testtube_kwargs = dict(save_dir=self.outpath, version=self.version, name=self.name)
|
||||||
self._neptune_kwargs = dict(offline_mode= self.debug,
|
self._neptune_kwargs = dict(offline_mode=self.debug,
|
||||||
api_key=self.config.project.neptune_key,
|
api_key=self.config.project.neptune_key,
|
||||||
project_name=self.project_name,
|
project_name=self.project_name,
|
||||||
name=self.name,
|
|
||||||
upload_source_files=list())
|
upload_source_files=list())
|
||||||
self.neptunelogger = NeptuneLogger(**self._neptune_kwargs)
|
self.neptunelogger = NeptuneLogger(**self._neptune_kwargs)
|
||||||
self.testtubelogger = TestTubeLogger(**self._testtube_kwargs)
|
self.testtubelogger = TestTubeLogger(**self._testtube_kwargs)
|
||||||
@@ -67,3 +72,6 @@ class Logger(LightningLoggerBase):
|
|||||||
self.neptunelogger.log_metrics(metrics, step_num)
|
self.neptunelogger.log_metrics(metrics, step_num)
|
||||||
self.testtubelogger.log_metrics(metrics, step_num)
|
self.testtubelogger.log_metrics(metrics, step_num)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def log_config_as_ini(self):
|
||||||
|
self.config.write(self.log_dir)
|
||||||
|
5
lib/variables.py
Normal file
5
lib/variables.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
_ROOT = Path('..')
|
||||||
|
|
||||||
|
HOMOTOPIC = 0
|
||||||
|
ALTERNATIVE = 1
|
58
main.py
58
main.py
@@ -7,10 +7,11 @@ from argparse import ArgumentParser
|
|||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
|
import torch
|
||||||
from pytorch_lightning import Trainer
|
from pytorch_lightning import Trainer
|
||||||
from torch.utils.data import DataLoader
|
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||||
|
|
||||||
from dataset.dataset import TrajDataset
|
from lib.modules.utils import LightningBaseModule
|
||||||
from lib.utils.config import Config
|
from lib.utils.config import Config
|
||||||
from lib.utils.logging import Logger
|
from lib.utils.logging import Logger
|
||||||
|
|
||||||
@@ -33,7 +34,7 @@ main_arg_parser.add_argument("--main_seed", type=int, default=69, help="")
|
|||||||
main_arg_parser.add_argument("--data_worker", type=int, default=10, help="")
|
main_arg_parser.add_argument("--data_worker", type=int, default=10, help="")
|
||||||
main_arg_parser.add_argument("--data_batchsize", type=int, default=100, help="")
|
main_arg_parser.add_argument("--data_batchsize", type=int, default=100, help="")
|
||||||
main_arg_parser.add_argument("--data_root", type=str, default='/data/', help="")
|
main_arg_parser.add_argument("--data_root", type=str, default='/data/', help="")
|
||||||
main_arg_parser.add_argument("--map_root", type=str, default='/res/maps', help="")
|
main_arg_parser.add_argument("--data_map_root", type=str, default='/res/maps', help="")
|
||||||
|
|
||||||
# Transformations
|
# Transformations
|
||||||
main_arg_parser.add_argument("--transformations_to_tensor", type=strtobool, default=False, help="")
|
main_arg_parser.add_argument("--transformations_to_tensor", type=strtobool, default=False, help="")
|
||||||
@@ -46,7 +47,7 @@ main_arg_parser.add_argument("--train_batch_size", type=int, default=512, help="
|
|||||||
main_arg_parser.add_argument("--train_lr", type=float, default=0.002, help="")
|
main_arg_parser.add_argument("--train_lr", type=float, default=0.002, help="")
|
||||||
|
|
||||||
# Model
|
# Model
|
||||||
main_arg_parser.add_argument("--model_type", type=str, default="LeNetAE", help="")
|
main_arg_parser.add_argument("--model_type", type=str, default="classifier_cnn", help="")
|
||||||
main_arg_parser.add_argument("--model_activation", type=str, default="relu", help="")
|
main_arg_parser.add_argument("--model_activation", type=str, default="relu", help="")
|
||||||
main_arg_parser.add_argument("--model_filters", type=str, default="[32, 16, 4]", help="")
|
main_arg_parser.add_argument("--model_filters", type=str, default="[32, 16, 4]", help="")
|
||||||
main_arg_parser.add_argument("--model_classes", type=int, default=2, help="")
|
main_arg_parser.add_argument("--model_classes", type=int, default=2, help="")
|
||||||
@@ -63,28 +64,47 @@ main_arg_parser.add_argument("--project_neptune_key", type=str, default=os.geten
|
|||||||
args = main_arg_parser.parse_args()
|
args = main_arg_parser.parse_args()
|
||||||
config = Config.read_namespace(args)
|
config = Config.read_namespace(args)
|
||||||
|
|
||||||
################
|
|
||||||
# TESTING ONLY #
|
|
||||||
# =============================================================================
|
|
||||||
hparams = config.model_paramters
|
|
||||||
dataset = TrajDataset('data', mapname='tate', alternatives=1000, trajectories=10000)
|
|
||||||
dataloader = DataLoader(dataset=dataset.train_dataset, shuffle=True,
|
|
||||||
batch_size=hparams.data_param.batchsize,
|
|
||||||
num_workers=hparams.data_param.worker)
|
|
||||||
|
|
||||||
# Logger
|
# Logger
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
logger = Logger(config, debug=True)
|
logger = Logger(config, debug=True)
|
||||||
|
|
||||||
# Trainer
|
# Checkpoint Callback
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
trainer = Trainer(logger=logger)
|
checkpoint_callback = ModelCheckpoint(
|
||||||
|
filepath=str(logger.log_dir / 'ckpt_weights'),
|
||||||
|
verbose=True,
|
||||||
|
period=1
|
||||||
|
)
|
||||||
|
|
||||||
# Model
|
# Model
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
model = None
|
# Init
|
||||||
|
model: LightningBaseModule = config.model_class(config.model_paramters)
|
||||||
|
model.init_weights()
|
||||||
|
|
||||||
|
# Trainer
|
||||||
|
# =============================================================================
|
||||||
|
trainer = Trainer(max_nb_epochs=config.train.epochs,
|
||||||
|
show_progress_bar=True,
|
||||||
|
weights_save_path=logger.log_dir,
|
||||||
|
gpus=[0] if torch.cuda.is_available() else None,
|
||||||
|
row_log_interval=model.data_len // 40, # TODO: Better Value / Setting
|
||||||
|
log_save_interval=model.data_len // 10, # TODO: Better Value / Setting
|
||||||
|
checkpoint_callback=checkpoint_callback,
|
||||||
|
logger=logger,
|
||||||
|
fast_dev_run=config.get('main', 'debug'),
|
||||||
|
early_stop_callback=None
|
||||||
|
)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
next(iter(dataloader))
|
# Check Cuda availability
|
||||||
pass
|
print(f'Cuda is {"" if torch.cuda.is_available() else "not"} available!!!')
|
||||||
|
# Train it
|
||||||
|
trainer.fit(model)
|
||||||
|
|
||||||
|
# Save the last state & all parameters
|
||||||
|
config.exp_path.mkdir(parents=True, exist_ok=True) # Todo: do i need this?
|
||||||
|
trainer.save_checkpoint(logger.log_dir / 'weights.ckpt')
|
||||||
|
model.save_to_disk(logger.log_dir)
|
||||||
|
|
||||||
|
# TODO: Eval here!
|
||||||
|
Reference in New Issue
Block a user