From 2acf91335f6a8cb45ec2d76699bebbc5748b7eba Mon Sep 17 00:00:00 2001
From: Si11ium <steffen.illium@ifi.lmu.de>
Date: Sun, 7 Jun 2020 16:47:51 +0200
Subject: [PATCH] Grid Clusters.

---
 point_toolset/point_io.py |  8 +++++---
 utils/tools.py            | 13 +++++--------
 2 files changed, 10 insertions(+), 11 deletions(-)

diff --git a/point_toolset/point_io.py b/point_toolset/point_io.py
index 49e1ebc..45ce200 100644
--- a/point_toolset/point_io.py
+++ b/point_toolset/point_io.py
@@ -6,19 +6,21 @@ class BatchToData(object):
     def __init__(self):
         super(BatchToData, self).__init__()
 
-    def __call__(self, batch_x: torch.Tensor, batch_pos: torch.Tensor, batch_y: torch.Tensor):
+    def __call__(self, batch_x: torch.Tensor, batch_pos: torch.Tensor,
+                 batch_y_l: torch.Tensor, batch_y_c: torch.Tensor):
         # Convert to torch_geometric.data.Data type
         # data = data.transpose(1, 2).contiguous()
         batch_size, num_points, _ = batch_x.shape  # (batch_size, num_points, 3)
 
         x = batch_x.reshape(batch_size * num_points, -1)
         pos = batch_pos.reshape(batch_size * num_points, -1)
-        batch_y = batch_y.reshape(batch_size * num_points)
+        batch_y_l = batch_y_l.reshape(batch_size * num_points)
+        batch_y_c = batch_y_c.reshape(batch_size * num_points)
         batch = torch.zeros((batch_size, num_points), device=pos.device, dtype=torch.long)
         for i in range(batch_size):
             batch[i] = i
         batch = batch.view(-1)
 
         data = Data()
-        data.x, data.pos, data.batch, data.y = x, pos, batch, batch_y
+        data.x, data.pos, data.batch, data.yl, data.yc = x, pos, batch, batch_y_l, batch_y_c
         return data
diff --git a/utils/tools.py b/utils/tools.py
index 68f5a21..6e291f8 100644
--- a/utils/tools.py
+++ b/utils/tools.py
@@ -3,20 +3,17 @@ import shelve
 from pathlib import Path
 
 import numpy as np
-
-from utils.project_config import GlobalVar
+import torch
+import random
 
 
-def to_one_hot(idx_array):
-    one_hot = np.zeros((idx_array.size, len(GlobalVar.classes)))
+def to_one_hot(idx_array, max_classes):
+    one_hot = np.zeros((idx_array.size, max_classes))
     one_hot[np.arange(idx_array.size), idx_array] = 1
     return one_hot
 
 
 def fix_all_random_seeds(config_obj):
-    import numpy as np
-    import torch
-    import random
     np.random.seed(config_obj.main.seed)
     torch.manual_seed(config_obj.main.seed)
     random.seed(config_obj.main.seed)
@@ -39,4 +36,4 @@ def load_from_shelve(file_path, key):
 
 def check_path(file_path):
     assert isinstance(file_path, Path)
-    assert str(file_path).endswith('.pik')
\ No newline at end of file
+    assert str(file_path).endswith('.pik')