dataset fixing
This commit is contained in:
@ -11,9 +11,9 @@ class BatchToData(object):
|
||||
# Convert to torch_geometric.data.Data type
|
||||
|
||||
batch_pos = batch_dict['pos']
|
||||
batch_norm = batch_dict['norm']
|
||||
batch_y = batch_dict['y']
|
||||
batch_y_c = batch_dict['y_c']
|
||||
batch_norm = batch_dict.get('norm', None)
|
||||
batch_y = batch_dict.get('y', None)
|
||||
batch_y_c = batch_dict.get('y_c', None)
|
||||
|
||||
batch_size, num_points, _ = batch_pos.shape # (batch_size, num_points, 3)
|
||||
|
||||
|
Reference in New Issue
Block a user