Normalization and transforms for batch_to_data class
This commit is contained in:
@@ -25,12 +25,10 @@ class _Point_Dataset(ABC, Dataset):
|
||||
samplers = dict(fps=FarthestpointSampling, rnd=RandomSampling)
|
||||
|
||||
def __init__(self, root=Path('data'), norm_as_feature=True, sampling_k=2048, sampling='rnd',
|
||||
transforms=None, load_preprocessed=True, split='train', dense_output=False, *args, **kwargs):
|
||||
transforms=None, load_preprocessed=True, split='train', *args, **kwargs):
|
||||
super(_Point_Dataset, self).__init__()
|
||||
|
||||
self.setting: str
|
||||
|
||||
self.dense_output = dense_output
|
||||
self.split = split
|
||||
self.norm_as_feature = norm_as_feature
|
||||
self.load_preprocessed = load_preprocessed
|
||||
|
||||
@@ -72,8 +72,12 @@ class GridClusters(_Point_Dataset):
|
||||
while sample_idxs.shape[0] < self.sampling_k:
|
||||
sample_idxs = np.concatenate((sample_idxs, sample_idxs))[:self.sampling_k]
|
||||
|
||||
return (normal[sample_idxs].astype(np.float),
|
||||
position[sample_idxs].astype(np.float),
|
||||
normal = normal[sample_idxs].astype(np.float)
|
||||
position = position[sample_idxs].astype(np.float)
|
||||
|
||||
normal = self.transforms(normal)
|
||||
position = self.transforms(position)
|
||||
return (normal, position,
|
||||
label[sample_idxs].astype(np.int),
|
||||
cl_label[sample_idxs].astype(np.int)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user