diff --git a/environments/factory/base_factory.py b/environments/factory/base_factory.py index 0627aba..a8b4dda 100644 --- a/environments/factory/base_factory.py +++ b/environments/factory/base_factory.py @@ -147,12 +147,13 @@ class BaseFactory: pos_x, pos_y = positions[0] # a.flatten() return pos_x, pos_y - def free_cells(self, excluded_slices: Union[None, List, int] = None) -> np.ndarray: + def free_cells(self, excluded_slices: Union[None, List[int], int] = None) -> np.ndarray: excluded_slices = excluded_slices or [] assert isinstance(excluded_slices, (int, list)) excluded_slices = excluded_slices if isinstance(excluded_slices, list) else [excluded_slices] state = self.state + if excluded_slices: # Todo: Is there a cleaner way? inds = list(range(self.state.shape[0])) @@ -160,6 +161,7 @@ class BaseFactory: state = self.state[[x for x in inds if x not in excluded_slices]] free_cells = state.sum(0) + free_cells[excluded_slices] = 0 free_cells = np.argwhere(free_cells == h.IS_FREE_CELL) np.random.shuffle(free_cells) return free_cells