mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-06-18 18:52:52 +02:00
in progress
This commit is contained in:
@ -1,3 +1,6 @@
|
||||
from collections import defaultdict
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
@ -11,6 +14,13 @@ IS_OCCUPIED_CELL = 1
|
||||
TO_BE_AVERAGED = ['dirt_amount', 'dirty_tiles']
|
||||
IGNORED_DF_COLUMNS = ['Episode', 'Run', 'train_step', 'step', 'index', 'dirt_amount', 'dirty_tile_count']
|
||||
|
||||
ACTIONMAP = defaultdict(lambda: (0, 0), dict(north=(-1, 0), east=(0, 1),
|
||||
south=(1, 0), west=(0, -1),
|
||||
north_east=(-1, +1), south_east=(1, 1),
|
||||
south_west=(+1, -1), north_west=(-1, -1)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# Utility functions
|
||||
def parse_level(path):
|
||||
@ -28,48 +38,19 @@ def one_hot_level(level, wall_char=WALL):
|
||||
return binary_grid
|
||||
|
||||
|
||||
def check_agent_move(state, dim, action):
|
||||
agent_slice = state[dim] # horizontal slice from state tensor
|
||||
agent_pos = np.argwhere(agent_slice == 1)
|
||||
if len(agent_pos) > 1:
|
||||
raise AssertionError('Only one agent per slice is allowed.')
|
||||
x, y = agent_pos[0]
|
||||
x_new, y_new = x, y
|
||||
# Actions
|
||||
if action == 0: # North
|
||||
x_new -= 1
|
||||
elif action == 1: # East
|
||||
y_new += 1
|
||||
elif action == 2: # South
|
||||
x_new += 1
|
||||
elif action == 3: # West
|
||||
y_new -= 1
|
||||
elif action == 4: # NE
|
||||
x_new -= 1
|
||||
y_new += 1
|
||||
elif action == 5: # SE
|
||||
x_new += 1
|
||||
y_new += 1
|
||||
elif action == 6: # SW
|
||||
x_new += 1
|
||||
y_new -= 1
|
||||
elif action == 7: # NW
|
||||
x_new -= 1
|
||||
y_new -= 1
|
||||
else:
|
||||
pass
|
||||
def check_position(state: np.ndarray, position_to_check: Tuple[int, int], dim: int = 0):
|
||||
x, y = position_to_check
|
||||
agent_slice = state[dim]
|
||||
|
||||
# Check if agent colides with grid boundrys
|
||||
valid = not (
|
||||
x_new < 0 or y_new < 0
|
||||
or x_new >= agent_slice.shape[0]
|
||||
or y_new >= agent_slice.shape[0]
|
||||
x < 0 or y < 0
|
||||
or x >= agent_slice.shape[0]
|
||||
or y >= agent_slice.shape[0]
|
||||
)
|
||||
|
||||
# Check for collision with level walls
|
||||
valid = valid and not state[LEVEL_IDX][x_new, y_new]
|
||||
|
||||
return (x, y), (x_new, y_new), valid
|
||||
valid = valid and not state[LEVEL_IDX][x, y]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Reference in New Issue
Block a user