Variational Generator
This commit is contained in:
@@ -5,8 +5,10 @@ from typing import Union, List
|
||||
import torch
|
||||
from random import choice
|
||||
from torch.utils.data import ConcatDataset, Dataset
|
||||
import numpy as np
|
||||
|
||||
from lib.objects.map import Map
|
||||
import lib.variables as V
|
||||
from PIL import Image
|
||||
|
||||
|
||||
@@ -36,13 +38,10 @@ class TrajDataset(Dataset):
|
||||
|
||||
if self.mode.lower() == 'just_route':
|
||||
trajectory = self.map.get_random_trajectory()
|
||||
trajectory_space = trajectory.draw_in_array(self.map.shape)
|
||||
label = choice([0, 1])
|
||||
blank_trajectory_space = torch.zeros(self.map.shape)
|
||||
for index in trajectory.vertices:
|
||||
blank_trajectory_space[index] = 1
|
||||
|
||||
map_array = torch.as_tensor(self.map.as_array).float()
|
||||
return (map_array, blank_trajectory_space), label
|
||||
return (map_array, trajectory_space), label
|
||||
|
||||
while True:
|
||||
trajectory = self.map.get_random_trajectory()
|
||||
@@ -55,13 +54,13 @@ class TrajDataset(Dataset):
|
||||
|
||||
self.last_label = label
|
||||
if self.mode.lower() in ['all_in_map', 'separated_arrays']:
|
||||
map_array = torch.as_tensor(self.map.as_array).float()
|
||||
map_array = self.map.as_array
|
||||
trajectory = trajectory.draw_in_array(self.map_shape)
|
||||
alternative = alternative.draw_in_array(self.map_shape)
|
||||
if self.mode == 'separated_arrays':
|
||||
return (map_array, torch.as_tensor(trajectory.draw_in_array(self.map_shape)).float(), int(label)), \
|
||||
torch.as_tensor(alternative.draw_in_array(self.map_shape)).float()
|
||||
return (map_array, trajectory, label), alternative
|
||||
else:
|
||||
return torch.cat((map_array, torch.as_tensor(trajectory.draw_in_array(self.map_shape)).float(),
|
||||
torch.as_tensor(alternative.draw_in_array(self.map_shape)).float())), int(label)
|
||||
return np.concatenate((map_array, trajectory, alternative)), label
|
||||
|
||||
elif self.mode == 'vectors':
|
||||
return trajectory.vertices, alternative.vertices, label, self.mapname
|
||||
|
||||
Reference in New Issue
Block a user