Future Prediction Training as Parameter
This commit is contained in:
parent
3e9ef013b3
commit
aa802cb2be
@ -24,6 +24,7 @@ args.add_argument('--size', default=9)
|
|||||||
args.add_argument('--latent_dim', default=2)
|
args.add_argument('--latent_dim', default=2)
|
||||||
args.add_argument('--model', default='AE_Model')
|
args.add_argument('--model', default='AE_Model')
|
||||||
args.add_argument('--refresh', type=strtobool, default=False)
|
args.add_argument('--refresh', type=strtobool, default=False)
|
||||||
|
args.add_argument('--future_predictions', type=strtobool, default=False)
|
||||||
|
|
||||||
|
|
||||||
class AE_Model(AutoEncoder_LO, LightningModule):
|
class AE_Model(AutoEncoder_LO, LightningModule):
|
||||||
@ -34,7 +35,7 @@ class AE_Model(AutoEncoder_LO, LightningModule):
|
|||||||
self.latent_dim = parameters.latent_dim
|
self.latent_dim = parameters.latent_dim
|
||||||
self.features = parameters.features
|
self.features = parameters.features
|
||||||
self.step = parameters.step
|
self.step = parameters.step
|
||||||
super(AE_Model, self).__init__()
|
super(AE_Model, self).__init__(train_on_predictions=parameters.future_predictions)
|
||||||
self.network = AutoEncoder(self.latent_dim, self.features)
|
self.network = AutoEncoder(self.latent_dim, self.features)
|
||||||
|
|
||||||
|
|
||||||
@ -46,7 +47,7 @@ class VAE_Model(VAE_LO, LightningModule):
|
|||||||
self.latent_dim = parameters.latent_dim
|
self.latent_dim = parameters.latent_dim
|
||||||
self.features = parameters.features
|
self.features = parameters.features
|
||||||
self.step = parameters.step
|
self.step = parameters.step
|
||||||
super(VAE_Model, self).__init__()
|
super(VAE_Model, self).__init__(train_on_predictions=parameters.future_predictions)
|
||||||
self.network = VariationalAE(self.latent_dim, self.features)
|
self.network = VariationalAE(self.latent_dim, self.features)
|
||||||
|
|
||||||
|
|
||||||
@ -58,7 +59,7 @@ class AAE_Model(AdversarialAE_LO, LightningModule):
|
|||||||
self.latent_dim = parameters.latent_dim
|
self.latent_dim = parameters.latent_dim
|
||||||
self.features = parameters.features
|
self.features = parameters.features
|
||||||
self.step = parameters.step
|
self.step = parameters.step
|
||||||
super(AAE_Model, self).__init__()
|
super(AAE_Model, self).__init__(train_on_predictions=parameters.future_predictions)
|
||||||
self.normal = Normal(0, 1)
|
self.normal = Normal(0, 1)
|
||||||
self.network = AdversarialAE(self.latent_dim, self.features)
|
self.network = AdversarialAE(self.latent_dim, self.features)
|
||||||
pass
|
pass
|
||||||
@ -72,7 +73,7 @@ class SAAE_Model(SeparatingAAE_LO, LightningModule):
|
|||||||
self.latent_dim = parameters.latent_dim
|
self.latent_dim = parameters.latent_dim
|
||||||
self.features = parameters.features
|
self.features = parameters.features
|
||||||
self.step = parameters.step
|
self.step = parameters.step
|
||||||
super(SAAE_Model, self).__init__()
|
super(SAAE_Model, self).__init__(train_on_predictions=parameters.future_predictions)
|
||||||
self.normal = Normal(0, 1)
|
self.normal = Normal(0, 1)
|
||||||
self.network = SeperatingAAE(self.latent_dim, self.features)
|
self.network = SeperatingAAE(self.latent_dim, self.features)
|
||||||
pass
|
pass
|
||||||
|
Loading…
x
Reference in New Issue
Block a user