Variational Generator

This commit is contained in:
Si11ium
2020-03-10 16:59:51 +01:00
parent 21e7e31805
commit 1b5a7dc69e
10 changed files with 177 additions and 95 deletions

View File

@@ -5,8 +5,10 @@ from typing import Union, List
import torch
from random import choice
from torch.utils.data import ConcatDataset, Dataset
import numpy as np
from lib.objects.map import Map
import lib.variables as V
from PIL import Image
@@ -36,13 +38,10 @@ class TrajDataset(Dataset):
if self.mode.lower() == 'just_route':
trajectory = self.map.get_random_trajectory()
trajectory_space = trajectory.draw_in_array(self.map.shape)
label = choice([0, 1])
blank_trajectory_space = torch.zeros(self.map.shape)
for index in trajectory.vertices:
blank_trajectory_space[index] = 1
map_array = torch.as_tensor(self.map.as_array).float()
return (map_array, blank_trajectory_space), label
return (map_array, trajectory_space), label
while True:
trajectory = self.map.get_random_trajectory()
@@ -55,13 +54,13 @@ class TrajDataset(Dataset):
self.last_label = label
if self.mode.lower() in ['all_in_map', 'separated_arrays']:
map_array = torch.as_tensor(self.map.as_array).float()
map_array = self.map.as_array
trajectory = trajectory.draw_in_array(self.map_shape)
alternative = alternative.draw_in_array(self.map_shape)
if self.mode == 'separated_arrays':
return (map_array, torch.as_tensor(trajectory.draw_in_array(self.map_shape)).float(), int(label)), \
torch.as_tensor(alternative.draw_in_array(self.map_shape)).float()
return (map_array, trajectory, label), alternative
else:
return torch.cat((map_array, torch.as_tensor(trajectory.draw_in_array(self.map_shape)).float(),
torch.as_tensor(alternative.draw_in_array(self.map_shape)).float())), int(label)
return np.concatenate((map_array, trajectory, alternative)), label
elif self.mode == 'vectors':
return trajectory.vertices, alternative.vertices, label, self.mapname