mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-22 23:06:43 +02:00
in progress
This commit is contained in:
parent
dbfa97aaba
commit
2589a06d02
@ -17,6 +17,32 @@ class MovementProperties(NamedTuple):
|
|||||||
allow_no_op: bool = False
|
allow_no_op: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class Entity():
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pos(self):
|
||||||
|
return self._pos
|
||||||
|
|
||||||
|
def __init__(self, pos):
|
||||||
|
self._pos = pos
|
||||||
|
|
||||||
|
def check_agent_move(state: np.ndarray, dim: int, action: str):
|
||||||
|
agent_slice = state[dim] # horizontal slice from state tensor
|
||||||
|
agent_pos = np.argwhere(agent_slice == 1)
|
||||||
|
if len(agent_pos) > 1:
|
||||||
|
raise AssertionError('Only one agent per slice is allowed.')
|
||||||
|
x, y = agent_pos[0]
|
||||||
|
|
||||||
|
# Actions
|
||||||
|
x_diff, y_diff = ACTIONMAP[action]
|
||||||
|
x_new = x + x_diff
|
||||||
|
y_new = y + y_diff
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
return (x, y), (x_new, y_new), valid
|
||||||
|
|
||||||
|
|
||||||
class AgentState:
|
class AgentState:
|
||||||
|
|
||||||
def __init__(self, i: int, action: int):
|
def __init__(self, i: int, action: int):
|
||||||
|
@ -1,3 +1,6 @@
|
|||||||
|
from collections import defaultdict
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@ -11,6 +14,13 @@ IS_OCCUPIED_CELL = 1
|
|||||||
TO_BE_AVERAGED = ['dirt_amount', 'dirty_tiles']
|
TO_BE_AVERAGED = ['dirt_amount', 'dirty_tiles']
|
||||||
IGNORED_DF_COLUMNS = ['Episode', 'Run', 'train_step', 'step', 'index', 'dirt_amount', 'dirty_tile_count']
|
IGNORED_DF_COLUMNS = ['Episode', 'Run', 'train_step', 'step', 'index', 'dirt_amount', 'dirty_tile_count']
|
||||||
|
|
||||||
|
ACTIONMAP = defaultdict(lambda: (0, 0), dict(north=(-1, 0), east=(0, 1),
|
||||||
|
south=(1, 0), west=(0, -1),
|
||||||
|
north_east=(-1, +1), south_east=(1, 1),
|
||||||
|
south_west=(+1, -1), north_west=(-1, -1)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Utility functions
|
# Utility functions
|
||||||
def parse_level(path):
|
def parse_level(path):
|
||||||
@ -28,48 +38,19 @@ def one_hot_level(level, wall_char=WALL):
|
|||||||
return binary_grid
|
return binary_grid
|
||||||
|
|
||||||
|
|
||||||
def check_agent_move(state, dim, action):
|
def check_position(state: np.ndarray, position_to_check: Tuple[int, int], dim: int = 0):
|
||||||
agent_slice = state[dim] # horizontal slice from state tensor
|
x, y = position_to_check
|
||||||
agent_pos = np.argwhere(agent_slice == 1)
|
agent_slice = state[dim]
|
||||||
if len(agent_pos) > 1:
|
|
||||||
raise AssertionError('Only one agent per slice is allowed.')
|
|
||||||
x, y = agent_pos[0]
|
|
||||||
x_new, y_new = x, y
|
|
||||||
# Actions
|
|
||||||
if action == 0: # North
|
|
||||||
x_new -= 1
|
|
||||||
elif action == 1: # East
|
|
||||||
y_new += 1
|
|
||||||
elif action == 2: # South
|
|
||||||
x_new += 1
|
|
||||||
elif action == 3: # West
|
|
||||||
y_new -= 1
|
|
||||||
elif action == 4: # NE
|
|
||||||
x_new -= 1
|
|
||||||
y_new += 1
|
|
||||||
elif action == 5: # SE
|
|
||||||
x_new += 1
|
|
||||||
y_new += 1
|
|
||||||
elif action == 6: # SW
|
|
||||||
x_new += 1
|
|
||||||
y_new -= 1
|
|
||||||
elif action == 7: # NW
|
|
||||||
x_new -= 1
|
|
||||||
y_new -= 1
|
|
||||||
else:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Check if agent colides with grid boundrys
|
# Check if agent colides with grid boundrys
|
||||||
valid = not (
|
valid = not (
|
||||||
x_new < 0 or y_new < 0
|
x < 0 or y < 0
|
||||||
or x_new >= agent_slice.shape[0]
|
or x >= agent_slice.shape[0]
|
||||||
or y_new >= agent_slice.shape[0]
|
or y >= agent_slice.shape[0]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check for collision with level walls
|
# Check for collision with level walls
|
||||||
valid = valid and not state[LEVEL_IDX][x_new, y_new]
|
valid = valid and not state[LEVEL_IDX][x, y]
|
||||||
|
|
||||||
return (x, y), (x_new, y_new), valid
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
15
main.py
15
main.py
@ -35,8 +35,8 @@ def combine_runs(run_path: Union[str, PathLike]):
|
|||||||
df = df.fillna(0).rename(columns={'episode': 'Episode', 'run': 'Run'})
|
df = df.fillna(0).rename(columns={'episode': 'Episode', 'run': 'Run'})
|
||||||
columns = [col for col in df.columns if col not in IGNORED_DF_COLUMNS]
|
columns = [col for col in df.columns if col not in IGNORED_DF_COLUMNS]
|
||||||
|
|
||||||
roll_n = 30
|
roll_n = 50
|
||||||
skip_n = 20
|
skip_n = 40
|
||||||
|
|
||||||
non_overlapp_window = df.groupby(['Run', 'Episode']).rolling(roll_n, min_periods=1).mean()
|
non_overlapp_window = df.groupby(['Run', 'Episode']).rolling(roll_n, min_periods=1).mean()
|
||||||
|
|
||||||
@ -68,8 +68,8 @@ def compare_runs(run_path: Path, run_identifier: int, parameter: Union[str, List
|
|||||||
df = df.fillna(0).rename(columns={'episode': 'Episode', 'run': 'Run', 'model': 'Model'})
|
df = df.fillna(0).rename(columns={'episode': 'Episode', 'run': 'Run', 'model': 'Model'})
|
||||||
columns = [col for col in df.columns if col in parameter]
|
columns = [col for col in df.columns if col in parameter]
|
||||||
|
|
||||||
roll_n = 30
|
roll_n = 40
|
||||||
skip_n = 10
|
skip_n = 20
|
||||||
|
|
||||||
non_overlapp_window = df.groupby(['Model', 'Run', 'Episode']).rolling(roll_n, min_periods=1).mean()
|
non_overlapp_window = df.groupby(['Model', 'Run', 'Episode']).rolling(roll_n, min_periods=1).mean()
|
||||||
|
|
||||||
@ -85,14 +85,15 @@ def compare_runs(run_path: Path, run_identifier: int, parameter: Union[str, List
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
||||||
compare_runs(Path('debug_out'), 1623052687, ['agent_0_vs_level'])
|
compare_runs(Path('debug_out'), 1623052687, ['step_reward'])
|
||||||
exit()
|
exit()
|
||||||
|
|
||||||
from stable_baselines3 import PPO, DQN, A2C
|
from stable_baselines3 import PPO, DQN, A2C
|
||||||
from algorithms.reg_dqn import RegDQN
|
from algorithms.reg_dqn import RegDQN
|
||||||
# from sb3_contrib import QRDQN
|
# from sb3_contrib import QRDQN
|
||||||
|
|
||||||
dirt_props = DirtProperties()
|
dirt_props = DirtProperties(clean_amount=3, gain_amount=0.2, max_global_amount=30,
|
||||||
|
max_local_amount=5, spawn_frequency=3)
|
||||||
move_props = MovementProperties(allow_diagonal_movement=False,
|
move_props = MovementProperties(allow_diagonal_movement=False,
|
||||||
allow_square_movement=True,
|
allow_square_movement=True,
|
||||||
allow_no_op=False)
|
allow_no_op=False)
|
||||||
@ -100,7 +101,7 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
out_path = None
|
out_path = None
|
||||||
|
|
||||||
for modeL_type in [PPO, A2C, RegDQN, DQN]:
|
for modeL_type in [PPO, A2C]: # , RegDQN, DQN]:
|
||||||
for seed in range(3):
|
for seed in range(3):
|
||||||
|
|
||||||
env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=3, max_steps=400,
|
env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=3, max_steps=400,
|
||||||
|
@ -28,7 +28,7 @@ if __name__ == '__main__':
|
|||||||
this_model = model_files[0]
|
this_model = model_files[0]
|
||||||
|
|
||||||
model = PPO.load(this_model)
|
model = PPO.load(this_model)
|
||||||
evaluation_result = evaluate_policy(model, env, n_eval_episodes=100, deterministic=True, render=True)
|
evaluation_result = evaluate_policy(model, env, n_eval_episodes=100, deterministic=False, render=True)
|
||||||
print(evaluation_result)
|
print(evaluation_result)
|
||||||
|
|
||||||
env.close()
|
env.close()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user