mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-23 07:16:44 +02:00
Merge remote-tracking branch 'origin/main'
This commit is contained in:
commit
33ba9e817e
Binary file not shown.
Before Width: | Height: | Size: 14 KiB |
Binary file not shown.
Before Width: | Height: | Size: 94 KiB After Width: | Height: | Size: 7.5 KiB |
@ -6,18 +6,22 @@ from pathlib import Path
|
|||||||
class Renderer:
|
class Renderer:
|
||||||
BG_COLOR = (99, 110, 114)
|
BG_COLOR = (99, 110, 114)
|
||||||
WHITE = (200, 200, 200)
|
WHITE = (200, 200, 200)
|
||||||
|
PINK = (0.5, 255, 118, 117)
|
||||||
|
|
||||||
def __init__(self, grid_w=16, grid_h=16, cell_size=25, fps=4, grid_lines=True, assets=['wall', 'agent']):
|
def __init__(self, grid_w=16, grid_h=16, cell_size=30, fps=4, grid_lines=True, view_radius=2):
|
||||||
self.grid_h = grid_h
|
self.grid_h = grid_h
|
||||||
self.grid_w = grid_w
|
self.grid_w = grid_w
|
||||||
self.cell_size = cell_size
|
self.cell_size = cell_size
|
||||||
self.fps = fps#
|
self.fps = fps
|
||||||
self.grid_lines = grid_lines
|
self.grid_lines = grid_lines
|
||||||
|
self.view_radius = view_radius
|
||||||
pygame.init()
|
pygame.init()
|
||||||
self.screen_size = (grid_h*cell_size, grid_w*cell_size)
|
self.screen_size = (grid_h*cell_size, grid_w*cell_size)
|
||||||
self.screen = pygame.display.set_mode(self.screen_size)
|
self.screen = pygame.display.set_mode(self.screen_size)
|
||||||
self.clock = pygame.time.Clock()
|
self.clock = pygame.time.Clock()
|
||||||
self.assets = {name: self.load_asset(name, 0.97) for name in assets}
|
assets = list((Path(__file__).parent / 'assets').glob('*.png'))
|
||||||
|
self.assets = {path.stem: self.load_asset(str(path), 0.95) for path in assets}
|
||||||
|
print(self.assets)
|
||||||
self.fill_bg()
|
self.fill_bg()
|
||||||
|
|
||||||
def fill_bg(self):
|
def fill_bg(self):
|
||||||
@ -29,18 +33,18 @@ class Renderer:
|
|||||||
rect = pygame.Rect(x, y, self.cell_size, self.cell_size)
|
rect = pygame.Rect(x, y, self.cell_size, self.cell_size)
|
||||||
pygame.draw.rect(self.screen, Renderer.WHITE, rect, 1)
|
pygame.draw.rect(self.screen, Renderer.WHITE, rect, 1)
|
||||||
|
|
||||||
def render_asset(self, r, c, name):
|
def blit_params(self, r, c, name):
|
||||||
img = self.assets[name]
|
img = self.assets[name]
|
||||||
o = self.cell_size//2
|
o = self.cell_size//2
|
||||||
r_, c_ = r*self.cell_size + o, c*self.cell_size + o
|
r_, c_ = r*self.cell_size + o, c*self.cell_size + o
|
||||||
rect = img.get_rect()
|
rect = img.get_rect()
|
||||||
rect.centerx, rect.centery = c_, r_
|
rect.centerx, rect.centery = c_, r_
|
||||||
self.screen.blit(img, rect)
|
|
||||||
return c_, r_
|
|
||||||
|
|
||||||
def load_asset(self, name, factor=1.0):
|
return img, rect
|
||||||
|
|
||||||
|
def load_asset(self, path, factor=1.0):
|
||||||
s = int(factor*self.cell_size)
|
s = int(factor*self.cell_size)
|
||||||
wall_img = pygame.image.load(str(Path(__file__).parent / 'assets' / f'{name}.png')).convert_alpha()
|
wall_img = pygame.image.load(path).convert_alpha()
|
||||||
wall_img = pygame.transform.scale(wall_img, (s, s))
|
wall_img = pygame.transform.scale(wall_img, (s, s))
|
||||||
return wall_img
|
return wall_img
|
||||||
|
|
||||||
@ -48,7 +52,13 @@ class Renderer:
|
|||||||
self.fill_bg()
|
self.fill_bg()
|
||||||
for asset, positions in pos_dict.items():
|
for asset, positions in pos_dict.items():
|
||||||
for x, y in positions:
|
for x, y in positions:
|
||||||
self.render_asset(x, y, asset)
|
img, rect = self.blit_params(x, y, asset)
|
||||||
|
if 'agent' in asset and self.view_radius > 0:
|
||||||
|
visibility_rect = rect.inflate((self.view_radius*2)*self.cell_size, (self.view_radius*2)*self.cell_size)
|
||||||
|
shape_surf = pygame.Surface(visibility_rect.size, pygame.SRCALPHA)
|
||||||
|
pygame.draw.rect(shape_surf, self.PINK, shape_surf.get_rect())
|
||||||
|
self.screen.blit(shape_surf, visibility_rect)
|
||||||
|
self.screen.blit(img, rect)
|
||||||
pygame.display.flip()
|
pygame.display.flip()
|
||||||
self.clock.tick(self.fps)
|
self.clock.tick(self.fps)
|
||||||
|
|
||||||
@ -56,5 +66,5 @@ class Renderer:
|
|||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
renderer = Renderer(fps=2, cell_size=40, assets=['wall', 'agent', 'dirt'])
|
renderer = Renderer(fps=2, cell_size=40, assets=['wall', 'agent', 'dirt'])
|
||||||
for i in range(15):
|
for i in range(15):
|
||||||
renderer.render({'agent': [(5, 5)], 'wall': [(0, i), (i, 0)], 'dirt': [(3,3), (3,4)]})
|
renderer.render({'agent': [(5, i)], 'wall': [(0, i), (i, 0)], 'dirt': [(3,3), (3,4)]})
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from collections import defaultdict
|
from collections import defaultdict, OrderedDict
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -7,6 +7,8 @@ from attr import dataclass
|
|||||||
from environments.factory.base_factory import BaseFactory, AgentState
|
from environments.factory.base_factory import BaseFactory, AgentState
|
||||||
from environments import helpers as h
|
from environments import helpers as h
|
||||||
|
|
||||||
|
from environments.factory.renderer import Renderer
|
||||||
|
|
||||||
DIRT_INDEX = -1
|
DIRT_INDEX = -1
|
||||||
@dataclass
|
@dataclass
|
||||||
class DirtProperties:
|
class DirtProperties:
|
||||||
@ -24,6 +26,18 @@ class GettingDirty(BaseFactory):
|
|||||||
self._dirt_properties = dirt_properties
|
self._dirt_properties = dirt_properties
|
||||||
super(GettingDirty, self).__init__(*args, **kwargs)
|
super(GettingDirty, self).__init__(*args, **kwargs)
|
||||||
self.slice_strings.update({self.state.shape[0]-1: 'dirt'})
|
self.slice_strings.update({self.state.shape[0]-1: 'dirt'})
|
||||||
|
self.renderer = None # expensive - dont use it when not required !
|
||||||
|
|
||||||
|
def render(self):
|
||||||
|
if not self.renderer: # lazy init
|
||||||
|
h, w = self.state.shape[1:]
|
||||||
|
self.renderer = Renderer(w, h, view_radius=0)
|
||||||
|
self.renderer.render( # todo: nur fuers prinzip, ist hardgecoded Dreck aktuell
|
||||||
|
OrderedDict(wall=np.argwhere(self.state[0] > 0), # Ordered dict defines the drawing order! important
|
||||||
|
dirt=np.argwhere(self.state[DIRT_INDEX] > 0),
|
||||||
|
agent=np.argwhere(self.state[1] > 0)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def spawn_dirt(self) -> None:
|
def spawn_dirt(self) -> None:
|
||||||
free_for_dirt = self.free_cells(excluded_slices=DIRT_INDEX)
|
free_for_dirt = self.free_cells(excluded_slices=DIRT_INDEX)
|
||||||
@ -91,6 +105,9 @@ class GettingDirty(BaseFactory):
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
import random
|
import random
|
||||||
|
|
||||||
|
render = True
|
||||||
|
|
||||||
dirt_props = DirtProperties()
|
dirt_props = DirtProperties()
|
||||||
factory = GettingDirty(n_agents=1, dirt_properties=dirt_props)
|
factory = GettingDirty(n_agents=1, dirt_properties=dirt_props)
|
||||||
monitor_list = list()
|
monitor_list = list()
|
||||||
@ -99,6 +116,7 @@ if __name__ == '__main__':
|
|||||||
state, r, done, _ = factory.reset()
|
state, r, done, _ = factory.reset()
|
||||||
for action in random_actions:
|
for action in random_actions:
|
||||||
state, r, done, info = factory.step(action)
|
state, r, done, info = factory.step(action)
|
||||||
|
if render: factory.render()
|
||||||
monitor_list.append(factory.monitor.to_pd_dataframe())
|
monitor_list.append(factory.monitor.to_pd_dataframe())
|
||||||
print(f'Factory run {epoch} done, reward is:\n {r}')
|
print(f'Factory run {epoch} done, reward is:\n {r}')
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user