train running

This commit is contained in:
Steffen Illium
2020-03-09 21:41:50 +01:00
parent daed810958
commit 6cc978e464
7 changed files with 68 additions and 57 deletions

View File

@@ -46,7 +46,6 @@ class TrajDataset(Dataset):
while True:
trajectory = self.map.get_random_trajectory()
# TODO: Sanity Check this while true loop...
alternative = self.map.generate_alternative(trajectory)
label = self.map.are_homotopic(trajectory, alternative)
if self.preserve_equal_samples and label == self.last_label:
@@ -56,18 +55,13 @@ class TrajDataset(Dataset):
self.last_label = label
if self.mode.lower() in ['all_in_map', 'separated_arrays']:
blank_trajectory_space = torch.zeros(self.map.shape)
blank_alternative_space = torch.zeros(self.map.shape)
for index in trajectory.vertices:
blank_trajectory_space[index] = 1
for index in alternative.vertices:
blank_alternative_space[index] = 1
map_array = torch.as_tensor(self.map.as_array).float()
if self.mode == 'separated_arrays':
return (map_array, blank_trajectory_space, int(label)), blank_alternative_space
return (map_array, trajectory.draw_in_array(self.map_shape), int(label)), \
alternative.draw_in_array(self.map_shape)
else:
return torch.cat((map_array, blank_trajectory_space, blank_alternative_space)), int(label)
return torch.cat((map_array, trajectory.draw_in_array(self.map_shape),
alternative.draw_in_array(self.map_shape))), int(label)
elif self.mode == 'vectors':
return trajectory.vertices, alternative.vertices, label, self.mapname