diff --git a/environments/factory/factory_cleaning.py b/environments/factory/factory_cleaning.py index 4372579..8b35167 100644 --- a/environments/factory/factory_cleaning.py +++ b/environments/factory/factory_cleaning.py @@ -31,7 +31,8 @@ class Factory(object): actions = [actions] # level, agent 1,..., agent n, for i, a in enumerate(actions): - h.check_agent_move(state=self.state, dim=i+1, action=a) + old_pos, new_pos, valid = h.check_agent_move(state=self.state, dim=i+1, action=a) + if __name__ == '__main__': factory = Factory(n_agents=1) diff --git a/environments/helpers.py b/environments/helpers.py index 57f593d..87abc56 100644 --- a/environments/helpers.py +++ b/environments/helpers.py @@ -28,6 +28,7 @@ def check_agent_move(state, dim, action): raise AssertionError('Only one agent per slice is allowed.') x, y = agent_pos[0] x_new, y_new = x, y + # Actions if action == 0: # North x_new -= 1 elif action == 1: # East @@ -48,6 +49,14 @@ def check_agent_move(state, dim, action): elif action == 7: # NW x_new -= 1 y_new -= 1 + # Check validity + valid = (x_new < 0 or y_new < 0 + or x_new >= agent_slice.shape[0] + or y_new >= agent_slice.shape[0] + ) + return (x, y), (x_new, y_new), valid + +