Refactor
This commit is contained in:
		
							
								
								
									
										7
									
								
								.idea/dictionaries/illium.xml
									
									
									
										generated
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								.idea/dictionaries/illium.xml
									
									
									
										generated
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,7 @@ | ||||
| <component name="ProjectDictionaryState"> | ||||
|   <dictionary name="illium"> | ||||
|     <words> | ||||
|       <w>cudnn</w> | ||||
|     </words> | ||||
|   </dictionary> | ||||
| </component> | ||||
							
								
								
									
										13
									
								
								.idea/inspectionProfiles/Project_Default.xml
									
									
									
										generated
									
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								.idea/inspectionProfiles/Project_Default.xml
									
									
									
										generated
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,13 @@ | ||||
| <component name="InspectionProjectProfileManager"> | ||||
|   <profile version="1.0"> | ||||
|     <option name="myName" value="Project Default" /> | ||||
|     <inspection_tool class="PyUnresolvedReferencesInspection" enabled="true" level="WARNING" enabled_by_default="true"> | ||||
|       <option name="ignoredIdentifiers"> | ||||
|         <list> | ||||
|           <option value="torch.cuda.manual_seed" /> | ||||
|           <option value="torch.backends.cudnn.enabled" /> | ||||
|         </list> | ||||
|       </option> | ||||
|     </inspection_tool> | ||||
|   </profile> | ||||
| </component> | ||||
							
								
								
									
										131
									
								
								main.py
									
									
									
									
									
								
							
							
						
						
									
										131
									
								
								main.py
									
									
									
									
									
								
							| @@ -1,7 +1,12 @@ | ||||
| """ | ||||
| Modified from https://github.com/fxia22/pointnet.pytorch/blob/master/utils/train_segmentation.py | ||||
|  | ||||
| Cloned from Git: | ||||
| https://github.com/dragonbook/pointnet2-pytorch/blob/master/main.py | ||||
| """ | ||||
| import os, sys | ||||
|  | ||||
| import os | ||||
| import sys | ||||
| import random | ||||
| import numpy as np | ||||
| import argparse | ||||
| @@ -10,14 +15,12 @@ from torch.utils.data import DataLoader | ||||
| import torch.nn as nn | ||||
| import torch.optim as optim | ||||
| import torch.nn.functional as F | ||||
| from torch import autograd | ||||
| import torch.backends.cudnn as cudnn | ||||
|  | ||||
| from dataset.shapenet import ShapeNetPartSegDataset | ||||
| from model.pointnet2_part_seg import PointNet2PartSegmentNet | ||||
| import torch_geometric.transforms as GT | ||||
|  | ||||
| import time | ||||
|  | ||||
| fs_root = os.path.splitdrive(sys.executable)[0] | ||||
|  | ||||
| @@ -62,10 +65,10 @@ if __name__ == '__main__': | ||||
|         test_transform = GT.Compose([GT.NormalizeScale(), ]) | ||||
|  | ||||
|     dataset = ShapeNetPartSegDataset(root_dir=opt.dataset, train=True, transform=train_transform, npoints=opt.npoints) | ||||
|     dataloader = DataLoader(dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers) | ||||
|     dataLoader = DataLoader(dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers) | ||||
|  | ||||
|     test_dataset = ShapeNetPartSegDataset(root_dir=opt.dataset, train=False, transform=test_transform, npoints=opt.npoints) | ||||
|     test_dataloader = DataLoader(test_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers) | ||||
|     test_dataLoader = DataLoader(test_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers) | ||||
|  | ||||
|     num_classes = dataset.num_classes() | ||||
|  | ||||
| @@ -76,17 +79,15 @@ if __name__ == '__main__': | ||||
|     try: | ||||
|         os.mkdir(opt.outf) | ||||
|     except OSError: | ||||
|         #FIXME: Why is this just a pass? What about missing permissions? LOL | ||||
|         # FIXME: Why is this just a pass? What about missing permissions? LOL | ||||
|         pass | ||||
|  | ||||
|  | ||||
|     ## Model, criterion and optimizer | ||||
|     # Model, criterion and optimizer | ||||
|     print('Construct model ..') | ||||
|     device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | ||||
|     dtype = torch.float | ||||
|     print('cudnn.enabled: ', torch.backends.cudnn.enabled) | ||||
|  | ||||
|  | ||||
|     net = PointNet2PartSegmentNet(num_classes) | ||||
|  | ||||
|     if opt.model != '': | ||||
| @@ -95,89 +96,93 @@ if __name__ == '__main__': | ||||
|  | ||||
|     criterion = nn.NLLLoss() | ||||
|     optimizer = optim.Adam(net.parameters()) | ||||
|     if True: | ||||
|         ## Train | ||||
|         print('Training ..') | ||||
|         blue = lambda x: '\033[94m' + x + '\033[0m' | ||||
|         num_batch = len(dataset) // opt.batch_size | ||||
|         test_per_batches = opt.test_per_batches | ||||
|  | ||||
|         print('number of epoches: ', opt.nepoch) | ||||
|         print('number of batches per epoch: ', num_batch) | ||||
|         print('run test per batches: ', test_per_batches) | ||||
|     # Train | ||||
|     print('Training ..') | ||||
|     blue = lambda x: f'\033[94m {x} \033[0m' | ||||
|     num_batch = len(dataset) // opt.batch_size | ||||
|     test_per_batches = opt.test_per_batches | ||||
|  | ||||
|         for epoch in range(opt.nepoch): | ||||
|             print('Epoch {}, total epoches {}'.format(epoch+1, opt.nepoch)) | ||||
|     print('number of epoches: ', opt.nepoch) | ||||
|     print('number of batches per epoch: ', num_batch) | ||||
|     print('run test per batches: ', test_per_batches) | ||||
|  | ||||
|             net.train() | ||||
|     for epoch in range(opt.nepoch): | ||||
|         print('Epoch {}, total epoches {}'.format(epoch+1, opt.nepoch)) | ||||
|  | ||||
|             for batch_idx, sample in enumerate(dataloader): | ||||
|                 # points: (batch_size, n, 3) | ||||
|                 # labels: (batch_size, n) | ||||
|                 points, labels = sample['points'], sample['labels'] | ||||
|                 points = points.transpose(1, 2).contiguous()  # (batch_size, 3, n) | ||||
|                 points, labels = points.to(device, dtype), labels.to(device, torch.long) | ||||
|         net.train() | ||||
|  | ||||
|                 optimizer.zero_grad() | ||||
|         for batch_idx, sample in enumerate(dataLoader): | ||||
|             # points: (batch_size, n, 3) | ||||
|             # labels: (batch_size, n) | ||||
|             points, labels = sample['points'], sample['labels'] | ||||
|             points = points.transpose(1, 2).contiguous()  # (batch_size, 3, n) | ||||
|             points, labels = points.to(device, dtype), labels.to(device, torch.long) | ||||
|  | ||||
|                 pred = net(points)  # (batch_size, n, num_classes) | ||||
|                 pred = pred.view(-1, num_classes)  # (batch_size * n, num_classes) | ||||
|                 target = labels.view(-1, 1)[:, 0] | ||||
|             optimizer.zero_grad() | ||||
|  | ||||
|                 loss = F.nll_loss(pred, target) | ||||
|                 loss.backward() | ||||
|             pred = net(points)  # (batch_size, n, num_classes) | ||||
|             pred = pred.view(-1, num_classes)  # (batch_size * n, num_classes) | ||||
|             target = labels.view(-1, 1)[:, 0] | ||||
|  | ||||
|                 optimizer.step() | ||||
|             loss = F.nll_loss(pred, target) | ||||
|             loss.backward() | ||||
|  | ||||
|                 ## | ||||
|                 pred_label = pred.detach().max(1)[1] | ||||
|                 correct = pred_label.eq(target.detach()).cpu().sum() | ||||
|                 total = pred_label.shape[0] | ||||
|             optimizer.step() | ||||
|  | ||||
|                 print('[{}: {}/{}] train loss: {} accuracy: {}'.format(epoch, batch_idx, num_batch, loss.item(), float(correct.item())/total)) | ||||
|             ## | ||||
|             pred_label = pred.detach().max(1)[1] | ||||
|             correct = pred_label.eq(target.detach()).cpu().sum() | ||||
|             total = pred_label.shape[0] | ||||
|  | ||||
|                 ## | ||||
|                 if batch_idx % test_per_batches == 0: | ||||
|                     print('Run a test batch') | ||||
|                     net.eval() | ||||
|             print(f'[{epoch}: {batch_idx}/{num_batch}] train loss: {loss.item()} ' | ||||
|                   f'accuracy: {float(correct.item())/total}') | ||||
|  | ||||
|                     with torch.no_grad(): | ||||
|                         batch_idx, sample = next(enumerate(test_dataloader)) | ||||
|             ## | ||||
|             if batch_idx % test_per_batches == 0: | ||||
|                 print('Run a test batch') | ||||
|                 net.eval() | ||||
|  | ||||
|                         points, labels = sample['points'], sample['labels'] | ||||
|                         points = points.transpose(1, 2).contiguous() | ||||
|                         points, labels = points.to(device, dtype), labels.to(device, torch.long) | ||||
|                 with torch.no_grad(): | ||||
|                     batch_idx, sample = next(enumerate(test_dataLoader)) | ||||
|  | ||||
|                         pred = net(points) | ||||
|                         pred = pred.view(-1, num_classes) | ||||
|                         target = labels.view(-1, 1)[:, 0] | ||||
|                     points, labels = sample['points'], sample['labels'] | ||||
|                     points = points.transpose(1, 2).contiguous() | ||||
|                     points, labels = points.to(device, dtype), labels.to(device, torch.long) | ||||
|  | ||||
|                         target += 1 if -1 in target else 0 | ||||
|                         loss = F.nll_loss(pred, target) | ||||
|                     pred = net(points) | ||||
|                     pred = pred.view(-1, num_classes) | ||||
|                     target = labels.view(-1, 1)[:, 0] | ||||
|  | ||||
|                         pred_label = pred.detach().max(1)[1] | ||||
|                         correct = pred_label.eq(target.detach()).cpu().sum() | ||||
|                         total = pred_label.shape[0] | ||||
|                         print('[{}: {}/{}] {} loss: {} accuracy: {}'.format(epoch, batch_idx, num_batch, blue('test'), loss.item(), float(correct.item())/total)) | ||||
|                     # FixMe: Hacky Fix to get the labels right. But this won't fix the original problem | ||||
|                     target += 1 if -1 in target else 0 | ||||
|                     loss = F.nll_loss(pred, target) | ||||
|  | ||||
|                     # Back to training mode | ||||
|                     net.train() | ||||
|                     pred_label = pred.detach().max(1)[1] | ||||
|                     correct = pred_label.eq(target.detach()).cpu().sum() | ||||
|                     total = pred_label.shape[0] | ||||
|                     print(f'[{epoch}: {batch_idx}/{num_batch}] {blue("test")} loss: {loss.item()} ' | ||||
|                           f'accuracy: {float(correct.item())/total}') | ||||
|  | ||||
|             torch.save(net.state_dict(), f'{opt.outf}/seg_model_custom_{epoch}.pth') | ||||
|                 # Back to training mode | ||||
|                 net.train() | ||||
|  | ||||
|         torch.save(net.state_dict(), f'{opt.outf}/seg_model_custom_{epoch}.pth') | ||||
|  | ||||
|     ## Benchmarm mIOU | ||||
|     # Benchmarm mIOU | ||||
|     # Link to relvant Paper | ||||
|     # https://arxiv.org/abs/1806.01896 | ||||
|     net.eval() | ||||
|     shape_ious = [] | ||||
|  | ||||
|     with torch.no_grad(): | ||||
|         for batch_idx, sample in enumerate(test_dataloader): | ||||
|         for batch_idx, sample in enumerate(test_dataLoader): | ||||
|             points, labels = sample['points'], sample['labels'] | ||||
|             points = points.transpose(1, 2).contiguous() | ||||
|             points = points.to(device, dtype) | ||||
|  | ||||
|             # start_t = time.time() | ||||
|             pred = net(points) # (batch_size, n, num_classes) | ||||
|             pred = net(points)  # (batch_size, n, num_classes) | ||||
|             # print('batch inference forward time used: {} ms'.format(time.time() - start_t)) | ||||
|  | ||||
|             pred_label = pred.max(2)[1] | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Si11ium
					Si11ium