Train Active

This commit is contained in:
Si11ium 2020-03-03 15:10:17 +01:00
parent 44f6589259
commit 1f612a968c
13 changed files with 102 additions and 98 deletions

View File

@ -7,7 +7,7 @@ from random import choice
from torch.utils.data import ConcatDataset, Dataset from torch.utils.data import ConcatDataset, Dataset
from lib.objects.map import Map from lib.objects.map import Map
from lib.preprocessing.generator import Generator from PIL import Image
class TrajDataset(Dataset): class TrajDataset(Dataset):
@ -17,14 +17,14 @@ class TrajDataset(Dataset):
return self.map.as_array.shape return self.map.as_array.shape
def __init__(self, *args, maps_root: Union[Path, str] = '', mapname='tate_sw', def __init__(self, *args, maps_root: Union[Path, str] = '', mapname='tate_sw',
length=100000, all_in_map=True, **kwargs): length=100000, all_in_map=True, embedding_size=None, **kwargs):
super(TrajDataset, self).__init__() super(TrajDataset, self).__init__()
self.all_in_map = all_in_map self.all_in_map = all_in_map
self.mapname = mapname if mapname.endswith('.bmp') else f'{mapname}.bmp' self.mapname = mapname if mapname.endswith('.bmp') else f'{mapname}.bmp'
self.maps_root = maps_root self.maps_root = maps_root
self._len = length self._len = length
self.map = Map(self.mapname).from_image(self.maps_root / self.mapname) self.map = Map(self.mapname).from_image(self.maps_root / self.mapname, embedding_size=embedding_size)
def __len__(self): def __len__(self):
return self._len return self._len
@ -35,15 +35,14 @@ class TrajDataset(Dataset):
label = choice([0, 1]) label = choice([0, 1])
if self.all_in_map: if self.all_in_map:
blank_trajectory_space = torch.zeros(self.map.shape) blank_trajectory_space = torch.zeros(self.map.shape)
blank_trajectory_space[trajectory.vertices] = 1
blank_alternative_space = torch.zeros(self.map.shape) blank_alternative_space = torch.zeros(self.map.shape)
blank_alternative_space[trajectory.np_vertices] = 1 for index in trajectory.vertices:
blank_trajectory_space[index] = 1
blank_alternative_space[index] = 1
map_array = torch.as_tensor(self.map.as_array) map_array = torch.as_tensor(self.map.as_array).float()
label = self.map.are_homotopic(trajectory, alternative) label = self.map.are_homotopic(trajectory, alternative)
return torch.cat((map_array, blank_trajectory_space, blank_alternative_space)), int(label)
return torch.cat((map_array, blank_trajectory_space, blank_alternative_space)), label
else: else:
return trajectory.vertices, alternative.vertices, label, self.mapname return trajectory.vertices, alternative.vertices, label, self.mapname
@ -56,7 +55,10 @@ class TrajData(object):
@property @property
def map_shapes_max(self): def map_shapes_max(self):
shapes = self.map_shapes shapes = self.map_shapes
return list(map(max, zip(*shapes))) shape_list = list(map(max, zip(*shapes)))
if self.all_in_map:
shape_list[0] += 2
return shape_list
@property @property
def name(self): def name(self):
@ -72,8 +74,12 @@ class TrajData(object):
def _load_datasets(self): def _load_datasets(self):
map_files = list(self.maps_root.glob('*.bmp')) map_files = list(self.maps_root.glob('*.bmp'))
equal_split = int(self.length // len(map_files)) equal_split = int(self.length // len(map_files))
return ConcatDataset([TrajDataset(maps_root=self.maps_root, mapname=map_image.name, length=equal_split,
all_in_map=self.all_in_map) for map_image in map_files]) # find max image size among available maps:
max_map_size = (1, ) + tuple(reversed(tuple(map(max, *[Image.open(map_file).size for map_file in map_files]))))
return ConcatDataset([TrajDataset(maps_root=self.maps_root, mapname=map_file.name, length=equal_split,
all_in_map=self.all_in_map, embedding_size=max_map_size)
for map_file in map_files])
@property @property
def train_dataset(self): def train_dataset(self):

View File

@ -18,21 +18,12 @@ class ConvHomDetector(LightningBaseModule):
def configure_optimizers(self): def configure_optimizers(self):
return Adam(self.parameters(), lr=self.hparams.lr) return Adam(self.parameters(), lr=self.hparams.lr)
def validation_step(self, *args, **kwargs):
pass
def validation_end(self, outputs):
pass
def training_step(self, batch_xy, batch_nb, *args, **kwargs): def training_step(self, batch_xy, batch_nb, *args, **kwargs):
batch_x, batch_y = batch_xy batch_x, batch_y = batch_xy
pred_y = self(batch_x) pred_y = self(batch_x)
loss = F.binary_cross_entropy(pred_y, batch_y) loss = F.binary_cross_entropy(pred_y, batch_y.float())
return {'loss': loss, 'log': dict(loss=loss)} return {'loss': loss, 'log': dict(loss=loss)}
def test_step(self, *args, **kwargs):
pass
def __init__(self, *params): def __init__(self, *params):
super(ConvHomDetector, self).__init__(*params) super(ConvHomDetector, self).__init__(*params)
@ -75,8 +66,9 @@ class ConvHomDetector(LightningBaseModule):
# #
self.linear = nn.Linear(reduce(mul, self.flatten.shape), self.hparams.model_param.classes * 10) self.linear = nn.Linear(reduce(mul, self.flatten.shape), self.hparams.model_param.classes * 10)
self.classifier = nn.Linear(self.hparams.model_param.classes * 10, self.hparams.model_param.classes) # Comments on Multi Class labels
self.softmax = nn.Softmax() self.classifier = nn.Linear(self.hparams.model_param.classes * 10, 1) # self.hparams.model_param.classes)
self.out_activation = nn.Sigmoid() # nn.Softmax
def forward(self, x): def forward(self, x):
tensor = self.map_conv_0(x) tensor = self.map_conv_0(x)
@ -88,5 +80,5 @@ class ConvHomDetector(LightningBaseModule):
tensor = self.flatten(tensor) tensor = self.flatten(tensor)
tensor = self.linear(tensor) tensor = self.linear(tensor)
tensor = self.classifier(tensor) tensor = self.classifier(tensor)
tensor = self.softmax(tensor) tensor = self.out_activation(tensor)
return tensor return tensor

View File

@ -106,7 +106,7 @@ class ResidualModule(nn.Module):
self.in_shape = in_shape self.in_shape = in_shape
module_paramters.update(in_shape=in_shape) module_paramters.update(in_shape=in_shape)
self.activation = activation() if activation else lambda x: x self.activation = activation() if activation else lambda x: x
self.residual_block = [module_class(**module_paramters) for _ in range(n)] self.residual_block = nn.ModuleList([module_class(**module_paramters) for _ in range(n)])
assert self.in_shape == self.shape, f'The in_shape: {self.in_shape} - must match the out_shape: {self.shape}.' assert self.in_shape == self.shape, f'The in_shape: {self.in_shape} - must match the out_shape: {self.shape}.'
def forward(self, x): def forward(self, x):

View File

@ -133,12 +133,6 @@ class LightningBaseModule(pl.LightningModule, ABC):
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
raise NotImplementedError raise NotImplementedError
def validation_step(self, *args, **kwargs):
raise NotImplementedError
def validation_end(self, outputs):
raise NotImplementedError
def training_step(self, batch_xy, batch_nb, *args, **kwargs): def training_step(self, batch_xy, batch_nb, *args, **kwargs):
raise NotImplementedError raise NotImplementedError
@ -146,21 +140,7 @@ class LightningBaseModule(pl.LightningModule, ABC):
raise NotImplementedError raise NotImplementedError
def test_end(self, outputs): def test_end(self, outputs):
from sklearn.metrics import roc_auc_score raise NotImplementedError
y_scores, y_true = [], []
for output in outputs:
y_scores.append(output['y_pred'])
y_true.append(output['y_true'])
y_true = torch.cat(y_true, dim=0)
# FIXME: What did this do do i need it?
# y_true = (y_true != V.HOMOTOPIC).long()
y_scores = torch.cat(y_scores, dim=0)
roc_auc_scores = roc_auc_score(y_true.cpu().numpy(), y_scores.cpu().numpy())
print(f'AUC Score: {roc_auc_scores}')
return {'roc_auc_scores': roc_auc_scores}
def init_weights(self): def init_weights(self):
def _weight_init(m): def _weight_init(m):

View File

@ -29,11 +29,11 @@ class Map(object):
@property @property
def width(self): def width(self):
return self.shape[0] return self.shape[-2]
@property @property
def height(self): def height(self):
return self.shape[1] return self.shape[-1]
@property @property
def as_graph(self): def as_graph(self):
@ -43,6 +43,10 @@ class Map(object):
def as_array(self): def as_array(self):
return self.map_array return self.map_array
@property
def as_2d_array(self):
return self.map_array[1:]
def __init__(self, name='', array_like_map_representation=None): def __init__(self, name='', array_like_map_representation=None):
if array_like_map_representation is not None: if array_like_map_representation is not None:
if array_like_map_representation.ndim == 2: if array_like_map_representation.ndim == 2:
@ -72,21 +76,25 @@ class Map(object):
# Differentiate between 8 and 4 neighbors # Differentiate between 8 and 4 neighbors
if not full_neighbors and n >= 2: if not full_neighbors and n >= 2:
break break
# ToDO: make this explicite and less ugly # ToDO: make this explicite and less ugly
query_node = idx[:1] + (idx[1] + ydif,) + (idx[2] + xdif,) query_node = idx[:1] + (idx[1] + ydif,) + (idx[2] + xdif,)
if graph.has_node(query_node): if graph.has_node(query_node):
graph.add_edge(idx, query_node, weight=weight) graph.add_edge(idx, query_node, weight=weight)
return graph return graph
@classmethod @classmethod
def from_image(cls, imagepath: Path): def from_image(cls, imagepath: Path, embedding_size=None):
with Image.open(imagepath) as image: with Image.open(imagepath) as image:
# Turn the image to single Channel Greyscale # Turn the image to single Channel Greyscale
if image.mode != 'L': if image.mode != 'L':
image = image.convert('L') image = image.convert('L')
map_array = np.expand_dims(np.array(image), axis=0) map_array = np.expand_dims(np.array(image), axis=0)
if embedding_size:
assert isinstance(embedding_size, tuple), f'embedding_size was of type: {type(embedding_size)}'
embedding = np.zeros(embedding_size)
embedding[:map_array.shape[0], :map_array.shape[1], :map_array.shape[2]] = map_array
map_array = embedding
return cls(name=imagepath.name, array_like_map_representation=map_array) return cls(name=imagepath.name, array_like_map_representation=map_array)
def simple_trajectory_between(self, start, dest): def simple_trajectory_between(self, start, dest):
@ -105,36 +113,46 @@ class Map(object):
return Trajectory(coords) return Trajectory(coords)
def get_random_trajectory(self): def get_random_trajectory(self):
simple_trajectory = None
while simple_trajectory is None:
try:
start = self.get_valid_position() start = self.get_valid_position()
dest = self.get_valid_position() dest = self.get_valid_position()
return self.simple_trajectory_between(start, dest) simple_trajectory = self.simple_trajectory_between(start, dest)
except nx.exception.NetworkXNoPath:
pass
return simple_trajectory
def generate_alternative(self, trajectory, mode='one_patching'): def generate_alternative(self, trajectory, mode='one_patching'):
start, dest = trajectory.endpoints start, dest = trajectory.endpoints
alternative = None
while alternative is None:
try:
if mode == 'one_patching': if mode == 'one_patching':
patch = self.get_valid_position() patch = self.get_valid_position()
alternative = self.get_trajectory_from_vertices(start, patch, dest) alternative = self.get_trajectory_from_vertices(start, patch, dest)
else: else:
raise RuntimeError(f'mode checking went wrong...') raise RuntimeError(f'mode checking went wrong...')
except nx.exception.NetworkXNoPath:
pass
return alternative return alternative
def are_homotopic(self, trajectory, other_trajectory): def are_homotopic(self, trajectory, other_trajectory):
if not all(isinstance(x, Trajectory) for x in [trajectory, other_trajectory]): if not all(isinstance(x, Trajectory) for x in [trajectory, other_trajectory]):
raise TypeError raise TypeError
polyline = trajectory.vertices.copy() polyline = trajectory.xy_vertices
polyline.extend(reversed(other_trajectory.vertices)) polyline.extend(reversed(other_trajectory.xy_vertices))
img = Image.new('L', (self.height, self.width), 0) img = Image.new('L', (self.height, self.width), 0)
draw = ImageDraw.Draw(img) draw = ImageDraw.Draw(img)
draw.polygon(polyline, outline=255, fill=255) draw.polygon(polyline, outline=255, fill=255)
a = (np.array(img) * np.where(self.map_array == self.white, 0, 1)).sum() a = (np.asarray(img) * np.where(self.as_2d_array == self.white, 0, 1)).sum()
if a >= 1: if a:
return False return False # Non-Homotoph
else: else:
return True return True # Homotoph
def draw(self): def draw(self):
fig, ax = plt.gcf(), plt.gca() fig, ax = plt.gcf(), plt.gca()

View File

@ -8,43 +8,51 @@ import numpy as np
class Trajectory(object): class Trajectory(object):
@property
def vertices(self):
return self._vertices
@property
def xy_vertices(self):
return [(x,y) for _, x,y in self._vertices]
@property @property
def endpoints(self): def endpoints(self):
return self.start, self.dest return self.start, self.dest
@property @property
def start(self): def start(self):
return self.vertices[0] return self._vertices[0]
@property @property
def dest(self): def dest(self):
return self.vertices[-1] return self._vertices[-1]
@property @property
def xs(self): def xs(self):
return [x[1] for x in self.vertices] return [x[1] for x in self._vertices]
@property @property
def ys(self): def ys(self):
return [x[0] for x in self.vertices] return [x[0] for x in self._vertices]
@property @property
def as_paired_list(self): def as_paired_list(self):
return list(zip(self.vertices[:-1], self.vertices[1:])) return list(zip(self._vertices[:-1], self._vertices[1:]))
@property @property
def np_vertices(self): def np_vertices(self):
return [np.array(vertice) for vertice in self.vertices] return [np.array(vertice) for vertice in self._vertices]
def __init__(self, vertices: Union[List[Tuple[int]], None] = None): def __init__(self, vertices: Union[List[Tuple[int]], None] = None):
assert any((isinstance(vertices, list), vertices is None)) assert any((isinstance(vertices, list), vertices is None))
if vertices is not None: if vertices is not None:
self.vertices = vertices self._vertices = vertices
pass pass
def is_equal_to(self, other_trajectory): def is_equal_to(self, other_trajectory):
# ToDo: do further equality Checks here # ToDo: do further equality Checks here
return self.vertices == other_trajectory.vertices return self._vertices == other_trajectory.vertices
def draw(self, highlights=True, label=None, **kwargs): def draw(self, highlights=True, label=None, **kwargs):
if label is not None: if label is not None:

View File

@ -31,7 +31,7 @@ main_arg_parser.add_argument("--main_eval", type=strtobool, default=False, help=
main_arg_parser.add_argument("--main_seed", type=int, default=69, help="") main_arg_parser.add_argument("--main_seed", type=int, default=69, help="")
# Data Parameters # Data Parameters
main_arg_parser.add_argument("--data_worker", type=int, default=10, help="") main_arg_parser.add_argument("--data_worker", type=int, default=0, help="")
main_arg_parser.add_argument("--data_batchsize", type=int, default=100, help="") main_arg_parser.add_argument("--data_batchsize", type=int, default=100, help="")
main_arg_parser.add_argument("--data_root", type=str, default='/data/', help="") main_arg_parser.add_argument("--data_root", type=str, default='/data/', help="")
main_arg_parser.add_argument("--data_map_root", type=str, default='/res/maps', help="") main_arg_parser.add_argument("--data_map_root", type=str, default='/res/maps', help="")
@ -76,6 +76,9 @@ checkpoint_callback = ModelCheckpoint(
period=1 period=1
) )
if __name__ == "__main__":
# Model # Model
# ============================================================================= # =============================================================================
# Init # Init
@ -96,9 +99,6 @@ trainer = Trainer(max_nb_epochs=config.train.epochs,
early_stop_callback=None early_stop_callback=None
) )
if __name__ == "__main__":
# Check Cuda availability
print(f'Cuda is {"" if torch.cuda.is_available() else "not"} available!!!')
# Train it # Train it
trainer.fit(model) trainer.fit(model)

Binary file not shown.

Before

Width:  |  Height:  |  Size: 194 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 248 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 198 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 102 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 66 KiB

After

Width:  |  Height:  |  Size: 3.8 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 129 KiB