2019-08-09 12:35:55 +02:00

224 lines
8.6 KiB
Python

"""
Modified from https://github.com/fxia22/pointnet.pytorch/blob/master/utils/train_segmentation.py
Cloned from Git:
https://github.com/dragonbook/pointnet2-pytorch/blob/master/main.py
"""
import os
import sys
from distutils.util import strtobool
import random
import numpy as np
import argparse
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from dataset.shapenet import ShapeNetPartSegDataset
from model.pointnet2_part_seg import PointNet2PartSegmentNet
import torch_geometric.transforms as GT
fs_root = os.path.splitdrive(sys.executable)[0]
# Argument parser
parser = argparse.ArgumentParser()
default_data_dir = os.path.join(os.getcwd(), 'data')
parser.add_argument('--dataset', type=str, default=default_data_dir, help='dataset path')
parser.add_argument('--npoints', type=int, default=1024, help='resample points number')
parser.add_argument('--model', type=str, default='', help='model path')
parser.add_argument('--nepoch', type=int, default=250, help='number of epochs to train for')
parser.add_argument('--outf', type=str, default='checkpoint', help='output folder')
parser.add_argument('--labels_within', type=strtobool, default=True, help='defines the label location')
parser.add_argument('--batch_size', type=int, default=8, help='input batch size')
parser.add_argument('--test_per_batches', type=int, default=1000, help='run a test batch per training batches number')
parser.add_argument('--num_workers', type=int, default=0, help='number of data loading workers')
parser.add_argument('--headers', type=strtobool, default=True, help='if raw files come with headers')
parser.add_argument('--collate_per_segment', type=strtobool, default=True, help='whether to look at pointclouds or sub')
parser.add_argument('--has_variations', type=strtobool, default=False,
help='whether a single pointcloud has variations '
'named int(id)_pc.(xyz|dat) look at pointclouds or sub')
opt = parser.parse_args()
print(opt)
# Random seed
opt.manual_seed = 123
opt.headers = bool(opt.headers)
print('Random seed: ', opt.manual_seed)
random.seed(opt.manual_seed)
np.random.seed(opt.manual_seed)
torch.manual_seed(opt.manual_seed)
torch.cuda.manual_seed(opt.manual_seed)
if __name__ == '__main__':
# Dataset and transform
print('Construct dataset ..')
rot_max_angle = 15
trans_max_distance = 0.01
RotTransform = GT.Compose([GT.RandomRotate(rot_max_angle, 0),
GT.RandomRotate(rot_max_angle, 1),
GT.RandomRotate(rot_max_angle, 2)]
)
TransTransform = GT.RandomTranslate(trans_max_distance)
train_transform = GT.Compose([GT.NormalizeScale(), RotTransform, TransTransform])
test_transform = GT.Compose([GT.NormalizeScale(), ])
params = dict(root_dir=opt.dataset,
collate_per_segment=opt.collate_per_segment,
transform=train_transform,
npoints=opt.npoints,
labels_within=opt.labels_within,
has_variations=opt.has_variations,
headers=opt.headers
)
dataset = ShapeNetPartSegDataset(mode='train', **params)
dataLoader = DataLoader(dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers)
test_dataset = ShapeNetPartSegDataset(mode='test', **params)
test_dataLoader = DataLoader(test_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers)
num_classes = dataset.num_classes()
print('dataset size: ', len(dataset))
print('test_dataset size: ', len(test_dataset))
print('num_classes: ', num_classes)
try:
os.mkdir(opt.outf)
except OSError:
# FIXME: Why is this just a pass? What about missing permissions? LOL
pass
# Model, criterion and optimizer
print('Construct model ..')
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
dtype = torch.float
print('cudnn.enabled: ', torch.backends.cudnn.enabled)
net = PointNet2PartSegmentNet(num_classes)
if opt.model != '':
net.load_state_dict(torch.load(opt.model))
net = net.to(device, dtype)
criterion = nn.NLLLoss()
optimizer = optim.Adam(net.parameters())
# Train
print('Training ..')
blue = lambda x: f'\033[94m {x} \033[0m'
num_batch = len(dataset) // opt.batch_size
test_per_batches = opt.test_per_batches
print('number of epoches: ', opt.nepoch)
print('number of batches per epoch: ', num_batch)
print('run test per batches: ', test_per_batches)
for epoch in range(opt.nepoch):
print('Epoch {}, total epoches {}'.format(epoch+1, opt.nepoch))
net.train()
# ToDo: We need different dataloader here to train the network in multiple iterations, maybe move the loop down
for batch_idx, sample in enumerate(dataLoader):
# points: (batch_size, n, 6)
# pos: (batch_size, n, 3)
# labels: (batch_size, n)
points, labels = sample['points'], sample['labels']
points = points.transpose(1, 2).contiguous() # (batch_size, 3/6, n)
points, labels = points.to(device, dtype), labels.to(device, torch.long)
optimizer.zero_grad()
pred = net(points) # (batch_size, n, num_classes)
pred = pred.view(-1, num_classes) # (batch_size * n, num_classes)
target = labels.view(-1, 1)[:, 0]
loss = F.nll_loss(pred, target)
loss.backward()
optimizer.step()
##
pred_label = pred.detach().max(1)[1]
correct = pred_label.eq(target.detach()).cpu().sum()
total = pred_label.shape[0]
print(f'[{epoch}: {batch_idx}/{num_batch}] train loss: {loss.item()} '
f'accuracy: {float(correct.item())/total}')
##
if batch_idx % test_per_batches == 0:
print('Run a test batch')
net.eval()
with torch.no_grad():
batch_idx, sample = next(enumerate(test_dataLoader))
points, labels = sample['points'], sample['labels']
points = points.transpose(1, 2).contiguous()
points, labels = points.to(device, dtype), labels.to(device, torch.long)
pred = net(points)
pred = pred.view(-1, num_classes)
target = labels.view(-1, 1)[:, 0]
# FixMe: Hacky Fix to get the labels right. But this won't fix the original problem
target += 1 if -1 in target else 0
loss = F.nll_loss(pred, target)
pred_label = pred.detach().max(1)[1]
correct = pred_label.eq(target.detach()).cpu().sum()
total = pred_label.shape[0]
print(f'[{epoch}: {batch_idx}/{num_batch}] {blue("test")} loss: {loss.item()} '
f'accuracy: {float(correct.item())/total}')
# Back to training mode
net.train()
torch.save(net.state_dict(), f'{opt.outf}/seg_model_custom_{epoch}.pth')
# Benchmarm mIOU
# Link to relvant Paper
# https://arxiv.org/abs/1806.01896
net.eval()
shape_ious = []
with torch.no_grad():
for batch_idx, sample in enumerate(test_dataLoader):
points, labels = sample['points'], sample['labels']
points = points.transpose(1, 2).contiguous()
points = points.to(device, dtype)
# start_t = time.time()
pred = net(points) # (batch_size, n, num_classes)
# print('batch inference forward time used: {} ms'.format(time.time() - start_t))
pred_label = pred.max(2)[1]
pred_label = pred_label.cpu().numpy()
target_label = labels.numpy()
batch_size = target_label.shape[0]
for shape_idx in range(batch_size):
parts = range(num_classes) # np.unique(target_label[shape_idx])
part_ious = []
for part in parts:
I = np.sum(np.logical_and(pred_label[shape_idx] == part, target_label[shape_idx] == part))
U = np.sum(np.logical_or(pred_label[shape_idx] == part, target_label[shape_idx] == part))
iou = 1 if U == 0 else float(I) / U
part_ious.append(iou)
shape_ious.append(np.mean(part_ious))
print(f'mIOU for us Custom: {np.mean(shape_ious)}')
print('Done.')