mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-23 07:16:44 +02:00
Check for Dirt vs wall collisions
This commit is contained in:
parent
02cc90c559
commit
f90889ef42
@ -1,4 +1,3 @@
|
|||||||
import numpy as np
|
|
||||||
import pygame
|
import pygame
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@ -66,5 +65,5 @@ class Renderer:
|
|||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
renderer = Renderer(fps=2, cell_size=40, assets=['wall', 'agent', 'dirt'])
|
renderer = Renderer(fps=2, cell_size=40, assets=['wall', 'agent', 'dirt'])
|
||||||
for i in range(15):
|
for i in range(15):
|
||||||
renderer.render({'agent': [(5, i)], 'wall': [(0, i), (i, 0)], 'dirt': [(3,3), (3,4)]})
|
renderer.render({'agent': [(5, i)], 'wall': [(0, i), (i, 0)], 'dirt': [(3, 3), (3, 4)]})
|
||||||
|
|
||||||
|
@ -1,14 +1,15 @@
|
|||||||
from collections import defaultdict, OrderedDict
|
from collections import OrderedDict
|
||||||
|
from dataclasses import dataclass
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from attr import dataclass
|
|
||||||
|
|
||||||
from environments.factory.base_factory import BaseFactory, AgentState
|
from environments.factory.base_factory import BaseFactory, AgentState
|
||||||
from environments import helpers as h
|
from environments import helpers as h
|
||||||
|
|
||||||
from environments.factory.renderer import Renderer
|
from environments.factory.renderer import Renderer
|
||||||
|
|
||||||
|
|
||||||
DIRT_INDEX = -1
|
DIRT_INDEX = -1
|
||||||
@dataclass
|
@dataclass
|
||||||
class DirtProperties:
|
class DirtProperties:
|
||||||
@ -20,6 +21,7 @@ class DirtProperties:
|
|||||||
class GettingDirty(BaseFactory):
|
class GettingDirty(BaseFactory):
|
||||||
|
|
||||||
def _is_clean_up_action(self, action):
|
def _is_clean_up_action(self, action):
|
||||||
|
# Account for NoOP; remove -1 when activating NoOP
|
||||||
return self.movement_actions + 1 - 1 == action
|
return self.movement_actions + 1 - 1 == action
|
||||||
|
|
||||||
def __init__(self, *args, dirt_properties: DirtProperties, **kwargs):
|
def __init__(self, *args, dirt_properties: DirtProperties, **kwargs):
|
||||||
@ -30,12 +32,12 @@ class GettingDirty(BaseFactory):
|
|||||||
|
|
||||||
def render(self):
|
def render(self):
|
||||||
if not self.renderer: # lazy init
|
if not self.renderer: # lazy init
|
||||||
h, w = self.state.shape[1:]
|
height, width = self.state.shape[1:]
|
||||||
self.renderer = Renderer(w, h, view_radius=0)
|
self.renderer = Renderer(width, height, view_radius=0)
|
||||||
self.renderer.render( # todo: nur fuers prinzip, ist hardgecoded Dreck aktuell
|
self.renderer.render(
|
||||||
OrderedDict(dirt=np.argwhere(self.state[DIRT_INDEX] > 0), # Ordered dict defines the drawing order! important
|
OrderedDict(dirt=np.argwhere(self.state[DIRT_INDEX] > 0),
|
||||||
wall=np.argwhere(self.state[0] > 0),
|
wall=np.argwhere(self.state[h.LEVEL_IDX] > 0),
|
||||||
agent=np.argwhere(self.state[1] > 0)
|
agent=np.argwhere(self.state[h.AGENT_START_IDX] > 0)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -88,6 +90,12 @@ class GettingDirty(BaseFactory):
|
|||||||
|
|
||||||
def calculate_reward(self, agent_states: List[AgentState]) -> (int, dict):
|
def calculate_reward(self, agent_states: List[AgentState]) -> (int, dict):
|
||||||
this_step_reward = 0
|
this_step_reward = 0
|
||||||
|
|
||||||
|
dirt_vs_level_collisions = np.argwhere(self.state[h.LEVEL_IDX, DIRT_INDEX].sum(0) == h.IS_FREE_CELL)
|
||||||
|
for dirt_vs_level_collision in dirt_vs_level_collisions:
|
||||||
|
print(f'Dirt was placed on Level at: {dirt_vs_level_collision.squeeze()}')
|
||||||
|
pass
|
||||||
|
|
||||||
for agent_state in agent_states:
|
for agent_state in agent_states:
|
||||||
collisions = agent_state.collisions
|
collisions = agent_state.collisions
|
||||||
print(f't = {self.steps}\tAgent {agent_state.i} has collisions with '
|
print(f't = {self.steps}\tAgent {agent_state.i} has collisions with '
|
||||||
@ -116,7 +124,8 @@ if __name__ == '__main__':
|
|||||||
state, r, done, _ = factory.reset()
|
state, r, done, _ = factory.reset()
|
||||||
for action in random_actions:
|
for action in random_actions:
|
||||||
state, r, done, info = factory.step(action)
|
state, r, done, info = factory.step(action)
|
||||||
if render: factory.render()
|
if render:
|
||||||
|
factory.render()
|
||||||
monitor_list.append(factory.monitor.to_pd_dataframe())
|
monitor_list.append(factory.monitor.to_pd_dataframe())
|
||||||
print(f'Factory run {epoch} done, reward is:\n {r}')
|
print(f'Factory run {epoch} done, reward is:\n {r}')
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user