Future Prediction Training

This commit is contained in:
Si11ium 2019-09-29 11:50:38 +02:00
parent a70c9b7fef
commit 3e9ef013b3
8 changed files with 86 additions and 88 deletions

View File

@ -195,17 +195,21 @@ class Trajectories(Dataset):
yield self[i] yield self[i]
def __getitem__(self, item): def __getitem__(self, item):
return self.data[item:item + self.size * self.step or None:self.step][:, 2:] assert isinstance(item, int), f"Item-Key has to be Integer, but was {type(item)}"
x = self.data[item:item + self.size * self.step or None:self.step][:, 2:]
futureItem = item + 1
y = self.data[futureItem:futureItem + self.size * self.step or None:self.step][:, 2:]
return x, y
def get_isovist_measures_by_key(self, item): def get_isovist_measures_by_key(self, item):
return self[item] return self[item][0]
def get_coordinates_by_key(self, item): def get_coordinates_by_key(self, item):
return self.data[item:item + self.size * self.step or None:self.step][:, :2] return self.data[item:item + self.size * self.step or None:self.step][:, :2]
def get_both_by_key(self, item): def get_both_by_key(self, item):
data = self.data[item:item + self.size * self.step or None:self.step] data = self.data[item:item + self.size * self.step or None:self.step]
return data return data[0]
def __len__(self): def __len__(self):
total_len = self.data.size()[0] total_len = self.data.size()[0]

View File

@ -28,16 +28,19 @@ class AdversarialAE(AutoEncoder):
class AdversarialAE_LO(LightningModuleOverrides): class AdversarialAE_LO(LightningModuleOverrides):
def __init__(self): def __init__(self, train_on_predictions=False):
super(AdversarialAE_LO, self).__init__() super(AdversarialAE_LO, self).__init__()
self.train_on_predictions = train_on_predictions
def training_step(self, batch, _, optimizer_i): def training_step(self, batch, _, optimizer_i):
x, y = batch
z, x_hat = self.forward(x)
if optimizer_i == 0: if optimizer_i == 0:
# --------------------- # ---------------------
# Train Discriminator # Train Discriminator
# ---------------------p # ---------------------p
# latent_fake, reconstruction # latent_fake, reconstruction
latent_fake = self.network.encoder.forward(batch) latent_fake = z
latent_real = self.normal.sample(latent_fake.shape).to(device) latent_real = self.normal.sample(latent_fake.shape).to(device)
# Evaluate the input # Evaluate the input
@ -57,9 +60,7 @@ class AdversarialAE_LO(LightningModuleOverrides):
# --------------------- # ---------------------
# Train AutoEncoder # Train AutoEncoder
# --------------------- # ---------------------
# z, x_hat loss = mse_loss(y, x_hat) if self.train_on_predictions else mse_loss(x, x_hat)
_, batch_hat = self.forward(batch)
loss = mse_loss(batch, batch_hat)
return {'loss': loss} return {'loss': loss}
else: else:

View File

@ -37,7 +37,7 @@ class AE_WithAttention_LO(LightningModuleOverrides):
# ToDo: We need a new loss function, fullfilling all attention needs # ToDo: We need a new loss function, fullfilling all attention needs
# z, x_hat # z, x_hat
_, x_hat = self.forward(x) _, x_hat = self.forward(x)
loss = mse_loss(x, x_hat) loss = mse_loss(y, x_hat) if self.train_on_predictions else mse_loss(x, x_hat)
return {'loss': loss} return {'loss': loss}
def configure_optimizers(self): def configure_optimizers(self):

View File

@ -9,13 +9,13 @@ from torch import Tensor
# Basic AE-Implementation # Basic AE-Implementation
class AutoEncoder(AbstractNeuralNetwork, ABC): class AutoEncoder(AbstractNeuralNetwork, ABC):
def __init__(self, latent_dim: int=0, features: int = 0, **kwargs): def __init__(self, latent_dim: int=0, features: int = 0, use_norm=True, **kwargs):
assert latent_dim and features assert latent_dim and features
super(AutoEncoder, self).__init__() super(AutoEncoder, self).__init__()
self.latent_dim = latent_dim self.latent_dim = latent_dim
self.features = features self.features = features
self.encoder = Encoder(self.latent_dim) self.encoder = Encoder(self.latent_dim, use_norm=use_norm)
self.decoder = Decoder(self.latent_dim, self.features) self.decoder = Decoder(self.latent_dim, self.features, use_norm=use_norm)
def forward(self, batch: Tensor): def forward(self, batch: Tensor):
# Encoder # Encoder
@ -30,13 +30,15 @@ class AutoEncoder(AbstractNeuralNetwork, ABC):
class AutoEncoder_LO(LightningModuleOverrides): class AutoEncoder_LO(LightningModuleOverrides):
def __init__(self): def __init__(self, train_on_predictions=False):
super(AutoEncoder_LO, self).__init__() super(AutoEncoder_LO, self).__init__()
self.train_on_predictions = train_on_predictions
def training_step(self, x, batch_nb): def training_step(self, batch, batch_nb):
x, y = batch
# z, x_hat # z, x_hat
_, x_hat = self.forward(x) _, x_hat = self.forward(x)
loss = mse_loss(x, x_hat) loss = mse_loss(y, x_hat) if self.train_on_predictions else mse_loss(x, x_hat)
return {'loss': loss} return {'loss': loss}
def configure_optimizers(self): def configure_optimizers(self):

View File

@ -6,7 +6,7 @@ import torch
from torch import randn from torch import randn
import pytorch_lightning as pl import pytorch_lightning as pl
from pytorch_lightning import data_loader from pytorch_lightning import data_loader
from torch.nn import Module, Linear, ReLU, Sigmoid, Dropout, GRU from torch.nn import Module, Linear, ReLU, Sigmoid, Dropout, GRU, Tanh
from torchvision.transforms import Normalize from torchvision.transforms import Normalize
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
@ -33,8 +33,7 @@ class LightningModuleOverrides:
@data_loader @data_loader
def tng_dataloader(self): def tng_dataloader(self):
num_workers = 0 # os.cpu_count() // 2 num_workers = 0 # os.cpu_count() // 2
return DataLoader(DataContainer(os.path.join('data', 'training'), return DataLoader(DataContainer(os.path.join('data', 'training'), self.size, self.step),
self.size, self.step, transforms=[Normalize]),
shuffle=True, batch_size=10000, num_workers=num_workers) shuffle=True, batch_size=10000, num_workers=num_workers)
""" """
@data_loader @data_loader
@ -193,17 +192,23 @@ class Discriminator(Module):
class DecoderLinearStack(Module): class DecoderLinearStack(Module):
def __init__(self, out_shape): def __init__(self, out_shape, use_norm=True):
super(DecoderLinearStack, self).__init__() super(DecoderLinearStack, self).__init__()
self.l1 = Linear(10, 100, bias=True) self.l1 = Linear(10, 100, bias=True)
self.norm1 = torch.nn.BatchNorm1d(100) if use_norm else False
self.l2 = Linear(100, out_shape, bias=True) self.l2 = Linear(100, out_shape, bias=True)
self.norm2 = torch.nn.BatchNorm1d(out_shape) if use_norm else False
self.activation = ReLU() self.activation = ReLU()
self.activation_out = Sigmoid() self.activation_out = Tanh()
def forward(self, x): def forward(self, x):
tensor = self.l1(x) tensor = self.l1(x)
if self.norm1:
tensor = self.norm1(tensor)
tensor = self.activation(tensor) tensor = self.activation(tensor)
tensor = self.l2(tensor) tensor = self.l2(tensor)
if self.norm2:
tensor = self.norm2(tensor)
tensor = self.activation_out(tensor) tensor = self.activation_out(tensor)
return tensor return tensor
@ -213,62 +218,64 @@ class EncoderLinearStack(Module):
@property @property
def shape(self): def shape(self):
x = randn(self.features).unsqueeze(0) x = randn(self.features).unsqueeze(0)
x = torch.cat((x,x,x,x,x))
output = self(x) output = self(x)
return output.shape[1:] return output.shape[1:]
def __init__(self, features=6, separated=False, use_bias=True): def __init__(self, features=6, factor=10, use_bias=True, use_norm=True):
super(EncoderLinearStack, self).__init__() super(EncoderLinearStack, self).__init__()
# FixMe: Get Hardcoded shit out of here # FixMe: Get Hardcoded shit out of here
self.separated = separated
self.features = features self.features = features
if self.separated: self.l1 = Linear(self.features, self.features * factor, bias=use_bias)
self.l1s = [Linear(1, 10, bias=use_bias) for _ in range(self.features)] self.l2 = Linear(self.features * factor, self.features * factor//2, bias=use_bias)
self.l2s = [Linear(10, 5, bias=use_bias) for _ in range(self.features)] self.l3 = Linear(self.features * factor//2, factor, use_bias)
else: self.norm1 = torch.nn.BatchNorm1d(self.features * factor) if use_norm else False
self.l1 = Linear(self.features, self.features * 10, bias=use_bias) self.norm2 = torch.nn.BatchNorm1d(self.features * factor//2) if use_norm else False
self.l2 = Linear(self.features * 10, self.features * 5, bias=use_bias) self.norm3 = torch.nn.BatchNorm1d(factor) if use_norm else False
self.l3 = Linear(self.features * 5, 10, use_bias)
self.activation = ReLU() self.activation = ReLU()
def forward(self, x): def forward(self, x):
if self.separated: tensor = self.l1(x)
x = x.unsqueeze(-1) if self.norm1:
tensors = [self.l1s[idx](x[:, idx, :]) for idx in range(len(self.l1s))] tensor = self.norm1(tensor)
tensors = [self.activation(tensor) for tensor in tensors] tensor = self.activation(tensor)
tensors = [self.l2s[idx](tensors[idx]) for idx in range(len(self.l2s))] tensor = self.l2(tensor)
tensors = [self.activation(tensor) for tensor in tensors] if self.norm2:
tensor = torch.cat(tensors, dim=-1) tensor = self.norm2(tensor)
else: tensor = self.activation(tensor)
tensor = self.l1(x)
tensor = self.activation(tensor)
tensor = self.l2(tensor)
tensor = self.l3(tensor) tensor = self.l3(tensor)
if self.norm3:
tensor = self.norm3(tensor)
tensor = self.activation(tensor) tensor = self.activation(tensor)
return tensor return tensor
class Encoder(Module): class Encoder(Module):
def __init__(self, lat_dim, variational=False, separate_features=False, with_dense=True, features=6): def __init__(self, lat_dim, variational=False, use_dense=True, features=6, use_norm=True):
self.lat_dim = lat_dim self.lat_dim = lat_dim
self.features = features self.features = features
self.lstm_cells = 10
self.variational = variational self.variational = variational
super(Encoder, self).__init__() super(Encoder, self).__init__()
self.l_stack = TimeDistributed(EncoderLinearStack(separated=separate_features,
features=features)) if with_dense else False self.l_stack = TimeDistributed(EncoderLinearStack(features, use_norm=use_norm)) if use_dense else False
self.gru = GRU(10 if with_dense else self.features, 10, batch_first=True) self.gru = GRU(10 if use_dense else self.features, self.lstm_cells, batch_first=True)
self.filter = RNNOutputFilter(only_last=True) self.filter = RNNOutputFilter(only_last=True)
self.norm = torch.nn.BatchNorm1d(self.lstm_cells) if use_norm else False
if variational: if variational:
self.mu = Linear(10, self.lat_dim) self.mu = Linear(self.lstm_cells, self.lat_dim)
self.logvar = Linear(10, self.lat_dim) self.logvar = Linear(self.lstm_cells, self.lat_dim)
else: else:
self.lat_dim_layer = Linear(10, self.lat_dim) self.lat_dim_layer = Linear(self.lstm_cells, self.lat_dim)
def forward(self, x): def forward(self, x):
if self.l_stack: if self.l_stack:
x = self.l_stack(x) x = self.l_stack(x)
tensor = self.gru(x) tensor = self.gru(x)
tensor = self.filter(tensor) tensor = self.filter(tensor)
if self.norm:
tensor = self.norm(tensor)
if self.variational: if self.variational:
tensor = self.mu(tensor), self.logvar(tensor) tensor = self.mu(tensor), self.logvar(tensor)
else: else:
@ -316,17 +323,20 @@ class PoolingEncoder(Module):
class Decoder(Module): class Decoder(Module):
def __init__(self, latent_dim, *args, variational=False): def __init__(self, latent_dim, *args, lstm_cells=10, use_norm=True, variational=False):
self.variational = variational self.variational = variational
super(Decoder, self).__init__() super(Decoder, self).__init__()
self.g = GRU(latent_dim, 10, batch_first=True) self.gru = GRU(latent_dim, lstm_cells, batch_first=True)
self.norm = TimeDistributed(torch.nn.BatchNorm1d(lstm_cells) if use_norm else False)
self.filter = RNNOutputFilter() self.filter = RNNOutputFilter()
self.l_stack = TimeDistributed(DecoderLinearStack(*args)) self.l_stack = TimeDistributed(DecoderLinearStack(*args, use_norm=use_norm))
pass pass
def forward(self, x): def forward(self, x):
tensor = self.g(x) tensor = self.gru(x)
tensor = self.filter(tensor) tensor = self.filter(tensor)
if self.norm:
tensor = self.norm(tensor)
tensor = self.l_stack(tensor) tensor = self.l_stack(tensor)
return tensor return tensor

View File

@ -6,14 +6,14 @@ import torch
class SeperatingAAE(Module): class SeperatingAAE(Module):
def __init__(self, latent_dim, features): def __init__(self, latent_dim, features, use_norm=True):
super(SeperatingAAE, self).__init__() super(SeperatingAAE, self).__init__()
self.latent_dim = latent_dim self.latent_dim = latent_dim
self.features = features self.features = features
self.spatial_encoder = PoolingEncoder(self.latent_dim) self.spatial_encoder = PoolingEncoder(self.latent_dim)
self.temporal_encoder = Encoder(self.latent_dim, with_dense=False) self.temporal_encoder = Encoder(self.latent_dim, use_dense=False, use_norm=use_norm)
self.decoder = Decoder(self.latent_dim * 2, self.features) self.decoder = Decoder(self.latent_dim * 2, self.features, use_norm=use_norm)
self.spatial_discriminator = Discriminator(self.latent_dim, self.features) self.spatial_discriminator = Discriminator(self.latent_dim, self.features)
self.temporal_discriminator = Discriminator(self.latent_dim, self.features) self.temporal_discriminator = Discriminator(self.latent_dim, self.features)
@ -29,22 +29,15 @@ class SeperatingAAE(Module):
return z_spatial, z_temporal, x_hat return z_spatial, z_temporal, x_hat
class SuperSeperatingAAE(SeperatingAAE):
def __init__(self, *args):
super(SuperSeperatingAAE, self).__init__(*args)
self.temporal_encoder = Encoder(self.latent_dim, separate_features=True)
def forward(self, batch):
return batch
class SeparatingAAE_LO(LightningModuleOverrides): class SeparatingAAE_LO(LightningModuleOverrides):
def __init__(self): def __init__(self, train_on_predictions=False):
super(SeparatingAAE_LO, self).__init__() super(SeparatingAAE_LO, self).__init__()
self.train_on_predictions = train_on_predictions
def training_step(self, batch, _, optimizer_i): def training_step(self, batch, _, optimizer_i):
spatial_latent_fake, temporal_latent_fake, batch_hat = self.network.forward(batch) x, y = batch
spatial_latent_fake, temporal_latent_fake, x_hat = self.network.forward(x)
if optimizer_i == 0: if optimizer_i == 0:
# --------------------- # ---------------------
# Train temporal Discriminator # Train temporal Discriminator
@ -93,7 +86,7 @@ class SeparatingAAE_LO(LightningModuleOverrides):
# --------------------- # ---------------------
# Train AutoEncoder # Train AutoEncoder
# --------------------- # ---------------------
loss = mse_loss(batch, batch_hat) loss = mse_loss(y, x_hat) if self.train_on_predictions else mse_loss(x, x_hat)
return {'loss': loss} return {'loss': loss}
else: else:

View File

@ -12,13 +12,13 @@ class VariationalAE(AbstractNeuralNetwork, ABC):
def name(self): def name(self):
return self.__class__.__name__ return self.__class__.__name__
def __init__(self, latent_dim=0, features=0, **kwargs): def __init__(self, latent_dim=0, features=0, use_norm=True, **kwargs):
assert latent_dim and features assert latent_dim and features
super(VariationalAE, self).__init__() super(VariationalAE, self).__init__()
self.features = features self.features = features
self.latent_dim = latent_dim self.latent_dim = latent_dim
self.encoder = Encoder(self.latent_dim, variational=True) self.encoder = Encoder(self.latent_dim, variational=True, use_norm=use_norm)
self.decoder = Decoder(self.latent_dim, self.features, variational=True) self.decoder = Decoder(self.latent_dim, self.features, variational=True, use_norm=use_norm)
@staticmethod @staticmethod
def reparameterize(mu, logvar): def reparameterize(mu, logvar):
@ -37,12 +37,14 @@ class VariationalAE(AbstractNeuralNetwork, ABC):
class VAE_LO(LightningModuleOverrides): class VAE_LO(LightningModuleOverrides):
def __init__(self): def __init__(self, train_on_predictions=False):
super(VAE_LO, self).__init__() super(VAE_LO, self).__init__()
self.train_on_predictions=train_on_predictions
def training_step(self, x, _): def training_step(self, batch, _):
x, y = batch
mu, logvar, x_hat = self.forward(x) mu, logvar, x_hat = self.forward(x)
BCE = mse_loss(x_hat, x, reduction='mean') BCE = mse_loss(y, x_hat) if self.train_on_predictions else mse_loss(x, x_hat)
# see Appendix B from VAE paper: # see Appendix B from VAE paper:
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014

View File

@ -10,7 +10,7 @@ from distutils.util import strtobool
from networks.auto_encoder import AutoEncoder, AutoEncoder_LO from networks.auto_encoder import AutoEncoder, AutoEncoder_LO
from networks.variational_auto_encoder import VariationalAE, VAE_LO from networks.variational_auto_encoder import VariationalAE, VAE_LO
from networks.adverserial_auto_encoder import AdversarialAE_LO, AdversarialAE from networks.adverserial_auto_encoder import AdversarialAE_LO, AdversarialAE
from networks.seperating_adversarial_auto_encoder import SeperatingAAE, SeparatingAAE_LO, SuperSeperatingAAE from networks.seperating_adversarial_auto_encoder import SeperatingAAE, SeparatingAAE_LO
from networks.modules import LightningModule from networks.modules import LightningModule
from pytorch_lightning import Trainer from pytorch_lightning import Trainer
@ -22,7 +22,7 @@ args.add_argument('--step', default=5)
args.add_argument('--features', default=6) args.add_argument('--features', default=6)
args.add_argument('--size', default=9) args.add_argument('--size', default=9)
args.add_argument('--latent_dim', default=2) args.add_argument('--latent_dim', default=2)
args.add_argument('--model', default='VAE_Model') args.add_argument('--model', default='AE_Model')
args.add_argument('--refresh', type=strtobool, default=False) args.add_argument('--refresh', type=strtobool, default=False)
@ -78,20 +78,6 @@ class SAAE_Model(SeparatingAAE_LO, LightningModule):
pass pass
class SSAAE_Model(SeparatingAAE_LO, LightningModule):
def __init__(self, parameters: Namespace):
assert all([x in parameters for x in ['step', 'size', 'latent_dim', 'features']])
self.size = parameters.size
self.latent_dim = parameters.latent_dim
self.features = parameters.features
self.step = parameters.step
super(SSAAE_Model, self).__init__()
self.normal = Normal(0, 1)
self.network = SuperSeperatingAAE(self.latent_dim, self.features)
pass
if __name__ == '__main__': if __name__ == '__main__':
arguments = args.parse_args() arguments = args.parse_args()