Lightning integration basic ae, dataloaders and dataset
This commit is contained in:
parent
fbc776c359
commit
fbe0600e24
3
.idea/ae_toolbox_torch.iml
generated
3
.idea/ae_toolbox_torch.iml
generated
@ -5,6 +5,9 @@
|
||||
<orderEntry type="inheritedJdk" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
<component name="PyDocumentationSettings">
|
||||
<option name="renderExternalDocumentation" value="true" />
|
||||
</component>
|
||||
<component name="TestRunnerService">
|
||||
<option name="PROJECT_TEST_RUNNER" value="Unittests" />
|
||||
</component>
|
||||
|
9
.idea/dictionaries/illium.xml
generated
Normal file
9
.idea/dictionaries/illium.xml
generated
Normal file
@ -0,0 +1,9 @@
|
||||
<component name="ProjectDictionaryState">
|
||||
<dictionary name="illium">
|
||||
<words>
|
||||
<w>dataloader</w>
|
||||
<w>datasets</w>
|
||||
<w>isovists</w>
|
||||
</words>
|
||||
</dictionary>
|
||||
</component>
|
7
.idea/other.xml
generated
Normal file
7
.idea/other.xml
generated
Normal file
@ -0,0 +1,7 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="PySciProjectComponent">
|
||||
<option name="PY_SCI_VIEW" value="true" />
|
||||
<option name="PY_SCI_VIEW_SUGGESTED" value="true" />
|
||||
</component>
|
||||
</project>
|
6
.idea/vcs.xml
generated
Normal file
6
.idea/vcs.xml
generated
Normal file
@ -0,0 +1,6 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="VcsDirectoryMappings">
|
||||
<mapping directory="$PROJECT_DIR$" vcs="Git" />
|
||||
</component>
|
||||
</project>
|
164
.idea/workspace.xml
generated
164
.idea/workspace.xml
generated
@ -1,9 +1,171 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ChangeListManager">
|
||||
<list default="true" id="5955480a-c876-43d5-afd7-8717f51f413e" name="Default Changelist" comment="">
|
||||
<change afterPath="$PROJECT_DIR$/.idea/dictionaries/illium.xml" afterDir="false" />
|
||||
<change afterPath="$PROJECT_DIR$/.idea/other.xml" afterDir="false" />
|
||||
<change afterPath="$PROJECT_DIR$/.idea/vcs.xml" afterDir="false" />
|
||||
<change afterPath="$PROJECT_DIR$/basic_ae_lightning_torch.py" afterDir="false" />
|
||||
<change afterPath="$PROJECT_DIR$/data/dataset.py" afterDir="false" />
|
||||
<change afterPath="$PROJECT_DIR$/networks/basic_ae.py" afterDir="false" />
|
||||
<change afterPath="$PROJECT_DIR$/networks/modules.py" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/.idea/ae_toolbox_torch.iml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/ae_toolbox_torch.iml" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
|
||||
</list>
|
||||
<option name="EXCLUDED_CONVERTED_TO_IGNORED" value="true" />
|
||||
<option name="SHOW_DIALOG" value="false" />
|
||||
<option name="HIGHLIGHT_CONFLICTS" value="true" />
|
||||
<option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
|
||||
<option name="LAST_RESOLUTION" value="IGNORE" />
|
||||
</component>
|
||||
<component name="FileTemplateManagerImpl">
|
||||
<option name="RECENT_TEMPLATES">
|
||||
<list>
|
||||
<option value="Python Script" />
|
||||
</list>
|
||||
</option>
|
||||
</component>
|
||||
<component name="Git.Settings">
|
||||
<option name="RECENT_GIT_ROOT_PATH" value="$PROJECT_DIR$" />
|
||||
</component>
|
||||
<component name="ProjectId" id="1Omeu5sz43kySmz8qHfWIcA2dn0" />
|
||||
<component name="ProjectLevelVcsManager" settingsEditedManually="true" />
|
||||
<component name="PropertiesComponent">
|
||||
<property name="ASKED_ADD_EXTERNAL_FILES" value="true" />
|
||||
<property name="SHARE_PROJECT_CONFIGURATION_FILES" value="true" />
|
||||
<property name="WebServerToolWindowFactoryState" value="false" />
|
||||
<property name="nodejs_interpreter_path.stuck_in_default_project" value="undefined stuck path" />
|
||||
<property name="settings.editor.selected.configurable" value="com.jetbrains.python.configuration.PyActiveSdkModuleConfigurable" />
|
||||
<property name="settings.editor.selected.configurable" value="pyconsole" />
|
||||
</component>
|
||||
<component name="PyConsoleOptionsProvider">
|
||||
<option name="myPythonConsoleState">
|
||||
<console-settings module-name="ae_toolbox_torch" is-module-sdk="true">
|
||||
<option name="myUseModuleSdk" value="true" />
|
||||
<option name="myModuleName" value="ae_toolbox_torch" />
|
||||
</console-settings>
|
||||
</option>
|
||||
<option name="myShowDebugConsoleByDefault" value="true" />
|
||||
</component>
|
||||
<component name="RunDashboard">
|
||||
<option name="ruleStates">
|
||||
<list>
|
||||
<RuleState>
|
||||
<option name="name" value="ConfigurationTypeDashboardGroupingRule" />
|
||||
</RuleState>
|
||||
<RuleState>
|
||||
<option name="name" value="StatusDashboardGroupingRule" />
|
||||
</RuleState>
|
||||
</list>
|
||||
</option>
|
||||
</component>
|
||||
<component name="RunManager" selected="Python.basic_ae_lightning_torch">
|
||||
<configuration default="true" type="PythonConfigurationType" factoryName="Python">
|
||||
<module name="ae_toolbox_torch" />
|
||||
<option name="INTERPRETER_OPTIONS" value="" />
|
||||
<option name="PARENT_ENVS" value="true" />
|
||||
<envs>
|
||||
<env name="PYTHONUNBUFFERED" value="1" />
|
||||
</envs>
|
||||
<option name="SDK_HOME" value="" />
|
||||
<option name="WORKING_DIRECTORY" value="" />
|
||||
<option name="IS_MODULE_SDK" value="false" />
|
||||
<option name="ADD_CONTENT_ROOTS" value="true" />
|
||||
<option name="ADD_SOURCE_ROOTS" value="true" />
|
||||
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
|
||||
<option name="SCRIPT_NAME" value="" />
|
||||
<option name="PARAMETERS" value="" />
|
||||
<option name="SHOW_COMMAND_LINE" value="true" />
|
||||
<option name="EMULATE_TERMINAL" value="false" />
|
||||
<option name="MODULE_MODE" value="false" />
|
||||
<option name="REDIRECT_INPUT" value="false" />
|
||||
<option name="INPUT_FILE" value="" />
|
||||
<method v="2" />
|
||||
</configuration>
|
||||
<configuration name="basic_ae_lightning_torch" type="PythonConfigurationType" factoryName="Python" temporary="true">
|
||||
<module name="ae_toolbox_torch" />
|
||||
<option name="INTERPRETER_OPTIONS" value="" />
|
||||
<option name="PARENT_ENVS" value="true" />
|
||||
<envs>
|
||||
<env name="PYTHONUNBUFFERED" value="1" />
|
||||
</envs>
|
||||
<option name="SDK_HOME" value="" />
|
||||
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" />
|
||||
<option name="IS_MODULE_SDK" value="true" />
|
||||
<option name="ADD_CONTENT_ROOTS" value="true" />
|
||||
<option name="ADD_SOURCE_ROOTS" value="true" />
|
||||
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
|
||||
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/basic_ae_lightning_torch.py" />
|
||||
<option name="PARAMETERS" value="" />
|
||||
<option name="SHOW_COMMAND_LINE" value="true" />
|
||||
<option name="EMULATE_TERMINAL" value="false" />
|
||||
<option name="MODULE_MODE" value="false" />
|
||||
<option name="REDIRECT_INPUT" value="false" />
|
||||
<option name="INPUT_FILE" value="" />
|
||||
<method v="2" />
|
||||
</configuration>
|
||||
<configuration name="dataset" type="PythonConfigurationType" factoryName="Python" temporary="true">
|
||||
<module name="ae_toolbox_torch" />
|
||||
<option name="INTERPRETER_OPTIONS" value="" />
|
||||
<option name="PARENT_ENVS" value="true" />
|
||||
<envs>
|
||||
<env name="PYTHONUNBUFFERED" value="1" />
|
||||
</envs>
|
||||
<option name="SDK_HOME" value="" />
|
||||
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$/data" />
|
||||
<option name="IS_MODULE_SDK" value="true" />
|
||||
<option name="ADD_CONTENT_ROOTS" value="true" />
|
||||
<option name="ADD_SOURCE_ROOTS" value="true" />
|
||||
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
|
||||
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/data/dataset.py" />
|
||||
<option name="PARAMETERS" value="" />
|
||||
<option name="SHOW_COMMAND_LINE" value="true" />
|
||||
<option name="EMULATE_TERMINAL" value="false" />
|
||||
<option name="MODULE_MODE" value="false" />
|
||||
<option name="REDIRECT_INPUT" value="false" />
|
||||
<option name="INPUT_FILE" value="" />
|
||||
<method v="2" />
|
||||
</configuration>
|
||||
<recent_temporary>
|
||||
<list>
|
||||
<item itemvalue="Python.basic_ae_lightning_torch" />
|
||||
<item itemvalue="Python.dataset" />
|
||||
</list>
|
||||
</recent_temporary>
|
||||
</component>
|
||||
<component name="SvnConfiguration">
|
||||
<configuration />
|
||||
</component>
|
||||
<component name="TaskManager">
|
||||
<task active="true" id="Default" summary="Default task">
|
||||
<changelist id="5955480a-c876-43d5-afd7-8717f51f413e" name="Default Changelist" comment="" />
|
||||
<created>1564587418949</created>
|
||||
<option name="number" value="Default" />
|
||||
<option name="presentableId" value="Default" />
|
||||
<updated>1564587418949</updated>
|
||||
<workItem from="1564587420277" duration="6891000" />
|
||||
<workItem from="1565364574595" duration="1092000" />
|
||||
<workItem from="1565592214301" duration="53660000" />
|
||||
</task>
|
||||
<servers />
|
||||
</component>
|
||||
<component name="TypeScriptGeneratedFilesManager">
|
||||
<option name="version" value="1" />
|
||||
</component>
|
||||
<component name="Vcs.Log.Tabs.Properties">
|
||||
<option name="TAB_STATES">
|
||||
<map>
|
||||
<entry key="MAIN">
|
||||
<value>
|
||||
<State>
|
||||
<option name="COLUMN_ORDER" />
|
||||
</State>
|
||||
</value>
|
||||
</entry>
|
||||
</map>
|
||||
</option>
|
||||
</component>
|
||||
<component name="com.intellij.coverage.CoverageDataManagerImpl">
|
||||
<SUITE FILE_PATH="coverage/ae_toolbox_torch$basic_ae_lightning_torch.coverage" NAME="basic_ae_lightning_torch Coverage Results" MODIFIED="1565790288699" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
|
||||
<SUITE FILE_PATH="coverage/ae_toolbox_torch$dataset.coverage" NAME="dataset Coverage Results" MODIFIED="1565772669750" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/data" />
|
||||
</component>
|
||||
</project>
|
42
basic_ae_lightning_torch.py
Normal file
42
basic_ae_lightning_torch.py
Normal file
@ -0,0 +1,42 @@
|
||||
from networks.basic_ae import BasicAE
|
||||
from networks.modules import LightningModule
|
||||
import pytorch_lightning as pl
|
||||
from torch.nn.functional import mse_loss
|
||||
from torch.optim import Adam
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
from data.dataset import DataContainer
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
|
||||
|
||||
class AEModel(LightningModule):
|
||||
|
||||
def __init__(self, dataParams: dict):
|
||||
super(AEModel, self).__init__()
|
||||
self.dataParams = dataParams
|
||||
self.network = BasicAE(self.dataParams)
|
||||
|
||||
def forward(self, x):
|
||||
return self.network.forward(x)
|
||||
|
||||
def training_step(self, x, batch_nb):
|
||||
z, x_hat = self.forward(x)
|
||||
return {'loss': mse_loss(x, x_hat)}
|
||||
|
||||
def configure_optimizers(self):
|
||||
# ToDo: Where do i get the Paramers from?
|
||||
return [Adam(self.parameters(), lr=0.02)]
|
||||
|
||||
@pl.data_loader
|
||||
def tng_dataloader(self):
|
||||
return DataLoader(DataContainer('data', **self.dataParams), shuffle=True, batch_size=100)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ae = AEModel(
|
||||
dict(refresh=False, size=5, step=5, features=6)
|
||||
)
|
||||
|
||||
trainer = Trainer()
|
||||
trainer.fit(ae)
|
252
data/dataset.py
Normal file
252
data/dataset.py
Normal file
@ -0,0 +1,252 @@
|
||||
import argparse
|
||||
from collections import defaultdict
|
||||
from distutils.util import strtobool
|
||||
import os
|
||||
import ast
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset, ConcatDataset
|
||||
|
||||
|
||||
# Command line argument parsing
|
||||
def build_parse_commands():
|
||||
# Init the Command Line Arguments Parser
|
||||
arg_parser = argparse.ArgumentParser(description='VAE and GSE Autoencoder with latent Space Clustering Approaches')
|
||||
|
||||
# Specify a pretrained weight file to load.
|
||||
arg_parser.add_argument('--model_file', nargs='?', default='',
|
||||
help='Specify a pretrained model file to load.')
|
||||
# Specify a pretrained weight file to load.
|
||||
arg_parser.add_argument('--files', nargs='?', default='',
|
||||
help='Set the raw data location. Should be filled with maps and trajectories')
|
||||
|
||||
# Set a fixed prng seed.
|
||||
arg_parser.add_argument('--seed', nargs='?', default=-999, help='Set a fixed prng seed.')
|
||||
|
||||
# DataSet parameters
|
||||
arg_parser.add_argument('--size', nargs='?', default=9, help='Set a trajectory length; the number of isovists.')
|
||||
arg_parser.add_argument('--step', nargs='?', default=5, help='Set a fixed stepsize between isovist centers.')
|
||||
arg_parser.add_argument('--overlapping', nargs='?', default=True, help='Whether the Isovists should overlap.')
|
||||
|
||||
# Specify the Map to use in Training and visualization
|
||||
arg_parser.add_argument('-p', '--print_on_map', default=False, type=strtobool,
|
||||
help='Whether trajcetories should be colored and displayed on a map.')
|
||||
arg_parser.add_argument('-l', '--print_latent', default=False, type=strtobool,
|
||||
help='Whether latent encoding space should be colored and displayed.')
|
||||
arg_parser.add_argument('-d', '--divided_latent_viz', default=False, type=strtobool,
|
||||
help='Whether latent encoding space should be colored and displayed seperatein saae case.')
|
||||
return arg_parser.parse_args()
|
||||
|
||||
|
||||
class AbstractDataset(ConcatDataset, ABC):
|
||||
|
||||
# maps = ['hotel', 'tum','gallery', 'queens', 'oet']
|
||||
@property
|
||||
def maps(self):
|
||||
return ['hotel', 'tum','gallery', 'queens', 'oet']
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def raw_filenames(self):
|
||||
raise NotImplementedError('Specify the file ending here')
|
||||
|
||||
@property
|
||||
def raw_paths(self):
|
||||
return [os.path.join(self.path, 'raw', x) for x in self.raw_filenames]
|
||||
|
||||
@property
|
||||
def processed_filenames(self):
|
||||
return [f'{x}_{self.__class__.__name__}.to' for x in self.maps]
|
||||
|
||||
@property
|
||||
def processed_paths(self):
|
||||
return [os.path.join(self.path, 'processed', x) for x in self.processed_filenames]
|
||||
|
||||
def __init__(self, path, refresh=False, **kwargs):
|
||||
self.path = path
|
||||
self.refresh = refresh
|
||||
super(AbstractDataset, self).__init__(datasets=self._load_datasets())
|
||||
|
||||
@abstractmethod
|
||||
def process(self, filepath):
|
||||
raise NotImplementedError
|
||||
|
||||
def _load_datasets(self):
|
||||
if self.refresh:
|
||||
for filepath in self.processed_paths:
|
||||
try:
|
||||
os.remove(filepath)
|
||||
print('Processed Location "Refreshed" (We deleted the Files)')
|
||||
except FileNotFoundError:
|
||||
print('You meant to refresh the allready processed dataset, but there were none...')
|
||||
print('continue processing')
|
||||
pass
|
||||
datasets = []
|
||||
# ToDo: Make this nicer
|
||||
for map_idx, _ in tqdm(enumerate(self.maps),
|
||||
total=len(self.maps), unit="files"
|
||||
):
|
||||
while True:
|
||||
try:
|
||||
datasets.append(torch.load(self.processed_paths[map_idx]))
|
||||
print(f'Dataset "{self.processed_paths[map_idx]}" loaded')
|
||||
break
|
||||
except FileNotFoundError:
|
||||
os.makedirs(os.path.join(*os.path.split(self.processed_paths[map_idx])[:-1]), exist_ok=True)
|
||||
processed = self.process(self.raw_paths[map_idx])
|
||||
torch.save(processed, self.processed_paths[map_idx])
|
||||
continue
|
||||
return datasets
|
||||
|
||||
|
||||
class DataContainer(AbstractDataset):
|
||||
|
||||
@staticmethod
|
||||
def calculate_model_shapes(size, step, **kwargs):
|
||||
|
||||
return
|
||||
|
||||
@property
|
||||
def raw_filenames(self):
|
||||
return [f'{x}_trajec.csv' for x in self.maps]
|
||||
|
||||
def __init__(self, path, size, step, **kwargs):
|
||||
self.size = size
|
||||
self.step = step
|
||||
super(DataContainer, self).__init__(path, **kwargs)
|
||||
pass
|
||||
|
||||
def process(self, filepath):
|
||||
dataDict = defaultdict(list)
|
||||
with open(filepath, 'r') as f:
|
||||
delimiter = ','
|
||||
# Separate the header
|
||||
headers = f.readline().rstrip().split(delimiter)
|
||||
headers.remove('inDoor')
|
||||
# Iterate over every line and convert it to float / value
|
||||
# ToDo: Make this nicer
|
||||
for line in tqdm(f, total=len(self.maps), unit="lines"):
|
||||
if line == '':
|
||||
continue
|
||||
else:
|
||||
for attr, x in zip(headers, line.rstrip().split(delimiter)[None:None]):
|
||||
if attr not in ['inDoor']:
|
||||
dataDict[attr].append(ast.literal_eval(x))
|
||||
return Trajectories(self.size, self.step, headers, **dataDict)
|
||||
|
||||
|
||||
class Trajectories(Dataset):
|
||||
|
||||
# As in "To take hold of isovists and isovist fields" - M. L. Benedikt, read only measures specified by Benedikt
|
||||
@property
|
||||
def isovistMeasures(self):
|
||||
return ['X', 'Z', 'realSurfacePerimeter', 'occlusionValue', 'area', 'variance', 'skewness', 'circularity_ben']
|
||||
|
||||
@property
|
||||
def features(self):
|
||||
return len(self.isovistMeasures)
|
||||
|
||||
def __init__(self, size, step, headers, **kwargs):
|
||||
super(Trajectories, self).__init__()
|
||||
self.size: int = size
|
||||
self.step: int = step
|
||||
self.headers: list = headers
|
||||
|
||||
dataDict = dict()
|
||||
for key, val in kwargs.items():
|
||||
if key in self.isovistMeasures:
|
||||
dataDict[key] = torch.tensor(val)
|
||||
# Check if all keys are of same length
|
||||
assert len(set(x.size()[0] for x in dataDict.values() if torch.is_tensor(x))) <= 1
|
||||
self.data = torch.stack([dataDict[key] for key in self.isovistMeasures], dim=-1)
|
||||
pass
|
||||
|
||||
def __iter_tenors__(self):
|
||||
return
|
||||
|
||||
def __iter__(self):
|
||||
for i in range(len(self)):
|
||||
yield self[i]
|
||||
|
||||
def __getitem__(self, item, coords=False):
|
||||
"""
|
||||
Return a trajectory sample from the dataset by a specific key.
|
||||
:param item: The index number of the trajectory to return.
|
||||
:return:
|
||||
"""
|
||||
subList = self.data[item:item + self.size * self.step or None:self.step]
|
||||
xy, tensor = subList[:, 2], subList[:, 2:]
|
||||
return (xy, tensor) if coords else tensor
|
||||
|
||||
def __len__(self):
|
||||
total_len = self.data.size()[0]
|
||||
return total_len - (self.size * self.step - (self.step - 1))
|
||||
|
||||
|
||||
class MapContainer(AbstractDataset):
|
||||
|
||||
@property
|
||||
def raw_filenames(self):
|
||||
return [f'{x}_map.csv' for x in self.maps]
|
||||
|
||||
def __init__(self, path, **kwargs):
|
||||
super(MapContainer, self).__init__(path, **kwargs)
|
||||
pass
|
||||
|
||||
def process(self, filepath):
|
||||
dataDict = defaultdict(list)
|
||||
with open(filepath, 'r') as f:
|
||||
delimiter = ','
|
||||
# Separate the header
|
||||
headers = f.readline().rstrip().split(delimiter)
|
||||
# Iterate over every line and convert it to float / value
|
||||
# ToDo: Make this nicer
|
||||
for line in tqdm(f):
|
||||
if line == '':
|
||||
continue
|
||||
else:
|
||||
for attr, x in zip(headers, line.rstrip().split(delimiter)[None:None]):
|
||||
dataDict[attr].append(ast.literal_eval(x))
|
||||
|
||||
return Map(np.asarray([dataDict[head] for head in headers]))
|
||||
|
||||
|
||||
class Map(object):
|
||||
|
||||
def __init__(self, mapData: np.ndarray):
|
||||
"""
|
||||
This is a Container Class for triangulated basemaps in csv format.
|
||||
:param mapData: The map as np.ndarray, already read from disk.
|
||||
"""
|
||||
|
||||
self.map: np.ndarray = mapData
|
||||
|
||||
self.minx, self.maxx = np.min(self.map[[0, 2, 4]]), np.max(self.map[[0, 2, 4]])
|
||||
self.miny, self.maxy = np.min(self.map[[1, 3, 5]]), np.max(self.map[[1, 3, 5]])
|
||||
|
||||
print('BaseMap Initialized')
|
||||
|
||||
def __len__(self):
|
||||
return self.map.shape[0]
|
||||
|
||||
def vertices(self):
|
||||
vertices = self.map.reshape((-1, 2, 3))
|
||||
return vertices
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.map[item].reshape(3, 2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = build_parse_commands()
|
||||
if args.seed != -999:
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
# d = DataContainer(args.files, args.size, args.step)
|
||||
m = MapContainer(args.files, refresh=True)
|
||||
print(len(m[1]))
|
52
networks/basic_ae.py
Normal file
52
networks/basic_ae.py
Normal file
@ -0,0 +1,52 @@
|
||||
from torch.nn import Sequential, Linear, GRU
|
||||
from data.dataset import DataContainer
|
||||
|
||||
from .modules import *
|
||||
|
||||
|
||||
#######################
|
||||
# Basic AE-Implementation
|
||||
class BasicAE(Module, ABC):
|
||||
|
||||
def __init__(self, dataParams, **kwargs):
|
||||
super(BasicAE, self).__init__()
|
||||
self.dataParams = dataParams
|
||||
self.latent_dim = kwargs.get('latent_dim', 2)
|
||||
self.encoder = self._build_encoder()
|
||||
self.decoder = self._build_decoder()
|
||||
|
||||
|
||||
def _build_encoder(self):
|
||||
encoder = Sequential()
|
||||
encoder.add_module(f'EncoderLinear_{1}', Linear(6, 10, bias=True))
|
||||
encoder.add_module(f'EncoderLinear_{2}', Linear(10, 10, bias=True))
|
||||
gru = Sequential()
|
||||
gru.add_module('Encoder', TimeDistributed(encoder))
|
||||
gru.add_module('GRU', GRU(10, self.latent_dim))
|
||||
return gru
|
||||
|
||||
def _build_decoder(self):
|
||||
decoder = Sequential()
|
||||
decoder.add_module(f'DecoderLinear_{1}', Linear(10, 10, bias=True))
|
||||
decoder.add_module(f'DecoderLinear_{2}', Linear(10, self.dataParams['features'], bias=True))
|
||||
|
||||
gru = Sequential()
|
||||
# There needs to be ab propper bat
|
||||
gru.add_module('Repeater', Repeater((1, self.dataParams['size'], -1)))
|
||||
gru.add_module('GRU', GRU(self.latent_dim, 10))
|
||||
gru.add_module('GRU Filter', RNNOutputFilter())
|
||||
gru.add_module('Decoder', TimeDistributed(decoder))
|
||||
return gru
|
||||
|
||||
def forward(self, batch):
|
||||
batch_size = batch.shape[0]
|
||||
self.decoder.Repeater.shape = (batch_size, ) + self.decoder.Repeater.shape[-2:]
|
||||
# outputs, hidden (Batch, Timesteps aka. Size, Features / Latent Dim Size)
|
||||
outputs, _ = self.encoder(batch)
|
||||
z = outputs[:, -1]
|
||||
x_hat = self.decoder(z)
|
||||
return z, x_hat
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise PermissionError('Get out of here - never run this module')
|
103
networks/modules.py
Normal file
103
networks/modules.py
Normal file
@ -0,0 +1,103 @@
|
||||
import torch
|
||||
import pytorch_lightning as pl
|
||||
from torch.nn import Module
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
######################
|
||||
# Abstract Network class following the Lightning Syntax
|
||||
class LightningModule(pl.LightningModule, ABC):
|
||||
|
||||
def __init__(self):
|
||||
super(LightningModule, self).__init__()
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, x):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def training_step(self, batch, batch_nb):
|
||||
# REQUIRED
|
||||
raise NotImplementedError
|
||||
|
||||
def validation_step(self, batch, batch_nb):
|
||||
# OPTIONAL
|
||||
pass
|
||||
|
||||
def validation_end(self, outputs):
|
||||
# OPTIONAL
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def configure_optimizers(self):
|
||||
# REQUIRED
|
||||
raise NotImplementedError
|
||||
|
||||
@pl.data_loader
|
||||
def tng_dataloader(self):
|
||||
# REQUIRED
|
||||
raise NotImplementedError
|
||||
# return DataLoader(MNIST(os.getcwd(), train=True, download=True,
|
||||
# transform=transforms.ToTensor()), batch_size=32)
|
||||
|
||||
@pl.data_loader
|
||||
def val_dataloader(self):
|
||||
# OPTIONAL
|
||||
pass
|
||||
|
||||
@pl.data_loader
|
||||
def test_dataloader(self):
|
||||
# OPTIONAL
|
||||
pass
|
||||
|
||||
|
||||
#######################
|
||||
# Utility Modules
|
||||
class TimeDistributed(Module):
|
||||
def __init__(self, module, batch_first=True):
|
||||
super(TimeDistributed, self).__init__()
|
||||
self.module = module
|
||||
self.batch_first = batch_first
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
if len(x.size()) <= 2:
|
||||
return self.module(x)
|
||||
|
||||
# Squash samples and timesteps into a single axis
|
||||
x_reshape = x.contiguous().view(-1, x.size(-1)) # (samples * timesteps, input_size)
|
||||
|
||||
y = self.module(x_reshape)
|
||||
|
||||
# We have to reshape Y
|
||||
if self.batch_first:
|
||||
y = y.contiguous().view(x.size(0), -1, y.size(-1)) # (samples, timesteps, output_size)
|
||||
else:
|
||||
y = y.view(-1, x.size(1), y.size(-1)) # (timesteps, samples, output_size)
|
||||
|
||||
return y
|
||||
|
||||
|
||||
class Repeater(Module):
|
||||
def __init__(self, shape):
|
||||
super(Repeater, self).__init__()
|
||||
self.shape = shape
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
x.unsqueeze_(-2)
|
||||
return x.expand(self.shape)
|
||||
|
||||
class RNNOutputFilter(Module):
|
||||
|
||||
def __init__(self, return_output=True):
|
||||
super(RNNOutputFilter, self).__init__()
|
||||
self.return_output = return_output
|
||||
|
||||
def forward(self, x: tuple):
|
||||
outputs, hidden = x
|
||||
return outputs if self.return_output else hidden
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise PermissionError('Get out of here - never run this module')
|
Loading…
x
Reference in New Issue
Block a user