smaller task train
This commit is contained in:
parent
c52b398819
commit
b6c8859081
@ -295,8 +295,7 @@ def flat_for_store(parameters):
|
|||||||
|
|
||||||
|
|
||||||
def train_self_replication(model, optimizer, st_stps) -> dict:
|
def train_self_replication(model, optimizer, st_stps) -> dict:
|
||||||
for _ in range(st_stps):
|
self_train_loss = model.combined_self_train(optimizer, st_stps)
|
||||||
self_train_loss = model.combined_self_train(optimizer)
|
|
||||||
# noinspection PyUnboundLocalVariable
|
# noinspection PyUnboundLocalVariable
|
||||||
stp_log = dict(Metric='Self Train Loss', Score=self_train_loss.item())
|
stp_log = dict(Metric='Self Train Loss', Score=self_train_loss.item())
|
||||||
return stp_log
|
return stp_log
|
||||||
|
@ -68,7 +68,8 @@ if __name__ == '__main__':
|
|||||||
mean_self_tain_loss = []
|
mean_self_tain_loss = []
|
||||||
|
|
||||||
for batch, (batch_x, batch_y) in tenumerate(dataloader):
|
for batch, (batch_x, batch_y) in tenumerate(dataloader):
|
||||||
# self_train_loss, _ = net.self_train(2, save_history=False, learning_rate=0.004)
|
self_train_loss, _ = net.self_train(2, save_history=False, learning_rate=0.004)
|
||||||
|
|
||||||
for _ in range(2):
|
for _ in range(2):
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
input_data = net.input_weight_matrix()
|
input_data = net.input_weight_matrix()
|
||||||
|
21
network.py
21
network.py
@ -463,19 +463,20 @@ class MetaNet(nn.Module):
|
|||||||
def particles(self):
|
def particles(self):
|
||||||
return (cell for metalayer in self.all_layers for cell in metalayer.particles)
|
return (cell for metalayer in self.all_layers for cell in metalayer.particles)
|
||||||
|
|
||||||
def combined_self_train(self, optimizer, reduction='mean'):
|
def combined_self_train(self, optimizer, n_st_steps, reduction='mean'):
|
||||||
optimizer.zero_grad()
|
|
||||||
losses = []
|
losses = []
|
||||||
n = 10
|
|
||||||
for particle in self.particles:
|
for particle in self.particles:
|
||||||
# Intergrate optimizer and backward function
|
for _ in range(n_st_steps):
|
||||||
input_data = particle.input_weight_matrix()
|
optimizer.zero_grad()
|
||||||
target_data = particle.create_target_weights(input_data)
|
# Intergrate optimizer and backward function
|
||||||
output = particle(input_data)
|
input_data = particle.input_weight_matrix()
|
||||||
losses.append(F.mse_loss(output, target_data, reduction=reduction))
|
target_data = particle.create_target_weights(input_data)
|
||||||
|
output = particle(input_data)
|
||||||
|
losses.append(F.mse_loss(output, target_data, reduction=reduction))
|
||||||
|
losses.backward()
|
||||||
|
optimizer.step()
|
||||||
losses = torch.hstack(losses).sum(dim=-1, keepdim=True)
|
losses = torch.hstack(losses).sum(dim=-1, keepdim=True)
|
||||||
losses.backward()
|
|
||||||
optimizer.step()
|
|
||||||
return losses.detach()
|
return losses.detach()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
Loading…
x
Reference in New Issue
Block a user