Future Prediction Training as Parameter

This commit is contained in:
Si11ium 2019-09-29 11:52:39 +02:00
parent 3e9ef013b3
commit aa802cb2be

View File

@ -24,6 +24,7 @@ args.add_argument('--size', default=9)
args.add_argument('--latent_dim', default=2)
args.add_argument('--model', default='AE_Model')
args.add_argument('--refresh', type=strtobool, default=False)
args.add_argument('--future_predictions', type=strtobool, default=False)
class AE_Model(AutoEncoder_LO, LightningModule):
@ -34,7 +35,7 @@ class AE_Model(AutoEncoder_LO, LightningModule):
self.latent_dim = parameters.latent_dim
self.features = parameters.features
self.step = parameters.step
super(AE_Model, self).__init__()
super(AE_Model, self).__init__(train_on_predictions=parameters.future_predictions)
self.network = AutoEncoder(self.latent_dim, self.features)
@ -46,7 +47,7 @@ class VAE_Model(VAE_LO, LightningModule):
self.latent_dim = parameters.latent_dim
self.features = parameters.features
self.step = parameters.step
super(VAE_Model, self).__init__()
super(VAE_Model, self).__init__(train_on_predictions=parameters.future_predictions)
self.network = VariationalAE(self.latent_dim, self.features)
@ -58,7 +59,7 @@ class AAE_Model(AdversarialAE_LO, LightningModule):
self.latent_dim = parameters.latent_dim
self.features = parameters.features
self.step = parameters.step
super(AAE_Model, self).__init__()
super(AAE_Model, self).__init__(train_on_predictions=parameters.future_predictions)
self.normal = Normal(0, 1)
self.network = AdversarialAE(self.latent_dim, self.features)
pass
@ -72,7 +73,7 @@ class SAAE_Model(SeparatingAAE_LO, LightningModule):
self.latent_dim = parameters.latent_dim
self.features = parameters.features
self.step = parameters.step
super(SAAE_Model, self).__init__()
super(SAAE_Model, self).__init__(train_on_predictions=parameters.future_predictions)
self.normal = Normal(0, 1)
self.network = SeperatingAAE(self.latent_dim, self.features)
pass