mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-23 15:26:43 +02:00
Merge remote-tracking branch 'origin/main'
This commit is contained in:
commit
33916c4aed
@ -1,6 +1,16 @@
|
|||||||
import pygame
|
|
||||||
from pathlib import Path
|
|
||||||
import sys
|
import sys
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import numpy as np
|
||||||
|
from pathlib import Path
|
||||||
|
import pygame
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Entity:
|
||||||
|
name: str
|
||||||
|
pos: np.array
|
||||||
|
value: float = 1
|
||||||
|
|
||||||
|
|
||||||
class Renderer:
|
class Renderer:
|
||||||
BG_COLOR = (99, 110, 114)
|
BG_COLOR = (99, 110, 114)
|
||||||
@ -31,14 +41,15 @@ class Renderer:
|
|||||||
rect = pygame.Rect(x, y, self.cell_size, self.cell_size)
|
rect = pygame.Rect(x, y, self.cell_size, self.cell_size)
|
||||||
pygame.draw.rect(self.screen, Renderer.WHITE, rect, 1)
|
pygame.draw.rect(self.screen, Renderer.WHITE, rect, 1)
|
||||||
|
|
||||||
def blit_params(self, r, c, name):
|
def blit_params(self, entity, name):
|
||||||
|
r, c = entity.pos
|
||||||
img = self.assets[name]
|
img = self.assets[name]
|
||||||
|
img.set_alpha(255*entity.value)
|
||||||
o = self.cell_size//2
|
o = self.cell_size//2
|
||||||
r_, c_ = r*self.cell_size + o, c*self.cell_size + o
|
r_, c_ = r*self.cell_size + o, c*self.cell_size + o
|
||||||
rect = img.get_rect()
|
rect = img.get_rect()
|
||||||
rect.centerx, rect.centery = c_, r_
|
rect.centerx, rect.centery = c_, r_
|
||||||
|
return dict(source=img, dest=rect)
|
||||||
return img, rect
|
|
||||||
|
|
||||||
def load_asset(self, path, factor=1.0):
|
def load_asset(self, path, factor=1.0):
|
||||||
s = int(factor*self.cell_size)
|
s = int(factor*self.cell_size)
|
||||||
@ -52,15 +63,15 @@ class Renderer:
|
|||||||
pygame.quit()
|
pygame.quit()
|
||||||
sys.exit()
|
sys.exit()
|
||||||
self.fill_bg()
|
self.fill_bg()
|
||||||
for asset, positions in pos_dict.items():
|
for asset, entities in pos_dict.items():
|
||||||
for x, y in positions:
|
for entity in entities:
|
||||||
img, rect = self.blit_params(x, y, asset)
|
bp = self.blit_params(entity, asset)
|
||||||
if 'agent' in asset and self.view_radius > 0:
|
if 'agent' in asset and self.view_radius > 0:
|
||||||
visibility_rect = rect.inflate((self.view_radius*2)*self.cell_size, (self.view_radius*2)*self.cell_size)
|
visibility_rect = bp['dest'].inflate((self.view_radius*2)*self.cell_size, (self.view_radius*2)*self.cell_size)
|
||||||
shape_surf = pygame.Surface(visibility_rect.size, pygame.SRCALPHA)
|
shape_surf = pygame.Surface(visibility_rect.size, pygame.SRCALPHA)
|
||||||
pygame.draw.rect(shape_surf, self.PINK, shape_surf.get_rect())
|
pygame.draw.rect(shape_surf, self.PINK, shape_surf.get_rect())
|
||||||
self.screen.blit(shape_surf, visibility_rect)
|
self.screen.blit(shape_surf, visibility_rect)
|
||||||
self.screen.blit(img, rect)
|
self.screen.blit(**bp)
|
||||||
pygame.display.flip()
|
pygame.display.flip()
|
||||||
self.clock.tick(self.fps)
|
self.clock.tick(self.fps)
|
||||||
|
|
||||||
|
@ -9,6 +9,8 @@ from environments.factory.base_factory import BaseFactory, AgentState
|
|||||||
from environments import helpers as h
|
from environments import helpers as h
|
||||||
|
|
||||||
from environments.factory.renderer import Renderer
|
from environments.factory.renderer import Renderer
|
||||||
|
from environments.factory.renderer import Entity
|
||||||
|
|
||||||
|
|
||||||
DIRT_INDEX = -1
|
DIRT_INDEX = -1
|
||||||
|
|
||||||
@ -37,12 +39,12 @@ class GettingDirty(BaseFactory):
|
|||||||
if not self.renderer: # lazy init
|
if not self.renderer: # lazy init
|
||||||
height, width = self.state.shape[1:]
|
height, width = self.state.shape[1:]
|
||||||
self.renderer = Renderer(width, height, view_radius=0)
|
self.renderer = Renderer(width, height, view_radius=0)
|
||||||
self.renderer.render(
|
|
||||||
OrderedDict(dirt=np.argwhere(self.state[DIRT_INDEX] > h.IS_FREE_CELL),
|
dirt = [Entity('dirt', [x, y], self.state[DIRT_INDEX, x, y]) for x, y in np.argwhere(self.state[DIRT_INDEX] > h.IS_FREE_CELL)]
|
||||||
wall=np.argwhere(self.state[h.LEVEL_IDX] > h.IS_FREE_CELL),
|
walls = [Entity('dirt', pos) for pos in np.argwhere(self.state[h.LEVEL_IDX] > h.IS_FREE_CELL)]
|
||||||
agent=np.argwhere(self.state[h.AGENT_START_IDX] > h.IS_FREE_CELL)
|
agents = [Entity('agent', pos) for pos in np.argwhere(self.state[h.AGENT_START_IDX] > h.IS_FREE_CELL)]
|
||||||
)
|
|
||||||
)
|
self.renderer.render(OrderedDict(dirt=dirt, wall=walls, agent=agents))
|
||||||
|
|
||||||
def spawn_dirt(self) -> None:
|
def spawn_dirt(self) -> None:
|
||||||
free_for_dirt = self.free_cells(excluded_slices=DIRT_INDEX)
|
free_for_dirt = self.free_cells(excluded_slices=DIRT_INDEX)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user