diff --git a/environments/factory/base_factory.py b/environments/factory/base_factory.py index 565f5a6..c708a9b 100644 --- a/environments/factory/base_factory.py +++ b/environments/factory/base_factory.py @@ -17,6 +17,32 @@ class MovementProperties(NamedTuple): allow_no_op: bool = False +class Entity(): + + @property + def pos(self): + return self._pos + + def __init__(self, pos): + self._pos = pos + + def check_agent_move(state: np.ndarray, dim: int, action: str): + agent_slice = state[dim] # horizontal slice from state tensor + agent_pos = np.argwhere(agent_slice == 1) + if len(agent_pos) > 1: + raise AssertionError('Only one agent per slice is allowed.') + x, y = agent_pos[0] + + # Actions + x_diff, y_diff = ACTIONMAP[action] + x_new = x + x_diff + y_new = y + y_diff + + + + return (x, y), (x_new, y_new), valid + + class AgentState: def __init__(self, i: int, action: int): diff --git a/environments/helpers.py b/environments/helpers.py index 1df6cd2..5b7224e 100644 --- a/environments/helpers.py +++ b/environments/helpers.py @@ -1,3 +1,6 @@ +from collections import defaultdict +from typing import Tuple + import numpy as np from pathlib import Path @@ -11,6 +14,13 @@ IS_OCCUPIED_CELL = 1 TO_BE_AVERAGED = ['dirt_amount', 'dirty_tiles'] IGNORED_DF_COLUMNS = ['Episode', 'Run', 'train_step', 'step', 'index', 'dirt_amount', 'dirty_tile_count'] +ACTIONMAP = defaultdict(lambda: (0, 0), dict(north=(-1, 0), east=(0, 1), + south=(1, 0), west=(0, -1), + north_east=(-1, +1), south_east=(1, 1), + south_west=(+1, -1), north_west=(-1, -1) + ) + ) + # Utility functions def parse_level(path): @@ -28,48 +38,19 @@ def one_hot_level(level, wall_char=WALL): return binary_grid -def check_agent_move(state, dim, action): - agent_slice = state[dim] # horizontal slice from state tensor - agent_pos = np.argwhere(agent_slice == 1) - if len(agent_pos) > 1: - raise AssertionError('Only one agent per slice is allowed.') - x, y = agent_pos[0] - x_new, y_new = x, y - # Actions - if action == 0: # North - x_new -= 1 - elif action == 1: # East - y_new += 1 - elif action == 2: # South - x_new += 1 - elif action == 3: # West - y_new -= 1 - elif action == 4: # NE - x_new -= 1 - y_new += 1 - elif action == 5: # SE - x_new += 1 - y_new += 1 - elif action == 6: # SW - x_new += 1 - y_new -= 1 - elif action == 7: # NW - x_new -= 1 - y_new -= 1 - else: - pass +def check_position(state: np.ndarray, position_to_check: Tuple[int, int], dim: int = 0): + x, y = position_to_check + agent_slice = state[dim] # Check if agent colides with grid boundrys valid = not ( - x_new < 0 or y_new < 0 - or x_new >= agent_slice.shape[0] - or y_new >= agent_slice.shape[0] + x < 0 or y < 0 + or x >= agent_slice.shape[0] + or y >= agent_slice.shape[0] ) # Check for collision with level walls - valid = valid and not state[LEVEL_IDX][x_new, y_new] - - return (x, y), (x_new, y_new), valid + valid = valid and not state[LEVEL_IDX][x, y] if __name__ == '__main__': diff --git a/main.py b/main.py index 4a983fd..e60f998 100644 --- a/main.py +++ b/main.py @@ -35,8 +35,8 @@ def combine_runs(run_path: Union[str, PathLike]): df = df.fillna(0).rename(columns={'episode': 'Episode', 'run': 'Run'}) columns = [col for col in df.columns if col not in IGNORED_DF_COLUMNS] - roll_n = 30 - skip_n = 20 + roll_n = 50 + skip_n = 40 non_overlapp_window = df.groupby(['Run', 'Episode']).rolling(roll_n, min_periods=1).mean() @@ -68,8 +68,8 @@ def compare_runs(run_path: Path, run_identifier: int, parameter: Union[str, List df = df.fillna(0).rename(columns={'episode': 'Episode', 'run': 'Run', 'model': 'Model'}) columns = [col for col in df.columns if col in parameter] - roll_n = 30 - skip_n = 10 + roll_n = 40 + skip_n = 20 non_overlapp_window = df.groupby(['Model', 'Run', 'Episode']).rolling(roll_n, min_periods=1).mean() @@ -85,14 +85,15 @@ def compare_runs(run_path: Path, run_identifier: int, parameter: Union[str, List if __name__ == '__main__': - compare_runs(Path('debug_out'), 1623052687, ['agent_0_vs_level']) + compare_runs(Path('debug_out'), 1623052687, ['step_reward']) exit() from stable_baselines3 import PPO, DQN, A2C from algorithms.reg_dqn import RegDQN # from sb3_contrib import QRDQN - dirt_props = DirtProperties() + dirt_props = DirtProperties(clean_amount=3, gain_amount=0.2, max_global_amount=30, + max_local_amount=5, spawn_frequency=3) move_props = MovementProperties(allow_diagonal_movement=False, allow_square_movement=True, allow_no_op=False) @@ -100,7 +101,7 @@ if __name__ == '__main__': out_path = None - for modeL_type in [PPO, A2C, RegDQN, DQN]: + for modeL_type in [PPO, A2C]: # , RegDQN, DQN]: for seed in range(3): env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=3, max_steps=400, diff --git a/reload_agent.py b/reload_agent.py index cfa3383..b5a8607 100644 --- a/reload_agent.py +++ b/reload_agent.py @@ -28,7 +28,7 @@ if __name__ == '__main__': this_model = model_files[0] model = PPO.load(this_model) - evaluation_result = evaluate_policy(model, env, n_eval_episodes=100, deterministic=True, render=True) + evaluation_result = evaluate_policy(model, env, n_eval_episodes=100, deterministic=False, render=True) print(evaluation_result) env.close()