Refactoring ML_Lib

This commit is contained in:
Si11ium 2020-04-08 14:50:16 +02:00
parent 25c0e8e358
commit 27660c1458
32 changed files with 70 additions and 79 deletions

View File

@ -6,8 +6,8 @@ import torch
from torch.utils.data import Dataset, ConcatDataset from torch.utils.data import Dataset, ConcatDataset
from datasets.utils import DatasetMapping from datasets.utils import DatasetMapping
from lib.preprocessing.generator import Generator from ml_lib.preprocessing.generator import Generator
from lib.objects.map import Map from ml_lib.objects.map import Map
class TrajPairDataset(Dataset): class TrajPairDataset(Dataset):

View File

@ -7,7 +7,7 @@ from pathlib import Path
from tqdm import tqdm from tqdm import tqdm
from lib.objects.map import Map from ml_lib.objects.map import Map
class Generator: class Generator:

View File

@ -14,11 +14,11 @@ from torch.utils.data import ConcatDataset, Dataset
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm
from lib.objects.map import Map from ml_lib.objects.map import Map
import lib.variables as V import ml_lib.variables as V
from PIL import Image from PIL import Image
from lib.utils.tools import write_to_shelve from ml_lib.utils.tools import write_to_shelve
class TrajDataShelve(VisionDataset): class TrajDataShelve(VisionDataset):

View File

@ -7,7 +7,7 @@ from sklearn.cluster import Birch, DBSCAN, KMeans
from sklearn.decomposition import PCA from sklearn.decomposition import PCA
from sklearn.manifold import TSNE from sklearn.manifold import TSNE
import lib.variables as V import ml_lib.variables as V
import numpy as np import numpy as np

View File

@ -11,10 +11,10 @@ import torch
from pytorch_lightning import Trainer from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from lib.modules.utils import LightningBaseModule from ml_lib.modules.utils import LightningBaseModule
from lib.utils.config import Config from ml_lib.utils.config import Config
from lib.utils.logging import Logger from ml_lib.utils.logging import Logger
from lib.utils.model_io import SavedLightningModels from ml_lib.utils.model_io import SavedLightningModels
warnings.filterwarnings('ignore', category=FutureWarning) warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=UserWarning) warnings.filterwarnings('ignore', category=UserWarning)

View File

@ -2,7 +2,7 @@ from typing import Union
import torch import torch
from torch import nn from torch import nn
from lib.modules.utils import AutoPad, Interpolate from ml_lib.modules.utils import AutoPad, Interpolate
# #

View File

@ -3,9 +3,9 @@ from typing import List
import torch import torch
from torch import nn from torch import nn
from lib.modules.utils import FlipTensor from ml_lib.modules.utils import FlipTensor
from lib.objects.map import MapStorage, Map from ml_lib.objects.map import MapStorage, Map
from lib.objects.trajectory import Trajectory from ml_lib.objects.trajectory import Trajectory
class BinaryHomotopicLoss(nn.Module): class BinaryHomotopicLoss(nn.Module):

View File

@ -5,12 +5,12 @@ from collections import defaultdict
from configparser import ConfigParser from configparser import ConfigParser
from pathlib import Path from pathlib import Path
from lib.models.generators.cnn import CNNRouteGeneratorModel from ml_lib.models.generators.cnn import CNNRouteGeneratorModel
from lib.models.generators.cnn_discriminated import CNNRouteGeneratorDiscriminated from ml_lib.models.generators.cnn_discriminated import CNNRouteGeneratorDiscriminated
from lib.models.homotopy_classification.cnn_based import ConvHomDetector from ml_lib.models.homotopy_classification.cnn_based import ConvHomDetector
from lib.utils.model_io import ModelParameters from ml_lib.utils.model_io import ModelParameters
from lib.utils.transforms import AsArray from ml_lib.utils.transforms import AsArray
def is_jsonable(x): def is_jsonable(x):

View File

@ -4,7 +4,7 @@ from pytorch_lightning.loggers.base import LightningLoggerBase
from pytorch_lightning.loggers.neptune import NeptuneLogger from pytorch_lightning.loggers.neptune import NeptuneLogger
from pytorch_lightning.loggers.test_tube import TestTubeLogger from pytorch_lightning.loggers.test_tube import TestTubeLogger
from lib.utils.config import Config from ml_lib.utils.config import Config
import numpy as np import numpy as np

View File

@ -1,26 +1,26 @@
from pathlib import Path from pathlib import Path
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
class Plotter(object): class Plotter(object):
def __init__(self, root_path=''): def __init__(self, root_path=''):
self.root_path = Path(root_path) self.root_path = Path(root_path)
def save_current_figure(self, path, extention='.png'): def save_current_figure(self, path, extention='.png'):
fig, _ = plt.gcf(), plt.gca() fig, _ = plt.gcf(), plt.gca()
# Prepare save location and check img file extention # Prepare save location and check img file extention
path = self.root_path / Path(path if str(path).endswith(extention) else f'{str(path)}{extention}') path = self.root_path / Path(path if str(path).endswith(extention) else f'{str(path)}{extention}')
path.parent.mkdir(exist_ok=True, parents=True) path.parent.mkdir(exist_ok=True, parents=True)
fig.savefig(path) fig.savefig(path)
fig.clf() fig.clf()
def show_current_figure(self): def show_current_figure(self):
fig, _ = plt.gcf(), plt.gca() fig, _ = plt.gcf(), plt.gca()
fig.show() fig.show()
fig.clf() fig.clf()
if __name__ == '__main__': if __name__ == '__main__':
output_root = Path('..') / 'output' output_root = Path('..') / 'output'
p = Plotter(output_root) p = Plotter(output_root)
p.save_current_figure('test.png') p.save_current_figure('test.png')

View File

@ -10,12 +10,12 @@ from torch.optim import Adam
from datasets.mnist import MyMNIST from datasets.mnist import MyMNIST
from datasets.trajectory_dataset import TrajData from datasets.trajectory_dataset import TrajData
from lib.modules.blocks import ConvModule, DeConvModule from ml_lib.modules.blocks import ConvModule, DeConvModule
from lib.modules.utils import LightningBaseModule, Flatten from ml_lib.modules.utils import LightningBaseModule, Flatten
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import lib.variables as V import variables as V
from lib.visualization.generator_eval import GeneratorVisualizer from generator_eval import GeneratorVisualizer
class CNNRouteGeneratorModel(LightningBaseModule): class CNNRouteGeneratorModel(LightningBaseModule):

View File

@ -1,18 +1,12 @@
from random import choices, seed
import numpy as np
import torch import torch
from functools import reduce from functools import reduce
from operator import mul from operator import mul
from torch import nn from torch import nn
from torch.optim import Adam
from datasets.trajectory_dataset import TrajData from datasets.trajectory_dataset import TrajData
from lib.evaluation.classification import ROCEvaluation from ml_lib.evaluation.classification import ROCEvaluation
from lib.models.generators.cnn import CNNRouteGeneratorModel from models.generators.cnn import CNNRouteGeneratorModel
from lib.modules.blocks import ConvModule, ResidualModule, DeConvModule
from lib.modules.utils import LightningBaseModule, Flatten
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@ -79,7 +73,7 @@ class CNNRouteGeneratorDiscriminated(CNNRouteGeneratorModel):
maps, trajectories, labels, val_restul_dict = self.generate_random() maps, trajectories, labels, val_restul_dict = self.generate_random()
from lib.visualization.generator_eval import GeneratorVisualizer from generator_eval import GeneratorVisualizer
g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict) g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict)
fig = g.draw() fig = g.draw()
self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step) self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step)

View File

@ -1,7 +1,7 @@
from lib.modules.losses import BinaryHomotopicLoss from ml_lib.modules.losses import BinaryHomotopicLoss
from lib.modules.utils import LightningBaseModule from ml_lib.modules.utils import LightningBaseModule
from lib.objects.map import Map from objects.map import Map
from lib.objects.trajectory import Trajectory from objects.trajectory import Trajectory
import torch.nn as nn import torch.nn as nn

View File

@ -8,9 +8,9 @@ from torch import nn
from torch.optim import Adam from torch.optim import Adam
from datasets.trajectory_dataset import TrajData from datasets.trajectory_dataset import TrajData
from lib.evaluation.classification import ROCEvaluation from ml_lib.evaluation.classification import ROCEvaluation
from lib.modules.blocks import ConvModule, ResidualModule, DeConvModule from ml_lib.modules.blocks import ConvModule, ResidualModule, DeConvModule
from lib.modules.utils import LightningBaseModule, Flatten from ml_lib.modules.utils import LightningBaseModule, Flatten
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@ -55,7 +55,7 @@ class CNNRouteGeneratorModel(LightningBaseModule):
def _test_val_epoch_end(self, outputs, test=False): def _test_val_epoch_end(self, outputs, test=False):
maps, trajectories, labels, val_restul_dict = self.generate_random() maps, trajectories, labels, val_restul_dict = self.generate_random()
from lib.visualization.generator_eval import GeneratorVisualizer from ml_lib.visualization.generator_eval import GeneratorVisualizer
g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict) g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict)
fig = g.draw() fig = g.draw()
self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step) self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step)
@ -312,7 +312,7 @@ class CNNRouteGeneratorDiscriminated(CNNRouteGeneratorModel):
maps, trajectories, labels, val_restul_dict = self.generate_random() maps, trajectories, labels, val_restul_dict = self.generate_random()
from lib.visualization.generator_eval import GeneratorVisualizer from ml_lib.visualization.generator_eval import GeneratorVisualizer
g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict) g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict)
fig = g.draw() fig = g.draw()
self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step) self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step)

View File

@ -8,9 +8,9 @@ from torch.optim import Adam
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from datasets.trajectory_dataset import TrajData from datasets.trajectory_dataset import TrajData
from lib.evaluation.classification import ROCEvaluation from ml_lib.evaluation.classification import ROCEvaluation
from lib.modules.utils import LightningBaseModule, Flatten from ml_lib.modules.utils import LightningBaseModule, Flatten
from lib.modules.blocks import ConvModule, ResidualModule from ml_lib.modules.blocks import ConvModule, ResidualModule
import matplotlib.pyplot as plt import matplotlib.pyplot as plt

View File

@ -1,6 +1,6 @@
import warnings import warnings
from lib.utils.config import Config from ml_lib.utils.config import Config
warnings.filterwarnings('ignore', category=FutureWarning) warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=UserWarning) warnings.filterwarnings('ignore', category=UserWarning)

View File

@ -11,8 +11,8 @@ from PIL import Image, ImageDraw
import networkx as nx import networkx as nx
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from lib.objects.trajectory import Trajectory from ml_lib.objects.trajectory import Trajectory
import lib.variables as V import ml_lib.variables as V
class Map(object): class Map(object):

View File

@ -2,7 +2,7 @@ from math import atan2
from typing import List, Tuple, Union from typing import List, Tuple, Union
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from lib import variables as V from ml_lib import variables as V
import numpy as np import numpy as np

View File

@ -1,6 +1,3 @@
from pathlib import Path
_ROOT = Path('..')
# Labels for classes # Labels for classes
HOMOTOPIC = 1 HOMOTOPIC = 1
ALTERNATIVE = 0 ALTERNATIVE = 0