diff --git a/datasets/paired_dataset.py b/datasets/paired_dataset.py index 3d152f7..0f10574 100644 --- a/datasets/paired_dataset.py +++ b/datasets/paired_dataset.py @@ -6,8 +6,8 @@ import torch from torch.utils.data import Dataset, ConcatDataset from datasets.utils import DatasetMapping -from lib.preprocessing.generator import Generator -from lib.objects.map import Map +from ml_lib.preprocessing.generator import Generator +from ml_lib.objects.map import Map class TrajPairDataset(Dataset): diff --git a/datasets/preprocessing/generator.py b/datasets/preprocessing/generator.py index 9831a70..e179d27 100644 --- a/datasets/preprocessing/generator.py +++ b/datasets/preprocessing/generator.py @@ -7,7 +7,7 @@ from pathlib import Path from tqdm import tqdm -from lib.objects.map import Map +from ml_lib.objects.map import Map class Generator: diff --git a/datasets/trajectory_dataset.py b/datasets/trajectory_dataset.py index c76afae..3c6a28a 100644 --- a/datasets/trajectory_dataset.py +++ b/datasets/trajectory_dataset.py @@ -14,11 +14,11 @@ from torch.utils.data import ConcatDataset, Dataset import numpy as np from tqdm import tqdm -from lib.objects.map import Map -import lib.variables as V +from ml_lib.objects.map import Map +import ml_lib.variables as V from PIL import Image -from lib.utils.tools import write_to_shelve +from ml_lib.utils.tools import write_to_shelve class TrajDataShelve(VisionDataset): diff --git a/generator_eval.py b/generator_eval.py index b895816..42a7c46 100644 --- a/generator_eval.py +++ b/generator_eval.py @@ -7,7 +7,7 @@ from sklearn.cluster import Birch, DBSCAN, KMeans from sklearn.decomposition import PCA from sklearn.manifold import TSNE -import lib.variables as V +import ml_lib.variables as V import numpy as np diff --git a/main.py b/main.py index d2a0db9..d3e4d8a 100644 --- a/main.py +++ b/main.py @@ -11,10 +11,10 @@ import torch from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping -from lib.modules.utils import LightningBaseModule -from lib.utils.config import Config -from lib.utils.logging import Logger -from lib.utils.model_io import SavedLightningModels +from ml_lib.modules.utils import LightningBaseModule +from ml_lib.utils.config import Config +from ml_lib.utils.logging import Logger +from ml_lib.utils.model_io import SavedLightningModels warnings.filterwarnings('ignore', category=FutureWarning) warnings.filterwarnings('ignore', category=UserWarning) diff --git a/lib/__init__.py b/ml_lib/__init__.py similarity index 100% rename from lib/__init__.py rename to ml_lib/__init__.py diff --git a/lib/evaluation/__init__.py b/ml_lib/evaluation/__init__.py similarity index 100% rename from lib/evaluation/__init__.py rename to ml_lib/evaluation/__init__.py diff --git a/lib/evaluation/classification.py b/ml_lib/evaluation/classification.py similarity index 100% rename from lib/evaluation/classification.py rename to ml_lib/evaluation/classification.py diff --git a/lib/examples/__init__.py b/ml_lib/examples/__init__.py similarity index 100% rename from lib/examples/__init__.py rename to ml_lib/examples/__init__.py diff --git a/lib/modules/__init__.py b/ml_lib/modules/__init__.py similarity index 100% rename from lib/modules/__init__.py rename to ml_lib/modules/__init__.py diff --git a/lib/modules/blocks.py b/ml_lib/modules/blocks.py similarity index 98% rename from lib/modules/blocks.py rename to ml_lib/modules/blocks.py index 2713906..316f528 100644 --- a/lib/modules/blocks.py +++ b/ml_lib/modules/blocks.py @@ -2,7 +2,7 @@ from typing import Union import torch from torch import nn -from lib.modules.utils import AutoPad, Interpolate +from ml_lib.modules.utils import AutoPad, Interpolate # diff --git a/lib/modules/losses.py b/ml_lib/modules/losses.py similarity index 78% rename from lib/modules/losses.py rename to ml_lib/modules/losses.py index b47a81e..e6e94d1 100644 --- a/lib/modules/losses.py +++ b/ml_lib/modules/losses.py @@ -3,9 +3,9 @@ from typing import List import torch from torch import nn -from lib.modules.utils import FlipTensor -from lib.objects.map import MapStorage, Map -from lib.objects.trajectory import Trajectory +from ml_lib.modules.utils import FlipTensor +from ml_lib.objects.map import MapStorage, Map +from ml_lib.objects.trajectory import Trajectory class BinaryHomotopicLoss(nn.Module): diff --git a/lib/modules/model_parts.py b/ml_lib/modules/model_parts.py similarity index 100% rename from lib/modules/model_parts.py rename to ml_lib/modules/model_parts.py diff --git a/lib/modules/utils.py b/ml_lib/modules/utils.py similarity index 100% rename from lib/modules/utils.py rename to ml_lib/modules/utils.py diff --git a/lib/utils/__init__.py b/ml_lib/utils/__init__.py similarity index 100% rename from lib/utils/__init__.py rename to ml_lib/utils/__init__.py diff --git a/lib/utils/config.py b/ml_lib/utils/config.py similarity index 92% rename from lib/utils/config.py rename to ml_lib/utils/config.py index 673bcfb..cd5ca57 100644 --- a/lib/utils/config.py +++ b/ml_lib/utils/config.py @@ -5,12 +5,12 @@ from collections import defaultdict from configparser import ConfigParser from pathlib import Path -from lib.models.generators.cnn import CNNRouteGeneratorModel -from lib.models.generators.cnn_discriminated import CNNRouteGeneratorDiscriminated +from ml_lib.models.generators.cnn import CNNRouteGeneratorModel +from ml_lib.models.generators.cnn_discriminated import CNNRouteGeneratorDiscriminated -from lib.models.homotopy_classification.cnn_based import ConvHomDetector -from lib.utils.model_io import ModelParameters -from lib.utils.transforms import AsArray +from ml_lib.models.homotopy_classification.cnn_based import ConvHomDetector +from ml_lib.utils.model_io import ModelParameters +from ml_lib.utils.transforms import AsArray def is_jsonable(x): diff --git a/lib/utils/logging.py b/ml_lib/utils/logging.py similarity index 99% rename from lib/utils/logging.py rename to ml_lib/utils/logging.py index 6a16c2b..74f233d 100644 --- a/lib/utils/logging.py +++ b/ml_lib/utils/logging.py @@ -4,7 +4,7 @@ from pytorch_lightning.loggers.base import LightningLoggerBase from pytorch_lightning.loggers.neptune import NeptuneLogger from pytorch_lightning.loggers.test_tube import TestTubeLogger -from lib.utils.config import Config +from ml_lib.utils.config import Config import numpy as np diff --git a/lib/utils/model_io.py b/ml_lib/utils/model_io.py similarity index 100% rename from lib/utils/model_io.py rename to ml_lib/utils/model_io.py diff --git a/lib/utils/parallel.py b/ml_lib/utils/parallel.py similarity index 100% rename from lib/utils/parallel.py rename to ml_lib/utils/parallel.py diff --git a/lib/utils/tools.py b/ml_lib/utils/tools.py similarity index 100% rename from lib/utils/tools.py rename to ml_lib/utils/tools.py diff --git a/lib/utils/transforms.py b/ml_lib/utils/transforms.py similarity index 100% rename from lib/utils/transforms.py rename to ml_lib/utils/transforms.py diff --git a/lib/visualization/__init__.py b/ml_lib/visualization/__init__.py similarity index 100% rename from lib/visualization/__init__.py rename to ml_lib/visualization/__init__.py diff --git a/lib/visualization/tools.py b/ml_lib/visualization/tools.py similarity index 96% rename from lib/visualization/tools.py rename to ml_lib/visualization/tools.py index 8ba6651..1337421 100644 --- a/lib/visualization/tools.py +++ b/ml_lib/visualization/tools.py @@ -1,26 +1,26 @@ -from pathlib import Path -import matplotlib.pyplot as plt - - -class Plotter(object): - def __init__(self, root_path=''): - self.root_path = Path(root_path) - - def save_current_figure(self, path, extention='.png'): - fig, _ = plt.gcf(), plt.gca() - # Prepare save location and check img file extention - path = self.root_path / Path(path if str(path).endswith(extention) else f'{str(path)}{extention}') - path.parent.mkdir(exist_ok=True, parents=True) - fig.savefig(path) - fig.clf() - - def show_current_figure(self): - fig, _ = plt.gcf(), plt.gca() - fig.show() - fig.clf() - - -if __name__ == '__main__': - output_root = Path('..') / 'output' - p = Plotter(output_root) - p.save_current_figure('test.png') +from pathlib import Path +import matplotlib.pyplot as plt + + +class Plotter(object): + def __init__(self, root_path=''): + self.root_path = Path(root_path) + + def save_current_figure(self, path, extention='.png'): + fig, _ = plt.gcf(), plt.gca() + # Prepare save location and check img file extention + path = self.root_path / Path(path if str(path).endswith(extention) else f'{str(path)}{extention}') + path.parent.mkdir(exist_ok=True, parents=True) + fig.savefig(path) + fig.clf() + + def show_current_figure(self): + fig, _ = plt.gcf(), plt.gca() + fig.show() + fig.clf() + + +if __name__ == '__main__': + output_root = Path('..') / 'output' + p = Plotter(output_root) + p.save_current_figure('test.png') diff --git a/models/generators/cnn.py b/models/generators/cnn.py index 00bf2c4..221bae6 100644 --- a/models/generators/cnn.py +++ b/models/generators/cnn.py @@ -10,12 +10,12 @@ from torch.optim import Adam from datasets.mnist import MyMNIST from datasets.trajectory_dataset import TrajData -from lib.modules.blocks import ConvModule, DeConvModule -from lib.modules.utils import LightningBaseModule, Flatten +from ml_lib.modules.blocks import ConvModule, DeConvModule +from ml_lib.modules.utils import LightningBaseModule, Flatten import matplotlib.pyplot as plt -import lib.variables as V -from lib.visualization.generator_eval import GeneratorVisualizer +import variables as V +from generator_eval import GeneratorVisualizer class CNNRouteGeneratorModel(LightningBaseModule): diff --git a/models/generators/cnn_discriminated.py b/models/generators/cnn_discriminated.py index 9857a01..5376ee3 100644 --- a/models/generators/cnn_discriminated.py +++ b/models/generators/cnn_discriminated.py @@ -1,18 +1,12 @@ -from random import choices, seed -import numpy as np - import torch from functools import reduce from operator import mul from torch import nn -from torch.optim import Adam from datasets.trajectory_dataset import TrajData -from lib.evaluation.classification import ROCEvaluation -from lib.models.generators.cnn import CNNRouteGeneratorModel -from lib.modules.blocks import ConvModule, ResidualModule, DeConvModule -from lib.modules.utils import LightningBaseModule, Flatten +from ml_lib.evaluation.classification import ROCEvaluation +from models.generators.cnn import CNNRouteGeneratorModel import matplotlib.pyplot as plt @@ -79,7 +73,7 @@ class CNNRouteGeneratorDiscriminated(CNNRouteGeneratorModel): maps, trajectories, labels, val_restul_dict = self.generate_random() - from lib.visualization.generator_eval import GeneratorVisualizer + from generator_eval import GeneratorVisualizer g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict) fig = g.draw() self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step) diff --git a/models/generators/full.py b/models/generators/full.py index 3b56206..353f6e1 100644 --- a/models/generators/full.py +++ b/models/generators/full.py @@ -1,7 +1,7 @@ -from lib.modules.losses import BinaryHomotopicLoss -from lib.modules.utils import LightningBaseModule -from lib.objects.map import Map -from lib.objects.trajectory import Trajectory +from ml_lib.modules.losses import BinaryHomotopicLoss +from ml_lib.modules.utils import LightningBaseModule +from objects.map import Map +from objects.trajectory import Trajectory import torch.nn as nn diff --git a/models/generators/recurrent.py b/models/generators/recurrent.py index 08ae85d..853fc41 100644 --- a/models/generators/recurrent.py +++ b/models/generators/recurrent.py @@ -8,9 +8,9 @@ from torch import nn from torch.optim import Adam from datasets.trajectory_dataset import TrajData -from lib.evaluation.classification import ROCEvaluation -from lib.modules.blocks import ConvModule, ResidualModule, DeConvModule -from lib.modules.utils import LightningBaseModule, Flatten +from ml_lib.evaluation.classification import ROCEvaluation +from ml_lib.modules.blocks import ConvModule, ResidualModule, DeConvModule +from ml_lib.modules.utils import LightningBaseModule, Flatten import matplotlib.pyplot as plt @@ -55,7 +55,7 @@ class CNNRouteGeneratorModel(LightningBaseModule): def _test_val_epoch_end(self, outputs, test=False): maps, trajectories, labels, val_restul_dict = self.generate_random() - from lib.visualization.generator_eval import GeneratorVisualizer + from ml_lib.visualization.generator_eval import GeneratorVisualizer g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict) fig = g.draw() self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step) @@ -312,7 +312,7 @@ class CNNRouteGeneratorDiscriminated(CNNRouteGeneratorModel): maps, trajectories, labels, val_restul_dict = self.generate_random() - from lib.visualization.generator_eval import GeneratorVisualizer + from ml_lib.visualization.generator_eval import GeneratorVisualizer g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict) fig = g.draw() self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step) diff --git a/models/homotopy_classification/cnn_based.py b/models/homotopy_classification/cnn_based.py index e3078a0..9b632a6 100644 --- a/models/homotopy_classification/cnn_based.py +++ b/models/homotopy_classification/cnn_based.py @@ -8,9 +8,9 @@ from torch.optim import Adam from torch.utils.data import DataLoader from datasets.trajectory_dataset import TrajData -from lib.evaluation.classification import ROCEvaluation -from lib.modules.utils import LightningBaseModule, Flatten -from lib.modules.blocks import ConvModule, ResidualModule +from ml_lib.evaluation.classification import ROCEvaluation +from ml_lib.modules.utils import LightningBaseModule, Flatten +from ml_lib.modules.blocks import ConvModule, ResidualModule import matplotlib.pyplot as plt diff --git a/multi_run.py b/multi_run.py index d535d48..3d0e377 100644 --- a/multi_run.py +++ b/multi_run.py @@ -1,6 +1,6 @@ import warnings -from lib.utils.config import Config +from ml_lib.utils.config import Config warnings.filterwarnings('ignore', category=FutureWarning) warnings.filterwarnings('ignore', category=UserWarning) diff --git a/objects/map.py b/objects/map.py index 726d292..98759a9 100644 --- a/objects/map.py +++ b/objects/map.py @@ -11,8 +11,8 @@ from PIL import Image, ImageDraw import networkx as nx from matplotlib import pyplot as plt -from lib.objects.trajectory import Trajectory -import lib.variables as V +from ml_lib.objects.trajectory import Trajectory +import ml_lib.variables as V class Map(object): diff --git a/objects/trajectory.py b/objects/trajectory.py index 4581529..2c23d47 100644 --- a/objects/trajectory.py +++ b/objects/trajectory.py @@ -2,7 +2,7 @@ from math import atan2 from typing import List, Tuple, Union from matplotlib import pyplot as plt -from lib import variables as V +from ml_lib import variables as V import numpy as np diff --git a/lib/variables.py b/variables.py similarity index 77% rename from lib/variables.py rename to variables.py index 9b27a6c..06bd575 100644 --- a/lib/variables.py +++ b/variables.py @@ -1,6 +1,3 @@ -from pathlib import Path -_ROOT = Path('..') - # Labels for classes HOMOTOPIC = 1 ALTERNATIVE = 0