Refactoring ML_Lib
This commit is contained in:
@ -10,12 +10,12 @@ from torch.optim import Adam
|
||||
|
||||
from datasets.mnist import MyMNIST
|
||||
from datasets.trajectory_dataset import TrajData
|
||||
from lib.modules.blocks import ConvModule, DeConvModule
|
||||
from lib.modules.utils import LightningBaseModule, Flatten
|
||||
from ml_lib.modules.blocks import ConvModule, DeConvModule
|
||||
from ml_lib.modules.utils import LightningBaseModule, Flatten
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import lib.variables as V
|
||||
from lib.visualization.generator_eval import GeneratorVisualizer
|
||||
import variables as V
|
||||
from generator_eval import GeneratorVisualizer
|
||||
|
||||
|
||||
class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
|
@ -1,18 +1,12 @@
|
||||
from random import choices, seed
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from functools import reduce
|
||||
from operator import mul
|
||||
|
||||
from torch import nn
|
||||
from torch.optim import Adam
|
||||
|
||||
from datasets.trajectory_dataset import TrajData
|
||||
from lib.evaluation.classification import ROCEvaluation
|
||||
from lib.models.generators.cnn import CNNRouteGeneratorModel
|
||||
from lib.modules.blocks import ConvModule, ResidualModule, DeConvModule
|
||||
from lib.modules.utils import LightningBaseModule, Flatten
|
||||
from ml_lib.evaluation.classification import ROCEvaluation
|
||||
from models.generators.cnn import CNNRouteGeneratorModel
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
@ -79,7 +73,7 @@ class CNNRouteGeneratorDiscriminated(CNNRouteGeneratorModel):
|
||||
|
||||
maps, trajectories, labels, val_restul_dict = self.generate_random()
|
||||
|
||||
from lib.visualization.generator_eval import GeneratorVisualizer
|
||||
from generator_eval import GeneratorVisualizer
|
||||
g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict)
|
||||
fig = g.draw()
|
||||
self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step)
|
||||
|
@ -1,7 +1,7 @@
|
||||
from lib.modules.losses import BinaryHomotopicLoss
|
||||
from lib.modules.utils import LightningBaseModule
|
||||
from lib.objects.map import Map
|
||||
from lib.objects.trajectory import Trajectory
|
||||
from ml_lib.modules.losses import BinaryHomotopicLoss
|
||||
from ml_lib.modules.utils import LightningBaseModule
|
||||
from objects.map import Map
|
||||
from objects.trajectory import Trajectory
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
|
@ -8,9 +8,9 @@ from torch import nn
|
||||
from torch.optim import Adam
|
||||
|
||||
from datasets.trajectory_dataset import TrajData
|
||||
from lib.evaluation.classification import ROCEvaluation
|
||||
from lib.modules.blocks import ConvModule, ResidualModule, DeConvModule
|
||||
from lib.modules.utils import LightningBaseModule, Flatten
|
||||
from ml_lib.evaluation.classification import ROCEvaluation
|
||||
from ml_lib.modules.blocks import ConvModule, ResidualModule, DeConvModule
|
||||
from ml_lib.modules.utils import LightningBaseModule, Flatten
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
@ -55,7 +55,7 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
def _test_val_epoch_end(self, outputs, test=False):
|
||||
maps, trajectories, labels, val_restul_dict = self.generate_random()
|
||||
|
||||
from lib.visualization.generator_eval import GeneratorVisualizer
|
||||
from ml_lib.visualization.generator_eval import GeneratorVisualizer
|
||||
g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict)
|
||||
fig = g.draw()
|
||||
self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step)
|
||||
@ -312,7 +312,7 @@ class CNNRouteGeneratorDiscriminated(CNNRouteGeneratorModel):
|
||||
|
||||
maps, trajectories, labels, val_restul_dict = self.generate_random()
|
||||
|
||||
from lib.visualization.generator_eval import GeneratorVisualizer
|
||||
from ml_lib.visualization.generator_eval import GeneratorVisualizer
|
||||
g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict)
|
||||
fig = g.draw()
|
||||
self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step)
|
||||
|
@ -8,9 +8,9 @@ from torch.optim import Adam
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from datasets.trajectory_dataset import TrajData
|
||||
from lib.evaluation.classification import ROCEvaluation
|
||||
from lib.modules.utils import LightningBaseModule, Flatten
|
||||
from lib.modules.blocks import ConvModule, ResidualModule
|
||||
from ml_lib.evaluation.classification import ROCEvaluation
|
||||
from ml_lib.modules.utils import LightningBaseModule, Flatten
|
||||
from ml_lib.modules.blocks import ConvModule, ResidualModule
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user