fix validity check

This commit is contained in:
romue 2021-05-07 14:04:59 +02:00
parent 3c86684586
commit 020288fd55
2 changed files with 6 additions and 4 deletions

View File

@ -32,6 +32,7 @@ class Factory(object):
# level, agent 1,..., agent n,
for i, a in enumerate(actions):
old_pos, new_pos, valid = h.check_agent_move(state=self.state, dim=i+1, action=a)
print(old_pos, new_pos, valid)
if __name__ == '__main__':

View File

@ -50,10 +50,11 @@ def check_agent_move(state, dim, action):
x_new -= 1
y_new -= 1
# Check validity
valid = (x_new < 0 or y_new < 0
or x_new >= agent_slice.shape[0]
or y_new >= agent_slice.shape[0]
)
valid = not (
x_new < 0 or y_new < 0
or x_new >= agent_slice.shape[0]
or y_new >= agent_slice.shape[0]
) # if agent tried to leave the grid
return (x, y), (x_new, y_new), valid