Refactoring ML_Lib

This commit is contained in:
Si11ium 2020-04-08 14:50:16 +02:00
parent 25c0e8e358
commit 27660c1458
32 changed files with 70 additions and 79 deletions

View File

@ -6,8 +6,8 @@ import torch
from torch.utils.data import Dataset, ConcatDataset
from datasets.utils import DatasetMapping
from lib.preprocessing.generator import Generator
from lib.objects.map import Map
from ml_lib.preprocessing.generator import Generator
from ml_lib.objects.map import Map
class TrajPairDataset(Dataset):

View File

@ -7,7 +7,7 @@ from pathlib import Path
from tqdm import tqdm
from lib.objects.map import Map
from ml_lib.objects.map import Map
class Generator:

View File

@ -14,11 +14,11 @@ from torch.utils.data import ConcatDataset, Dataset
import numpy as np
from tqdm import tqdm
from lib.objects.map import Map
import lib.variables as V
from ml_lib.objects.map import Map
import ml_lib.variables as V
from PIL import Image
from lib.utils.tools import write_to_shelve
from ml_lib.utils.tools import write_to_shelve
class TrajDataShelve(VisionDataset):

View File

@ -7,7 +7,7 @@ from sklearn.cluster import Birch, DBSCAN, KMeans
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import lib.variables as V
import ml_lib.variables as V
import numpy as np

View File

@ -11,10 +11,10 @@ import torch
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from lib.modules.utils import LightningBaseModule
from lib.utils.config import Config
from lib.utils.logging import Logger
from lib.utils.model_io import SavedLightningModels
from ml_lib.modules.utils import LightningBaseModule
from ml_lib.utils.config import Config
from ml_lib.utils.logging import Logger
from ml_lib.utils.model_io import SavedLightningModels
warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=UserWarning)

View File

@ -2,7 +2,7 @@ from typing import Union
import torch
from torch import nn
from lib.modules.utils import AutoPad, Interpolate
from ml_lib.modules.utils import AutoPad, Interpolate
#

View File

@ -3,9 +3,9 @@ from typing import List
import torch
from torch import nn
from lib.modules.utils import FlipTensor
from lib.objects.map import MapStorage, Map
from lib.objects.trajectory import Trajectory
from ml_lib.modules.utils import FlipTensor
from ml_lib.objects.map import MapStorage, Map
from ml_lib.objects.trajectory import Trajectory
class BinaryHomotopicLoss(nn.Module):

View File

@ -5,12 +5,12 @@ from collections import defaultdict
from configparser import ConfigParser
from pathlib import Path
from lib.models.generators.cnn import CNNRouteGeneratorModel
from lib.models.generators.cnn_discriminated import CNNRouteGeneratorDiscriminated
from ml_lib.models.generators.cnn import CNNRouteGeneratorModel
from ml_lib.models.generators.cnn_discriminated import CNNRouteGeneratorDiscriminated
from lib.models.homotopy_classification.cnn_based import ConvHomDetector
from lib.utils.model_io import ModelParameters
from lib.utils.transforms import AsArray
from ml_lib.models.homotopy_classification.cnn_based import ConvHomDetector
from ml_lib.utils.model_io import ModelParameters
from ml_lib.utils.transforms import AsArray
def is_jsonable(x):

View File

@ -4,7 +4,7 @@ from pytorch_lightning.loggers.base import LightningLoggerBase
from pytorch_lightning.loggers.neptune import NeptuneLogger
from pytorch_lightning.loggers.test_tube import TestTubeLogger
from lib.utils.config import Config
from ml_lib.utils.config import Config
import numpy as np

View File

@ -1,26 +1,26 @@
from pathlib import Path
import matplotlib.pyplot as plt
class Plotter(object):
def __init__(self, root_path=''):
self.root_path = Path(root_path)
def save_current_figure(self, path, extention='.png'):
fig, _ = plt.gcf(), plt.gca()
# Prepare save location and check img file extention
path = self.root_path / Path(path if str(path).endswith(extention) else f'{str(path)}{extention}')
path.parent.mkdir(exist_ok=True, parents=True)
fig.savefig(path)
fig.clf()
def show_current_figure(self):
fig, _ = plt.gcf(), plt.gca()
fig.show()
fig.clf()
if __name__ == '__main__':
output_root = Path('..') / 'output'
p = Plotter(output_root)
p.save_current_figure('test.png')
from pathlib import Path
import matplotlib.pyplot as plt
class Plotter(object):
def __init__(self, root_path=''):
self.root_path = Path(root_path)
def save_current_figure(self, path, extention='.png'):
fig, _ = plt.gcf(), plt.gca()
# Prepare save location and check img file extention
path = self.root_path / Path(path if str(path).endswith(extention) else f'{str(path)}{extention}')
path.parent.mkdir(exist_ok=True, parents=True)
fig.savefig(path)
fig.clf()
def show_current_figure(self):
fig, _ = plt.gcf(), plt.gca()
fig.show()
fig.clf()
if __name__ == '__main__':
output_root = Path('..') / 'output'
p = Plotter(output_root)
p.save_current_figure('test.png')

View File

@ -10,12 +10,12 @@ from torch.optim import Adam
from datasets.mnist import MyMNIST
from datasets.trajectory_dataset import TrajData
from lib.modules.blocks import ConvModule, DeConvModule
from lib.modules.utils import LightningBaseModule, Flatten
from ml_lib.modules.blocks import ConvModule, DeConvModule
from ml_lib.modules.utils import LightningBaseModule, Flatten
import matplotlib.pyplot as plt
import lib.variables as V
from lib.visualization.generator_eval import GeneratorVisualizer
import variables as V
from generator_eval import GeneratorVisualizer
class CNNRouteGeneratorModel(LightningBaseModule):

View File

@ -1,18 +1,12 @@
from random import choices, seed
import numpy as np
import torch
from functools import reduce
from operator import mul
from torch import nn
from torch.optim import Adam
from datasets.trajectory_dataset import TrajData
from lib.evaluation.classification import ROCEvaluation
from lib.models.generators.cnn import CNNRouteGeneratorModel
from lib.modules.blocks import ConvModule, ResidualModule, DeConvModule
from lib.modules.utils import LightningBaseModule, Flatten
from ml_lib.evaluation.classification import ROCEvaluation
from models.generators.cnn import CNNRouteGeneratorModel
import matplotlib.pyplot as plt
@ -79,7 +73,7 @@ class CNNRouteGeneratorDiscriminated(CNNRouteGeneratorModel):
maps, trajectories, labels, val_restul_dict = self.generate_random()
from lib.visualization.generator_eval import GeneratorVisualizer
from generator_eval import GeneratorVisualizer
g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict)
fig = g.draw()
self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step)

View File

@ -1,7 +1,7 @@
from lib.modules.losses import BinaryHomotopicLoss
from lib.modules.utils import LightningBaseModule
from lib.objects.map import Map
from lib.objects.trajectory import Trajectory
from ml_lib.modules.losses import BinaryHomotopicLoss
from ml_lib.modules.utils import LightningBaseModule
from objects.map import Map
from objects.trajectory import Trajectory
import torch.nn as nn

View File

@ -8,9 +8,9 @@ from torch import nn
from torch.optim import Adam
from datasets.trajectory_dataset import TrajData
from lib.evaluation.classification import ROCEvaluation
from lib.modules.blocks import ConvModule, ResidualModule, DeConvModule
from lib.modules.utils import LightningBaseModule, Flatten
from ml_lib.evaluation.classification import ROCEvaluation
from ml_lib.modules.blocks import ConvModule, ResidualModule, DeConvModule
from ml_lib.modules.utils import LightningBaseModule, Flatten
import matplotlib.pyplot as plt
@ -55,7 +55,7 @@ class CNNRouteGeneratorModel(LightningBaseModule):
def _test_val_epoch_end(self, outputs, test=False):
maps, trajectories, labels, val_restul_dict = self.generate_random()
from lib.visualization.generator_eval import GeneratorVisualizer
from ml_lib.visualization.generator_eval import GeneratorVisualizer
g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict)
fig = g.draw()
self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step)
@ -312,7 +312,7 @@ class CNNRouteGeneratorDiscriminated(CNNRouteGeneratorModel):
maps, trajectories, labels, val_restul_dict = self.generate_random()
from lib.visualization.generator_eval import GeneratorVisualizer
from ml_lib.visualization.generator_eval import GeneratorVisualizer
g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict)
fig = g.draw()
self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step)

View File

@ -8,9 +8,9 @@ from torch.optim import Adam
from torch.utils.data import DataLoader
from datasets.trajectory_dataset import TrajData
from lib.evaluation.classification import ROCEvaluation
from lib.modules.utils import LightningBaseModule, Flatten
from lib.modules.blocks import ConvModule, ResidualModule
from ml_lib.evaluation.classification import ROCEvaluation
from ml_lib.modules.utils import LightningBaseModule, Flatten
from ml_lib.modules.blocks import ConvModule, ResidualModule
import matplotlib.pyplot as plt

View File

@ -1,6 +1,6 @@
import warnings
from lib.utils.config import Config
from ml_lib.utils.config import Config
warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=UserWarning)

View File

@ -11,8 +11,8 @@ from PIL import Image, ImageDraw
import networkx as nx
from matplotlib import pyplot as plt
from lib.objects.trajectory import Trajectory
import lib.variables as V
from ml_lib.objects.trajectory import Trajectory
import ml_lib.variables as V
class Map(object):

View File

@ -2,7 +2,7 @@ from math import atan2
from typing import List, Tuple, Union
from matplotlib import pyplot as plt
from lib import variables as V
from ml_lib import variables as V
import numpy as np

View File

@ -1,6 +1,3 @@
from pathlib import Path
_ROOT = Path('..')
# Labels for classes
HOMOTOPIC = 1
ALTERNATIVE = 0