Lightning integration basic ae, dataloaders and dataset
This commit is contained in:
3
.idea/ae_toolbox_torch.iml
generated
3
.idea/ae_toolbox_torch.iml
generated
@ -5,6 +5,9 @@
|
|||||||
<orderEntry type="inheritedJdk" />
|
<orderEntry type="inheritedJdk" />
|
||||||
<orderEntry type="sourceFolder" forTests="false" />
|
<orderEntry type="sourceFolder" forTests="false" />
|
||||||
</component>
|
</component>
|
||||||
|
<component name="PyDocumentationSettings">
|
||||||
|
<option name="renderExternalDocumentation" value="true" />
|
||||||
|
</component>
|
||||||
<component name="TestRunnerService">
|
<component name="TestRunnerService">
|
||||||
<option name="PROJECT_TEST_RUNNER" value="Unittests" />
|
<option name="PROJECT_TEST_RUNNER" value="Unittests" />
|
||||||
</component>
|
</component>
|
||||||
|
9
.idea/dictionaries/illium.xml
generated
Normal file
9
.idea/dictionaries/illium.xml
generated
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
<component name="ProjectDictionaryState">
|
||||||
|
<dictionary name="illium">
|
||||||
|
<words>
|
||||||
|
<w>dataloader</w>
|
||||||
|
<w>datasets</w>
|
||||||
|
<w>isovists</w>
|
||||||
|
</words>
|
||||||
|
</dictionary>
|
||||||
|
</component>
|
7
.idea/other.xml
generated
Normal file
7
.idea/other.xml
generated
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="PySciProjectComponent">
|
||||||
|
<option name="PY_SCI_VIEW" value="true" />
|
||||||
|
<option name="PY_SCI_VIEW_SUGGESTED" value="true" />
|
||||||
|
</component>
|
||||||
|
</project>
|
6
.idea/vcs.xml
generated
Normal file
6
.idea/vcs.xml
generated
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="VcsDirectoryMappings">
|
||||||
|
<mapping directory="$PROJECT_DIR$" vcs="Git" />
|
||||||
|
</component>
|
||||||
|
</project>
|
164
.idea/workspace.xml
generated
164
.idea/workspace.xml
generated
@ -1,9 +1,171 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
<project version="4">
|
<project version="4">
|
||||||
|
<component name="ChangeListManager">
|
||||||
|
<list default="true" id="5955480a-c876-43d5-afd7-8717f51f413e" name="Default Changelist" comment="">
|
||||||
|
<change afterPath="$PROJECT_DIR$/.idea/dictionaries/illium.xml" afterDir="false" />
|
||||||
|
<change afterPath="$PROJECT_DIR$/.idea/other.xml" afterDir="false" />
|
||||||
|
<change afterPath="$PROJECT_DIR$/.idea/vcs.xml" afterDir="false" />
|
||||||
|
<change afterPath="$PROJECT_DIR$/basic_ae_lightning_torch.py" afterDir="false" />
|
||||||
|
<change afterPath="$PROJECT_DIR$/data/dataset.py" afterDir="false" />
|
||||||
|
<change afterPath="$PROJECT_DIR$/networks/basic_ae.py" afterDir="false" />
|
||||||
|
<change afterPath="$PROJECT_DIR$/networks/modules.py" afterDir="false" />
|
||||||
|
<change beforePath="$PROJECT_DIR$/.idea/ae_toolbox_torch.iml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/ae_toolbox_torch.iml" afterDir="false" />
|
||||||
|
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
|
||||||
|
</list>
|
||||||
|
<option name="EXCLUDED_CONVERTED_TO_IGNORED" value="true" />
|
||||||
|
<option name="SHOW_DIALOG" value="false" />
|
||||||
|
<option name="HIGHLIGHT_CONFLICTS" value="true" />
|
||||||
|
<option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
|
||||||
|
<option name="LAST_RESOLUTION" value="IGNORE" />
|
||||||
|
</component>
|
||||||
|
<component name="FileTemplateManagerImpl">
|
||||||
|
<option name="RECENT_TEMPLATES">
|
||||||
|
<list>
|
||||||
|
<option value="Python Script" />
|
||||||
|
</list>
|
||||||
|
</option>
|
||||||
|
</component>
|
||||||
|
<component name="Git.Settings">
|
||||||
|
<option name="RECENT_GIT_ROOT_PATH" value="$PROJECT_DIR$" />
|
||||||
|
</component>
|
||||||
<component name="ProjectId" id="1Omeu5sz43kySmz8qHfWIcA2dn0" />
|
<component name="ProjectId" id="1Omeu5sz43kySmz8qHfWIcA2dn0" />
|
||||||
|
<component name="ProjectLevelVcsManager" settingsEditedManually="true" />
|
||||||
<component name="PropertiesComponent">
|
<component name="PropertiesComponent">
|
||||||
|
<property name="ASKED_ADD_EXTERNAL_FILES" value="true" />
|
||||||
|
<property name="SHARE_PROJECT_CONFIGURATION_FILES" value="true" />
|
||||||
<property name="WebServerToolWindowFactoryState" value="false" />
|
<property name="WebServerToolWindowFactoryState" value="false" />
|
||||||
<property name="nodejs_interpreter_path.stuck_in_default_project" value="undefined stuck path" />
|
<property name="nodejs_interpreter_path.stuck_in_default_project" value="undefined stuck path" />
|
||||||
<property name="settings.editor.selected.configurable" value="com.jetbrains.python.configuration.PyActiveSdkModuleConfigurable" />
|
<property name="settings.editor.selected.configurable" value="pyconsole" />
|
||||||
|
</component>
|
||||||
|
<component name="PyConsoleOptionsProvider">
|
||||||
|
<option name="myPythonConsoleState">
|
||||||
|
<console-settings module-name="ae_toolbox_torch" is-module-sdk="true">
|
||||||
|
<option name="myUseModuleSdk" value="true" />
|
||||||
|
<option name="myModuleName" value="ae_toolbox_torch" />
|
||||||
|
</console-settings>
|
||||||
|
</option>
|
||||||
|
<option name="myShowDebugConsoleByDefault" value="true" />
|
||||||
|
</component>
|
||||||
|
<component name="RunDashboard">
|
||||||
|
<option name="ruleStates">
|
||||||
|
<list>
|
||||||
|
<RuleState>
|
||||||
|
<option name="name" value="ConfigurationTypeDashboardGroupingRule" />
|
||||||
|
</RuleState>
|
||||||
|
<RuleState>
|
||||||
|
<option name="name" value="StatusDashboardGroupingRule" />
|
||||||
|
</RuleState>
|
||||||
|
</list>
|
||||||
|
</option>
|
||||||
|
</component>
|
||||||
|
<component name="RunManager" selected="Python.basic_ae_lightning_torch">
|
||||||
|
<configuration default="true" type="PythonConfigurationType" factoryName="Python">
|
||||||
|
<module name="ae_toolbox_torch" />
|
||||||
|
<option name="INTERPRETER_OPTIONS" value="" />
|
||||||
|
<option name="PARENT_ENVS" value="true" />
|
||||||
|
<envs>
|
||||||
|
<env name="PYTHONUNBUFFERED" value="1" />
|
||||||
|
</envs>
|
||||||
|
<option name="SDK_HOME" value="" />
|
||||||
|
<option name="WORKING_DIRECTORY" value="" />
|
||||||
|
<option name="IS_MODULE_SDK" value="false" />
|
||||||
|
<option name="ADD_CONTENT_ROOTS" value="true" />
|
||||||
|
<option name="ADD_SOURCE_ROOTS" value="true" />
|
||||||
|
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
|
||||||
|
<option name="SCRIPT_NAME" value="" />
|
||||||
|
<option name="PARAMETERS" value="" />
|
||||||
|
<option name="SHOW_COMMAND_LINE" value="true" />
|
||||||
|
<option name="EMULATE_TERMINAL" value="false" />
|
||||||
|
<option name="MODULE_MODE" value="false" />
|
||||||
|
<option name="REDIRECT_INPUT" value="false" />
|
||||||
|
<option name="INPUT_FILE" value="" />
|
||||||
|
<method v="2" />
|
||||||
|
</configuration>
|
||||||
|
<configuration name="basic_ae_lightning_torch" type="PythonConfigurationType" factoryName="Python" temporary="true">
|
||||||
|
<module name="ae_toolbox_torch" />
|
||||||
|
<option name="INTERPRETER_OPTIONS" value="" />
|
||||||
|
<option name="PARENT_ENVS" value="true" />
|
||||||
|
<envs>
|
||||||
|
<env name="PYTHONUNBUFFERED" value="1" />
|
||||||
|
</envs>
|
||||||
|
<option name="SDK_HOME" value="" />
|
||||||
|
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" />
|
||||||
|
<option name="IS_MODULE_SDK" value="true" />
|
||||||
|
<option name="ADD_CONTENT_ROOTS" value="true" />
|
||||||
|
<option name="ADD_SOURCE_ROOTS" value="true" />
|
||||||
|
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
|
||||||
|
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/basic_ae_lightning_torch.py" />
|
||||||
|
<option name="PARAMETERS" value="" />
|
||||||
|
<option name="SHOW_COMMAND_LINE" value="true" />
|
||||||
|
<option name="EMULATE_TERMINAL" value="false" />
|
||||||
|
<option name="MODULE_MODE" value="false" />
|
||||||
|
<option name="REDIRECT_INPUT" value="false" />
|
||||||
|
<option name="INPUT_FILE" value="" />
|
||||||
|
<method v="2" />
|
||||||
|
</configuration>
|
||||||
|
<configuration name="dataset" type="PythonConfigurationType" factoryName="Python" temporary="true">
|
||||||
|
<module name="ae_toolbox_torch" />
|
||||||
|
<option name="INTERPRETER_OPTIONS" value="" />
|
||||||
|
<option name="PARENT_ENVS" value="true" />
|
||||||
|
<envs>
|
||||||
|
<env name="PYTHONUNBUFFERED" value="1" />
|
||||||
|
</envs>
|
||||||
|
<option name="SDK_HOME" value="" />
|
||||||
|
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$/data" />
|
||||||
|
<option name="IS_MODULE_SDK" value="true" />
|
||||||
|
<option name="ADD_CONTENT_ROOTS" value="true" />
|
||||||
|
<option name="ADD_SOURCE_ROOTS" value="true" />
|
||||||
|
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
|
||||||
|
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/data/dataset.py" />
|
||||||
|
<option name="PARAMETERS" value="" />
|
||||||
|
<option name="SHOW_COMMAND_LINE" value="true" />
|
||||||
|
<option name="EMULATE_TERMINAL" value="false" />
|
||||||
|
<option name="MODULE_MODE" value="false" />
|
||||||
|
<option name="REDIRECT_INPUT" value="false" />
|
||||||
|
<option name="INPUT_FILE" value="" />
|
||||||
|
<method v="2" />
|
||||||
|
</configuration>
|
||||||
|
<recent_temporary>
|
||||||
|
<list>
|
||||||
|
<item itemvalue="Python.basic_ae_lightning_torch" />
|
||||||
|
<item itemvalue="Python.dataset" />
|
||||||
|
</list>
|
||||||
|
</recent_temporary>
|
||||||
|
</component>
|
||||||
|
<component name="SvnConfiguration">
|
||||||
|
<configuration />
|
||||||
|
</component>
|
||||||
|
<component name="TaskManager">
|
||||||
|
<task active="true" id="Default" summary="Default task">
|
||||||
|
<changelist id="5955480a-c876-43d5-afd7-8717f51f413e" name="Default Changelist" comment="" />
|
||||||
|
<created>1564587418949</created>
|
||||||
|
<option name="number" value="Default" />
|
||||||
|
<option name="presentableId" value="Default" />
|
||||||
|
<updated>1564587418949</updated>
|
||||||
|
<workItem from="1564587420277" duration="6891000" />
|
||||||
|
<workItem from="1565364574595" duration="1092000" />
|
||||||
|
<workItem from="1565592214301" duration="53660000" />
|
||||||
|
</task>
|
||||||
|
<servers />
|
||||||
|
</component>
|
||||||
|
<component name="TypeScriptGeneratedFilesManager">
|
||||||
|
<option name="version" value="1" />
|
||||||
|
</component>
|
||||||
|
<component name="Vcs.Log.Tabs.Properties">
|
||||||
|
<option name="TAB_STATES">
|
||||||
|
<map>
|
||||||
|
<entry key="MAIN">
|
||||||
|
<value>
|
||||||
|
<State>
|
||||||
|
<option name="COLUMN_ORDER" />
|
||||||
|
</State>
|
||||||
|
</value>
|
||||||
|
</entry>
|
||||||
|
</map>
|
||||||
|
</option>
|
||||||
|
</component>
|
||||||
|
<component name="com.intellij.coverage.CoverageDataManagerImpl">
|
||||||
|
<SUITE FILE_PATH="coverage/ae_toolbox_torch$basic_ae_lightning_torch.coverage" NAME="basic_ae_lightning_torch Coverage Results" MODIFIED="1565790288699" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
|
||||||
|
<SUITE FILE_PATH="coverage/ae_toolbox_torch$dataset.coverage" NAME="dataset Coverage Results" MODIFIED="1565772669750" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/data" />
|
||||||
</component>
|
</component>
|
||||||
</project>
|
</project>
|
42
basic_ae_lightning_torch.py
Normal file
42
basic_ae_lightning_torch.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
from networks.basic_ae import BasicAE
|
||||||
|
from networks.modules import LightningModule
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
from torch.nn.functional import mse_loss
|
||||||
|
from torch.optim import Adam
|
||||||
|
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from data.dataset import DataContainer
|
||||||
|
|
||||||
|
from pytorch_lightning import Trainer
|
||||||
|
|
||||||
|
|
||||||
|
class AEModel(LightningModule):
|
||||||
|
|
||||||
|
def __init__(self, dataParams: dict):
|
||||||
|
super(AEModel, self).__init__()
|
||||||
|
self.dataParams = dataParams
|
||||||
|
self.network = BasicAE(self.dataParams)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.network.forward(x)
|
||||||
|
|
||||||
|
def training_step(self, x, batch_nb):
|
||||||
|
z, x_hat = self.forward(x)
|
||||||
|
return {'loss': mse_loss(x, x_hat)}
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
# ToDo: Where do i get the Paramers from?
|
||||||
|
return [Adam(self.parameters(), lr=0.02)]
|
||||||
|
|
||||||
|
@pl.data_loader
|
||||||
|
def tng_dataloader(self):
|
||||||
|
return DataLoader(DataContainer('data', **self.dataParams), shuffle=True, batch_size=100)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
ae = AEModel(
|
||||||
|
dict(refresh=False, size=5, step=5, features=6)
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer = Trainer()
|
||||||
|
trainer.fit(ae)
|
252
data/dataset.py
Normal file
252
data/dataset.py
Normal file
@ -0,0 +1,252 @@
|
|||||||
|
import argparse
|
||||||
|
from collections import defaultdict
|
||||||
|
from distutils.util import strtobool
|
||||||
|
import os
|
||||||
|
import ast
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset, ConcatDataset
|
||||||
|
|
||||||
|
|
||||||
|
# Command line argument parsing
|
||||||
|
def build_parse_commands():
|
||||||
|
# Init the Command Line Arguments Parser
|
||||||
|
arg_parser = argparse.ArgumentParser(description='VAE and GSE Autoencoder with latent Space Clustering Approaches')
|
||||||
|
|
||||||
|
# Specify a pretrained weight file to load.
|
||||||
|
arg_parser.add_argument('--model_file', nargs='?', default='',
|
||||||
|
help='Specify a pretrained model file to load.')
|
||||||
|
# Specify a pretrained weight file to load.
|
||||||
|
arg_parser.add_argument('--files', nargs='?', default='',
|
||||||
|
help='Set the raw data location. Should be filled with maps and trajectories')
|
||||||
|
|
||||||
|
# Set a fixed prng seed.
|
||||||
|
arg_parser.add_argument('--seed', nargs='?', default=-999, help='Set a fixed prng seed.')
|
||||||
|
|
||||||
|
# DataSet parameters
|
||||||
|
arg_parser.add_argument('--size', nargs='?', default=9, help='Set a trajectory length; the number of isovists.')
|
||||||
|
arg_parser.add_argument('--step', nargs='?', default=5, help='Set a fixed stepsize between isovist centers.')
|
||||||
|
arg_parser.add_argument('--overlapping', nargs='?', default=True, help='Whether the Isovists should overlap.')
|
||||||
|
|
||||||
|
# Specify the Map to use in Training and visualization
|
||||||
|
arg_parser.add_argument('-p', '--print_on_map', default=False, type=strtobool,
|
||||||
|
help='Whether trajcetories should be colored and displayed on a map.')
|
||||||
|
arg_parser.add_argument('-l', '--print_latent', default=False, type=strtobool,
|
||||||
|
help='Whether latent encoding space should be colored and displayed.')
|
||||||
|
arg_parser.add_argument('-d', '--divided_latent_viz', default=False, type=strtobool,
|
||||||
|
help='Whether latent encoding space should be colored and displayed seperatein saae case.')
|
||||||
|
return arg_parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
class AbstractDataset(ConcatDataset, ABC):
|
||||||
|
|
||||||
|
# maps = ['hotel', 'tum','gallery', 'queens', 'oet']
|
||||||
|
@property
|
||||||
|
def maps(self):
|
||||||
|
return ['hotel', 'tum','gallery', 'queens', 'oet']
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def raw_filenames(self):
|
||||||
|
raise NotImplementedError('Specify the file ending here')
|
||||||
|
|
||||||
|
@property
|
||||||
|
def raw_paths(self):
|
||||||
|
return [os.path.join(self.path, 'raw', x) for x in self.raw_filenames]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def processed_filenames(self):
|
||||||
|
return [f'{x}_{self.__class__.__name__}.to' for x in self.maps]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def processed_paths(self):
|
||||||
|
return [os.path.join(self.path, 'processed', x) for x in self.processed_filenames]
|
||||||
|
|
||||||
|
def __init__(self, path, refresh=False, **kwargs):
|
||||||
|
self.path = path
|
||||||
|
self.refresh = refresh
|
||||||
|
super(AbstractDataset, self).__init__(datasets=self._load_datasets())
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def process(self, filepath):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _load_datasets(self):
|
||||||
|
if self.refresh:
|
||||||
|
for filepath in self.processed_paths:
|
||||||
|
try:
|
||||||
|
os.remove(filepath)
|
||||||
|
print('Processed Location "Refreshed" (We deleted the Files)')
|
||||||
|
except FileNotFoundError:
|
||||||
|
print('You meant to refresh the allready processed dataset, but there were none...')
|
||||||
|
print('continue processing')
|
||||||
|
pass
|
||||||
|
datasets = []
|
||||||
|
# ToDo: Make this nicer
|
||||||
|
for map_idx, _ in tqdm(enumerate(self.maps),
|
||||||
|
total=len(self.maps), unit="files"
|
||||||
|
):
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
datasets.append(torch.load(self.processed_paths[map_idx]))
|
||||||
|
print(f'Dataset "{self.processed_paths[map_idx]}" loaded')
|
||||||
|
break
|
||||||
|
except FileNotFoundError:
|
||||||
|
os.makedirs(os.path.join(*os.path.split(self.processed_paths[map_idx])[:-1]), exist_ok=True)
|
||||||
|
processed = self.process(self.raw_paths[map_idx])
|
||||||
|
torch.save(processed, self.processed_paths[map_idx])
|
||||||
|
continue
|
||||||
|
return datasets
|
||||||
|
|
||||||
|
|
||||||
|
class DataContainer(AbstractDataset):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def calculate_model_shapes(size, step, **kwargs):
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
@property
|
||||||
|
def raw_filenames(self):
|
||||||
|
return [f'{x}_trajec.csv' for x in self.maps]
|
||||||
|
|
||||||
|
def __init__(self, path, size, step, **kwargs):
|
||||||
|
self.size = size
|
||||||
|
self.step = step
|
||||||
|
super(DataContainer, self).__init__(path, **kwargs)
|
||||||
|
pass
|
||||||
|
|
||||||
|
def process(self, filepath):
|
||||||
|
dataDict = defaultdict(list)
|
||||||
|
with open(filepath, 'r') as f:
|
||||||
|
delimiter = ','
|
||||||
|
# Separate the header
|
||||||
|
headers = f.readline().rstrip().split(delimiter)
|
||||||
|
headers.remove('inDoor')
|
||||||
|
# Iterate over every line and convert it to float / value
|
||||||
|
# ToDo: Make this nicer
|
||||||
|
for line in tqdm(f, total=len(self.maps), unit="lines"):
|
||||||
|
if line == '':
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
for attr, x in zip(headers, line.rstrip().split(delimiter)[None:None]):
|
||||||
|
if attr not in ['inDoor']:
|
||||||
|
dataDict[attr].append(ast.literal_eval(x))
|
||||||
|
return Trajectories(self.size, self.step, headers, **dataDict)
|
||||||
|
|
||||||
|
|
||||||
|
class Trajectories(Dataset):
|
||||||
|
|
||||||
|
# As in "To take hold of isovists and isovist fields" - M. L. Benedikt, read only measures specified by Benedikt
|
||||||
|
@property
|
||||||
|
def isovistMeasures(self):
|
||||||
|
return ['X', 'Z', 'realSurfacePerimeter', 'occlusionValue', 'area', 'variance', 'skewness', 'circularity_ben']
|
||||||
|
|
||||||
|
@property
|
||||||
|
def features(self):
|
||||||
|
return len(self.isovistMeasures)
|
||||||
|
|
||||||
|
def __init__(self, size, step, headers, **kwargs):
|
||||||
|
super(Trajectories, self).__init__()
|
||||||
|
self.size: int = size
|
||||||
|
self.step: int = step
|
||||||
|
self.headers: list = headers
|
||||||
|
|
||||||
|
dataDict = dict()
|
||||||
|
for key, val in kwargs.items():
|
||||||
|
if key in self.isovistMeasures:
|
||||||
|
dataDict[key] = torch.tensor(val)
|
||||||
|
# Check if all keys are of same length
|
||||||
|
assert len(set(x.size()[0] for x in dataDict.values() if torch.is_tensor(x))) <= 1
|
||||||
|
self.data = torch.stack([dataDict[key] for key in self.isovistMeasures], dim=-1)
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __iter_tenors__(self):
|
||||||
|
return
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
for i in range(len(self)):
|
||||||
|
yield self[i]
|
||||||
|
|
||||||
|
def __getitem__(self, item, coords=False):
|
||||||
|
"""
|
||||||
|
Return a trajectory sample from the dataset by a specific key.
|
||||||
|
:param item: The index number of the trajectory to return.
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
subList = self.data[item:item + self.size * self.step or None:self.step]
|
||||||
|
xy, tensor = subList[:, 2], subList[:, 2:]
|
||||||
|
return (xy, tensor) if coords else tensor
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
total_len = self.data.size()[0]
|
||||||
|
return total_len - (self.size * self.step - (self.step - 1))
|
||||||
|
|
||||||
|
|
||||||
|
class MapContainer(AbstractDataset):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def raw_filenames(self):
|
||||||
|
return [f'{x}_map.csv' for x in self.maps]
|
||||||
|
|
||||||
|
def __init__(self, path, **kwargs):
|
||||||
|
super(MapContainer, self).__init__(path, **kwargs)
|
||||||
|
pass
|
||||||
|
|
||||||
|
def process(self, filepath):
|
||||||
|
dataDict = defaultdict(list)
|
||||||
|
with open(filepath, 'r') as f:
|
||||||
|
delimiter = ','
|
||||||
|
# Separate the header
|
||||||
|
headers = f.readline().rstrip().split(delimiter)
|
||||||
|
# Iterate over every line and convert it to float / value
|
||||||
|
# ToDo: Make this nicer
|
||||||
|
for line in tqdm(f):
|
||||||
|
if line == '':
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
for attr, x in zip(headers, line.rstrip().split(delimiter)[None:None]):
|
||||||
|
dataDict[attr].append(ast.literal_eval(x))
|
||||||
|
|
||||||
|
return Map(np.asarray([dataDict[head] for head in headers]))
|
||||||
|
|
||||||
|
|
||||||
|
class Map(object):
|
||||||
|
|
||||||
|
def __init__(self, mapData: np.ndarray):
|
||||||
|
"""
|
||||||
|
This is a Container Class for triangulated basemaps in csv format.
|
||||||
|
:param mapData: The map as np.ndarray, already read from disk.
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.map: np.ndarray = mapData
|
||||||
|
|
||||||
|
self.minx, self.maxx = np.min(self.map[[0, 2, 4]]), np.max(self.map[[0, 2, 4]])
|
||||||
|
self.miny, self.maxy = np.min(self.map[[1, 3, 5]]), np.max(self.map[[1, 3, 5]])
|
||||||
|
|
||||||
|
print('BaseMap Initialized')
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.map.shape[0]
|
||||||
|
|
||||||
|
def vertices(self):
|
||||||
|
vertices = self.map.reshape((-1, 2, 3))
|
||||||
|
return vertices
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
return self.map[item].reshape(3, 2)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
args = build_parse_commands()
|
||||||
|
if args.seed != -999:
|
||||||
|
np.random.seed(args.seed)
|
||||||
|
torch.manual_seed(args.seed)
|
||||||
|
|
||||||
|
# d = DataContainer(args.files, args.size, args.step)
|
||||||
|
m = MapContainer(args.files, refresh=True)
|
||||||
|
print(len(m[1]))
|
52
networks/basic_ae.py
Normal file
52
networks/basic_ae.py
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
from torch.nn import Sequential, Linear, GRU
|
||||||
|
from data.dataset import DataContainer
|
||||||
|
|
||||||
|
from .modules import *
|
||||||
|
|
||||||
|
|
||||||
|
#######################
|
||||||
|
# Basic AE-Implementation
|
||||||
|
class BasicAE(Module, ABC):
|
||||||
|
|
||||||
|
def __init__(self, dataParams, **kwargs):
|
||||||
|
super(BasicAE, self).__init__()
|
||||||
|
self.dataParams = dataParams
|
||||||
|
self.latent_dim = kwargs.get('latent_dim', 2)
|
||||||
|
self.encoder = self._build_encoder()
|
||||||
|
self.decoder = self._build_decoder()
|
||||||
|
|
||||||
|
|
||||||
|
def _build_encoder(self):
|
||||||
|
encoder = Sequential()
|
||||||
|
encoder.add_module(f'EncoderLinear_{1}', Linear(6, 10, bias=True))
|
||||||
|
encoder.add_module(f'EncoderLinear_{2}', Linear(10, 10, bias=True))
|
||||||
|
gru = Sequential()
|
||||||
|
gru.add_module('Encoder', TimeDistributed(encoder))
|
||||||
|
gru.add_module('GRU', GRU(10, self.latent_dim))
|
||||||
|
return gru
|
||||||
|
|
||||||
|
def _build_decoder(self):
|
||||||
|
decoder = Sequential()
|
||||||
|
decoder.add_module(f'DecoderLinear_{1}', Linear(10, 10, bias=True))
|
||||||
|
decoder.add_module(f'DecoderLinear_{2}', Linear(10, self.dataParams['features'], bias=True))
|
||||||
|
|
||||||
|
gru = Sequential()
|
||||||
|
# There needs to be ab propper bat
|
||||||
|
gru.add_module('Repeater', Repeater((1, self.dataParams['size'], -1)))
|
||||||
|
gru.add_module('GRU', GRU(self.latent_dim, 10))
|
||||||
|
gru.add_module('GRU Filter', RNNOutputFilter())
|
||||||
|
gru.add_module('Decoder', TimeDistributed(decoder))
|
||||||
|
return gru
|
||||||
|
|
||||||
|
def forward(self, batch):
|
||||||
|
batch_size = batch.shape[0]
|
||||||
|
self.decoder.Repeater.shape = (batch_size, ) + self.decoder.Repeater.shape[-2:]
|
||||||
|
# outputs, hidden (Batch, Timesteps aka. Size, Features / Latent Dim Size)
|
||||||
|
outputs, _ = self.encoder(batch)
|
||||||
|
z = outputs[:, -1]
|
||||||
|
x_hat = self.decoder(z)
|
||||||
|
return z, x_hat
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
raise PermissionError('Get out of here - never run this module')
|
103
networks/modules.py
Normal file
103
networks/modules.py
Normal file
@ -0,0 +1,103 @@
|
|||||||
|
import torch
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
from torch.nn import Module
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
|
||||||
|
######################
|
||||||
|
# Abstract Network class following the Lightning Syntax
|
||||||
|
class LightningModule(pl.LightningModule, ABC):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(LightningModule, self).__init__()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def forward(self, x):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def training_step(self, batch, batch_nb):
|
||||||
|
# REQUIRED
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def validation_step(self, batch, batch_nb):
|
||||||
|
# OPTIONAL
|
||||||
|
pass
|
||||||
|
|
||||||
|
def validation_end(self, outputs):
|
||||||
|
# OPTIONAL
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def configure_optimizers(self):
|
||||||
|
# REQUIRED
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@pl.data_loader
|
||||||
|
def tng_dataloader(self):
|
||||||
|
# REQUIRED
|
||||||
|
raise NotImplementedError
|
||||||
|
# return DataLoader(MNIST(os.getcwd(), train=True, download=True,
|
||||||
|
# transform=transforms.ToTensor()), batch_size=32)
|
||||||
|
|
||||||
|
@pl.data_loader
|
||||||
|
def val_dataloader(self):
|
||||||
|
# OPTIONAL
|
||||||
|
pass
|
||||||
|
|
||||||
|
@pl.data_loader
|
||||||
|
def test_dataloader(self):
|
||||||
|
# OPTIONAL
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
#######################
|
||||||
|
# Utility Modules
|
||||||
|
class TimeDistributed(Module):
|
||||||
|
def __init__(self, module, batch_first=True):
|
||||||
|
super(TimeDistributed, self).__init__()
|
||||||
|
self.module = module
|
||||||
|
self.batch_first = batch_first
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
|
||||||
|
if len(x.size()) <= 2:
|
||||||
|
return self.module(x)
|
||||||
|
|
||||||
|
# Squash samples and timesteps into a single axis
|
||||||
|
x_reshape = x.contiguous().view(-1, x.size(-1)) # (samples * timesteps, input_size)
|
||||||
|
|
||||||
|
y = self.module(x_reshape)
|
||||||
|
|
||||||
|
# We have to reshape Y
|
||||||
|
if self.batch_first:
|
||||||
|
y = y.contiguous().view(x.size(0), -1, y.size(-1)) # (samples, timesteps, output_size)
|
||||||
|
else:
|
||||||
|
y = y.view(-1, x.size(1), y.size(-1)) # (timesteps, samples, output_size)
|
||||||
|
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
class Repeater(Module):
|
||||||
|
def __init__(self, shape):
|
||||||
|
super(Repeater, self).__init__()
|
||||||
|
self.shape = shape
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
x.unsqueeze_(-2)
|
||||||
|
return x.expand(self.shape)
|
||||||
|
|
||||||
|
class RNNOutputFilter(Module):
|
||||||
|
|
||||||
|
def __init__(self, return_output=True):
|
||||||
|
super(RNNOutputFilter, self).__init__()
|
||||||
|
self.return_output = return_output
|
||||||
|
|
||||||
|
def forward(self, x: tuple):
|
||||||
|
outputs, hidden = x
|
||||||
|
return outputs if self.return_output else hidden
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
raise PermissionError('Get out of here - never run this module')
|
Reference in New Issue
Block a user