Debugging Validation and testing
This commit is contained in:
@@ -57,11 +57,11 @@ class TrajDataset(Dataset):
|
||||
if self.mode.lower() in ['all_in_map', 'separated_arrays']:
|
||||
map_array = torch.as_tensor(self.map.as_array).float()
|
||||
if self.mode == 'separated_arrays':
|
||||
return (map_array, trajectory.draw_in_array(self.map_shape), int(label)), \
|
||||
alternative.draw_in_array(self.map_shape)
|
||||
return (map_array, torch.as_tensor(trajectory.draw_in_array(self.map_shape)).float(), int(label)), \
|
||||
torch.as_tensor(alternative.draw_in_array(self.map_shape)).float()
|
||||
else:
|
||||
return torch.cat((map_array, trajectory.draw_in_array(self.map_shape),
|
||||
alternative.draw_in_array(self.map_shape))), int(label)
|
||||
return torch.cat((map_array, torch.as_tensor(trajectory.draw_in_array(self.map_shape)).float(),
|
||||
torch.as_tensor(alternative.draw_in_array(self.map_shape)).float())), int(label)
|
||||
|
||||
elif self.mode == 'vectors':
|
||||
return trajectory.vertices, alternative.vertices, label, self.mapname
|
||||
|
||||
Reference in New Issue
Block a user