Refactor
This commit is contained in:
parent
0159363642
commit
3a1eebf13b
7
.idea/dictionaries/illium.xml
generated
Normal file
7
.idea/dictionaries/illium.xml
generated
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
<component name="ProjectDictionaryState">
|
||||||
|
<dictionary name="illium">
|
||||||
|
<words>
|
||||||
|
<w>cudnn</w>
|
||||||
|
</words>
|
||||||
|
</dictionary>
|
||||||
|
</component>
|
13
.idea/inspectionProfiles/Project_Default.xml
generated
Normal file
13
.idea/inspectionProfiles/Project_Default.xml
generated
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
<component name="InspectionProjectProfileManager">
|
||||||
|
<profile version="1.0">
|
||||||
|
<option name="myName" value="Project Default" />
|
||||||
|
<inspection_tool class="PyUnresolvedReferencesInspection" enabled="true" level="WARNING" enabled_by_default="true">
|
||||||
|
<option name="ignoredIdentifiers">
|
||||||
|
<list>
|
||||||
|
<option value="torch.cuda.manual_seed" />
|
||||||
|
<option value="torch.backends.cudnn.enabled" />
|
||||||
|
</list>
|
||||||
|
</option>
|
||||||
|
</inspection_tool>
|
||||||
|
</profile>
|
||||||
|
</component>
|
131
main.py
131
main.py
@ -1,7 +1,12 @@
|
|||||||
"""
|
"""
|
||||||
Modified from https://github.com/fxia22/pointnet.pytorch/blob/master/utils/train_segmentation.py
|
Modified from https://github.com/fxia22/pointnet.pytorch/blob/master/utils/train_segmentation.py
|
||||||
|
|
||||||
|
Cloned from Git:
|
||||||
|
https://github.com/dragonbook/pointnet2-pytorch/blob/master/main.py
|
||||||
"""
|
"""
|
||||||
import os, sys
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
import random
|
import random
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import argparse
|
import argparse
|
||||||
@ -10,14 +15,12 @@ from torch.utils.data import DataLoader
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import autograd
|
|
||||||
import torch.backends.cudnn as cudnn
|
import torch.backends.cudnn as cudnn
|
||||||
|
|
||||||
from dataset.shapenet import ShapeNetPartSegDataset
|
from dataset.shapenet import ShapeNetPartSegDataset
|
||||||
from model.pointnet2_part_seg import PointNet2PartSegmentNet
|
from model.pointnet2_part_seg import PointNet2PartSegmentNet
|
||||||
import torch_geometric.transforms as GT
|
import torch_geometric.transforms as GT
|
||||||
|
|
||||||
import time
|
|
||||||
|
|
||||||
fs_root = os.path.splitdrive(sys.executable)[0]
|
fs_root = os.path.splitdrive(sys.executable)[0]
|
||||||
|
|
||||||
@ -62,10 +65,10 @@ if __name__ == '__main__':
|
|||||||
test_transform = GT.Compose([GT.NormalizeScale(), ])
|
test_transform = GT.Compose([GT.NormalizeScale(), ])
|
||||||
|
|
||||||
dataset = ShapeNetPartSegDataset(root_dir=opt.dataset, train=True, transform=train_transform, npoints=opt.npoints)
|
dataset = ShapeNetPartSegDataset(root_dir=opt.dataset, train=True, transform=train_transform, npoints=opt.npoints)
|
||||||
dataloader = DataLoader(dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers)
|
dataLoader = DataLoader(dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers)
|
||||||
|
|
||||||
test_dataset = ShapeNetPartSegDataset(root_dir=opt.dataset, train=False, transform=test_transform, npoints=opt.npoints)
|
test_dataset = ShapeNetPartSegDataset(root_dir=opt.dataset, train=False, transform=test_transform, npoints=opt.npoints)
|
||||||
test_dataloader = DataLoader(test_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers)
|
test_dataLoader = DataLoader(test_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers)
|
||||||
|
|
||||||
num_classes = dataset.num_classes()
|
num_classes = dataset.num_classes()
|
||||||
|
|
||||||
@ -76,17 +79,15 @@ if __name__ == '__main__':
|
|||||||
try:
|
try:
|
||||||
os.mkdir(opt.outf)
|
os.mkdir(opt.outf)
|
||||||
except OSError:
|
except OSError:
|
||||||
#FIXME: Why is this just a pass? What about missing permissions? LOL
|
# FIXME: Why is this just a pass? What about missing permissions? LOL
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# Model, criterion and optimizer
|
||||||
## Model, criterion and optimizer
|
|
||||||
print('Construct model ..')
|
print('Construct model ..')
|
||||||
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||||
dtype = torch.float
|
dtype = torch.float
|
||||||
print('cudnn.enabled: ', torch.backends.cudnn.enabled)
|
print('cudnn.enabled: ', torch.backends.cudnn.enabled)
|
||||||
|
|
||||||
|
|
||||||
net = PointNet2PartSegmentNet(num_classes)
|
net = PointNet2PartSegmentNet(num_classes)
|
||||||
|
|
||||||
if opt.model != '':
|
if opt.model != '':
|
||||||
@ -95,89 +96,93 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
criterion = nn.NLLLoss()
|
criterion = nn.NLLLoss()
|
||||||
optimizer = optim.Adam(net.parameters())
|
optimizer = optim.Adam(net.parameters())
|
||||||
if True:
|
|
||||||
## Train
|
|
||||||
print('Training ..')
|
|
||||||
blue = lambda x: '\033[94m' + x + '\033[0m'
|
|
||||||
num_batch = len(dataset) // opt.batch_size
|
|
||||||
test_per_batches = opt.test_per_batches
|
|
||||||
|
|
||||||
print('number of epoches: ', opt.nepoch)
|
# Train
|
||||||
print('number of batches per epoch: ', num_batch)
|
print('Training ..')
|
||||||
print('run test per batches: ', test_per_batches)
|
blue = lambda x: f'\033[94m {x} \033[0m'
|
||||||
|
num_batch = len(dataset) // opt.batch_size
|
||||||
|
test_per_batches = opt.test_per_batches
|
||||||
|
|
||||||
for epoch in range(opt.nepoch):
|
print('number of epoches: ', opt.nepoch)
|
||||||
print('Epoch {}, total epoches {}'.format(epoch+1, opt.nepoch))
|
print('number of batches per epoch: ', num_batch)
|
||||||
|
print('run test per batches: ', test_per_batches)
|
||||||
|
|
||||||
net.train()
|
for epoch in range(opt.nepoch):
|
||||||
|
print('Epoch {}, total epoches {}'.format(epoch+1, opt.nepoch))
|
||||||
|
|
||||||
for batch_idx, sample in enumerate(dataloader):
|
net.train()
|
||||||
# points: (batch_size, n, 3)
|
|
||||||
# labels: (batch_size, n)
|
|
||||||
points, labels = sample['points'], sample['labels']
|
|
||||||
points = points.transpose(1, 2).contiguous() # (batch_size, 3, n)
|
|
||||||
points, labels = points.to(device, dtype), labels.to(device, torch.long)
|
|
||||||
|
|
||||||
optimizer.zero_grad()
|
for batch_idx, sample in enumerate(dataLoader):
|
||||||
|
# points: (batch_size, n, 3)
|
||||||
|
# labels: (batch_size, n)
|
||||||
|
points, labels = sample['points'], sample['labels']
|
||||||
|
points = points.transpose(1, 2).contiguous() # (batch_size, 3, n)
|
||||||
|
points, labels = points.to(device, dtype), labels.to(device, torch.long)
|
||||||
|
|
||||||
pred = net(points) # (batch_size, n, num_classes)
|
optimizer.zero_grad()
|
||||||
pred = pred.view(-1, num_classes) # (batch_size * n, num_classes)
|
|
||||||
target = labels.view(-1, 1)[:, 0]
|
|
||||||
|
|
||||||
loss = F.nll_loss(pred, target)
|
pred = net(points) # (batch_size, n, num_classes)
|
||||||
loss.backward()
|
pred = pred.view(-1, num_classes) # (batch_size * n, num_classes)
|
||||||
|
target = labels.view(-1, 1)[:, 0]
|
||||||
|
|
||||||
optimizer.step()
|
loss = F.nll_loss(pred, target)
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
##
|
optimizer.step()
|
||||||
pred_label = pred.detach().max(1)[1]
|
|
||||||
correct = pred_label.eq(target.detach()).cpu().sum()
|
|
||||||
total = pred_label.shape[0]
|
|
||||||
|
|
||||||
print('[{}: {}/{}] train loss: {} accuracy: {}'.format(epoch, batch_idx, num_batch, loss.item(), float(correct.item())/total))
|
##
|
||||||
|
pred_label = pred.detach().max(1)[1]
|
||||||
|
correct = pred_label.eq(target.detach()).cpu().sum()
|
||||||
|
total = pred_label.shape[0]
|
||||||
|
|
||||||
##
|
print(f'[{epoch}: {batch_idx}/{num_batch}] train loss: {loss.item()} '
|
||||||
if batch_idx % test_per_batches == 0:
|
f'accuracy: {float(correct.item())/total}')
|
||||||
print('Run a test batch')
|
|
||||||
net.eval()
|
|
||||||
|
|
||||||
with torch.no_grad():
|
##
|
||||||
batch_idx, sample = next(enumerate(test_dataloader))
|
if batch_idx % test_per_batches == 0:
|
||||||
|
print('Run a test batch')
|
||||||
|
net.eval()
|
||||||
|
|
||||||
points, labels = sample['points'], sample['labels']
|
with torch.no_grad():
|
||||||
points = points.transpose(1, 2).contiguous()
|
batch_idx, sample = next(enumerate(test_dataLoader))
|
||||||
points, labels = points.to(device, dtype), labels.to(device, torch.long)
|
|
||||||
|
|
||||||
pred = net(points)
|
points, labels = sample['points'], sample['labels']
|
||||||
pred = pred.view(-1, num_classes)
|
points = points.transpose(1, 2).contiguous()
|
||||||
target = labels.view(-1, 1)[:, 0]
|
points, labels = points.to(device, dtype), labels.to(device, torch.long)
|
||||||
|
|
||||||
target += 1 if -1 in target else 0
|
pred = net(points)
|
||||||
loss = F.nll_loss(pred, target)
|
pred = pred.view(-1, num_classes)
|
||||||
|
target = labels.view(-1, 1)[:, 0]
|
||||||
|
|
||||||
pred_label = pred.detach().max(1)[1]
|
# FixMe: Hacky Fix to get the labels right. But this won't fix the original problem
|
||||||
correct = pred_label.eq(target.detach()).cpu().sum()
|
target += 1 if -1 in target else 0
|
||||||
total = pred_label.shape[0]
|
loss = F.nll_loss(pred, target)
|
||||||
print('[{}: {}/{}] {} loss: {} accuracy: {}'.format(epoch, batch_idx, num_batch, blue('test'), loss.item(), float(correct.item())/total))
|
|
||||||
|
|
||||||
# Back to training mode
|
pred_label = pred.detach().max(1)[1]
|
||||||
net.train()
|
correct = pred_label.eq(target.detach()).cpu().sum()
|
||||||
|
total = pred_label.shape[0]
|
||||||
|
print(f'[{epoch}: {batch_idx}/{num_batch}] {blue("test")} loss: {loss.item()} '
|
||||||
|
f'accuracy: {float(correct.item())/total}')
|
||||||
|
|
||||||
torch.save(net.state_dict(), f'{opt.outf}/seg_model_custom_{epoch}.pth')
|
# Back to training mode
|
||||||
|
net.train()
|
||||||
|
|
||||||
|
torch.save(net.state_dict(), f'{opt.outf}/seg_model_custom_{epoch}.pth')
|
||||||
|
|
||||||
## Benchmarm mIOU
|
# Benchmarm mIOU
|
||||||
|
# Link to relvant Paper
|
||||||
|
# https://arxiv.org/abs/1806.01896
|
||||||
net.eval()
|
net.eval()
|
||||||
shape_ious = []
|
shape_ious = []
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for batch_idx, sample in enumerate(test_dataloader):
|
for batch_idx, sample in enumerate(test_dataLoader):
|
||||||
points, labels = sample['points'], sample['labels']
|
points, labels = sample['points'], sample['labels']
|
||||||
points = points.transpose(1, 2).contiguous()
|
points = points.transpose(1, 2).contiguous()
|
||||||
points = points.to(device, dtype)
|
points = points.to(device, dtype)
|
||||||
|
|
||||||
# start_t = time.time()
|
# start_t = time.time()
|
||||||
pred = net(points) # (batch_size, n, num_classes)
|
pred = net(points) # (batch_size, n, num_classes)
|
||||||
# print('batch inference forward time used: {} ms'.format(time.time() - start_t))
|
# print('batch inference forward time used: {} ms'.format(time.time() - start_t))
|
||||||
|
|
||||||
pred_label = pred.max(2)[1]
|
pred_label = pred.max(2)[1]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user