diff --git a/environments/factory/base_factory.py b/environments/factory/base_factory.py index 58dfedd..25e089c 100644 --- a/environments/factory/base_factory.py +++ b/environments/factory/base_factory.py @@ -63,9 +63,24 @@ class BaseFactory: self.done = True return self.state, r, self.done, info - def check_collisions(self, agent_i, pos, valid): + def _is_moving_action(self, action): + movement_actions = (int(self.allow_vertical_movement) + int(self.allow_horizontal_movement)) * 4 + if action < movement_actions: + return True + else: + return False + + def check_all_collisions(self, agent_action_pos_valid_tuples: (int, int, (int, int), bool), collisions: int) -> np.ndarray: + collision_vecs = np.zeros((len(agent_action_pos_valid_tuples), collisions)) # n_agents x n_slices + for agent_i, action, pos, valid in agent_action_pos_valid_tuples: + if self._is_moving_action(action): + collision_vecs[agent_i] = self.check_collisions(agent_i, pos, valid) + return collision_vecs + + def check_collisions(self, agent_i: int, pos: (int, int), valid: bool) -> np.ndarray: pos_x, pos_y = pos - collisions_vec = self.state[:, pos_x, pos_y].copy() # "vertical fiber" at position of agent i + # FixMe: We need to find a way to spare out some dimensions, eg. an info dimension etc... a[?,] + collisions_vec = self.state[:, pos_x, pos_y].copy() # "vertical fiber" at position of agent i collisions_vec[h.AGENT_START_IDX + agent_i] = h.IS_FREE_CELL # no self-collisions if valid: # ToDo: Place a function hook here @@ -84,10 +99,12 @@ class BaseFactory: dim=agent_i + h.AGENT_START_IDX, action=action) if valid: - # Does not collide width level boundaries + # Does not collide width level boundrys self.do_move(agent_i, old_pos, new_pos) return new_pos, valid - return old_pos, valid + else: + # Agent seems to be trying to collide in this step + return old_pos, valid @property def free_cells(self) -> np.ndarray: diff --git a/environments/factory/simple_factory_getting_dirty.py b/environments/factory/simple_factory_getting_dirty.py index c018130..c3a8b01 100644 --- a/environments/factory/simple_factory_getting_dirty.py +++ b/environments/factory/simple_factory_getting_dirty.py @@ -2,15 +2,13 @@ import numpy as np from environments.factory.base_factory import BaseFactory from collections import namedtuple - +DIRT_INDEX = -1 DirtProperties = namedtuple('DirtProperties', ['clean_amount', 'max_spawn_ratio', 'gain_amount'], defaults=[0.25, 0.1, 0.1]) class GettingDirty(BaseFactory): - _dirt_indx = -1 - def __init__(self, *args, dirt_properties:DirtProperties, **kwargs): super(GettingDirty, self).__init__(*args, **kwargs) self._dirt_properties = dirt_properties @@ -21,7 +19,10 @@ class GettingDirty(BaseFactory): # randomly distribute dirt across the grid n_dirt_tiles = self._dirt_properties.max_spawn_ratio * len(free_for_dirt) for x, y in free_for_dirt[:n_dirt_tiles]: - self.state[self._dirt_indx, x, y] += self._dirt_properties.gain_amount + self.state[DIRT_INDEX, x, y] += self._dirt_properties.gain_amount + + def additional_actions(self, agent_i, action) -> ((int, int), bool): + if action == def reset(self): # ToDo: When self.reset returns the new states and stuff, use it here!