From 62c141aa1c8bdb2c4f7d4d86987d8d7c1237b8cd Mon Sep 17 00:00:00 2001 From: steffen-illium Date: Mon, 7 Jun 2021 16:52:06 +0200 Subject: [PATCH] actions are now checked by string --- environments/factory/base_factory.py | 36 +++++++++++++--------------- environments/helpers.py | 14 +++++------ main.py | 4 ++-- 3 files changed, 26 insertions(+), 28 deletions(-) diff --git a/environments/factory/base_factory.py b/environments/factory/base_factory.py index c708a9b..1c8b921 100644 --- a/environments/factory/base_factory.py +++ b/environments/factory/base_factory.py @@ -26,22 +26,6 @@ class Entity(): def __init__(self, pos): self._pos = pos - def check_agent_move(state: np.ndarray, dim: int, action: str): - agent_slice = state[dim] # horizontal slice from state tensor - agent_pos = np.argwhere(agent_slice == 1) - if len(agent_pos) > 1: - raise AssertionError('Only one agent per slice is allowed.') - x, y = agent_pos[0] - - # Actions - x_diff, y_diff = ACTIONMAP[action] - x_new = x + x_diff - y_new = y + y_diff - - - - return (x, y), (x_new, y_new), valid - class AgentState: @@ -312,9 +296,7 @@ class BaseFactory(gym.Env): self._state[agent_i + h.AGENT_START_IDX, x_new, y_new] = h.IS_OCCUPIED_CELL def move_or_colide(self, agent_i: int, action: int) -> ((int, int), bool): - old_pos, new_pos, valid = h.check_agent_move(state=self._state, - dim=agent_i + h.AGENT_START_IDX, - action=action) + old_pos, new_pos, valid = self._check_agent_move(agent_i=agent_i, action=self._actions[action]) if valid: # Does not collide width level boundaries self.do_move(agent_i, old_pos, new_pos) @@ -323,6 +305,22 @@ class BaseFactory(gym.Env): # Agent seems to be trying to collide in this step return old_pos, valid + def _check_agent_move(self, agent_i, action: str): + agent_slice = self._state[h.AGENT_START_IDX + agent_i] # horizontal slice from state tensor + agent_pos = np.argwhere(agent_slice == 1) + if len(agent_pos) > 1: + raise AssertionError('Only one agent per slice is allowed.') + x, y = agent_pos[0] + + # Actions + x_diff, y_diff = h.ACTIONMAP[action] + x_new = x + x_diff + y_new = y + y_diff + + valid = h.check_position(self._state[h.LEVEL_IDX], (x_new, y_new)) + + return (x, y), (x_new, y_new), valid + def agent_i_position(self, agent_i: int) -> (int, int): positions = np.argwhere(self._state[h.AGENT_START_IDX + agent_i] == h.IS_OCCUPIED_CELL) assert positions.shape[0] == 1 diff --git a/environments/helpers.py b/environments/helpers.py index 5b7224e..64670a3 100644 --- a/environments/helpers.py +++ b/environments/helpers.py @@ -38,19 +38,19 @@ def one_hot_level(level, wall_char=WALL): return binary_grid -def check_position(state: np.ndarray, position_to_check: Tuple[int, int], dim: int = 0): - x, y = position_to_check - agent_slice = state[dim] +def check_position(slice_to_check_against: np.ndarray, position_to_check: Tuple[int, int]): + x_pos, y_pos = position_to_check # Check if agent colides with grid boundrys valid = not ( - x < 0 or y < 0 - or x >= agent_slice.shape[0] - or y >= agent_slice.shape[0] + x_pos < 0 or y_pos < 0 + or x_pos >= slice_to_check_against.shape[0] + or y_pos >= slice_to_check_against.shape[0] ) # Check for collision with level walls - valid = valid and not state[LEVEL_IDX][x, y] + valid = valid and not slice_to_check_against[x_pos, y_pos] + return valid if __name__ == '__main__': diff --git a/main.py b/main.py index e60f998..f77bf8a 100644 --- a/main.py +++ b/main.py @@ -85,8 +85,8 @@ def compare_runs(run_path: Path, run_identifier: int, parameter: Union[str, List if __name__ == '__main__': - compare_runs(Path('debug_out'), 1623052687, ['step_reward']) - exit() + # compare_runs(Path('debug_out'), 1623052687, ['step_reward']) + # exit() from stable_baselines3 import PPO, DQN, A2C from algorithms.reg_dqn import RegDQN