Normalization and transforms for batch_to_data class

This commit is contained in:
Si11ium
2020-06-15 15:14:08 +02:00
parent bc70f42c74
commit 4898e98851
8 changed files with 26 additions and 24 deletions

View File

@@ -25,12 +25,10 @@ class _Point_Dataset(ABC, Dataset):
samplers = dict(fps=FarthestpointSampling, rnd=RandomSampling)
def __init__(self, root=Path('data'), norm_as_feature=True, sampling_k=2048, sampling='rnd',
transforms=None, load_preprocessed=True, split='train', dense_output=False, *args, **kwargs):
transforms=None, load_preprocessed=True, split='train', *args, **kwargs):
super(_Point_Dataset, self).__init__()
self.setting: str
self.dense_output = dense_output
self.split = split
self.norm_as_feature = norm_as_feature
self.load_preprocessed = load_preprocessed

View File

@@ -72,8 +72,12 @@ class GridClusters(_Point_Dataset):
while sample_idxs.shape[0] < self.sampling_k:
sample_idxs = np.concatenate((sample_idxs, sample_idxs))[:self.sampling_k]
return (normal[sample_idxs].astype(np.float),
position[sample_idxs].astype(np.float),
normal = normal[sample_idxs].astype(np.float)
position = position[sample_idxs].astype(np.float)
normal = self.transforms(normal)
position = self.transforms(position)
return (normal, position,
label[sample_idxs].astype(np.int),
cl_label[sample_idxs].astype(np.int)
)