in progress

This commit is contained in:
steffen-illium
2021-06-07 16:14:29 +02:00
parent dbfa97aaba
commit 2589a06d02
4 changed files with 52 additions and 44 deletions

View File

@ -1,3 +1,6 @@
from collections import defaultdict
from typing import Tuple
import numpy as np
from pathlib import Path
@ -11,6 +14,13 @@ IS_OCCUPIED_CELL = 1
TO_BE_AVERAGED = ['dirt_amount', 'dirty_tiles']
IGNORED_DF_COLUMNS = ['Episode', 'Run', 'train_step', 'step', 'index', 'dirt_amount', 'dirty_tile_count']
ACTIONMAP = defaultdict(lambda: (0, 0), dict(north=(-1, 0), east=(0, 1),
south=(1, 0), west=(0, -1),
north_east=(-1, +1), south_east=(1, 1),
south_west=(+1, -1), north_west=(-1, -1)
)
)
# Utility functions
def parse_level(path):
@ -28,48 +38,19 @@ def one_hot_level(level, wall_char=WALL):
return binary_grid
def check_agent_move(state, dim, action):
agent_slice = state[dim] # horizontal slice from state tensor
agent_pos = np.argwhere(agent_slice == 1)
if len(agent_pos) > 1:
raise AssertionError('Only one agent per slice is allowed.')
x, y = agent_pos[0]
x_new, y_new = x, y
# Actions
if action == 0: # North
x_new -= 1
elif action == 1: # East
y_new += 1
elif action == 2: # South
x_new += 1
elif action == 3: # West
y_new -= 1
elif action == 4: # NE
x_new -= 1
y_new += 1
elif action == 5: # SE
x_new += 1
y_new += 1
elif action == 6: # SW
x_new += 1
y_new -= 1
elif action == 7: # NW
x_new -= 1
y_new -= 1
else:
pass
def check_position(state: np.ndarray, position_to_check: Tuple[int, int], dim: int = 0):
x, y = position_to_check
agent_slice = state[dim]
# Check if agent colides with grid boundrys
valid = not (
x_new < 0 or y_new < 0
or x_new >= agent_slice.shape[0]
or y_new >= agent_slice.shape[0]
x < 0 or y < 0
or x >= agent_slice.shape[0]
or y >= agent_slice.shape[0]
)
# Check for collision with level walls
valid = valid and not state[LEVEL_IDX][x_new, y_new]
return (x, y), (x_new, y_new), valid
valid = valid and not state[LEVEL_IDX][x, y]
if __name__ == '__main__':