CNN Model Body
This commit is contained in:
parent
1ce8d5993b
commit
2e60b19fa6
2
.idea/dictionaries/steffen.xml
generated
2
.idea/dictionaries/steffen.xml
generated
@ -2,6 +2,8 @@
|
||||
<dictionary name="steffen">
|
||||
<words>
|
||||
<w>conv</w>
|
||||
<w>homotopic</w>
|
||||
<w>hyperparamter</w>
|
||||
<w>numlayers</w>
|
||||
</words>
|
||||
</dictionary>
|
||||
|
15
.idea/webResources.xml
generated
Normal file
15
.idea/webResources.xml
generated
Normal file
@ -0,0 +1,15 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="WebResourcesPaths">
|
||||
<contentEntries>
|
||||
<entry url="file://$PROJECT_DIR$">
|
||||
<entryData>
|
||||
<resourceRoots>
|
||||
<path value="file://$PROJECT_DIR$/res" />
|
||||
<path value="file://$PROJECT_DIR$/data" />
|
||||
</resourceRoots>
|
||||
</entryData>
|
||||
</entry>
|
||||
</contentEntries>
|
||||
</component>
|
||||
</project>
|
@ -0,0 +1,21 @@
|
||||
from PIL import ImageDraw
|
||||
from PIL import Image
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def are_homotopic(map_array, trajectory, other_trajectory):
|
||||
|
||||
polyline = trajectory.vertices.copy()
|
||||
polyline.extend(reversed(other_trajectory.vertices))
|
||||
|
||||
height, width = map_array.shape
|
||||
|
||||
img = Image.new('L', (height, width), 0)
|
||||
ImageDraw.Draw(img).polygon(polyline, outline=1, fill=1)
|
||||
|
||||
a = (np.array(img) * map_array).sum()
|
||||
if a >= 1:
|
||||
return False
|
||||
else:
|
||||
return True
|
@ -1,125 +1,132 @@
|
||||
from pathlib import Path
|
||||
|
||||
import copy
|
||||
from math import sqrt
|
||||
import numpy as np
|
||||
|
||||
from PIL import Image, ImageDraw
|
||||
import networkx as nx
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
from lib.objects.trajectory import Trajectory
|
||||
|
||||
|
||||
class Map(object):
|
||||
|
||||
white = [1, 255]
|
||||
black = [0]
|
||||
|
||||
def __copy__(self):
|
||||
return copy.deepcopy(self)
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self.map_array.shape
|
||||
|
||||
@property
|
||||
def width(self):
|
||||
return self.shape[0]
|
||||
|
||||
@property
|
||||
def height(self):
|
||||
return self.shape[1]
|
||||
|
||||
@property
|
||||
def as_graph(self):
|
||||
return self._G
|
||||
|
||||
@property
|
||||
def as_array(self):
|
||||
return self.map_array
|
||||
|
||||
def __init__(self, name='', array_like_map_representation=None):
|
||||
self.map_array: np.ndarray = array_like_map_representation
|
||||
self.name = name
|
||||
pass
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
super(Map, self).__setattr__(key, value)
|
||||
if key == 'map_array' and self.map_array is not None:
|
||||
self._G = self._build_graph()
|
||||
|
||||
def _build_graph(self, full_neighbors=True):
|
||||
graph = nx.Graph()
|
||||
# Do checks in order: up - left - upperLeft - lowerLeft
|
||||
neighbors = [(0, -1, 1), (-1, 0, 1), (-1, -1, sqrt(2)), (-1, 1, sqrt(2))]
|
||||
|
||||
# Check pixels for their color (determine if walkable)
|
||||
for idx, value in np.ndenumerate(self.map_array):
|
||||
if value in self.white:
|
||||
y, x = idx
|
||||
# IF walkable, add node
|
||||
graph.add_node((y, x), count=0)
|
||||
# Fully connect to all surrounding neighbors
|
||||
for n, (xdif, ydif, weight) in enumerate(neighbors):
|
||||
# Differentiate between 8 and 4 neighbors
|
||||
if not full_neighbors and n >= 2:
|
||||
break
|
||||
|
||||
query_node = (y + ydif, x + xdif)
|
||||
if graph.has_node(query_node):
|
||||
graph.add_edge(idx, query_node, weight=weight)
|
||||
return graph
|
||||
|
||||
@classmethod
|
||||
def from_image(cls, imagepath: Path):
|
||||
with Image.open(imagepath) as image:
|
||||
return cls(name=imagepath.name, array_like_map_representation=np.array(image))
|
||||
|
||||
def simple_trajectory_between(self, start, dest):
|
||||
vertices = list(nx.shortest_path(self._G, start, dest))
|
||||
trajectory = Trajectory(vertices)
|
||||
return trajectory
|
||||
|
||||
def get_valid_position(self):
|
||||
not_found, valid_position = True, (-9999, -9999)
|
||||
while not_found:
|
||||
valid_position = int(np.random.choice(self.height, 1)), int(np.random.choice(self.width, 1))
|
||||
if self._G.has_node(valid_position):
|
||||
not_found = False
|
||||
pass
|
||||
return valid_position
|
||||
|
||||
def get_trajectory_from_vertices(self, *args):
|
||||
coords = list()
|
||||
for start, dest in zip(args[:-1], args[1:]):
|
||||
coords.extend(nx.shortest_path(self._G, start, dest))
|
||||
return Trajectory(coords)
|
||||
|
||||
def get_random_trajectory(self):
|
||||
start = self.get_valid_position()
|
||||
dest = self.get_valid_position()
|
||||
return self.simple_trajectory_between(start, dest)
|
||||
|
||||
def are_homotopic(self, trajectory, other_trajectory):
|
||||
if not all(isinstance(x, Trajectory) for x in [trajectory, other_trajectory]):
|
||||
raise TypeError
|
||||
polyline = trajectory.vertices.copy()
|
||||
polyline.extend(reversed(other_trajectory.vertices))
|
||||
|
||||
img = Image.new('L', (self.height, self.width), 0)
|
||||
ImageDraw.Draw(img).polygon(polyline, outline=1, fill=1)
|
||||
|
||||
a = (np.array(img) * self.map_array).sum()
|
||||
if a >= 1:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def draw(self):
|
||||
fig, ax = plt.gcf(), plt.gca()
|
||||
# The standard colormaps also all have reversed versions.
|
||||
# They have the same names with _r tacked on to the end.
|
||||
# https: // matplotlib.org / api / pyplot_summary.html?highlight = colormaps
|
||||
img = ax.imshow(self.as_array, cmap='Greys_r')
|
||||
return dict(img=img, fig=fig, ax=ax)
|
||||
from pathlib import Path
|
||||
|
||||
import copy
|
||||
from math import sqrt
|
||||
from random import choice
|
||||
|
||||
import numpy as np
|
||||
|
||||
from PIL import Image, ImageDraw
|
||||
import networkx as nx
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
from lib.objects.trajectory import Trajectory
|
||||
|
||||
|
||||
class Map(object):
|
||||
|
||||
# This setting is for Img mode "L" aka GreyScale Image; values: 0-255
|
||||
white = 255
|
||||
black = 0
|
||||
|
||||
def __copy__(self):
|
||||
return copy.deepcopy(self)
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self.map_array.shape
|
||||
|
||||
@property
|
||||
def width(self):
|
||||
return self.shape[0]
|
||||
|
||||
@property
|
||||
def height(self):
|
||||
return self.shape[1]
|
||||
|
||||
@property
|
||||
def as_graph(self):
|
||||
return self._G
|
||||
|
||||
@property
|
||||
def as_array(self):
|
||||
return self.map_array
|
||||
|
||||
def __init__(self, name='', array_like_map_representation=None):
|
||||
self.map_array: np.ndarray = array_like_map_representation
|
||||
self.name = name
|
||||
pass
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
super(Map, self).__setattr__(key, value)
|
||||
if key == 'map_array' and self.map_array is not None:
|
||||
self._G = self._build_graph()
|
||||
|
||||
def _build_graph(self, full_neighbors=True):
|
||||
graph = nx.Graph()
|
||||
# Do checks in order: up - left - upperLeft - lowerLeft
|
||||
neighbors = [(0, -1, 1), (-1, 0, 1), (-1, -1, sqrt(2)), (-1, 1, sqrt(2))]
|
||||
|
||||
# Check pixels for their color (determine if walkable)
|
||||
for idx, value in np.ndenumerate(self.map_array):
|
||||
if value == self.white:
|
||||
try:
|
||||
y, x = idx
|
||||
except ValueError:
|
||||
y, x, channels = idx
|
||||
idx = (y, x)
|
||||
# IF walkable, add node
|
||||
graph.add_node((y, x), count=0)
|
||||
# Fully connect to all surrounding neighbors
|
||||
for n, (xdif, ydif, weight) in enumerate(neighbors):
|
||||
# Differentiate between 8 and 4 neighbors
|
||||
if not full_neighbors and n >= 2:
|
||||
break
|
||||
|
||||
query_node = (y + ydif, x + xdif)
|
||||
if graph.has_node(query_node):
|
||||
graph.add_edge(idx, query_node, weight=weight)
|
||||
return graph
|
||||
|
||||
@classmethod
|
||||
def from_image(cls, imagepath: Path):
|
||||
with Image.open(imagepath) as image:
|
||||
# Turn the image to single Channel Greyscale
|
||||
if image.mode != 'L':
|
||||
image = image.convert('L')
|
||||
map_array = np.array(image)
|
||||
return cls(name=imagepath.name, array_like_map_representation=map_array)
|
||||
|
||||
def simple_trajectory_between(self, start, dest):
|
||||
vertices = list(nx.shortest_path(self._G, start, dest))
|
||||
trajectory = Trajectory(vertices)
|
||||
return trajectory
|
||||
|
||||
def get_valid_position(self):
|
||||
valid_position = choice(list(self._G.nodes))
|
||||
return valid_position
|
||||
|
||||
def get_trajectory_from_vertices(self, *args):
|
||||
coords = list()
|
||||
for start, dest in zip(args[:-1], args[1:]):
|
||||
coords.extend(nx.shortest_path(self._G, start, dest))
|
||||
return Trajectory(coords)
|
||||
|
||||
def get_random_trajectory(self):
|
||||
start = self.get_valid_position()
|
||||
dest = self.get_valid_position()
|
||||
return self.simple_trajectory_between(start, dest)
|
||||
|
||||
def are_homotopic(self, trajectory, other_trajectory):
|
||||
if not all(isinstance(x, Trajectory) for x in [trajectory, other_trajectory]):
|
||||
raise TypeError
|
||||
polyline = trajectory.vertices.copy()
|
||||
polyline.extend(reversed(other_trajectory.vertices))
|
||||
|
||||
img = Image.new('L', (self.height, self.width), 0)
|
||||
draw = ImageDraw.Draw(img)
|
||||
draw.polygon(polyline, outline=255, fill=255)
|
||||
|
||||
a = (np.array(img) * np.where(self.map_array == self.white, 0, 1)).sum()
|
||||
if a >= 1:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def draw(self):
|
||||
fig, ax = plt.gcf(), plt.gca()
|
||||
# The standard colormaps also all have reversed versions.
|
||||
# They have the same names with _r tacked on to the end.
|
||||
# https: // matplotlib.org / api / pyplot_summary.html?highlight = colormaps
|
||||
img = ax.imshow(self.as_array, cmap='Greys_r')
|
||||
return dict(img=img, fig=fig, ax=ax)
|
||||
|
@ -1,121 +1,124 @@
|
||||
import multiprocessing as mp
|
||||
import pickle
|
||||
import shelve
|
||||
from collections import defaultdict
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
from tqdm import trange
|
||||
|
||||
from lib.objects.map import Map
|
||||
|
||||
|
||||
class Generator:
|
||||
|
||||
possible_modes = ['one_patching']
|
||||
|
||||
def __init__(self, data_root, map_obj, binary=True):
|
||||
self.binary: bool = binary
|
||||
self.map: Map = map_obj
|
||||
|
||||
self.data_root = Path(data_root)
|
||||
|
||||
def generate_n_trajectories_m_alternatives(self, n, m, dataset_name='', **kwargs):
|
||||
trajectories_with_alternatives = list()
|
||||
for _ in trange(n, desc='Processing Trajectories'):
|
||||
trajectory = self.map.get_random_trajectory()
|
||||
alternatives, labels = self.generate_n_alternatives(trajectory, m, dataset_name=dataset_name, **kwargs)
|
||||
trajectories_with_alternatives.append(dict(trajectory=trajectory, alternatives=alternatives, labels=labels))
|
||||
return trajectories_with_alternatives
|
||||
|
||||
def generate_alternatives(self, trajectory, output: Union[mp.
|
||||
Queue, None] = None, mode='one_patching'):
|
||||
start, dest = trajectory.endpoints
|
||||
if mode == 'one_patching':
|
||||
patch = self.map.get_valid_position()
|
||||
alternative = self.map.get_trajectory_from_vertices(start, patch, dest)
|
||||
else:
|
||||
raise RuntimeError(f'mode checking went wrong...')
|
||||
|
||||
if output:
|
||||
output.put(alternative)
|
||||
return alternative
|
||||
|
||||
def generate_n_alternatives(self, trajectory, n, dataset_name: Union[str, Path] = '',
|
||||
mode='one_patching', equal_samples=True):
|
||||
assert mode in self.possible_modes, f'Parameter "mode" must be either {self.possible_modes}, but was {mode}.'
|
||||
# Define an output queue
|
||||
output = mp.Queue()
|
||||
# Setup a list of processes that we want to run
|
||||
processes = [mp.Process(target=self.generate_alternatives,
|
||||
kwargs=dict(trajectory=trajectory, output=output, mode=mode))
|
||||
for _ in range(n)]
|
||||
# Run processes
|
||||
for p in processes:
|
||||
p.start()
|
||||
# Exit the completed processes
|
||||
for p in processes:
|
||||
p.join()
|
||||
# Get process results from the output queue
|
||||
results = [output.get() for _ in processes]
|
||||
|
||||
# label per homotopic class
|
||||
homotopy_classes = defaultdict(list)
|
||||
homotopy_classes[0].append(trajectory)
|
||||
for i in range(len(results)):
|
||||
alternative = results[i]
|
||||
class_not_found, label = True, None
|
||||
# check for homotopy class
|
||||
for label in homotopy_classes.keys():
|
||||
if self.map.are_homotopic(homotopy_classes[label][0], alternative):
|
||||
homotopy_classes[label].append(alternative)
|
||||
class_not_found = False
|
||||
break
|
||||
if class_not_found:
|
||||
label = len(homotopy_classes)
|
||||
homotopy_classes[label].append(alternative)
|
||||
|
||||
# There should be as much homotopic samples as non-homotopic samples
|
||||
if equal_samples:
|
||||
homotopy_classes = self._remove_unequal(homotopy_classes)
|
||||
|
||||
# Compose lists of alternatives with labels
|
||||
alternatives, labels = list(), list()
|
||||
for key in homotopy_classes.keys():
|
||||
alternatives.extend([homotopy_classes[key]])
|
||||
labels.extend([key] * len(homotopy_classes[key]))
|
||||
|
||||
# Write to disk
|
||||
if dataset_name:
|
||||
self.write_to_disk(dataset_name, trajectory, alternatives, labels)
|
||||
|
||||
# Return
|
||||
return alternatives, labels
|
||||
|
||||
def write_to_disk(self, filepath, trajectory, alternatives, labels):
|
||||
dataset_name = filepath if filepath.endswith('.pik') else f'{filepath}.pik'
|
||||
self.data_root.mkdir(exist_ok=True, parents=True)
|
||||
with shelve.open(str(self.data_root / dataset_name), protocol=pickle.HIGHEST_PROTOCOL) as f:
|
||||
new_key = len(f)
|
||||
f[f'trajectory_{new_key}'] = dict(alternatives=alternatives,
|
||||
trajectory=trajectory,
|
||||
labels=labels)
|
||||
if 'map' not in f:
|
||||
f['map'] = dict(map=self.map, name=f'map_{self.map.name}')
|
||||
|
||||
@staticmethod
|
||||
def _remove_unequal(hom_dict):
|
||||
hom_dict = hom_dict.copy()
|
||||
|
||||
counter = len(hom_dict)
|
||||
while sum([len(hom_dict[class_id]) for class_id in range(len(hom_dict))]) > len(hom_dict[0]):
|
||||
if counter > len(hom_dict):
|
||||
counter = len(hom_dict)
|
||||
if counter in hom_dict:
|
||||
if len(hom_dict[counter]) == 0:
|
||||
del hom_dict[counter]
|
||||
else:
|
||||
del hom_dict[counter][-1]
|
||||
counter -= 1
|
||||
return hom_dict
|
||||
import multiprocessing as mp
|
||||
import pickle
|
||||
import shelve
|
||||
from collections import defaultdict
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
from tqdm import trange
|
||||
|
||||
from lib.objects.map import Map
|
||||
from lib.utils.parallel import run_n_in_parallel
|
||||
|
||||
class Generator:
|
||||
|
||||
possible_modes = ['one_patching']
|
||||
|
||||
def __init__(self, data_root, map_obj, binary=True):
|
||||
self.binary: bool = binary
|
||||
self.map: Map = map_obj
|
||||
|
||||
self.data_root = Path(data_root)
|
||||
|
||||
def generate_n_trajectories_m_alternatives(self, n, m, dataset_name='', **kwargs):
|
||||
trajectories_with_alternatives = list()
|
||||
for _ in trange(n, desc='Processing Trajectories'):
|
||||
trajectory = self.map.get_random_trajectory()
|
||||
alternatives, labels = self.generate_n_alternatives(trajectory, m, dataset_name=dataset_name, **kwargs)
|
||||
if not alternatives or labels:
|
||||
continue
|
||||
else:
|
||||
trajectories_with_alternatives.append(
|
||||
dict(trajectory=trajectory, alternatives=alternatives, labels=labels)
|
||||
)
|
||||
return trajectories_with_alternatives
|
||||
|
||||
def generate_alternatives(self, trajectory, output: Union[mp.
|
||||
Queue, None] = None, mode='one_patching'):
|
||||
start, dest = trajectory.endpoints
|
||||
if mode == 'one_patching':
|
||||
patch = self.map.get_valid_position()
|
||||
alternative = self.map.get_trajectory_from_vertices(start, patch, dest)
|
||||
else:
|
||||
raise RuntimeError(f'mode checking went wrong...')
|
||||
|
||||
if output:
|
||||
output.put(alternative)
|
||||
return alternative
|
||||
|
||||
def generate_n_alternatives(self, trajectory, n, dataset_name: Union[str, Path] = '',
|
||||
mode='one_patching', equal_samples=True, binary_check=True):
|
||||
assert mode in self.possible_modes, f'Parameter "mode" must be either {self.possible_modes}, but was {mode}.'
|
||||
# Define an output queue
|
||||
#output = mp.Queue()
|
||||
|
||||
results = run_n_in_parallel(self.generate_alternatives, n, trajectory=trajectory, mode=mode) # , output=output)
|
||||
|
||||
# Get process results from the output queue
|
||||
#results = [output.get() for _ in range(n)]
|
||||
|
||||
# label per homotopic class
|
||||
homotopy_classes = defaultdict(list)
|
||||
homotopy_classes[0].append(trajectory)
|
||||
for i in range(len(results)):
|
||||
alternative = results[i]
|
||||
class_not_found = True
|
||||
# check for homotopy class
|
||||
for label in homotopy_classes.keys():
|
||||
if self.map.are_homotopic(homotopy_classes[label][0], alternative):
|
||||
homotopy_classes[label].append(alternative)
|
||||
class_not_found = False
|
||||
break
|
||||
if class_not_found:
|
||||
label = 1 if binary_check else len(homotopy_classes)
|
||||
homotopy_classes[label].append(alternative)
|
||||
|
||||
# There should be as much homotopic samples as non-homotopic samples
|
||||
if equal_samples:
|
||||
homotopy_classes = self._remove_unequal(homotopy_classes)
|
||||
if not homotopy_classes:
|
||||
return None, None
|
||||
|
||||
# Compose lists of alternatives with labels
|
||||
alternatives, labels = list(), list()
|
||||
for key in homotopy_classes.keys():
|
||||
alternatives.extend(homotopy_classes[key])
|
||||
labels.extend([key] * len(homotopy_classes[key]))
|
||||
|
||||
# Write to disk
|
||||
if dataset_name:
|
||||
self.write_to_disk(dataset_name, trajectory, alternatives, labels)
|
||||
|
||||
# Return
|
||||
return alternatives, labels
|
||||
|
||||
def write_to_disk(self, filepath, trajectory, alternatives, labels):
|
||||
dataset_name = filepath if filepath.endswith('.pik') else f'{filepath}.pik'
|
||||
self.data_root.mkdir(exist_ok=True, parents=True)
|
||||
with shelve.open(str(self.data_root / dataset_name), protocol=pickle.HIGHEST_PROTOCOL) as f:
|
||||
new_key = len(f)
|
||||
f[f'trajectory_{new_key}'] = dict(alternatives=alternatives,
|
||||
trajectory=trajectory,
|
||||
labels=labels)
|
||||
if 'map' not in f:
|
||||
f['map'] = dict(map=self.map, name=f'map_{self.map.name}')
|
||||
|
||||
@staticmethod
|
||||
def _remove_unequal(hom_dict):
|
||||
# We argue, that there will always be more non-homotopic routes than homotopic alternatives.
|
||||
# TODO: Otherwise introduce a second condition / loop
|
||||
hom_dict = hom_dict.copy()
|
||||
if len(hom_dict[0]) <= 1:
|
||||
return None
|
||||
counter = len(hom_dict)
|
||||
while sum([len(hom_dict[class_id]) for class_id in range(1, len(hom_dict))]) > len(hom_dict[0]):
|
||||
if counter == 0:
|
||||
counter = len(hom_dict)
|
||||
if counter in hom_dict:
|
||||
if len(hom_dict[counter]) == 0:
|
||||
del hom_dict[counter]
|
||||
else:
|
||||
del hom_dict[counter][-1]
|
||||
counter -= 1
|
||||
return hom_dict
|
||||
|
@ -5,6 +5,8 @@ from collections import defaultdict
|
||||
from configparser import ConfigParser
|
||||
from pathlib import Path
|
||||
|
||||
from lib.utils.model_io import ModelParameters
|
||||
|
||||
|
||||
def is_jsonable(x):
|
||||
import json
|
||||
@ -43,6 +45,10 @@ class Config(ConfigParser):
|
||||
return self._get_namespace_for_section('project')
|
||||
###################################################
|
||||
|
||||
@property
|
||||
def model_paramters(self):
|
||||
return ModelParameters(self.model, self.train, self.data)
|
||||
|
||||
@property
|
||||
def tags(self, ):
|
||||
return [f'{key}: {val}' for key, val in self.serializable.items()]
|
||||
|
@ -50,7 +50,7 @@ class Logger(LightningLoggerBase):
|
||||
self.debug = debug
|
||||
self.config = config
|
||||
self._testtube_kwargs = dict(save_dir=self.outpath, version=self.version, name=self.name)
|
||||
self._neptune_kwargs = dict(offline_mode=not self.debug,
|
||||
self._neptune_kwargs = dict(offline_mode= self.debug,
|
||||
api_key=self.config.project.neptune_key,
|
||||
project_name=self.project_name,
|
||||
name=self.name,
|
||||
|
23
lib/utils/parallel.py
Normal file
23
lib/utils/parallel.py
Normal file
@ -0,0 +1,23 @@
|
||||
import multiprocessing as mp
|
||||
import time
|
||||
|
||||
|
||||
def run_n_in_parallel(f, n, **kwargs):
|
||||
output = mp.Queue()
|
||||
kwargs.update(output=output)
|
||||
# Setup a list of processes that we want to run
|
||||
processes = [mp.Process(target=f, kwargs=kwargs) for _ in range(n)]
|
||||
# Run processes
|
||||
results = []
|
||||
for p in processes:
|
||||
p.start()
|
||||
while len(results) != n:
|
||||
time.sleep(1)
|
||||
# Get process results from the output queue
|
||||
results.extend([output.get() for _ in processes])
|
||||
|
||||
# Exit the completed processes
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
return results
|
159
main.py
159
main.py
@ -1,71 +1,88 @@
|
||||
# Imports
|
||||
# =============================================================================
|
||||
import os
|
||||
from distutils.util import strtobool
|
||||
from pathlib import Path
|
||||
from argparse import ArgumentParser
|
||||
|
||||
import warnings
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from dataset.dataset import TrajData
|
||||
from lib.utils.config import Config
|
||||
from lib.utils.logging import Logger
|
||||
|
||||
warnings.filterwarnings('ignore', category=FutureWarning)
|
||||
warnings.filterwarnings('ignore', category=UserWarning)
|
||||
|
||||
_ROOT = Path(__file__).parent
|
||||
|
||||
# Paramter Configuration
|
||||
# =============================================================================
|
||||
# Argument Parser
|
||||
main_arg_parser = ArgumentParser(description="parser for fast-neural-style")
|
||||
|
||||
# Main Parameters
|
||||
main_arg_parser.add_argument("--main_debug", type=strtobool, default=False, help="")
|
||||
main_arg_parser.add_argument("--main_eval", type=strtobool, default=False, help="")
|
||||
main_arg_parser.add_argument("--main_seed", type=int, default=69, help="")
|
||||
|
||||
# Data Parameters
|
||||
main_arg_parser.add_argument("--data_worker", type=int, default=10, help="")
|
||||
main_arg_parser.add_argument("--data_batchsize", type=int, default=100, help="")
|
||||
main_arg_parser.add_argument("--data_root", type=str, default='../data/rpoot', help="")
|
||||
|
||||
# Transformations
|
||||
main_arg_parser.add_argument("--transformations_to_tensor", type=strtobool, default=False, help="")
|
||||
|
||||
# Transformations
|
||||
main_arg_parser.add_argument("--train_outpath", type=str, default="output", help="")
|
||||
main_arg_parser.add_argument("--train_version", type=strtobool, required=False, help="")
|
||||
main_arg_parser.add_argument("--train_epochs", type=int, default=10, help="")
|
||||
main_arg_parser.add_argument("--train_batch_size", type=int, default=512, help="")
|
||||
main_arg_parser.add_argument("--train_lr", type=float, default=0.002, help="")
|
||||
|
||||
# Model
|
||||
main_arg_parser.add_argument("--model_type", type=str, default="LeNetAE", help="")
|
||||
main_arg_parser.add_argument("--model_activation", type=str, default="relu", help="")
|
||||
main_arg_parser.add_argument("--model_filters", type=str, default="[32, 16, 4]", help="")
|
||||
main_arg_parser.add_argument("--model_use_bias", type=strtobool, default=True, help="")
|
||||
main_arg_parser.add_argument("--model_use_norm", type=strtobool, default=True, help="")
|
||||
main_arg_parser.add_argument("--model_dropout", type=float, default=0.00, help="")
|
||||
|
||||
# Project
|
||||
main_arg_parser.add_argument("--project_name", type=str, default='traj-gen', help="")
|
||||
main_arg_parser.add_argument("--project_owner", type=str, default='si11ium', help="")
|
||||
main_arg_parser.add_argument("--project_neptune_key", type=str, default=os.getenv('NEPTUNE_KEY'), help="")
|
||||
|
||||
# Parse it
|
||||
args = main_arg_parser.parse_args()
|
||||
config = Config.read_namespace(args)
|
||||
|
||||
# Trainer loading
|
||||
# =============================================================================
|
||||
trainer = Trainer(logger=Logger(config, debug=True))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
print(next(iter(train_dataloader)))
|
||||
pass
|
||||
# Imports
|
||||
# =============================================================================
|
||||
import os
|
||||
from distutils.util import strtobool
|
||||
from pathlib import Path
|
||||
from argparse import ArgumentParser
|
||||
|
||||
import warnings
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from dataset.dataset import TrajData
|
||||
from lib.utils.config import Config
|
||||
from lib.utils.logging import Logger
|
||||
|
||||
warnings.filterwarnings('ignore', category=FutureWarning)
|
||||
warnings.filterwarnings('ignore', category=UserWarning)
|
||||
|
||||
_ROOT = Path(__file__).parent
|
||||
|
||||
# Paramter Configuration
|
||||
# =============================================================================
|
||||
# Argument Parser
|
||||
main_arg_parser = ArgumentParser(description="parser for fast-neural-style")
|
||||
|
||||
# Main Parameters
|
||||
main_arg_parser.add_argument("--main_debug", type=strtobool, default=False, help="")
|
||||
main_arg_parser.add_argument("--main_eval", type=strtobool, default=False, help="")
|
||||
main_arg_parser.add_argument("--main_seed", type=int, default=69, help="")
|
||||
|
||||
# Data Parameters
|
||||
main_arg_parser.add_argument("--data_worker", type=int, default=10, help="")
|
||||
main_arg_parser.add_argument("--data_batchsize", type=int, default=100, help="")
|
||||
main_arg_parser.add_argument("--data_root", type=str, default='../data/rpoot', help="")
|
||||
|
||||
# Transformations
|
||||
main_arg_parser.add_argument("--transformations_to_tensor", type=strtobool, default=False, help="")
|
||||
|
||||
# Transformations
|
||||
main_arg_parser.add_argument("--train_outpath", type=str, default="output", help="")
|
||||
main_arg_parser.add_argument("--train_version", type=strtobool, required=False, help="")
|
||||
main_arg_parser.add_argument("--train_epochs", type=int, default=10, help="")
|
||||
main_arg_parser.add_argument("--train_batch_size", type=int, default=512, help="")
|
||||
main_arg_parser.add_argument("--train_lr", type=float, default=0.002, help="")
|
||||
|
||||
# Model
|
||||
main_arg_parser.add_argument("--model_type", type=str, default="LeNetAE", help="")
|
||||
main_arg_parser.add_argument("--model_activation", type=str, default="relu", help="")
|
||||
main_arg_parser.add_argument("--model_filters", type=str, default="[32, 16, 4]", help="")
|
||||
main_arg_parser.add_argument("--model_use_bias", type=strtobool, default=True, help="")
|
||||
main_arg_parser.add_argument("--model_use_norm", type=strtobool, default=True, help="")
|
||||
main_arg_parser.add_argument("--model_dropout", type=float, default=0.00, help="")
|
||||
|
||||
# Project
|
||||
main_arg_parser.add_argument("--project_name", type=str, default='traj-gen', help="")
|
||||
main_arg_parser.add_argument("--project_owner", type=str, default='si11ium', help="")
|
||||
main_arg_parser.add_argument("--project_neptune_key", type=str, default=os.getenv('NEPTUNE_KEY'), help="")
|
||||
|
||||
# Parse it
|
||||
args = main_arg_parser.parse_args()
|
||||
config = Config.read_namespace(args)
|
||||
|
||||
################
|
||||
# TESTING ONLY #
|
||||
# =============================================================================
|
||||
hparams = config.model_paramters
|
||||
dataset = TrajData('data', mapname='tate', alternatives=100, trajectories=10000)
|
||||
dataloader = DataLoader(dataset=dataset.train_dataset, shuffle=True,
|
||||
batch_size=hparams.data_param.batchsize,
|
||||
num_workers=hparams.data_param.worker)
|
||||
|
||||
# Logger
|
||||
# =============================================================================
|
||||
logger = Logger(config, debug=True)
|
||||
|
||||
# Trainer
|
||||
# =============================================================================
|
||||
trainer = Trainer(logger=logger)
|
||||
|
||||
# Model
|
||||
# =============================================================================
|
||||
model = None
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
next(iter(dataloader))
|
||||
pass
|
||||
|
Loading…
x
Reference in New Issue
Block a user