""" Modified from https://github.com/fxia22/pointnet.pytorch/blob/master/utils/train_segmentation.py Cloned from Git: https://github.com/dragonbook/pointnet2-pytorch/blob/master/main.py """ import os import sys from distutils.util import strtobool import random import numpy as np import argparse import torch from torch.utils.data import DataLoader import torch.nn as nn import torch.optim as optim import torch.nn.functional as F import torch.backends.cudnn as cudnn from dataset.shapenet import ShapeNetPartSegDataset from model.pointnet2_part_seg import PointNet2PartSegmentNet import torch_geometric.transforms as GT fs_root = os.path.splitdrive(sys.executable)[0] # Argument parser parser = argparse.ArgumentParser() default_data_dir = os.path.join(os.getcwd(), 'data') parser.add_argument('--dataset', type=str, default=default_data_dir, help='dataset path') parser.add_argument('--npoints', type=int, default=1024, help='resample points number') parser.add_argument('--model', type=str, default='', help='model path') parser.add_argument('--nepoch', type=int, default=250, help='number of epochs to train for') parser.add_argument('--outf', type=str, default='checkpoint', help='output folder') parser.add_argument('--labels_within', type=strtobool, default=True, help='defines the label location') parser.add_argument('--batch_size', type=int, default=8, help='input batch size') parser.add_argument('--test_per_batches', type=int, default=1000, help='run a test batch per training batches number') parser.add_argument('--num_workers', type=int, default=0, help='number of data loading workers') parser.add_argument('--headers', type=strtobool, default=True, help='if raw files come with headers') parser.add_argument('--with_normals', type=strtobool, default=True, help='if training will include normals') parser.add_argument('--collate_per_segment', type=strtobool, default=True, help='whether to look at pointclouds or sub') parser.add_argument('--has_variations', type=strtobool, default=False, help='whether a single pointcloud has variations ' 'named int(id)_pc.(xyz|dat) look at pointclouds or sub') opt = parser.parse_args() print(opt) # Random seed opt.manual_seed = 123 opt.headers = bool(opt.headers) print('Random seed: ', opt.manual_seed) random.seed(opt.manual_seed) np.random.seed(opt.manual_seed) torch.manual_seed(opt.manual_seed) torch.cuda.manual_seed(opt.manual_seed) if __name__ == '__main__': # Dataset and transform print('Construct dataset ..') rot_max_angle = 15 trans_max_distance = 0.01 RotTransform = GT.Compose([GT.RandomRotate(rot_max_angle, 0), GT.RandomRotate(rot_max_angle, 1), GT.RandomRotate(rot_max_angle, 2)] ) TransTransform = GT.RandomTranslate(trans_max_distance) train_transform = GT.Compose([GT.NormalizeScale(), ]) test_transform = GT.Compose([GT.NormalizeScale(), ]) params = dict(root_dir=opt.dataset, collate_per_segment=opt.collate_per_segment, transform=train_transform, npoints=opt.npoints, labels_within=opt.labels_within, has_variations=opt.has_variations, headers=opt.headers, with_normals=opt.with_normals ) dataset = ShapeNetPartSegDataset(mode='train', **params) dataLoader = DataLoader(dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers) test_dataset = ShapeNetPartSegDataset(mode='test', **params) test_dataLoader = DataLoader(test_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers) num_classes = dataset.num_classes() print('dataset size: ', len(dataset)) print('test_dataset size: ', len(test_dataset)) print('num_classes: ', num_classes) try: os.mkdir(opt.outf) except OSError: # FIXME: Why is this just a pass? What about missing permissions? LOL pass # Model, criterion and optimizer print('Construct model ..') device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') dtype = torch.float print('cudnn.enabled: ', torch.backends.cudnn.enabled) net = PointNet2PartSegmentNet(num_classes, with_normals=opt.with_normals) if opt.model != '': net.load_state_dict(torch.load(opt.model)) net = net.to(device, dtype) criterion = nn.NLLLoss() optimizer = optim.Adam(net.parameters()) # Train print('Training ..') blue = lambda x: f'\033[94m {x} \033[0m' num_batch = len(dataset) // opt.batch_size test_per_batches = opt.test_per_batches print('number of epoches: ', opt.nepoch) print('number of batches per epoch: ', num_batch) print('run test per batches: ', test_per_batches) for epoch in range(opt.nepoch): print('Epoch {}, total epoches {}'.format(epoch+1, opt.nepoch)) net.train() # ToDo: We need different dataloader here to train the network in multiple iterations, maybe move the loop down for batch_idx, sample in enumerate(dataLoader): # points: (batch_size, n, 6) # pos: (batch_size, n, 3) # labels: (batch_size, n) points, labels = sample['points'], sample['labels'] points = points.transpose(1, 2).contiguous() # (batch_size, 3/6, n) points, labels = points.to(device, dtype), labels.to(device, torch.long) optimizer.zero_grad() pred = net(points) # (batch_size, n, num_classes) pred = pred.view(-1, num_classes) # (batch_size * n, num_classes) target = labels.view(-1, 1)[:, 0] loss = F.nll_loss(pred, target) loss.backward() optimizer.step() ## pred_label = pred.detach().max(1)[1] correct = pred_label.eq(target.detach()).cpu().sum() total = pred_label.shape[0] print(f'[{epoch}: {batch_idx}/{num_batch}] train loss: {loss.item()} ' f'accuracy: {float(correct.item())/total}') ## if batch_idx % test_per_batches == 0: print('Run a test batch') net.eval() with torch.no_grad(): batch_idx, sample = next(enumerate(test_dataLoader)) points, labels = sample['points'], sample['labels'] points = points.transpose(1, 2).contiguous() points, labels = points.to(device, dtype), labels.to(device, torch.long) pred = net(points) pred = pred.view(-1, num_classes) target = labels.view(-1, 1)[:, 0] # FixMe: Hacky Fix to get the labels right. But this won't fix the original problem target += 1 if -1 in target else 0 loss = F.nll_loss(pred, target) pred_label = pred.detach().max(1)[1] correct = pred_label.eq(target.detach()).cpu().sum() total = pred_label.shape[0] print(f'[{epoch}: {batch_idx}/{num_batch}] {blue("test")} loss: {loss.item()} ' f'accuracy: {float(correct.item())/total}') # Back to training mode net.train() torch.save(net.state_dict(), f'{opt.outf}/seg_model_custom_{epoch}.pth') # Benchmarm mIOU # Link to relvant Paper # https://arxiv.org/abs/1806.01896 net.eval() shape_ious = [] with torch.no_grad(): for batch_idx, sample in enumerate(test_dataLoader): points, labels = sample['points'], sample['labels'] points = points.transpose(1, 2).contiguous() points = points.to(device, dtype) # start_t = time.time() pred = net(points) # (batch_size, n, num_classes) # print('batch inference forward time used: {} ms'.format(time.time() - start_t)) pred_label = pred.max(2)[1] pred_label = pred_label.cpu().numpy() target_label = labels.numpy() batch_size = target_label.shape[0] for shape_idx in range(batch_size): parts = range(num_classes) # np.unique(target_label[shape_idx]) part_ious = [] for part in parts: I = np.sum(np.logical_and(pred_label[shape_idx] == part, target_label[shape_idx] == part)) U = np.sum(np.logical_or(pred_label[shape_idx] == part, target_label[shape_idx] == part)) iou = 1 if U == 0 else float(I) / U part_ious.append(iou) shape_ious.append(np.mean(part_ious)) print(f'mIOU for us Custom: {np.mean(shape_ious)}') print('Done.')