Check for Dirt vs wall collisions

This commit is contained in:
steffen-illium 2021-05-18 10:01:56 +02:00
parent 02cc90c559
commit f90889ef42
2 changed files with 19 additions and 11 deletions

View File

@ -1,4 +1,3 @@
import numpy as np
import pygame
from pathlib import Path

View File

@ -1,14 +1,15 @@
from collections import defaultdict, OrderedDict
from collections import OrderedDict
from dataclasses import dataclass
from typing import List
import numpy as np
from attr import dataclass
from environments.factory.base_factory import BaseFactory, AgentState
from environments import helpers as h
from environments.factory.renderer import Renderer
DIRT_INDEX = -1
@dataclass
class DirtProperties:
@ -20,6 +21,7 @@ class DirtProperties:
class GettingDirty(BaseFactory):
def _is_clean_up_action(self, action):
# Account for NoOP; remove -1 when activating NoOP
return self.movement_actions + 1 - 1 == action
def __init__(self, *args, dirt_properties: DirtProperties, **kwargs):
@ -30,12 +32,12 @@ class GettingDirty(BaseFactory):
def render(self):
if not self.renderer: # lazy init
h, w = self.state.shape[1:]
self.renderer = Renderer(w, h, view_radius=0)
self.renderer.render( # todo: nur fuers prinzip, ist hardgecoded Dreck aktuell
OrderedDict(dirt=np.argwhere(self.state[DIRT_INDEX] > 0), # Ordered dict defines the drawing order! important
wall=np.argwhere(self.state[0] > 0),
agent=np.argwhere(self.state[1] > 0)
height, width = self.state.shape[1:]
self.renderer = Renderer(width, height, view_radius=0)
self.renderer.render(
OrderedDict(dirt=np.argwhere(self.state[DIRT_INDEX] > 0),
wall=np.argwhere(self.state[h.LEVEL_IDX] > 0),
agent=np.argwhere(self.state[h.AGENT_START_IDX] > 0)
)
)
@ -88,6 +90,12 @@ class GettingDirty(BaseFactory):
def calculate_reward(self, agent_states: List[AgentState]) -> (int, dict):
this_step_reward = 0
dirt_vs_level_collisions = np.argwhere(self.state[h.LEVEL_IDX, DIRT_INDEX].sum(0) == h.IS_FREE_CELL)
for dirt_vs_level_collision in dirt_vs_level_collisions:
print(f'Dirt was placed on Level at: {dirt_vs_level_collision.squeeze()}')
pass
for agent_state in agent_states:
collisions = agent_state.collisions
print(f't = {self.steps}\tAgent {agent_state.i} has collisions with '
@ -116,7 +124,8 @@ if __name__ == '__main__':
state, r, done, _ = factory.reset()
for action in random_actions:
state, r, done, info = factory.step(action)
if render: factory.render()
if render:
factory.render()
monitor_list.append(factory.monitor.to_pd_dataframe())
print(f'Factory run {epoch} done, reward is:\n {r}')