Refactoring ML_Lib

This commit is contained in:
Si11ium
2020-04-08 14:50:16 +02:00
parent 25c0e8e358
commit 27660c1458
32 changed files with 70 additions and 79 deletions

View File

@ -10,12 +10,12 @@ from torch.optim import Adam
from datasets.mnist import MyMNIST
from datasets.trajectory_dataset import TrajData
from lib.modules.blocks import ConvModule, DeConvModule
from lib.modules.utils import LightningBaseModule, Flatten
from ml_lib.modules.blocks import ConvModule, DeConvModule
from ml_lib.modules.utils import LightningBaseModule, Flatten
import matplotlib.pyplot as plt
import lib.variables as V
from lib.visualization.generator_eval import GeneratorVisualizer
import variables as V
from generator_eval import GeneratorVisualizer
class CNNRouteGeneratorModel(LightningBaseModule):

View File

@ -1,18 +1,12 @@
from random import choices, seed
import numpy as np
import torch
from functools import reduce
from operator import mul
from torch import nn
from torch.optim import Adam
from datasets.trajectory_dataset import TrajData
from lib.evaluation.classification import ROCEvaluation
from lib.models.generators.cnn import CNNRouteGeneratorModel
from lib.modules.blocks import ConvModule, ResidualModule, DeConvModule
from lib.modules.utils import LightningBaseModule, Flatten
from ml_lib.evaluation.classification import ROCEvaluation
from models.generators.cnn import CNNRouteGeneratorModel
import matplotlib.pyplot as plt
@ -79,7 +73,7 @@ class CNNRouteGeneratorDiscriminated(CNNRouteGeneratorModel):
maps, trajectories, labels, val_restul_dict = self.generate_random()
from lib.visualization.generator_eval import GeneratorVisualizer
from generator_eval import GeneratorVisualizer
g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict)
fig = g.draw()
self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step)

View File

@ -1,7 +1,7 @@
from lib.modules.losses import BinaryHomotopicLoss
from lib.modules.utils import LightningBaseModule
from lib.objects.map import Map
from lib.objects.trajectory import Trajectory
from ml_lib.modules.losses import BinaryHomotopicLoss
from ml_lib.modules.utils import LightningBaseModule
from objects.map import Map
from objects.trajectory import Trajectory
import torch.nn as nn

View File

@ -8,9 +8,9 @@ from torch import nn
from torch.optim import Adam
from datasets.trajectory_dataset import TrajData
from lib.evaluation.classification import ROCEvaluation
from lib.modules.blocks import ConvModule, ResidualModule, DeConvModule
from lib.modules.utils import LightningBaseModule, Flatten
from ml_lib.evaluation.classification import ROCEvaluation
from ml_lib.modules.blocks import ConvModule, ResidualModule, DeConvModule
from ml_lib.modules.utils import LightningBaseModule, Flatten
import matplotlib.pyplot as plt
@ -55,7 +55,7 @@ class CNNRouteGeneratorModel(LightningBaseModule):
def _test_val_epoch_end(self, outputs, test=False):
maps, trajectories, labels, val_restul_dict = self.generate_random()
from lib.visualization.generator_eval import GeneratorVisualizer
from ml_lib.visualization.generator_eval import GeneratorVisualizer
g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict)
fig = g.draw()
self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step)
@ -312,7 +312,7 @@ class CNNRouteGeneratorDiscriminated(CNNRouteGeneratorModel):
maps, trajectories, labels, val_restul_dict = self.generate_random()
from lib.visualization.generator_eval import GeneratorVisualizer
from ml_lib.visualization.generator_eval import GeneratorVisualizer
g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict)
fig = g.draw()
self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step)

View File

@ -8,9 +8,9 @@ from torch.optim import Adam
from torch.utils.data import DataLoader
from datasets.trajectory_dataset import TrajData
from lib.evaluation.classification import ROCEvaluation
from lib.modules.utils import LightningBaseModule, Flatten
from lib.modules.blocks import ConvModule, ResidualModule
from ml_lib.evaluation.classification import ROCEvaluation
from ml_lib.modules.utils import LightningBaseModule, Flatten
from ml_lib.modules.blocks import ConvModule, ResidualModule
import matplotlib.pyplot as plt