From f98f689f5ed73f89bddaaf05afb0d12edc948f4c Mon Sep 17 00:00:00 2001 From: romue Date: Mon, 10 May 2021 15:02:17 +0200 Subject: [PATCH] updated simple factory --- environments/factory/base_factory.py | 1 - environments/factory/simple_factory.py | 13 ++++++++----- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/environments/factory/base_factory.py b/environments/factory/base_factory.py index df4827f..cc7efee 100644 --- a/environments/factory/base_factory.py +++ b/environments/factory/base_factory.py @@ -46,7 +46,6 @@ class BaseFactory(object): x, y = np.argwhere(agent_slice == 1)[0] collisions_vec = self.state[:, x, y].copy() # otherwise you overwrite the grid/state collisions_vec[i+1] = 0 # no self-collisions - #collision_vecs.append(collisions_vec) collision_vecs[i] += collisions_vec reward, info = self.step_core(np.array(collision_vecs), actions, r) r += reward diff --git a/environments/factory/simple_factory.py b/environments/factory/simple_factory.py index f5b1604..abac209 100644 --- a/environments/factory/simple_factory.py +++ b/environments/factory/simple_factory.py @@ -8,18 +8,21 @@ class SimpleFactory(BaseFactory): super(SimpleFactory, self).__init__(*args, **kwargs) self.slice_strings.update({self.state.shape[0]-1: 'dirt'}) + def spawn_dirt(self): + free_for_dirt = self.free_cells() + for x, y in free_for_dirt[:self.max_dirt]: # randomly distribute dirt across the grid + self.state[-1, x, y] = 1 + def reset(self): super().reset() dirt_slice = np.zeros((1, *self.state.shape[1:])) self.state = np.concatenate((self.state, dirt_slice)) # dirt is now the last slice - free_for_dirt = self.free_cells() - for x, y in free_for_dirt[:self.max_dirt]: - self.state[-1, x, y] = 1 + self.spawn_dirt() def step_core(self, collisions_vecs, actions, r): for agent_i, cols in enumerate(collisions_vecs): cols = np.argwhere(cols != 0).flatten() - print(f'Agent #{agent_i} has collisions with ' + print(f't = {self.steps}\tAgent {agent_i} has collisions with ' f'{[self.slice_strings[entity] for entity in cols]}') return 0, {} @@ -28,6 +31,6 @@ class SimpleFactory(BaseFactory): if __name__ == '__main__': import random factory = SimpleFactory(n_agents=1, max_dirt=8) - random_actions = [random.randint(0,8) for _ in range(200)] + random_actions = [random.randint(0, 8) for _ in range(200)] for action in random_actions: state, r, done, _ = factory.step(action) \ No newline at end of file