mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-23 07:16:44 +02:00
Check for Dirt vs wall collisions
This commit is contained in:
parent
02cc90c559
commit
f90889ef42
@ -1,4 +1,3 @@
|
||||
import numpy as np
|
||||
import pygame
|
||||
from pathlib import Path
|
||||
|
||||
@ -66,5 +65,5 @@ class Renderer:
|
||||
if __name__ == '__main__':
|
||||
renderer = Renderer(fps=2, cell_size=40, assets=['wall', 'agent', 'dirt'])
|
||||
for i in range(15):
|
||||
renderer.render({'agent': [(5, i)], 'wall': [(0, i), (i, 0)], 'dirt': [(3,3), (3,4)]})
|
||||
renderer.render({'agent': [(5, i)], 'wall': [(0, i), (i, 0)], 'dirt': [(3, 3), (3, 4)]})
|
||||
|
||||
|
@ -1,14 +1,15 @@
|
||||
from collections import defaultdict, OrderedDict
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
from attr import dataclass
|
||||
|
||||
from environments.factory.base_factory import BaseFactory, AgentState
|
||||
from environments import helpers as h
|
||||
|
||||
from environments.factory.renderer import Renderer
|
||||
|
||||
|
||||
DIRT_INDEX = -1
|
||||
@dataclass
|
||||
class DirtProperties:
|
||||
@ -20,6 +21,7 @@ class DirtProperties:
|
||||
class GettingDirty(BaseFactory):
|
||||
|
||||
def _is_clean_up_action(self, action):
|
||||
# Account for NoOP; remove -1 when activating NoOP
|
||||
return self.movement_actions + 1 - 1 == action
|
||||
|
||||
def __init__(self, *args, dirt_properties: DirtProperties, **kwargs):
|
||||
@ -30,12 +32,12 @@ class GettingDirty(BaseFactory):
|
||||
|
||||
def render(self):
|
||||
if not self.renderer: # lazy init
|
||||
h, w = self.state.shape[1:]
|
||||
self.renderer = Renderer(w, h, view_radius=0)
|
||||
self.renderer.render( # todo: nur fuers prinzip, ist hardgecoded Dreck aktuell
|
||||
OrderedDict(dirt=np.argwhere(self.state[DIRT_INDEX] > 0), # Ordered dict defines the drawing order! important
|
||||
wall=np.argwhere(self.state[0] > 0),
|
||||
agent=np.argwhere(self.state[1] > 0)
|
||||
height, width = self.state.shape[1:]
|
||||
self.renderer = Renderer(width, height, view_radius=0)
|
||||
self.renderer.render(
|
||||
OrderedDict(dirt=np.argwhere(self.state[DIRT_INDEX] > 0),
|
||||
wall=np.argwhere(self.state[h.LEVEL_IDX] > 0),
|
||||
agent=np.argwhere(self.state[h.AGENT_START_IDX] > 0)
|
||||
)
|
||||
)
|
||||
|
||||
@ -88,6 +90,12 @@ class GettingDirty(BaseFactory):
|
||||
|
||||
def calculate_reward(self, agent_states: List[AgentState]) -> (int, dict):
|
||||
this_step_reward = 0
|
||||
|
||||
dirt_vs_level_collisions = np.argwhere(self.state[h.LEVEL_IDX, DIRT_INDEX].sum(0) == h.IS_FREE_CELL)
|
||||
for dirt_vs_level_collision in dirt_vs_level_collisions:
|
||||
print(f'Dirt was placed on Level at: {dirt_vs_level_collision.squeeze()}')
|
||||
pass
|
||||
|
||||
for agent_state in agent_states:
|
||||
collisions = agent_state.collisions
|
||||
print(f't = {self.steps}\tAgent {agent_state.i} has collisions with '
|
||||
@ -116,7 +124,8 @@ if __name__ == '__main__':
|
||||
state, r, done, _ = factory.reset()
|
||||
for action in random_actions:
|
||||
state, r, done, info = factory.step(action)
|
||||
if render: factory.render()
|
||||
if render:
|
||||
factory.render()
|
||||
monitor_list.append(factory.monitor.to_pd_dataframe())
|
||||
print(f'Factory run {epoch} done, reward is:\n {r}')
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user