import pickle from pathlib import Path from typing import List, Union, Iterable, NamedTuple import gym import numpy as np from gym import spaces import yaml from environments import helpers as h class MovementProperties(NamedTuple): allow_square_movement: bool = False allow_diagonal_movement: bool = False allow_no_op: bool = False class AgentState: def __init__(self, i: int, action: int): self.i = i self.action = action self.collision_vector = None self.action_valid = None self.pos = None self.info = {} @property def collisions(self): return np.argwhere(self.collision_vector != 0).flatten() def update(self, **kwargs): # is this hacky?? o.0 for key, value in kwargs.items(): if hasattr(self, key): self.__setattr__(key, value) else: raise AttributeError(f'"{key}" cannot be updated, this attr is not a part of {self.__class__.__name__}') class Register: @property def n(self): return len(self) def __init__(self): self._register = dict() def __len__(self): return len(self._register) def __add__(self, other: Union[str, List[str]]): other = other if isinstance(other, list) else [other] assert all([isinstance(x, str) for x in other]), f'All item names have to be of type {str}.' self._register.update({key+len(self._register): value for key, value in enumerate(other)}) return self def register_additional_items(self, other: Union[str, List[str]]): self_with_additional_items = self + other return self_with_additional_items def keys(self): return self._register.keys() def items(self): return self._register.items() def __getitem__(self, item): return self._register[item] def by_name(self, item): return list(self._register.keys())[list(self._register.values()).index(item)] def __repr__(self): return f'{self.__class__.__name__}({self._register})' class Actions(Register): @property def movement_actions(self): return self._movement_actions def __init__(self, movement_properties: MovementProperties): self.allow_no_op = movement_properties.allow_no_op self.allow_diagonal_movement = movement_properties.allow_diagonal_movement self.allow_square_movement = movement_properties.allow_square_movement # FIXME: There is a bug in helpers because there actions are ints. and the order matters. assert not(self.allow_square_movement is False and self.allow_diagonal_movement is True), "There is a bug in helpers!!!" super(Actions, self).__init__() if self.allow_square_movement: self + ['north', 'east', 'south', 'west'] if self.allow_diagonal_movement: self + ['north-east', 'south-east', 'south-west', 'north-west'] self._movement_actions = self._register.copy() if self.allow_no_op: self + 'no-op' class StateSlice(Register): def __init__(self, n_agents: int): super(StateSlice, self).__init__() offset = 1 self.register_additional_items(['level', *[f'agent#{i}' for i in range(offset, n_agents+offset)]]) class BaseFactory(gym.Env): @property def action_space(self): return spaces.Discrete(self._actions.n) @property def observation_space(self): agent_slice = self.n_agents if self.omit_agent_slice_in_obs else 0 if self.pomdp_radius: return spaces.Box(low=0, high=1, shape=(self._state.shape[0] - agent_slice, self.pomdp_radius * 2 + 1, self.pomdp_radius * 2 + 1), dtype=np.float32) else: shape = [x-agent_slice if idx == 0 else x for idx, x in enumerate(self._state.shape)] space = spaces.Box(low=0, high=1, shape=shape, dtype=np.float32) return space @property def movement_actions(self): return self._actions.movement_actions def __init__(self, level='simple', n_agents=1, max_steps=int(5e2), pomdp_radius: Union[None, int] = None, movement_properties: MovementProperties = MovementProperties(), omit_agent_slice_in_obs=False, **kwargs): self.movement_properties = movement_properties self.n_agents = n_agents self.max_steps = max_steps self.pomdp_radius = pomdp_radius self.omit_agent_slice_in_obs = omit_agent_slice_in_obs self.done_at_collision = False _actions = Actions(self.movement_properties) self._actions = _actions + self.additional_actions self._level = h.one_hot_level( h.parse_level(Path(__file__).parent / h.LEVELS_DIR / f'{level}.txt') ) self._state_slices = StateSlice(n_agents) if 'additional_slices' in kwargs: self._state_slices.register_additional_items(kwargs.get('additional_slices')) self.reset() @property def additional_actions(self) -> Union[str, List[str]]: """ When heriting from this Base Class, you musst implement this methode!!! Please return a dict with the given types -> {int: str}. The int should start at 0. :return: An Actions-object holding all actions with keys in range 0-n. :rtype: Actions """ raise NotImplementedError('Please register additional actions ') def reset(self) -> (np.ndarray, int, bool, dict): self.steps = 0 self._agent_states = [] # Agent placement ... agents = np.zeros((self.n_agents, *self._level.shape), dtype=np.int8) floor_tiles = np.argwhere(self._level == h.IS_FREE_CELL) # ... on random positions np.random.shuffle(floor_tiles) for i, (x, y) in enumerate(floor_tiles[:self.n_agents]): agents[i, x, y] = h.IS_OCCUPIED_CELL agent_state = AgentState(i, -1) agent_state.update(pos=[x, y]) self._agent_states.append(agent_state) # state.shape = level, agent 1,..., agent n, self._state = np.concatenate((np.expand_dims(self._level, axis=0), agents), axis=0) # Returns State return None def _return_state(self): if self.pomdp_radius: pos = self._agent_states[0].pos # pos = [agent_state.pos for agent_state in self.agent_states] # obs = [] ... list comprehension... pos per agent x0, x1 = max(0, pos[0] - self.pomdp_radius), pos[0] + self.pomdp_radius + 1 y0, y1 = max(0, pos[1] - self.pomdp_radius), pos[1] + self.pomdp_radius + 1 obs = self._state[:, x0:x1, y0:y1] if obs.shape[1] != self.pomdp_radius ** 2 + 1 or obs.shape[2] != self.pomdp_radius ** 2 + 1: obs_padded = np.full((obs.shape[0], self.pomdp_radius ** 2 + 1, self.pomdp_radius ** 2 + 1), 1) a_pos = np.argwhere(obs[h.AGENT_START_IDX] == h.IS_OCCUPIED_CELL)[0] obs_padded[:, abs(a_pos[0]-self.pomdp_radius):abs(a_pos[0]-self.pomdp_radius)+obs.shape[1], abs(a_pos[1]-self.pomdp_radius):abs(a_pos[1]-self.pomdp_radius)+obs.shape[2]] = obs obs = obs_padded else: obs = self._state if self.omit_agent_slice_in_obs: obs_new = obs[[key for key, val in self._state_slices.items() if 'agent' not in val]] return obs_new else: return obs def do_additional_actions(self, agent_i: int, action: int) -> ((int, int), bool): raise NotImplementedError def step(self, actions): actions = [actions] if isinstance(actions, int) or np.isscalar(actions) else actions assert isinstance(actions, Iterable), f'"actions" has to be in [{int, list}]' self.steps += 1 done = False # Move this in a seperate function? agent_states = list() for agent_i, action in enumerate(actions): agent_i_state = AgentState(agent_i, action) if self._is_moving_action(action): pos, valid = self.move_or_colide(agent_i, action) elif self._is_no_op(action): pos, valid = self.agent_i_position(agent_i), True else: pos, valid = self.do_additional_actions(agent_i, action) # Update state accordingly agent_i_state.update(pos=pos, action_valid=valid) agent_states.append(agent_i_state) for i, collision_vec in enumerate(self.check_all_collisions(agent_states, self._state.shape[0])): agent_states[i].update(collision_vector=collision_vec) if self.done_at_collision and collision_vec.any(): done = True self._agent_states = agent_states reward, info = self.calculate_reward(agent_states) if self.steps >= self.max_steps: done = True info.update(step_reward=reward, step=self.steps) return None, reward, done, info def _is_moving_action(self, action): return action in self._actions.movement_actions def _is_no_op(self, action): return self._actions[action] == 'no-op' def check_all_collisions(self, agent_states: List[AgentState], collisions: int) -> np.ndarray: collision_vecs = np.zeros((len(agent_states), collisions)) # n_agents x n_slices for agent_state in agent_states: # Register only collisions of moving agents if self._is_moving_action(agent_state.action): collision_vecs[agent_state.i] = self.check_collisions(agent_state) return collision_vecs def check_collisions(self, agent_state: AgentState) -> np.ndarray: pos_x, pos_y = agent_state.pos # FixMe: We need to find a way to spare out some dimensions, eg. an info dimension etc... a[?,] collisions_vec = self._state[:, pos_x, pos_y].copy() # "vertical fiber" at position of agent i collisions_vec[h.AGENT_START_IDX + agent_state.i] = h.IS_FREE_CELL # no self-collisions if agent_state.action_valid: # ToDo: Place a function hook here pass else: # Place a marker to indicate a collision with the level boundrys collisions_vec[h.LEVEL_IDX] = h.IS_OCCUPIED_CELL return collisions_vec def do_move(self, agent_i: int, old_pos: (int, int), new_pos: (int, int)) -> None: (x, y), (x_new, y_new) = old_pos, new_pos self._state[agent_i + h.AGENT_START_IDX, x, y] = h.IS_FREE_CELL self._state[agent_i + h.AGENT_START_IDX, x_new, y_new] = h.IS_OCCUPIED_CELL def move_or_colide(self, agent_i: int, action: int) -> ((int, int), bool): old_pos, new_pos, valid = h.check_agent_move(state=self._state, dim=agent_i + h.AGENT_START_IDX, action=action) if valid: # Does not collide width level boundaries self.do_move(agent_i, old_pos, new_pos) return new_pos, valid else: # Agent seems to be trying to collide in this step return old_pos, valid def agent_i_position(self, agent_i: int) -> (int, int): positions = np.argwhere(self._state[h.AGENT_START_IDX + agent_i] == h.IS_OCCUPIED_CELL) assert positions.shape[0] == 1 pos_x, pos_y = positions[0] # a.flatten() return pos_x, pos_y def free_cells(self, excluded_slices: Union[None, List[int], int] = None) -> np.array: excluded_slices = excluded_slices or [] assert isinstance(excluded_slices, (int, list)) excluded_slices = excluded_slices if isinstance(excluded_slices, list) else [excluded_slices] state = self._state if excluded_slices: # Todo: Is there a cleaner way? inds = list(range(self._state.shape[0])) excluded_slices = [inds[x] if x < 0 else x for x in excluded_slices] state = self._state[[x for x in inds if x not in excluded_slices]] free_cells = np.argwhere(state.sum(0) == h.IS_FREE_CELL) np.random.shuffle(free_cells) return free_cells def calculate_reward(self, agent_states: List[AgentState]) -> (int, dict): # Returns: Reward, Info raise NotImplementedError def render(self): raise NotImplementedError def save_params(self, filepath: Path): d = {key: val for key, val in self.__dict__.items() if not key.startswith('_') and not key.startswith('__')} filepath.parent.mkdir(parents=True, exist_ok=True) with filepath.open('wb') as f: # yaml.dump(d, f) pickle.dump(d, f, protocol=pickle.HIGHEST_PROTOCOL)