Refactoring ML_Lib
This commit is contained in:
parent
25c0e8e358
commit
27660c1458
@ -6,8 +6,8 @@ import torch
|
|||||||
from torch.utils.data import Dataset, ConcatDataset
|
from torch.utils.data import Dataset, ConcatDataset
|
||||||
|
|
||||||
from datasets.utils import DatasetMapping
|
from datasets.utils import DatasetMapping
|
||||||
from lib.preprocessing.generator import Generator
|
from ml_lib.preprocessing.generator import Generator
|
||||||
from lib.objects.map import Map
|
from ml_lib.objects.map import Map
|
||||||
|
|
||||||
|
|
||||||
class TrajPairDataset(Dataset):
|
class TrajPairDataset(Dataset):
|
||||||
|
@ -7,7 +7,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from lib.objects.map import Map
|
from ml_lib.objects.map import Map
|
||||||
|
|
||||||
|
|
||||||
class Generator:
|
class Generator:
|
||||||
|
@ -14,11 +14,11 @@ from torch.utils.data import ConcatDataset, Dataset
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from lib.objects.map import Map
|
from ml_lib.objects.map import Map
|
||||||
import lib.variables as V
|
import ml_lib.variables as V
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from lib.utils.tools import write_to_shelve
|
from ml_lib.utils.tools import write_to_shelve
|
||||||
|
|
||||||
|
|
||||||
class TrajDataShelve(VisionDataset):
|
class TrajDataShelve(VisionDataset):
|
||||||
|
@ -7,7 +7,7 @@ from sklearn.cluster import Birch, DBSCAN, KMeans
|
|||||||
from sklearn.decomposition import PCA
|
from sklearn.decomposition import PCA
|
||||||
from sklearn.manifold import TSNE
|
from sklearn.manifold import TSNE
|
||||||
|
|
||||||
import lib.variables as V
|
import ml_lib.variables as V
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
8
main.py
8
main.py
@ -11,10 +11,10 @@ import torch
|
|||||||
from pytorch_lightning import Trainer
|
from pytorch_lightning import Trainer
|
||||||
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
||||||
|
|
||||||
from lib.modules.utils import LightningBaseModule
|
from ml_lib.modules.utils import LightningBaseModule
|
||||||
from lib.utils.config import Config
|
from ml_lib.utils.config import Config
|
||||||
from lib.utils.logging import Logger
|
from ml_lib.utils.logging import Logger
|
||||||
from lib.utils.model_io import SavedLightningModels
|
from ml_lib.utils.model_io import SavedLightningModels
|
||||||
|
|
||||||
warnings.filterwarnings('ignore', category=FutureWarning)
|
warnings.filterwarnings('ignore', category=FutureWarning)
|
||||||
warnings.filterwarnings('ignore', category=UserWarning)
|
warnings.filterwarnings('ignore', category=UserWarning)
|
||||||
|
@ -2,7 +2,7 @@ from typing import Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from lib.modules.utils import AutoPad, Interpolate
|
from ml_lib.modules.utils import AutoPad, Interpolate
|
||||||
|
|
||||||
|
|
||||||
#
|
#
|
@ -3,9 +3,9 @@ from typing import List
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from lib.modules.utils import FlipTensor
|
from ml_lib.modules.utils import FlipTensor
|
||||||
from lib.objects.map import MapStorage, Map
|
from ml_lib.objects.map import MapStorage, Map
|
||||||
from lib.objects.trajectory import Trajectory
|
from ml_lib.objects.trajectory import Trajectory
|
||||||
|
|
||||||
|
|
||||||
class BinaryHomotopicLoss(nn.Module):
|
class BinaryHomotopicLoss(nn.Module):
|
@ -5,12 +5,12 @@ from collections import defaultdict
|
|||||||
from configparser import ConfigParser
|
from configparser import ConfigParser
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from lib.models.generators.cnn import CNNRouteGeneratorModel
|
from ml_lib.models.generators.cnn import CNNRouteGeneratorModel
|
||||||
from lib.models.generators.cnn_discriminated import CNNRouteGeneratorDiscriminated
|
from ml_lib.models.generators.cnn_discriminated import CNNRouteGeneratorDiscriminated
|
||||||
|
|
||||||
from lib.models.homotopy_classification.cnn_based import ConvHomDetector
|
from ml_lib.models.homotopy_classification.cnn_based import ConvHomDetector
|
||||||
from lib.utils.model_io import ModelParameters
|
from ml_lib.utils.model_io import ModelParameters
|
||||||
from lib.utils.transforms import AsArray
|
from ml_lib.utils.transforms import AsArray
|
||||||
|
|
||||||
|
|
||||||
def is_jsonable(x):
|
def is_jsonable(x):
|
@ -4,7 +4,7 @@ from pytorch_lightning.loggers.base import LightningLoggerBase
|
|||||||
from pytorch_lightning.loggers.neptune import NeptuneLogger
|
from pytorch_lightning.loggers.neptune import NeptuneLogger
|
||||||
from pytorch_lightning.loggers.test_tube import TestTubeLogger
|
from pytorch_lightning.loggers.test_tube import TestTubeLogger
|
||||||
|
|
||||||
from lib.utils.config import Config
|
from ml_lib.utils.config import Config
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
@ -1,26 +1,26 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
class Plotter(object):
|
class Plotter(object):
|
||||||
def __init__(self, root_path=''):
|
def __init__(self, root_path=''):
|
||||||
self.root_path = Path(root_path)
|
self.root_path = Path(root_path)
|
||||||
|
|
||||||
def save_current_figure(self, path, extention='.png'):
|
def save_current_figure(self, path, extention='.png'):
|
||||||
fig, _ = plt.gcf(), plt.gca()
|
fig, _ = plt.gcf(), plt.gca()
|
||||||
# Prepare save location and check img file extention
|
# Prepare save location and check img file extention
|
||||||
path = self.root_path / Path(path if str(path).endswith(extention) else f'{str(path)}{extention}')
|
path = self.root_path / Path(path if str(path).endswith(extention) else f'{str(path)}{extention}')
|
||||||
path.parent.mkdir(exist_ok=True, parents=True)
|
path.parent.mkdir(exist_ok=True, parents=True)
|
||||||
fig.savefig(path)
|
fig.savefig(path)
|
||||||
fig.clf()
|
fig.clf()
|
||||||
|
|
||||||
def show_current_figure(self):
|
def show_current_figure(self):
|
||||||
fig, _ = plt.gcf(), plt.gca()
|
fig, _ = plt.gcf(), plt.gca()
|
||||||
fig.show()
|
fig.show()
|
||||||
fig.clf()
|
fig.clf()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
output_root = Path('..') / 'output'
|
output_root = Path('..') / 'output'
|
||||||
p = Plotter(output_root)
|
p = Plotter(output_root)
|
||||||
p.save_current_figure('test.png')
|
p.save_current_figure('test.png')
|
@ -10,12 +10,12 @@ from torch.optim import Adam
|
|||||||
|
|
||||||
from datasets.mnist import MyMNIST
|
from datasets.mnist import MyMNIST
|
||||||
from datasets.trajectory_dataset import TrajData
|
from datasets.trajectory_dataset import TrajData
|
||||||
from lib.modules.blocks import ConvModule, DeConvModule
|
from ml_lib.modules.blocks import ConvModule, DeConvModule
|
||||||
from lib.modules.utils import LightningBaseModule, Flatten
|
from ml_lib.modules.utils import LightningBaseModule, Flatten
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import lib.variables as V
|
import variables as V
|
||||||
from lib.visualization.generator_eval import GeneratorVisualizer
|
from generator_eval import GeneratorVisualizer
|
||||||
|
|
||||||
|
|
||||||
class CNNRouteGeneratorModel(LightningBaseModule):
|
class CNNRouteGeneratorModel(LightningBaseModule):
|
||||||
|
@ -1,18 +1,12 @@
|
|||||||
from random import choices, seed
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
from operator import mul
|
from operator import mul
|
||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.optim import Adam
|
|
||||||
|
|
||||||
from datasets.trajectory_dataset import TrajData
|
from datasets.trajectory_dataset import TrajData
|
||||||
from lib.evaluation.classification import ROCEvaluation
|
from ml_lib.evaluation.classification import ROCEvaluation
|
||||||
from lib.models.generators.cnn import CNNRouteGeneratorModel
|
from models.generators.cnn import CNNRouteGeneratorModel
|
||||||
from lib.modules.blocks import ConvModule, ResidualModule, DeConvModule
|
|
||||||
from lib.modules.utils import LightningBaseModule, Flatten
|
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
@ -79,7 +73,7 @@ class CNNRouteGeneratorDiscriminated(CNNRouteGeneratorModel):
|
|||||||
|
|
||||||
maps, trajectories, labels, val_restul_dict = self.generate_random()
|
maps, trajectories, labels, val_restul_dict = self.generate_random()
|
||||||
|
|
||||||
from lib.visualization.generator_eval import GeneratorVisualizer
|
from generator_eval import GeneratorVisualizer
|
||||||
g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict)
|
g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict)
|
||||||
fig = g.draw()
|
fig = g.draw()
|
||||||
self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step)
|
self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step)
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from lib.modules.losses import BinaryHomotopicLoss
|
from ml_lib.modules.losses import BinaryHomotopicLoss
|
||||||
from lib.modules.utils import LightningBaseModule
|
from ml_lib.modules.utils import LightningBaseModule
|
||||||
from lib.objects.map import Map
|
from objects.map import Map
|
||||||
from lib.objects.trajectory import Trajectory
|
from objects.trajectory import Trajectory
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
@ -8,9 +8,9 @@ from torch import nn
|
|||||||
from torch.optim import Adam
|
from torch.optim import Adam
|
||||||
|
|
||||||
from datasets.trajectory_dataset import TrajData
|
from datasets.trajectory_dataset import TrajData
|
||||||
from lib.evaluation.classification import ROCEvaluation
|
from ml_lib.evaluation.classification import ROCEvaluation
|
||||||
from lib.modules.blocks import ConvModule, ResidualModule, DeConvModule
|
from ml_lib.modules.blocks import ConvModule, ResidualModule, DeConvModule
|
||||||
from lib.modules.utils import LightningBaseModule, Flatten
|
from ml_lib.modules.utils import LightningBaseModule, Flatten
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
@ -55,7 +55,7 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
|||||||
def _test_val_epoch_end(self, outputs, test=False):
|
def _test_val_epoch_end(self, outputs, test=False):
|
||||||
maps, trajectories, labels, val_restul_dict = self.generate_random()
|
maps, trajectories, labels, val_restul_dict = self.generate_random()
|
||||||
|
|
||||||
from lib.visualization.generator_eval import GeneratorVisualizer
|
from ml_lib.visualization.generator_eval import GeneratorVisualizer
|
||||||
g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict)
|
g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict)
|
||||||
fig = g.draw()
|
fig = g.draw()
|
||||||
self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step)
|
self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step)
|
||||||
@ -312,7 +312,7 @@ class CNNRouteGeneratorDiscriminated(CNNRouteGeneratorModel):
|
|||||||
|
|
||||||
maps, trajectories, labels, val_restul_dict = self.generate_random()
|
maps, trajectories, labels, val_restul_dict = self.generate_random()
|
||||||
|
|
||||||
from lib.visualization.generator_eval import GeneratorVisualizer
|
from ml_lib.visualization.generator_eval import GeneratorVisualizer
|
||||||
g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict)
|
g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict)
|
||||||
fig = g.draw()
|
fig = g.draw()
|
||||||
self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step)
|
self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step)
|
||||||
|
@ -8,9 +8,9 @@ from torch.optim import Adam
|
|||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from datasets.trajectory_dataset import TrajData
|
from datasets.trajectory_dataset import TrajData
|
||||||
from lib.evaluation.classification import ROCEvaluation
|
from ml_lib.evaluation.classification import ROCEvaluation
|
||||||
from lib.modules.utils import LightningBaseModule, Flatten
|
from ml_lib.modules.utils import LightningBaseModule, Flatten
|
||||||
from lib.modules.blocks import ConvModule, ResidualModule
|
from ml_lib.modules.blocks import ConvModule, ResidualModule
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from lib.utils.config import Config
|
from ml_lib.utils.config import Config
|
||||||
|
|
||||||
warnings.filterwarnings('ignore', category=FutureWarning)
|
warnings.filterwarnings('ignore', category=FutureWarning)
|
||||||
warnings.filterwarnings('ignore', category=UserWarning)
|
warnings.filterwarnings('ignore', category=UserWarning)
|
||||||
|
@ -11,8 +11,8 @@ from PIL import Image, ImageDraw
|
|||||||
import networkx as nx
|
import networkx as nx
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
|
|
||||||
from lib.objects.trajectory import Trajectory
|
from ml_lib.objects.trajectory import Trajectory
|
||||||
import lib.variables as V
|
import ml_lib.variables as V
|
||||||
|
|
||||||
|
|
||||||
class Map(object):
|
class Map(object):
|
||||||
|
@ -2,7 +2,7 @@ from math import atan2
|
|||||||
from typing import List, Tuple, Union
|
from typing import List, Tuple, Union
|
||||||
|
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
from lib import variables as V
|
from ml_lib import variables as V
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
@ -1,6 +1,3 @@
|
|||||||
from pathlib import Path
|
|
||||||
_ROOT = Path('..')
|
|
||||||
|
|
||||||
# Labels for classes
|
# Labels for classes
|
||||||
HOMOTOPIC = 1
|
HOMOTOPIC = 1
|
||||||
ALTERNATIVE = 0
|
ALTERNATIVE = 0
|
Loading…
x
Reference in New Issue
Block a user