mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-06-20 03:08:08 +02:00
Merge remote-tracking branch 'origin/main'
This commit is contained in:
@ -9,6 +9,8 @@ from environments.factory.base_factory import BaseFactory, AgentState
|
||||
from environments import helpers as h
|
||||
|
||||
from environments.factory.renderer import Renderer
|
||||
from environments.factory.renderer import Entity
|
||||
|
||||
|
||||
DIRT_INDEX = -1
|
||||
|
||||
@ -37,12 +39,12 @@ class GettingDirty(BaseFactory):
|
||||
if not self.renderer: # lazy init
|
||||
height, width = self.state.shape[1:]
|
||||
self.renderer = Renderer(width, height, view_radius=0)
|
||||
self.renderer.render(
|
||||
OrderedDict(dirt=np.argwhere(self.state[DIRT_INDEX] > h.IS_FREE_CELL),
|
||||
wall=np.argwhere(self.state[h.LEVEL_IDX] > h.IS_FREE_CELL),
|
||||
agent=np.argwhere(self.state[h.AGENT_START_IDX] > h.IS_FREE_CELL)
|
||||
)
|
||||
)
|
||||
|
||||
dirt = [Entity('dirt', [x, y], self.state[DIRT_INDEX, x, y]) for x, y in np.argwhere(self.state[DIRT_INDEX] > h.IS_FREE_CELL)]
|
||||
walls = [Entity('dirt', pos) for pos in np.argwhere(self.state[h.LEVEL_IDX] > h.IS_FREE_CELL)]
|
||||
agents = [Entity('agent', pos) for pos in np.argwhere(self.state[h.AGENT_START_IDX] > h.IS_FREE_CELL)]
|
||||
|
||||
self.renderer.render(OrderedDict(dirt=dirt, wall=walls, agent=agents))
|
||||
|
||||
def spawn_dirt(self) -> None:
|
||||
free_for_dirt = self.free_cells(excluded_slices=DIRT_INDEX)
|
||||
|
Reference in New Issue
Block a user