""" Modified from https://github.com/fxia22/pointnet.pytorch/blob/master/utils/train_segmentation.py """ import os, sys import random import numpy as np import argparse import torch from torch.utils.data import DataLoader import torch.nn as nn import torch.optim as optim import torch.nn.functional as F from torch import autograd import torch.backends.cudnn as cudnn from dataset.shapenet import ShapeNetPartSegDataset from model.pointnet2_part_seg import PointNet2PartSegmentNet import torch_geometric.transforms as GT import time fs_root = os.path.splitdrive(sys.executable)[0] # Argument parser parser = argparse.ArgumentParser() default_data_dir = os.path.join(os.getcwd(), 'data') parser.add_argument('--dataset', type=str, default=default_data_dir, help='dataset path') parser.add_argument('--npoints', type=int, default=50, help='resample points number') parser.add_argument('--model', type=str, default='checkpoint//seg_model_custom_24.pth', help='model path') parser.add_argument('--nepoch', type=int, default=10, help='number of epochs to train for') parser.add_argument('--outf', type=str, default='checkpoint', help='output folder') parser.add_argument('--batch_size', type=int, default=8, help='input batch size') parser.add_argument('--test_per_batches', type=int, default=1000, help='run a test batch per training batches number') parser.add_argument('--num_workers', type=int, default=0, help='number of data loading workers') opt = parser.parse_args() print(opt) # Random seed opt.manual_seed = 123 print('Random seed: ', opt.manual_seed) random.seed(opt.manual_seed) np.random.seed(opt.manual_seed) torch.manual_seed(opt.manual_seed) torch.cuda.manual_seed(opt.manual_seed) if __name__ == '__main__': # Dataset and transform print('Construct dataset ..') if True: rot_max_angle = 15 trans_max_distance = 0.01 RotTransform = GT.Compose([GT.RandomRotate(rot_max_angle, 0), GT.RandomRotate(rot_max_angle, 1), GT.RandomRotate(rot_max_angle, 2)] ) TransTransform = GT.RandomTranslate(trans_max_distance) train_transform = GT.Compose([GT.NormalizeScale(), RotTransform, TransTransform]) test_transform = GT.Compose([GT.NormalizeScale(), ]) dataset = ShapeNetPartSegDataset(root_dir=opt.dataset, train=True, transform=train_transform, npoints=opt.npoints) dataloader = DataLoader(dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers) test_dataset = ShapeNetPartSegDataset(root_dir=opt.dataset, train=False, transform=test_transform, npoints=opt.npoints) test_dataloader = DataLoader(test_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers) num_classes = dataset.num_classes() print('dataset size: ', len(dataset)) print('test_dataset size: ', len(test_dataset)) print('num_classes: ', num_classes) try: os.mkdir(opt.outf) except OSError: #FIXME: Why is this just a pass? What about missing permissions? LOL pass ## Model, criterion and optimizer print('Construct model ..') device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') dtype = torch.float print('cudnn.enabled: ', torch.backends.cudnn.enabled) net = PointNet2PartSegmentNet(num_classes) if opt.model != '': net.load_state_dict(torch.load(opt.model)) net = net.to(device, dtype) criterion = nn.NLLLoss() optimizer = optim.Adam(net.parameters()) if True: ## Train print('Training ..') blue = lambda x: '\033[94m' + x + '\033[0m' num_batch = len(dataset) // opt.batch_size test_per_batches = opt.test_per_batches print('number of epoches: ', opt.nepoch) print('number of batches per epoch: ', num_batch) print('run test per batches: ', test_per_batches) for epoch in range(opt.nepoch): print('Epoch {}, total epoches {}'.format(epoch+1, opt.nepoch)) net.train() for batch_idx, sample in enumerate(dataloader): # points: (batch_size, n, 3) # labels: (batch_size, n) points, labels = sample['points'], sample['labels'] points = points.transpose(1, 2).contiguous() # (batch_size, 3, n) points, labels = points.to(device, dtype), labels.to(device, torch.long) optimizer.zero_grad() pred = net(points) # (batch_size, n, num_classes) pred = pred.view(-1, num_classes) # (batch_size * n, num_classes) target = labels.view(-1, 1)[:, 0] loss = F.nll_loss(pred, target) loss.backward() optimizer.step() ## pred_label = pred.detach().max(1)[1] correct = pred_label.eq(target.detach()).cpu().sum() total = pred_label.shape[0] print('[{}: {}/{}] train loss: {} accuracy: {}'.format(epoch, batch_idx, num_batch, loss.item(), float(correct.item())/total)) ## if batch_idx % test_per_batches == 0: print('Run a test batch') net.eval() with torch.no_grad(): batch_idx, sample = next(enumerate(test_dataloader)) points, labels = sample['points'], sample['labels'] points = points.transpose(1, 2).contiguous() points, labels = points.to(device, dtype), labels.to(device, torch.long) pred = net(points) pred = pred.view(-1, num_classes) target = labels.view(-1, 1)[:, 0] target += 1 if -1 in target else 0 loss = F.nll_loss(pred, target) pred_label = pred.detach().max(1)[1] correct = pred_label.eq(target.detach()).cpu().sum() total = pred_label.shape[0] print('[{}: {}/{}] {} loss: {} accuracy: {}'.format(epoch, batch_idx, num_batch, blue('test'), loss.item(), float(correct.item())/total)) # Back to training mode net.train() torch.save(net.state_dict(), f'{opt.outf}/seg_model_custom_{epoch}.pth') ## Benchmarm mIOU net.eval() shape_ious = [] with torch.no_grad(): for batch_idx, sample in enumerate(test_dataloader): points, labels = sample['points'], sample['labels'] points = points.transpose(1, 2).contiguous() points = points.to(device, dtype) # start_t = time.time() pred = net(points) # (batch_size, n, num_classes) # print('batch inference forward time used: {} ms'.format(time.time() - start_t)) pred_label = pred.max(2)[1] pred_label = pred_label.cpu().numpy() target_label = labels.numpy() batch_size = target_label.shape[0] for shape_idx in range(batch_size): parts = range(num_classes) # np.unique(target_label[shape_idx]) part_ious = [] for part in parts: I = np.sum(np.logical_and(pred_label[shape_idx] == part, target_label[shape_idx] == part)) U = np.sum(np.logical_or(pred_label[shape_idx] == part, target_label[shape_idx] == part)) iou = 1 if U == 0 else float(I) / U part_ious.append(iou) shape_ious.append(np.mean(part_ious)) print(f'mIOU for us Custom: {np.mean(shape_ious)}') print('Done.')