diff --git a/environments/factory/renderer.py b/environments/factory/renderer.py index ad44a56..0109084 100644 --- a/environments/factory/renderer.py +++ b/environments/factory/renderer.py @@ -1,6 +1,16 @@ -import pygame -from pathlib import Path import sys +from dataclasses import dataclass +import numpy as np +from pathlib import Path +import pygame + + +@dataclass +class Entity: + name: str + pos: np.array + value: float = 1 + class Renderer: BG_COLOR = (99, 110, 114) @@ -31,14 +41,15 @@ class Renderer: rect = pygame.Rect(x, y, self.cell_size, self.cell_size) pygame.draw.rect(self.screen, Renderer.WHITE, rect, 1) - def blit_params(self, r, c, name): + def blit_params(self, entity, name): + r, c = entity.pos img = self.assets[name] + img.set_alpha(255*entity.value) o = self.cell_size//2 r_, c_ = r*self.cell_size + o, c*self.cell_size + o rect = img.get_rect() rect.centerx, rect.centery = c_, r_ - - return img, rect + return dict(source=img, dest=rect) def load_asset(self, path, factor=1.0): s = int(factor*self.cell_size) @@ -52,15 +63,15 @@ class Renderer: pygame.quit() sys.exit() self.fill_bg() - for asset, positions in pos_dict.items(): - for x, y in positions: - img, rect = self.blit_params(x, y, asset) + for asset, entities in pos_dict.items(): + for entity in entities: + bp = self.blit_params(entity, asset) if 'agent' in asset and self.view_radius > 0: - visibility_rect = rect.inflate((self.view_radius*2)*self.cell_size, (self.view_radius*2)*self.cell_size) + visibility_rect = bp['dest'].inflate((self.view_radius*2)*self.cell_size, (self.view_radius*2)*self.cell_size) shape_surf = pygame.Surface(visibility_rect.size, pygame.SRCALPHA) pygame.draw.rect(shape_surf, self.PINK, shape_surf.get_rect()) self.screen.blit(shape_surf, visibility_rect) - self.screen.blit(img, rect) + self.screen.blit(**bp) pygame.display.flip() self.clock.tick(self.fps) diff --git a/environments/factory/simple_factory_getting_dirty.py b/environments/factory/simple_factory_getting_dirty.py index e5a85db..92eef56 100644 --- a/environments/factory/simple_factory_getting_dirty.py +++ b/environments/factory/simple_factory_getting_dirty.py @@ -9,6 +9,8 @@ from environments.factory.base_factory import BaseFactory, AgentState from environments import helpers as h from environments.factory.renderer import Renderer +from environments.factory.renderer import Entity + DIRT_INDEX = -1 @@ -37,12 +39,12 @@ class GettingDirty(BaseFactory): if not self.renderer: # lazy init height, width = self.state.shape[1:] self.renderer = Renderer(width, height, view_radius=0) - self.renderer.render( - OrderedDict(dirt=np.argwhere(self.state[DIRT_INDEX] > h.IS_FREE_CELL), - wall=np.argwhere(self.state[h.LEVEL_IDX] > h.IS_FREE_CELL), - agent=np.argwhere(self.state[h.AGENT_START_IDX] > h.IS_FREE_CELL) - ) - ) + + dirt = [Entity('dirt', [x, y], self.state[DIRT_INDEX, x, y]) for x, y in np.argwhere(self.state[DIRT_INDEX] > h.IS_FREE_CELL)] + walls = [Entity('dirt', pos) for pos in np.argwhere(self.state[h.LEVEL_IDX] > h.IS_FREE_CELL)] + agents = [Entity('agent', pos) for pos in np.argwhere(self.state[h.AGENT_START_IDX] > h.IS_FREE_CELL)] + + self.renderer.render(OrderedDict(dirt=dirt, wall=walls, agent=agents)) def spawn_dirt(self) -> None: free_for_dirt = self.free_cells(excluded_slices=DIRT_INDEX)