diff --git a/environments/factory/renderer.py b/environments/factory/renderer.py index 8e3ad92..302dacd 100644 --- a/environments/factory/renderer.py +++ b/environments/factory/renderer.py @@ -8,7 +8,7 @@ class Renderer: WHITE = (200, 200, 200) PINK = (0.5, 255, 118, 117) - def __init__(self, grid_w=16, grid_h=16, cell_size=25, fps=4, grid_lines=True, view_radius=2, assets=['wall', 'agent']): + def __init__(self, grid_w=16, grid_h=16, cell_size=30, fps=4, grid_lines=True, view_radius=2, assets=['wall', 'dirt', 'agent']): self.grid_h = grid_h self.grid_w = grid_w self.cell_size = cell_size diff --git a/environments/factory/simple_factory_getting_dirty.py b/environments/factory/simple_factory_getting_dirty.py index 99954b9..d047fbf 100644 --- a/environments/factory/simple_factory_getting_dirty.py +++ b/environments/factory/simple_factory_getting_dirty.py @@ -1,4 +1,4 @@ -from collections import defaultdict +from collections import defaultdict, OrderedDict from typing import List import numpy as np @@ -7,6 +7,8 @@ from attr import dataclass from environments.factory.base_factory import BaseFactory, AgentState from environments import helpers as h +from environments.factory.renderer import Renderer + DIRT_INDEX = -1 @dataclass class DirtProperties: @@ -24,6 +26,18 @@ class GettingDirty(BaseFactory): self._dirt_properties = dirt_properties super(GettingDirty, self).__init__(*args, **kwargs) self.slice_strings.update({self.state.shape[0]-1: 'dirt'}) + self.renderer = None # expensive - dont use it when not required ! + + def render(self): + if not self.renderer: # lazy init + h, w = self.state.shape[1:] + self.renderer = Renderer(w, h, view_radius=0, assets=['wall', 'agent', 'dirt']) + self.renderer.render( # todo: nur fuers prinzip, ist hardgecoded Dreck aktuell + OrderedDict(wall=np.argwhere(self.state[0] > 0), # Ordered dict defines the drawing order! important + dirt=np.argwhere(self.state[DIRT_INDEX] > 0), + agent=np.argwhere(self.state[1] > 0) + ) + ) def spawn_dirt(self) -> None: free_for_dirt = self.free_cells(excluded_slices=DIRT_INDEX) @@ -91,6 +105,9 @@ class GettingDirty(BaseFactory): if __name__ == '__main__': import random + + render = True + dirt_props = DirtProperties() factory = GettingDirty(n_agents=1, dirt_properties=dirt_props) monitor_list = list() @@ -99,6 +116,7 @@ if __name__ == '__main__': state, r, done, _ = factory.reset() for action in random_actions: state, r, done, info = factory.step(action) + if render: factory.render() monitor_list.append(factory.monitor.to_pd_dataframe()) print(f'Factory run {epoch} done, reward is:\n {r}')