diff --git a/environments/factory/base/base_factory.py b/environments/factory/base/base_factory.py index 5ef9227..3e4f859 100644 --- a/environments/factory/base/base_factory.py +++ b/environments/factory/base/base_factory.py @@ -150,9 +150,9 @@ class BaseFactory(gym.Env): # Objects self._entities = Entities() # Level - level_array = h.one_hot_level(self._parsed_level) - level_array = np.pad(level_array, self.obs_prop.pomdp_r, 'constant', constant_values=1) + self._level_init_shape = level_array.shape + level_array = np.pad(level_array, self.obs_prop.pomdp_r, 'constant', constant_values=c.OCCUPIED_CELL) self._level_shape = level_array.shape self._obs_shape = self._level_shape if not self.obs_prop.pomdp_r else (self.pomdp_diameter, ) * 2 diff --git a/environments/factory/base/renderer.py b/environments/factory/base/renderer.py index 92eefc1..7388ef7 100644 --- a/environments/factory/base/renderer.py +++ b/environments/factory/base/renderer.py @@ -20,21 +20,33 @@ class RenderEntity(NamedTuple): aux: Any = None +class RenderNames: + AGENT: str = 'agent' + BLANK: str = 'blank' + DOOR: str = 'door' + OPACITY: str = 'opacity' + SCALE: str = 'scale' +rn = RenderNames + + class Renderer: BG_COLOR = (178, 190, 195) # (99, 110, 114) WHITE = (223, 230, 233) # (200, 200, 200) AGENT_VIEW_COLOR = (9, 132, 227) ASSETS = Path(__file__).parent.parent / 'assets' - def __init__(self, grid_w=16, grid_h=16, cell_size=40, fps=7, grid_lines=True, view_radius=2): - self.grid_h = grid_h - self.grid_w = grid_w + def __init__(self, lvl_shape=(16, 16), + lvl_padded_shape=None, + cell_size=40, fps=7, + grid_lines=True, view_radius=2): + self.grid_h, self.grid_w = lvl_shape + self.lvl_padded_shape = lvl_padded_shape if lvl_padded_shape is not None else lvl_shape self.cell_size = cell_size self.fps = fps self.grid_lines = grid_lines self.view_radius = view_radius pygame.init() - self.screen_size = (grid_w*cell_size, grid_h*cell_size) + self.screen_size = (self.grid_w*cell_size, self.grid_h*cell_size) self.screen = pygame.display.set_mode(self.screen_size) self.clock = pygame.time.Clock() assets = list(self.ASSETS.rglob('*.png')) @@ -43,7 +55,7 @@ class Renderer: now = time.time() self.font = pygame.font.Font(None, 20) - self.font.set_bold(1) + self.font.set_bold(True) print('Loading System font with pygame.font.Font took', time.time() - now) def fill_bg(self): @@ -56,11 +68,16 @@ class Renderer: pygame.draw.rect(self.screen, Renderer.WHITE, rect, 1) def blit_params(self, entity): + offset_r, offset_c = (self.lvl_padded_shape[0] - self.grid_h) // 2, \ + (self.lvl_padded_shape[1] - self.grid_w) // 2 + r, c = entity.pos + r, c = r - offset_r, c-offset_c + img = self.assets[entity.name.lower()] - if entity.value_operation == 'opacity': + if entity.value_operation == rn.OPACITY: img.set_alpha(255*entity.value) - elif entity.value_operation == 'scale': + elif entity.value_operation == rn.SCALE: re = img.get_rect() img = pygame.transform.smoothscale( img, (int(entity.value*re.width), int(entity.value*re.height)) @@ -99,19 +116,19 @@ class Renderer: sys.exit() self.fill_bg() blits = deque() - for entity in [x for x in entities if 'door' in x.name]: + for entity in [x for x in entities if rn.DOOR in x.name]: bp = self.blit_params(entity) blits.append(bp) - for entity in [x for x in entities if 'door' not in x.name]: + for entity in [x for x in entities if rn.DOOR not in x.name]: bp = self.blit_params(entity) blits.append(bp) - if entity.name.lower() == 'agent': + if entity.name.lower() == rn.AGENT: if self.view_radius > 0: vis_rects = self.visibility_rects(bp, entity.aux) blits.extendleft(vis_rects) - if entity.state != 'blank': + if entity.state != rn.BLANK: agent_state_blits = self.blit_params( - RenderEntity(entity.state, (entity.pos[0] + 0.12, entity.pos[1]), 0.48, 'scale') + RenderEntity(entity.state, (entity.pos[0] + 0.12, entity.pos[1]), 0.48, rn.SCALE) ) textsurface = self.font.render(str(entity.id), False, (0, 0, 0)) text_blit = dict(source=textsurface, dest=(bp['dest'].center[0]-.07*self.cell_size,