diff --git a/environments/factory/assets/agents/idle.png b/environments/factory/assets/agents/idle.png new file mode 100644 index 0000000..152f27a Binary files /dev/null and b/environments/factory/assets/agents/idle.png differ diff --git a/environments/factory/assets/agents/invalid.png b/environments/factory/assets/agents/invalid.png new file mode 100644 index 0000000..4cf34c0 Binary files /dev/null and b/environments/factory/assets/agents/invalid.png differ diff --git a/environments/factory/assets/agents/valid.png b/environments/factory/assets/agents/valid.png new file mode 100644 index 0000000..ae7c768 Binary files /dev/null and b/environments/factory/assets/agents/valid.png differ diff --git a/environments/factory/renderer.py b/environments/factory/renderer.py index fdb56d2..b439eaa 100644 --- a/environments/factory/renderer.py +++ b/environments/factory/renderer.py @@ -1,22 +1,24 @@ import sys -from dataclasses import dataclass import numpy as np from pathlib import Path from collections import deque import pygame +from typing import NamedTuple -@dataclass -class Entity: + +class Entity(NamedTuple): name: str pos: np.array value: float = 1 value_operation: str = 'none' + state: str = None class Renderer: BG_COLOR = (178, 190, 195) # (99, 110, 114) WHITE = (223, 230, 233) # (200, 200, 200) AGENT_VIEW_COLOR = (9, 132, 227) + ASSETS = Path(__file__).parent / 'assets' def __init__(self, grid_w=16, grid_h=16, cell_size=40, fps=4, grid_lines=True, view_radius=2): self.grid_h = grid_h @@ -29,7 +31,7 @@ class Renderer: self.screen_size = (grid_w*cell_size, grid_h*cell_size) self.screen = pygame.display.set_mode(self.screen_size) self.clock = pygame.time.Clock() - assets = list((Path(__file__).parent / 'assets').rglob('*.png')) + assets = list(self.ASSETS.rglob('*.png')) self.assets = {path.stem: self.load_asset(str(path), 1) for path in assets} self.fill_bg() @@ -64,25 +66,29 @@ class Renderer: wall_img = pygame.transform.smoothscale(wall_img, (s, s)) return wall_img - def render(self, pos_dict): + def render(self, entities): for event in pygame.event.get(): if event.type == pygame.QUIT: pygame.quit() sys.exit() self.fill_bg() blits = deque() - for asset, entities in pos_dict.items(): - for entity in entities: - bp = self.blit_params(entity) - if 'agent' in entity.name and self.view_radius > 0: - visibility_rect = bp['dest'].inflate((self.view_radius*2)*self.cell_size, (self.view_radius*2)*self.cell_size) - shape_surf = pygame.Surface(visibility_rect.size, pygame.SRCALPHA) - pygame.draw.rect(shape_surf, self.AGENT_VIEW_COLOR, shape_surf.get_rect()) - shape_surf.set_alpha(64) - blits.appendleft(dict(source=shape_surf, dest=visibility_rect)) - blits.append(bp) + for entity in entities: + bp = self.blit_params(entity) + blits.append(bp) + if 'agent' in entity.name and self.view_radius > 0: + visibility_rect = bp['dest'].inflate((self.view_radius*2)*self.cell_size, (self.view_radius*2)*self.cell_size) + shape_surf = pygame.Surface(visibility_rect.size, pygame.SRCALPHA) + pygame.draw.rect(shape_surf, self.AGENT_VIEW_COLOR, shape_surf.get_rect()) + shape_surf.set_alpha(64) + blits.appendleft(dict(source=shape_surf, dest=visibility_rect)) + agent_state_blits = self.blit_params(Entity(entity.state, (entity.pos[0]+0.11, entity.pos[1]), 0.48, 'scale')) + blits.append(agent_state_blits) + for blit in blits: self.screen.blit(**blit) + + pygame.display.flip() self.clock.tick(self.fps) @@ -90,5 +96,6 @@ class Renderer: if __name__ == '__main__': renderer = Renderer(fps=2, cell_size=40) for i in range(15): - renderer.render({'agent': [(5, i)], 'wall': [(0, i), (i, 0)], 'dirt': [(3,3), (3,4)]}) + entity = Entity('agent', [5, i], 1, 'idle', 'idle') + renderer.render([entity]) diff --git a/environments/factory/simple_factory.py b/environments/factory/simple_factory.py index c1f052d..bcc7fcb 100644 --- a/environments/factory/simple_factory.py +++ b/environments/factory/simple_factory.py @@ -58,15 +58,16 @@ class SimpleFactory(BaseFactory): if 'agent' in cols: return 'agent_collision' elif not agent.action_valid or 'level' in cols or 'agent' in cols: - return f'agent{agent.i + 1}violation' + return f'agent{agent.i + 1}', 'invalid' elif self._is_clean_up_action(agent.action): - return f'agent{agent.i + 1}valid' + return f'agent{agent.i + 1}', 'valid' else: - return f'agent{agent.i + 1}' - - agents = {f'agent{i+1}': [Entity(asset_str(agent), agent.pos)] - for i, agent in enumerate(self._agent_states)} - self._renderer.render(OrderedDict(dirt=dirt, wall=walls, **agents)) + return f'agent{agent.i + 1}', 'idle' + agents = [] + for i, agent in enumerate(self._agent_states): + name, state = asset_str(agent) + agents.append(Entity(name, agent.pos, 1, 'none', state)) + self._renderer.render(dirt+walls+agents) def spawn_dirt(self) -> None: if not np.argwhere(self._state[DIRT_INDEX] != h.IS_FREE_CELL).shape[0] > self.dirt_properties.max_global_amount: