From c539e5dddd6026881d3f6fa1f27b35bec020b2d6 Mon Sep 17 00:00:00 2001 From: steffen-illium Date: Tue, 11 May 2021 17:30:06 +0200 Subject: [PATCH] Refactoring of movement logic and parallel collision checks --- environments/factory/base_factory.py | 38 +++++++++++++++++++------- environments/factory/simple_factory.py | 2 +- environments/helpers.py | 11 +------- 3 files changed, 30 insertions(+), 21 deletions(-) diff --git a/environments/factory/base_factory.py b/environments/factory/base_factory.py index 4fe7291..3bd1286 100644 --- a/environments/factory/base_factory.py +++ b/environments/factory/base_factory.py @@ -31,19 +31,27 @@ class BaseFactory: # Returns State, Reward, Done, Info return self.state, 0, self.done, {} + def additional_actions(self, agent_i, action) -> ((int, int), bool): + raise NotImplementedError + def step(self, actions): actions = [actions] if isinstance(actions, int) else actions assert isinstance(actions, list), f'"actions has to be in [{int, list}]' self.steps += 1 r = 0 - collision_vecs = np.zeros((self.n_agents, self.state.shape[0])) # n_agents x n_slices actions = list(enumerate(actions)) random.shuffle(actions) - for agent_i, action in actions: - new_pos, collision_vec, did_collide = self.move_or_colide(agent_i, action) - collision_vecs[agent_i] = collision_vec + if action <= 8: + pos, did_collide = self.move_or_colide(agent_i, action) + else: + pos, did_collide = self.additional_actions(agent_i, action) + actions[agent_i] = (pos, did_collide) + + collision_vecs = np.zeros((self.n_agents, self.state.shape[0])) # n_agents x n_slices + for agent_i, action in enumerate(actions): + collision_vecs[agent_i] = self.check_collisions(agent_i, *action) reward, info = self.step_core(collision_vecs, actions, r) r += reward @@ -51,22 +59,32 @@ class BaseFactory: self.done = True return self.state, r, self.done, info + def check_collisions(self, agent_i, pos, valid): + pos_x, pos_y = pos + collisions_vec = self.state[:, pos_x, pos_y].copy() # "vertical fiber" at position of agent i + collisions_vec[h.AGENT_START_IDX + agent_i] = h.IS_FREE_CELL # no self-collisions + if valid: + pass + else: + collisions_vec[h.LEVEL_IDX] = h.IS_OCCUPIED_CELL + return collisions_vec + def move(self, agent_i, old_pos, new_pos): (x, y), (x_new, y_new) = old_pos, new_pos self.state[agent_i + h.AGENT_START_IDX, x, y] = h.IS_FREE_CELL self.state[agent_i + h.AGENT_START_IDX, x_new, y_new] = h.IS_OCCUPIED_CELL def move_or_colide(self, agent_i, action) -> ((int, int), bool): - old_pos, new_pos, collision_vec, did_collide = h.check_agent_move(state=self.state, - dim=agent_i + h.AGENT_START_IDX, - action=action) - if not did_collide: + old_pos, new_pos, valid = h.check_agent_move(state=self.state, + dim=agent_i + h.AGENT_START_IDX, + action=action) + if valid: # Does not collide width level boundrys self.move(agent_i, old_pos, new_pos) - return new_pos, collision_vec, did_collide + return new_pos, valid else: # Agent seems to be trying to collide in this step - return old_pos, collision_vec, did_collide + return old_pos, valid @property def free_cells(self) -> np.ndarray: diff --git a/environments/factory/simple_factory.py b/environments/factory/simple_factory.py index 3af863d..0588f82 100644 --- a/environments/factory/simple_factory.py +++ b/environments/factory/simple_factory.py @@ -30,6 +30,6 @@ class SimpleFactory(BaseFactory): if __name__ == '__main__': import random factory = SimpleFactory(n_agents=1, max_dirt=8) - random_actions = [random.randint(0, 8) for _ in range(200)] + random_actions = [random.randint(0, 7) for _ in range(200)] for action in random_actions: state, r, done, _ = factory.step(action) diff --git a/environments/helpers.py b/environments/helpers.py index 45870b0..be3d274 100644 --- a/environments/helpers.py +++ b/environments/helpers.py @@ -64,16 +64,7 @@ def check_agent_move(state, dim, action): or y_new >= agent_slice.shape[0] ) - if valid: - collisions_vec = state[:, x_new, y_new].copy() # "vertical fiber" at position of agent i - collisions_vec[dim] = IS_FREE_CELL # no self-collisions - pass - else: - collisions_vec = state[:, x, y].copy() # "vertical fiber" at position of agent i - collisions_vec[dim] = IS_FREE_CELL # no self-collisions - collisions_vec[LEVEL_IDX] = IS_OCCUPIED_CELL - did_collide = collisions_vec.sum(0) != IS_FREE_CELL - return (x, y), (x_new, y_new), collisions_vec, did_collide + return (x, y), (x_new, y_new), valid if __name__ == '__main__':