moved renderer.py to base, added initial salina experiments

This commit is contained in:
romue
2021-11-12 13:47:53 +01:00
parent f625b9d8a5
commit b6bda84033
7 changed files with 105 additions and 31 deletions

View File

@ -544,7 +544,7 @@ class BaseFactory(gym.Env):
def render(self, mode='human'):
if not self._renderer: # lazy init
from environments.factory.renderer import Renderer, RenderEntity
from environments.factory.base.renderer import Renderer, RenderEntity
global Renderer, RenderEntity
height, width = self._obs_cube.shape[1:]
self._renderer = Renderer(width, height, view_radius=self._pomdp_r, fps=5)
@ -562,7 +562,7 @@ class BaseFactory(gym.Env):
doors.append(RenderEntity(name, door.pos, 1, 'none', state, i + 1))
additional_assets = self.render_additional_assets()
self._renderer.render(walls + doors + additional_assets + agents)
return self._renderer.render(walls + doors + additional_assets + agents)
def save_params(self, filepath: Path):
# noinspection PyProtectedMember

View File

@ -0,0 +1,135 @@
import sys
import numpy as np
from pathlib import Path
from collections import deque
from itertools import product
import pygame
from typing import NamedTuple, Any
import time
import torch
class RenderEntity(NamedTuple):
name: str
pos: np.array
value: float = 1
value_operation: str = 'none'
state: str = None
id: int = 0
aux: Any = None
class Renderer:
BG_COLOR = (178, 190, 195) # (99, 110, 114)
WHITE = (223, 230, 233) # (200, 200, 200)
AGENT_VIEW_COLOR = (9, 132, 227)
ASSETS = Path(__file__).parent.parent / 'assets'
def __init__(self, grid_w=16, grid_h=16, cell_size=40, fps=7, grid_lines=True, view_radius=2):
self.grid_h = grid_h
self.grid_w = grid_w
self.cell_size = cell_size
self.fps = fps
self.grid_lines = grid_lines
self.view_radius = view_radius
pygame.init()
self.screen_size = (grid_w*cell_size, grid_h*cell_size)
self.screen = pygame.display.set_mode(self.screen_size)
self.clock = pygame.time.Clock()
assets = list(self.ASSETS.rglob('*.png'))
self.assets = {path.stem: self.load_asset(str(path), 1) for path in assets}
self.fill_bg()
now = time.time()
self.font = pygame.font.Font(None, 20)
self.font.set_bold(1)
print('Loading System font with pygame.font.Font took', time.time() - now)
def fill_bg(self):
self.screen.fill(Renderer.BG_COLOR)
if self.grid_lines:
w, h = self.screen_size
for x in range(0, w, self.cell_size):
for y in range(0, h, self.cell_size):
rect = pygame.Rect(x, y, self.cell_size, self.cell_size)
pygame.draw.rect(self.screen, Renderer.WHITE, rect, 1)
def blit_params(self, entity):
r, c = entity.pos
img = self.assets[entity.name.lower()]
if entity.value_operation == 'opacity':
img.set_alpha(255*entity.value)
elif entity.value_operation == 'scale':
re = img.get_rect()
img = pygame.transform.smoothscale(
img, (int(entity.value*re.width), int(entity.value*re.height))
)
o = self.cell_size//2
r_, c_ = r*self.cell_size + o, c*self.cell_size + o
rect = img.get_rect()
rect.centerx, rect.centery = c_, r_
return dict(source=img, dest=rect)
def load_asset(self, path, factor=1.0):
s = int(factor*self.cell_size)
asset = pygame.image.load(path).convert_alpha()
asset = pygame.transform.smoothscale(asset, (s, s))
return asset
def visibility_rects(self, bp, view):
rects = []
for i, j in product(range(-self.view_radius, self.view_radius+1),
range(-self.view_radius, self.view_radius+1)):
if view is not None:
if bool(view[self.view_radius+j, self.view_radius+i]):
visibility_rect = bp['dest'].copy()
visibility_rect.centerx += i*self.cell_size
visibility_rect.centery += j*self.cell_size
shape_surf = pygame.Surface(visibility_rect.size, pygame.SRCALPHA)
pygame.draw.rect(shape_surf, self.AGENT_VIEW_COLOR, shape_surf.get_rect())
shape_surf.set_alpha(64)
rects.append(dict(source=shape_surf, dest=visibility_rect))
return rects
def render(self, entities):
for event in pygame.event.get():
if event.type == pygame.QUIT:
pygame.quit()
sys.exit()
self.fill_bg()
blits = deque()
for entity in [x for x in entities if 'door' in x.name]:
bp = self.blit_params(entity)
blits.append(bp)
for entity in [x for x in entities if 'door' not in x.name]:
bp = self.blit_params(entity)
blits.append(bp)
if entity.name.lower() == 'agent':
if self.view_radius > 0:
vis_rects = self.visibility_rects(bp, entity.aux)
blits.extendleft(vis_rects)
if entity.state != 'blank':
agent_state_blits = self.blit_params(
RenderEntity(entity.state, (entity.pos[0] + 0.12, entity.pos[1]), 0.48, 'scale')
)
textsurface = self.font.render(str(entity.id), False, (0, 0, 0))
text_blit = dict(source=textsurface, dest=(bp['dest'].center[0]-.07*self.cell_size,
bp['dest'].center[1]))
blits += [agent_state_blits, text_blit]
for blit in blits:
self.screen.blit(**blit)
pygame.display.flip()
self.clock.tick(self.fps)
rgb_obs = pygame.surfarray.array3d(self.screen)
return torch.from_numpy(rgb_obs).permute(2, 0, 1)
if __name__ == '__main__':
renderer = Renderer(fps=2, cell_size=40)
for i in range(15):
entity_1 = RenderEntity('agent_collision', [5, i], 1, 'idle', 'idle')
renderer.render([entity_1])