actions are now checked by string

This commit is contained in:
steffen-illium
2021-06-07 16:52:06 +02:00
parent 2589a06d02
commit 62c141aa1c
3 changed files with 26 additions and 28 deletions

View File

@ -26,22 +26,6 @@ class Entity():
def __init__(self, pos):
self._pos = pos
def check_agent_move(state: np.ndarray, dim: int, action: str):
agent_slice = state[dim] # horizontal slice from state tensor
agent_pos = np.argwhere(agent_slice == 1)
if len(agent_pos) > 1:
raise AssertionError('Only one agent per slice is allowed.')
x, y = agent_pos[0]
# Actions
x_diff, y_diff = ACTIONMAP[action]
x_new = x + x_diff
y_new = y + y_diff
return (x, y), (x_new, y_new), valid
class AgentState:
@ -312,9 +296,7 @@ class BaseFactory(gym.Env):
self._state[agent_i + h.AGENT_START_IDX, x_new, y_new] = h.IS_OCCUPIED_CELL
def move_or_colide(self, agent_i: int, action: int) -> ((int, int), bool):
old_pos, new_pos, valid = h.check_agent_move(state=self._state,
dim=agent_i + h.AGENT_START_IDX,
action=action)
old_pos, new_pos, valid = self._check_agent_move(agent_i=agent_i, action=self._actions[action])
if valid:
# Does not collide width level boundaries
self.do_move(agent_i, old_pos, new_pos)
@ -323,6 +305,22 @@ class BaseFactory(gym.Env):
# Agent seems to be trying to collide in this step
return old_pos, valid
def _check_agent_move(self, agent_i, action: str):
agent_slice = self._state[h.AGENT_START_IDX + agent_i] # horizontal slice from state tensor
agent_pos = np.argwhere(agent_slice == 1)
if len(agent_pos) > 1:
raise AssertionError('Only one agent per slice is allowed.')
x, y = agent_pos[0]
# Actions
x_diff, y_diff = h.ACTIONMAP[action]
x_new = x + x_diff
y_new = y + y_diff
valid = h.check_position(self._state[h.LEVEL_IDX], (x_new, y_new))
return (x, y), (x_new, y_new), valid
def agent_i_position(self, agent_i: int) -> (int, int):
positions = np.argwhere(self._state[h.AGENT_START_IDX + agent_i] == h.IS_OCCUPIED_CELL)
assert positions.shape[0] == 1