Merge remote-tracking branch 'origin/main'

This commit is contained in:
steffen-illium 2021-05-18 11:06:49 +02:00
commit 33916c4aed
2 changed files with 29 additions and 16 deletions

View File

@ -1,6 +1,16 @@
import pygame
from pathlib import Path
import sys import sys
from dataclasses import dataclass
import numpy as np
from pathlib import Path
import pygame
@dataclass
class Entity:
name: str
pos: np.array
value: float = 1
class Renderer: class Renderer:
BG_COLOR = (99, 110, 114) BG_COLOR = (99, 110, 114)
@ -31,14 +41,15 @@ class Renderer:
rect = pygame.Rect(x, y, self.cell_size, self.cell_size) rect = pygame.Rect(x, y, self.cell_size, self.cell_size)
pygame.draw.rect(self.screen, Renderer.WHITE, rect, 1) pygame.draw.rect(self.screen, Renderer.WHITE, rect, 1)
def blit_params(self, r, c, name): def blit_params(self, entity, name):
r, c = entity.pos
img = self.assets[name] img = self.assets[name]
img.set_alpha(255*entity.value)
o = self.cell_size//2 o = self.cell_size//2
r_, c_ = r*self.cell_size + o, c*self.cell_size + o r_, c_ = r*self.cell_size + o, c*self.cell_size + o
rect = img.get_rect() rect = img.get_rect()
rect.centerx, rect.centery = c_, r_ rect.centerx, rect.centery = c_, r_
return dict(source=img, dest=rect)
return img, rect
def load_asset(self, path, factor=1.0): def load_asset(self, path, factor=1.0):
s = int(factor*self.cell_size) s = int(factor*self.cell_size)
@ -52,15 +63,15 @@ class Renderer:
pygame.quit() pygame.quit()
sys.exit() sys.exit()
self.fill_bg() self.fill_bg()
for asset, positions in pos_dict.items(): for asset, entities in pos_dict.items():
for x, y in positions: for entity in entities:
img, rect = self.blit_params(x, y, asset) bp = self.blit_params(entity, asset)
if 'agent' in asset and self.view_radius > 0: if 'agent' in asset and self.view_radius > 0:
visibility_rect = rect.inflate((self.view_radius*2)*self.cell_size, (self.view_radius*2)*self.cell_size) visibility_rect = bp['dest'].inflate((self.view_radius*2)*self.cell_size, (self.view_radius*2)*self.cell_size)
shape_surf = pygame.Surface(visibility_rect.size, pygame.SRCALPHA) shape_surf = pygame.Surface(visibility_rect.size, pygame.SRCALPHA)
pygame.draw.rect(shape_surf, self.PINK, shape_surf.get_rect()) pygame.draw.rect(shape_surf, self.PINK, shape_surf.get_rect())
self.screen.blit(shape_surf, visibility_rect) self.screen.blit(shape_surf, visibility_rect)
self.screen.blit(img, rect) self.screen.blit(**bp)
pygame.display.flip() pygame.display.flip()
self.clock.tick(self.fps) self.clock.tick(self.fps)

View File

@ -9,6 +9,8 @@ from environments.factory.base_factory import BaseFactory, AgentState
from environments import helpers as h from environments import helpers as h
from environments.factory.renderer import Renderer from environments.factory.renderer import Renderer
from environments.factory.renderer import Entity
DIRT_INDEX = -1 DIRT_INDEX = -1
@ -37,12 +39,12 @@ class GettingDirty(BaseFactory):
if not self.renderer: # lazy init if not self.renderer: # lazy init
height, width = self.state.shape[1:] height, width = self.state.shape[1:]
self.renderer = Renderer(width, height, view_radius=0) self.renderer = Renderer(width, height, view_radius=0)
self.renderer.render(
OrderedDict(dirt=np.argwhere(self.state[DIRT_INDEX] > h.IS_FREE_CELL), dirt = [Entity('dirt', [x, y], self.state[DIRT_INDEX, x, y]) for x, y in np.argwhere(self.state[DIRT_INDEX] > h.IS_FREE_CELL)]
wall=np.argwhere(self.state[h.LEVEL_IDX] > h.IS_FREE_CELL), walls = [Entity('dirt', pos) for pos in np.argwhere(self.state[h.LEVEL_IDX] > h.IS_FREE_CELL)]
agent=np.argwhere(self.state[h.AGENT_START_IDX] > h.IS_FREE_CELL) agents = [Entity('agent', pos) for pos in np.argwhere(self.state[h.AGENT_START_IDX] > h.IS_FREE_CELL)]
)
) self.renderer.render(OrderedDict(dirt=dirt, wall=walls, agent=agents))
def spawn_dirt(self) -> None: def spawn_dirt(self) -> None:
free_for_dirt = self.free_cells(excluded_slices=DIRT_INDEX) free_for_dirt = self.free_cells(excluded_slices=DIRT_INDEX)