train running
This commit is contained in:
@@ -46,7 +46,6 @@ class TrajDataset(Dataset):
|
||||
|
||||
while True:
|
||||
trajectory = self.map.get_random_trajectory()
|
||||
# TODO: Sanity Check this while true loop...
|
||||
alternative = self.map.generate_alternative(trajectory)
|
||||
label = self.map.are_homotopic(trajectory, alternative)
|
||||
if self.preserve_equal_samples and label == self.last_label:
|
||||
@@ -56,18 +55,13 @@ class TrajDataset(Dataset):
|
||||
|
||||
self.last_label = label
|
||||
if self.mode.lower() in ['all_in_map', 'separated_arrays']:
|
||||
blank_trajectory_space = torch.zeros(self.map.shape)
|
||||
blank_alternative_space = torch.zeros(self.map.shape)
|
||||
for index in trajectory.vertices:
|
||||
blank_trajectory_space[index] = 1
|
||||
for index in alternative.vertices:
|
||||
blank_alternative_space[index] = 1
|
||||
|
||||
map_array = torch.as_tensor(self.map.as_array).float()
|
||||
if self.mode == 'separated_arrays':
|
||||
return (map_array, blank_trajectory_space, int(label)), blank_alternative_space
|
||||
return (map_array, trajectory.draw_in_array(self.map_shape), int(label)), \
|
||||
alternative.draw_in_array(self.map_shape)
|
||||
else:
|
||||
return torch.cat((map_array, blank_trajectory_space, blank_alternative_space)), int(label)
|
||||
return torch.cat((map_array, trajectory.draw_in_array(self.map_shape),
|
||||
alternative.draw_in_array(self.map_shape))), int(label)
|
||||
|
||||
elif self.mode == 'vectors':
|
||||
return trajectory.vertices, alternative.vertices, label, self.mapname
|
||||
|
||||
Reference in New Issue
Block a user