ae_toolbox_torch/dataset.py
2021-02-01 09:59:56 +01:00

283 lines
10 KiB
Python

import argparse
import bisect
from collections import defaultdict
from distutils.util import strtobool
import os
import ast
from abc import ABC, abstractmethod
from tqdm import tqdm
import numpy as np
import torch
from torch.utils.data import Dataset, ConcatDataset
# Command line argument parsing
def build_parse_commands():
# Init the Command Line Arguments Parser
arg_parser = argparse.ArgumentParser(description='VAE and GSE Autoencoder with latent Space Clustering Approaches')
# Specify a pretrained weight file to load.
arg_parser.add_argument('--model_file', nargs='?', default='',
help='Specify a pretrained model file to load.')
# Specify a pretrained weight file to load.
arg_parser.add_argument('--files', nargs='?', default='',
help='Set the raw data location. Should be filled with maps and trajectories')
# Set a fixed prng seed.
arg_parser.add_argument('--seed', nargs='?', default=-999, help='Set a fixed prng seed.')
# DataSet parameters
arg_parser.add_argument('--size', nargs='?', default=9, help='Set a trajectory length; the number of isovists.')
arg_parser.add_argument('--step', nargs='?', default=5, help='Set a fixed stepsize between isovist centers.')
arg_parser.add_argument('--overlapping', nargs='?', default=True, help='Whether the Isovists should overlap.')
# Specify the Map to use in Training and visualization
arg_parser.add_argument('-p', '--print_on_map', default=False, type=strtobool,
help='Whether trajcetories should be colored and displayed on a map.')
arg_parser.add_argument('-l', '--print_latent', default=False, type=strtobool,
help='Whether latent encoding space should be colored and displayed.')
arg_parser.add_argument('-d', '--divided_latent_viz', default=False, type=strtobool,
help='Whether latent encoding space should be colored and displayed seperatein saae case.')
return arg_parser.parse_args()
class AbstractDataset(ConcatDataset, ABC):
@property
@abstractmethod
def raw_filenames(self):
raise NotImplementedError('Specify the file ending here')
@property
def raw_paths(self):
return [os.path.join(self.path, 'raw', x) for x in self.raw_filenames]
@property
def processed_filenames(self):
return [f'{x}_{self.__class__.__name__}.to' for x in self.maps]
@property
def processed_paths(self):
return [os.path.join(self.path, 'processed', x) for x in self.processed_filenames]
def __init__(self, path, refresh=False, **kwargs):
self.path = path
self.refresh = refresh
self.maps = list(set([x.name.split('_')[0] for x in os.scandir(os.path.join(self.path, 'raw'))]))
super(AbstractDataset, self).__init__(datasets=self._load_datasets())
def to(self, device):
self.datasets = [dataset.to(device) for dataset in self.datasets]
return self
@abstractmethod
def process(self, filepath):
raise NotImplementedError
def _load_datasets(self):
if self.refresh:
for filepath in self.processed_paths:
try:
os.remove(filepath)
print('Processed Location "Refreshed" (We deleted the Files)')
except FileNotFoundError:
print('You meant to refresh the allready processed dataset, but there were none...')
print('continue processing')
pass
datasets = []
# ToDo: Make this nicer
for map_idx, _ in tqdm(enumerate(self.maps),
total=len(self.maps), unit="files"
):
while True:
try:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
datasets.append(torch.load(self.processed_paths[map_idx], map_location=device))
break
except FileNotFoundError:
os.makedirs(os.path.join(*os.path.split(self.processed_paths[map_idx])[:-1]), exist_ok=True)
processed = self.process(self.raw_paths[map_idx])
tqdm.write(f'Dataset "{self.processed_paths[map_idx]}" processed')
torch.save(processed, self.processed_paths[map_idx])
continue
return datasets
class DataContainer(AbstractDataset):
@property
def raw_filenames(self):
return [f'{x}_trajec.csv' for x in self.maps]
def __init__(self, path, size, step, **kwargs):
self.size = size
self.step = step
super(DataContainer, self).__init__(path, **kwargs)
pass
def process(self, filepath):
dataDict = defaultdict(list)
total_lines = len(open(filepath, 'r').readlines())
with open(filepath, 'r') as f:
delimiter = ','
# Separate the header
headers = f.readline().rstrip().split(delimiter)
headers.remove('inDoor')
# Iterate over every line and convert it to float / value
# ToDo: Make this nicer
for line in tqdm(f, total=total_lines, unit=" lines", mininterval=1, miniters=1000):
if line == '':
continue
else:
for attr, x in zip(headers, line.rstrip().split(delimiter)[None:None]):
if attr not in ['inDoor']:
dataDict[attr].append(ast.literal_eval(x))
return Trajectories(self.size, self.step, headers, **dataDict, normalize=True)
def get_both_by_key(self, item):
if item < 0:
if -item > len(self):
raise ValueError("absolute value of index should not exceed dataset length")
item = len(self) + item
dataset_idx = bisect.bisect_right(self.cumulative_sizes, item)
if dataset_idx == 0:
sample_idx = item
else:
sample_idx = item - self.cumulative_sizes[dataset_idx - 1]
return self.datasets[dataset_idx].get_both_by_key(sample_idx)
class Trajectories(Dataset):
# As in "To take hold of isovists and isovist fields" - M. L. Benedikt, read only measures specified by Benedikt
@property
def isovistMeasures(self):
return ['X', 'Z', 'realSurfacePerimeter', 'occlusionValue', 'area', 'variance', 'skewness', 'circularity_ben']
@property
def features(self):
return len(self.isovistMeasures)
def __init__(self, size, step, headers, normalize=True, **kwargs):
super(Trajectories, self).__init__()
self.size: int = size
self.step: int = step
self.headers: list = headers
self.normalize: bool = normalize
self.data = self.__init_data_(**kwargs)
pass
def __init_data_(self, **kwargs: dict):
dataDict = dict()
for key, val in kwargs.items():
if key in self.isovistMeasures:
dataDict[key] = torch.tensor(val, requires_grad=False)
# Check if all keys are of same length
assert len(set(x.size()[0] for x in dataDict.values() if torch.is_tensor(x))) <= 1
data = torch.stack([dataDict[key] for key in self.isovistMeasures], dim=-1)
if self.normalize:
# All but x,y
std, mean = torch.std_mean(data[:, 2:], dim=0)
data[:, 2:] = (data[:, 2:] - mean) / std
return data
def __iter__(self):
# FixMe: is that correct?
for i in range(len(self)):
yield self[i]
def __getitem__(self, item):
assert isinstance(item, int), f"Item-Key has to be Integer, but was {type(item)}"
x = self.data[item:item + self.size * self.step or None:self.step][:, 2:]
futureItem = item + 1
y = self.data[futureItem:futureItem + self.size * self.step or None:self.step][:, 2:]
return x, y
def get_isovist_measures_by_key(self, item):
return self[item][0]
def get_coordinates_by_key(self, item):
return self.data[item:item + self.size * self.step or None:self.step][:, :2]
def get_both_by_key(self, item):
data = self.data[item:item + self.size * self.step or None:self.step]
return data
def __len__(self):
total_len = self.data.size()[0]
return total_len - (self.size * self.step - (self.step - 1))
def to(self, device):
self.data = self.data.to(device)
return self
class MapContainer(AbstractDataset):
@property
def raw_filenames(self):
return [f'{x}_map.csv' for x in self.maps]
def __init__(self, path, **kwargs):
super(MapContainer, self).__init__(path, **kwargs)
pass
def process(self, filepath):
dataDict = defaultdict(list)
with open(filepath, 'r') as f:
delimiter = ','
# Separate the header
headers = f.readline().rstrip().split(delimiter)
# Iterate over every line and convert it to float / value
# ToDo: Make this nicer
for line in tqdm(f):
if line == '':
continue
else:
for attr, x in zip(headers, line.rstrip().split(delimiter)[None:None]):
dataDict[attr].append(ast.literal_eval(x))
return Map(np.asarray([dataDict[head] for head in headers]),
name=os.path.splitext(os.path.basename(filepath))[0]
)
class Map(object):
def __init__(self, mapData: np.ndarray, name='MapName'):
"""
This is a Container Class for triangulated basemaps in csv format.
:param mapData: The map as np.ndarray, already read from disk.
"""
self.map: np.ndarray = np.transpose(mapData)
self.name = name
self.minx, self.maxx = np.min(self.map[[0, 2, 4]]), np.max(self.map[[0, 2, 4]])
self.miny, self.maxy = np.min(self.map[[1, 3, 5]]), np.max(self.map[[1, 3, 5]])
print('BaseMap Initialized')
def __len__(self):
return self.map.shape[0]
def vertices(self):
vertices = self.map.reshape((-1, 2, 3))
return vertices
def __getitem__(self, item):
return self.map[item].reshape(3, 2)
if __name__ == '__main__':
args = build_parse_commands()
if args.seed != -999:
np.random.seed(args.seed)
torch.manual_seed(args.seed)
# d = DataContainer(args.files, args.size, args.step)
m = MapContainer(args.files, refresh=True)
print(len(m[1]))