better agent visualization

This commit is contained in:
romue
2021-06-09 13:21:30 +02:00
parent dbfa97aaba
commit 76b97b126e
5 changed files with 31 additions and 23 deletions

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.6 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.6 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.5 KiB

View File

@ -1,22 +1,24 @@
import sys import sys
from dataclasses import dataclass
import numpy as np import numpy as np
from pathlib import Path from pathlib import Path
from collections import deque from collections import deque
import pygame import pygame
from typing import NamedTuple
@dataclass
class Entity: class Entity(NamedTuple):
name: str name: str
pos: np.array pos: np.array
value: float = 1 value: float = 1
value_operation: str = 'none' value_operation: str = 'none'
state: str = None
class Renderer: class Renderer:
BG_COLOR = (178, 190, 195) # (99, 110, 114) BG_COLOR = (178, 190, 195) # (99, 110, 114)
WHITE = (223, 230, 233) # (200, 200, 200) WHITE = (223, 230, 233) # (200, 200, 200)
AGENT_VIEW_COLOR = (9, 132, 227) AGENT_VIEW_COLOR = (9, 132, 227)
ASSETS = Path(__file__).parent / 'assets'
def __init__(self, grid_w=16, grid_h=16, cell_size=40, fps=4, grid_lines=True, view_radius=2): def __init__(self, grid_w=16, grid_h=16, cell_size=40, fps=4, grid_lines=True, view_radius=2):
self.grid_h = grid_h self.grid_h = grid_h
@ -29,7 +31,7 @@ class Renderer:
self.screen_size = (grid_w*cell_size, grid_h*cell_size) self.screen_size = (grid_w*cell_size, grid_h*cell_size)
self.screen = pygame.display.set_mode(self.screen_size) self.screen = pygame.display.set_mode(self.screen_size)
self.clock = pygame.time.Clock() self.clock = pygame.time.Clock()
assets = list((Path(__file__).parent / 'assets').rglob('*.png')) assets = list(self.ASSETS.rglob('*.png'))
self.assets = {path.stem: self.load_asset(str(path), 1) for path in assets} self.assets = {path.stem: self.load_asset(str(path), 1) for path in assets}
self.fill_bg() self.fill_bg()
@ -64,25 +66,29 @@ class Renderer:
wall_img = pygame.transform.smoothscale(wall_img, (s, s)) wall_img = pygame.transform.smoothscale(wall_img, (s, s))
return wall_img return wall_img
def render(self, pos_dict): def render(self, entities):
for event in pygame.event.get(): for event in pygame.event.get():
if event.type == pygame.QUIT: if event.type == pygame.QUIT:
pygame.quit() pygame.quit()
sys.exit() sys.exit()
self.fill_bg() self.fill_bg()
blits = deque() blits = deque()
for asset, entities in pos_dict.items(): for entity in entities:
for entity in entities: bp = self.blit_params(entity)
bp = self.blit_params(entity) blits.append(bp)
if 'agent' in entity.name and self.view_radius > 0: if 'agent' in entity.name and self.view_radius > 0:
visibility_rect = bp['dest'].inflate((self.view_radius*2)*self.cell_size, (self.view_radius*2)*self.cell_size) visibility_rect = bp['dest'].inflate((self.view_radius*2)*self.cell_size, (self.view_radius*2)*self.cell_size)
shape_surf = pygame.Surface(visibility_rect.size, pygame.SRCALPHA) shape_surf = pygame.Surface(visibility_rect.size, pygame.SRCALPHA)
pygame.draw.rect(shape_surf, self.AGENT_VIEW_COLOR, shape_surf.get_rect()) pygame.draw.rect(shape_surf, self.AGENT_VIEW_COLOR, shape_surf.get_rect())
shape_surf.set_alpha(64) shape_surf.set_alpha(64)
blits.appendleft(dict(source=shape_surf, dest=visibility_rect)) blits.appendleft(dict(source=shape_surf, dest=visibility_rect))
blits.append(bp) agent_state_blits = self.blit_params(Entity(entity.state, (entity.pos[0]+0.11, entity.pos[1]), 0.48, 'scale'))
blits.append(agent_state_blits)
for blit in blits: for blit in blits:
self.screen.blit(**blit) self.screen.blit(**blit)
pygame.display.flip() pygame.display.flip()
self.clock.tick(self.fps) self.clock.tick(self.fps)
@ -90,5 +96,6 @@ class Renderer:
if __name__ == '__main__': if __name__ == '__main__':
renderer = Renderer(fps=2, cell_size=40) renderer = Renderer(fps=2, cell_size=40)
for i in range(15): for i in range(15):
renderer.render({'agent': [(5, i)], 'wall': [(0, i), (i, 0)], 'dirt': [(3,3), (3,4)]}) entity = Entity('agent', [5, i], 1, 'idle', 'idle')
renderer.render([entity])

View File

@ -58,15 +58,16 @@ class SimpleFactory(BaseFactory):
if 'agent' in cols: if 'agent' in cols:
return 'agent_collision' return 'agent_collision'
elif not agent.action_valid or 'level' in cols or 'agent' in cols: elif not agent.action_valid or 'level' in cols or 'agent' in cols:
return f'agent{agent.i + 1}violation' return f'agent{agent.i + 1}', 'invalid'
elif self._is_clean_up_action(agent.action): elif self._is_clean_up_action(agent.action):
return f'agent{agent.i + 1}valid' return f'agent{agent.i + 1}', 'valid'
else: else:
return f'agent{agent.i + 1}' return f'agent{agent.i + 1}', 'idle'
agents = []
agents = {f'agent{i+1}': [Entity(asset_str(agent), agent.pos)] for i, agent in enumerate(self._agent_states):
for i, agent in enumerate(self._agent_states)} name, state = asset_str(agent)
self._renderer.render(OrderedDict(dirt=dirt, wall=walls, **agents)) agents.append(Entity(name, agent.pos, 1, 'none', state))
self._renderer.render(dirt+walls+agents)
def spawn_dirt(self) -> None: def spawn_dirt(self) -> None:
if not np.argwhere(self._state[DIRT_INDEX] != h.IS_FREE_CELL).shape[0] > self.dirt_properties.max_global_amount: if not np.argwhere(self._state[DIRT_INDEX] != h.IS_FREE_CELL).shape[0] > self.dirt_properties.max_global_amount: