actions are now checked by string
This commit is contained in:
parent
2589a06d02
commit
62c141aa1c
@ -26,22 +26,6 @@ class Entity():
|
|||||||
def __init__(self, pos):
|
def __init__(self, pos):
|
||||||
self._pos = pos
|
self._pos = pos
|
||||||
|
|
||||||
def check_agent_move(state: np.ndarray, dim: int, action: str):
|
|
||||||
agent_slice = state[dim] # horizontal slice from state tensor
|
|
||||||
agent_pos = np.argwhere(agent_slice == 1)
|
|
||||||
if len(agent_pos) > 1:
|
|
||||||
raise AssertionError('Only one agent per slice is allowed.')
|
|
||||||
x, y = agent_pos[0]
|
|
||||||
|
|
||||||
# Actions
|
|
||||||
x_diff, y_diff = ACTIONMAP[action]
|
|
||||||
x_new = x + x_diff
|
|
||||||
y_new = y + y_diff
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
return (x, y), (x_new, y_new), valid
|
|
||||||
|
|
||||||
|
|
||||||
class AgentState:
|
class AgentState:
|
||||||
|
|
||||||
@ -312,9 +296,7 @@ class BaseFactory(gym.Env):
|
|||||||
self._state[agent_i + h.AGENT_START_IDX, x_new, y_new] = h.IS_OCCUPIED_CELL
|
self._state[agent_i + h.AGENT_START_IDX, x_new, y_new] = h.IS_OCCUPIED_CELL
|
||||||
|
|
||||||
def move_or_colide(self, agent_i: int, action: int) -> ((int, int), bool):
|
def move_or_colide(self, agent_i: int, action: int) -> ((int, int), bool):
|
||||||
old_pos, new_pos, valid = h.check_agent_move(state=self._state,
|
old_pos, new_pos, valid = self._check_agent_move(agent_i=agent_i, action=self._actions[action])
|
||||||
dim=agent_i + h.AGENT_START_IDX,
|
|
||||||
action=action)
|
|
||||||
if valid:
|
if valid:
|
||||||
# Does not collide width level boundaries
|
# Does not collide width level boundaries
|
||||||
self.do_move(agent_i, old_pos, new_pos)
|
self.do_move(agent_i, old_pos, new_pos)
|
||||||
@ -323,6 +305,22 @@ class BaseFactory(gym.Env):
|
|||||||
# Agent seems to be trying to collide in this step
|
# Agent seems to be trying to collide in this step
|
||||||
return old_pos, valid
|
return old_pos, valid
|
||||||
|
|
||||||
|
def _check_agent_move(self, agent_i, action: str):
|
||||||
|
agent_slice = self._state[h.AGENT_START_IDX + agent_i] # horizontal slice from state tensor
|
||||||
|
agent_pos = np.argwhere(agent_slice == 1)
|
||||||
|
if len(agent_pos) > 1:
|
||||||
|
raise AssertionError('Only one agent per slice is allowed.')
|
||||||
|
x, y = agent_pos[0]
|
||||||
|
|
||||||
|
# Actions
|
||||||
|
x_diff, y_diff = h.ACTIONMAP[action]
|
||||||
|
x_new = x + x_diff
|
||||||
|
y_new = y + y_diff
|
||||||
|
|
||||||
|
valid = h.check_position(self._state[h.LEVEL_IDX], (x_new, y_new))
|
||||||
|
|
||||||
|
return (x, y), (x_new, y_new), valid
|
||||||
|
|
||||||
def agent_i_position(self, agent_i: int) -> (int, int):
|
def agent_i_position(self, agent_i: int) -> (int, int):
|
||||||
positions = np.argwhere(self._state[h.AGENT_START_IDX + agent_i] == h.IS_OCCUPIED_CELL)
|
positions = np.argwhere(self._state[h.AGENT_START_IDX + agent_i] == h.IS_OCCUPIED_CELL)
|
||||||
assert positions.shape[0] == 1
|
assert positions.shape[0] == 1
|
||||||
|
@ -38,19 +38,19 @@ def one_hot_level(level, wall_char=WALL):
|
|||||||
return binary_grid
|
return binary_grid
|
||||||
|
|
||||||
|
|
||||||
def check_position(state: np.ndarray, position_to_check: Tuple[int, int], dim: int = 0):
|
def check_position(slice_to_check_against: np.ndarray, position_to_check: Tuple[int, int]):
|
||||||
x, y = position_to_check
|
x_pos, y_pos = position_to_check
|
||||||
agent_slice = state[dim]
|
|
||||||
|
|
||||||
# Check if agent colides with grid boundrys
|
# Check if agent colides with grid boundrys
|
||||||
valid = not (
|
valid = not (
|
||||||
x < 0 or y < 0
|
x_pos < 0 or y_pos < 0
|
||||||
or x >= agent_slice.shape[0]
|
or x_pos >= slice_to_check_against.shape[0]
|
||||||
or y >= agent_slice.shape[0]
|
or y_pos >= slice_to_check_against.shape[0]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check for collision with level walls
|
# Check for collision with level walls
|
||||||
valid = valid and not state[LEVEL_IDX][x, y]
|
valid = valid and not slice_to_check_against[x_pos, y_pos]
|
||||||
|
return valid
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
4
main.py
4
main.py
@ -85,8 +85,8 @@ def compare_runs(run_path: Path, run_identifier: int, parameter: Union[str, List
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
||||||
compare_runs(Path('debug_out'), 1623052687, ['step_reward'])
|
# compare_runs(Path('debug_out'), 1623052687, ['step_reward'])
|
||||||
exit()
|
# exit()
|
||||||
|
|
||||||
from stable_baselines3 import PPO, DQN, A2C
|
from stable_baselines3 import PPO, DQN, A2C
|
||||||
from algorithms.reg_dqn import RegDQN
|
from algorithms.reg_dqn import RegDQN
|
||||||
|
Loading…
x
Reference in New Issue
Block a user