Future Prediction Training as Parameter
This commit is contained in:
parent
3e9ef013b3
commit
aa802cb2be
@ -24,6 +24,7 @@ args.add_argument('--size', default=9)
|
||||
args.add_argument('--latent_dim', default=2)
|
||||
args.add_argument('--model', default='AE_Model')
|
||||
args.add_argument('--refresh', type=strtobool, default=False)
|
||||
args.add_argument('--future_predictions', type=strtobool, default=False)
|
||||
|
||||
|
||||
class AE_Model(AutoEncoder_LO, LightningModule):
|
||||
@ -34,7 +35,7 @@ class AE_Model(AutoEncoder_LO, LightningModule):
|
||||
self.latent_dim = parameters.latent_dim
|
||||
self.features = parameters.features
|
||||
self.step = parameters.step
|
||||
super(AE_Model, self).__init__()
|
||||
super(AE_Model, self).__init__(train_on_predictions=parameters.future_predictions)
|
||||
self.network = AutoEncoder(self.latent_dim, self.features)
|
||||
|
||||
|
||||
@ -46,7 +47,7 @@ class VAE_Model(VAE_LO, LightningModule):
|
||||
self.latent_dim = parameters.latent_dim
|
||||
self.features = parameters.features
|
||||
self.step = parameters.step
|
||||
super(VAE_Model, self).__init__()
|
||||
super(VAE_Model, self).__init__(train_on_predictions=parameters.future_predictions)
|
||||
self.network = VariationalAE(self.latent_dim, self.features)
|
||||
|
||||
|
||||
@ -58,7 +59,7 @@ class AAE_Model(AdversarialAE_LO, LightningModule):
|
||||
self.latent_dim = parameters.latent_dim
|
||||
self.features = parameters.features
|
||||
self.step = parameters.step
|
||||
super(AAE_Model, self).__init__()
|
||||
super(AAE_Model, self).__init__(train_on_predictions=parameters.future_predictions)
|
||||
self.normal = Normal(0, 1)
|
||||
self.network = AdversarialAE(self.latent_dim, self.features)
|
||||
pass
|
||||
@ -72,7 +73,7 @@ class SAAE_Model(SeparatingAAE_LO, LightningModule):
|
||||
self.latent_dim = parameters.latent_dim
|
||||
self.features = parameters.features
|
||||
self.step = parameters.step
|
||||
super(SAAE_Model, self).__init__()
|
||||
super(SAAE_Model, self).__init__(train_on_predictions=parameters.future_predictions)
|
||||
self.normal = Normal(0, 1)
|
||||
self.network = SeperatingAAE(self.latent_dim, self.features)
|
||||
pass
|
||||
|
Loading…
x
Reference in New Issue
Block a user