dataset fixing

This commit is contained in:
Si11ium 2020-06-19 15:37:44 +02:00
parent 49b373a8a1
commit b3c67bab40
3 changed files with 20 additions and 11 deletions

View File

@ -56,14 +56,18 @@ class CustomShapeNet(InMemoryDataset):
def processed_file_names(self): def processed_file_names(self):
return [f'{self.mode}.pt'] return [f'{self.mode}.pt']
def __download(self): def check_and_resolve_cloud_count(self):
dir_count = len([name for name in os.listdir(self.raw_dir) if os.path.isdir(os.path.join(self.raw_dir, name))]) if self.raw_dir.exists():
dir_count = len([name for name in os.listdir(self.raw_dir) if os.path.isdir(os.path.join(self.raw_dir, name))])
if dir_count: if dir_count:
print(f'{dir_count} folders have been found....') print(f'{dir_count} folders have been found....')
return dir_count return dir_count
warn(ResourceWarning("No raw pointclouds have been found. Was this intentional?")) else:
return dir_count warn(ResourceWarning("No raw pointclouds have been found. Was this intentional?"))
return dir_count
warn(ResourceWarning("The raw data folder does not exist. Was this intentional?"))
return -1
@property @property
def num_classes(self): def num_classes(self):
@ -87,6 +91,10 @@ class CustomShapeNet(InMemoryDataset):
print('Dataset Loaded') print('Dataset Loaded')
break break
except FileNotFoundError: except FileNotFoundError:
status = self.check_and_resolve_cloud_count()
if status in [0, -1]:
print(f'No dataset was loaded, status: {status}')
break
self.process() self.process()
continue continue
return data, slices return data, slices

View File

@ -42,9 +42,8 @@ def restore_logger_and_model(log_dir):
def predict_prim_type(input_pc, model): def predict_prim_type(input_pc, model):
input_data = dict(norm=torch.tensor(np.array([input_pc[:, 3:6]], np.float)), input_data = dict(norm=torch.tensor(np.array([input_pc[:, 3:6]], np.float)).unsqueeze(0),
pos=torch.tensor(input_pc[:, 0:3]), pos=torch.tensor(input_pc[:, 0:3]).unsqueeze(0),
y=np.zeros(input_pc.shape[0])
) )
batch_to_data = BatchToData() batch_to_data = BatchToData()
@ -71,7 +70,7 @@ if __name__ == '__main__':
input_pc = normalize_pointcloud(input_pc) input_pc = normalize_pointcloud(input_pc)
grid_clusters = cluster_cubes(input_pc, [1,1,1], 2048) grid_clusters = cluster_cubes(input_pc, [1,1,1], 1024)
ps.init() ps.init()

View File

@ -1,6 +1,8 @@
import numpy as np import numpy as np
from sklearn.cluster import DBSCAN from sklearn.cluster import DBSCAN
import open3d as o3d
from pyod.models.knn import KNN from pyod.models.knn import KNN
from pyod.models.sod import SOD from pyod.models.sod import SOD
from pyod.models.abod import ABOD from pyod.models.abod import ABOD