This commit is contained in:
Si11ium 2019-07-30 08:52:07 +02:00
parent 0159363642
commit 3a1eebf13b
3 changed files with 88 additions and 63 deletions

7
.idea/dictionaries/illium.xml generated Normal file
View File

@ -0,0 +1,7 @@
<component name="ProjectDictionaryState">
<dictionary name="illium">
<words>
<w>cudnn</w>
</words>
</dictionary>
</component>

View File

@ -0,0 +1,13 @@
<component name="InspectionProjectProfileManager">
<profile version="1.0">
<option name="myName" value="Project Default" />
<inspection_tool class="PyUnresolvedReferencesInspection" enabled="true" level="WARNING" enabled_by_default="true">
<option name="ignoredIdentifiers">
<list>
<option value="torch.cuda.manual_seed" />
<option value="torch.backends.cudnn.enabled" />
</list>
</option>
</inspection_tool>
</profile>
</component>

131
main.py
View File

@ -1,7 +1,12 @@
""" """
Modified from https://github.com/fxia22/pointnet.pytorch/blob/master/utils/train_segmentation.py Modified from https://github.com/fxia22/pointnet.pytorch/blob/master/utils/train_segmentation.py
Cloned from Git:
https://github.com/dragonbook/pointnet2-pytorch/blob/master/main.py
""" """
import os, sys
import os
import sys
import random import random
import numpy as np import numpy as np
import argparse import argparse
@ -10,14 +15,12 @@ from torch.utils.data import DataLoader
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
import torch.nn.functional as F import torch.nn.functional as F
from torch import autograd
import torch.backends.cudnn as cudnn import torch.backends.cudnn as cudnn
from dataset.shapenet import ShapeNetPartSegDataset from dataset.shapenet import ShapeNetPartSegDataset
from model.pointnet2_part_seg import PointNet2PartSegmentNet from model.pointnet2_part_seg import PointNet2PartSegmentNet
import torch_geometric.transforms as GT import torch_geometric.transforms as GT
import time
fs_root = os.path.splitdrive(sys.executable)[0] fs_root = os.path.splitdrive(sys.executable)[0]
@ -62,10 +65,10 @@ if __name__ == '__main__':
test_transform = GT.Compose([GT.NormalizeScale(), ]) test_transform = GT.Compose([GT.NormalizeScale(), ])
dataset = ShapeNetPartSegDataset(root_dir=opt.dataset, train=True, transform=train_transform, npoints=opt.npoints) dataset = ShapeNetPartSegDataset(root_dir=opt.dataset, train=True, transform=train_transform, npoints=opt.npoints)
dataloader = DataLoader(dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers) dataLoader = DataLoader(dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers)
test_dataset = ShapeNetPartSegDataset(root_dir=opt.dataset, train=False, transform=test_transform, npoints=opt.npoints) test_dataset = ShapeNetPartSegDataset(root_dir=opt.dataset, train=False, transform=test_transform, npoints=opt.npoints)
test_dataloader = DataLoader(test_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers) test_dataLoader = DataLoader(test_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers)
num_classes = dataset.num_classes() num_classes = dataset.num_classes()
@ -76,17 +79,15 @@ if __name__ == '__main__':
try: try:
os.mkdir(opt.outf) os.mkdir(opt.outf)
except OSError: except OSError:
#FIXME: Why is this just a pass? What about missing permissions? LOL # FIXME: Why is this just a pass? What about missing permissions? LOL
pass pass
# Model, criterion and optimizer
## Model, criterion and optimizer
print('Construct model ..') print('Construct model ..')
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
dtype = torch.float dtype = torch.float
print('cudnn.enabled: ', torch.backends.cudnn.enabled) print('cudnn.enabled: ', torch.backends.cudnn.enabled)
net = PointNet2PartSegmentNet(num_classes) net = PointNet2PartSegmentNet(num_classes)
if opt.model != '': if opt.model != '':
@ -95,89 +96,93 @@ if __name__ == '__main__':
criterion = nn.NLLLoss() criterion = nn.NLLLoss()
optimizer = optim.Adam(net.parameters()) optimizer = optim.Adam(net.parameters())
if True:
## Train
print('Training ..')
blue = lambda x: '\033[94m' + x + '\033[0m'
num_batch = len(dataset) // opt.batch_size
test_per_batches = opt.test_per_batches
print('number of epoches: ', opt.nepoch) # Train
print('number of batches per epoch: ', num_batch) print('Training ..')
print('run test per batches: ', test_per_batches) blue = lambda x: f'\033[94m {x} \033[0m'
num_batch = len(dataset) // opt.batch_size
test_per_batches = opt.test_per_batches
for epoch in range(opt.nepoch): print('number of epoches: ', opt.nepoch)
print('Epoch {}, total epoches {}'.format(epoch+1, opt.nepoch)) print('number of batches per epoch: ', num_batch)
print('run test per batches: ', test_per_batches)
net.train() for epoch in range(opt.nepoch):
print('Epoch {}, total epoches {}'.format(epoch+1, opt.nepoch))
for batch_idx, sample in enumerate(dataloader): net.train()
# points: (batch_size, n, 3)
# labels: (batch_size, n)
points, labels = sample['points'], sample['labels']
points = points.transpose(1, 2).contiguous() # (batch_size, 3, n)
points, labels = points.to(device, dtype), labels.to(device, torch.long)
optimizer.zero_grad() for batch_idx, sample in enumerate(dataLoader):
# points: (batch_size, n, 3)
# labels: (batch_size, n)
points, labels = sample['points'], sample['labels']
points = points.transpose(1, 2).contiguous() # (batch_size, 3, n)
points, labels = points.to(device, dtype), labels.to(device, torch.long)
pred = net(points) # (batch_size, n, num_classes) optimizer.zero_grad()
pred = pred.view(-1, num_classes) # (batch_size * n, num_classes)
target = labels.view(-1, 1)[:, 0]
loss = F.nll_loss(pred, target) pred = net(points) # (batch_size, n, num_classes)
loss.backward() pred = pred.view(-1, num_classes) # (batch_size * n, num_classes)
target = labels.view(-1, 1)[:, 0]
optimizer.step() loss = F.nll_loss(pred, target)
loss.backward()
## optimizer.step()
pred_label = pred.detach().max(1)[1]
correct = pred_label.eq(target.detach()).cpu().sum()
total = pred_label.shape[0]
print('[{}: {}/{}] train loss: {} accuracy: {}'.format(epoch, batch_idx, num_batch, loss.item(), float(correct.item())/total)) ##
pred_label = pred.detach().max(1)[1]
correct = pred_label.eq(target.detach()).cpu().sum()
total = pred_label.shape[0]
## print(f'[{epoch}: {batch_idx}/{num_batch}] train loss: {loss.item()} '
if batch_idx % test_per_batches == 0: f'accuracy: {float(correct.item())/total}')
print('Run a test batch')
net.eval()
with torch.no_grad(): ##
batch_idx, sample = next(enumerate(test_dataloader)) if batch_idx % test_per_batches == 0:
print('Run a test batch')
net.eval()
points, labels = sample['points'], sample['labels'] with torch.no_grad():
points = points.transpose(1, 2).contiguous() batch_idx, sample = next(enumerate(test_dataLoader))
points, labels = points.to(device, dtype), labels.to(device, torch.long)
pred = net(points) points, labels = sample['points'], sample['labels']
pred = pred.view(-1, num_classes) points = points.transpose(1, 2).contiguous()
target = labels.view(-1, 1)[:, 0] points, labels = points.to(device, dtype), labels.to(device, torch.long)
target += 1 if -1 in target else 0 pred = net(points)
loss = F.nll_loss(pred, target) pred = pred.view(-1, num_classes)
target = labels.view(-1, 1)[:, 0]
pred_label = pred.detach().max(1)[1] # FixMe: Hacky Fix to get the labels right. But this won't fix the original problem
correct = pred_label.eq(target.detach()).cpu().sum() target += 1 if -1 in target else 0
total = pred_label.shape[0] loss = F.nll_loss(pred, target)
print('[{}: {}/{}] {} loss: {} accuracy: {}'.format(epoch, batch_idx, num_batch, blue('test'), loss.item(), float(correct.item())/total))
# Back to training mode pred_label = pred.detach().max(1)[1]
net.train() correct = pred_label.eq(target.detach()).cpu().sum()
total = pred_label.shape[0]
print(f'[{epoch}: {batch_idx}/{num_batch}] {blue("test")} loss: {loss.item()} '
f'accuracy: {float(correct.item())/total}')
torch.save(net.state_dict(), f'{opt.outf}/seg_model_custom_{epoch}.pth') # Back to training mode
net.train()
torch.save(net.state_dict(), f'{opt.outf}/seg_model_custom_{epoch}.pth')
## Benchmarm mIOU # Benchmarm mIOU
# Link to relvant Paper
# https://arxiv.org/abs/1806.01896
net.eval() net.eval()
shape_ious = [] shape_ious = []
with torch.no_grad(): with torch.no_grad():
for batch_idx, sample in enumerate(test_dataloader): for batch_idx, sample in enumerate(test_dataLoader):
points, labels = sample['points'], sample['labels'] points, labels = sample['points'], sample['labels']
points = points.transpose(1, 2).contiguous() points = points.transpose(1, 2).contiguous()
points = points.to(device, dtype) points = points.to(device, dtype)
# start_t = time.time() # start_t = time.time()
pred = net(points) # (batch_size, n, num_classes) pred = net(points) # (batch_size, n, num_classes)
# print('batch inference forward time used: {} ms'.format(time.time() - start_t)) # print('batch inference forward time used: {} ms'.format(time.time() - start_t))
pred_label = pred.max(2)[1] pred_label = pred.max(2)[1]