mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-23 07:16:44 +02:00
Getting Dirty
Viz
This commit is contained in:
parent
27f5abad64
commit
7704a98dcc
@ -147,12 +147,13 @@ class BaseFactory:
|
||||
pos_x, pos_y = positions[0] # a.flatten()
|
||||
return pos_x, pos_y
|
||||
|
||||
def free_cells(self, excluded_slices: Union[None, List, int] = None) -> np.ndarray:
|
||||
def free_cells(self, excluded_slices: Union[None, List[int], int] = None) -> np.ndarray:
|
||||
excluded_slices = excluded_slices or []
|
||||
assert isinstance(excluded_slices, (int, list))
|
||||
excluded_slices = excluded_slices if isinstance(excluded_slices, list) else [excluded_slices]
|
||||
|
||||
state = self.state
|
||||
|
||||
if excluded_slices:
|
||||
# Todo: Is there a cleaner way?
|
||||
inds = list(range(self.state.shape[0]))
|
||||
@ -160,6 +161,7 @@ class BaseFactory:
|
||||
state = self.state[[x for x in inds if x not in excluded_slices]]
|
||||
|
||||
free_cells = state.sum(0)
|
||||
free_cells[excluded_slices] = 0
|
||||
free_cells = np.argwhere(free_cells == h.IS_FREE_CELL)
|
||||
np.random.shuffle(free_cells)
|
||||
return free_cells
|
||||
|
Loading…
x
Reference in New Issue
Block a user