multi_agent observation when n_agent more then 1
This commit is contained in:
@ -1,6 +1,6 @@
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from typing import List, Union, Iterable, NamedTuple
|
||||
from typing import List, Union, Iterable
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
@ -9,115 +9,7 @@ from gym import spaces
|
||||
import yaml
|
||||
|
||||
from environments import helpers as h
|
||||
|
||||
|
||||
class MovementProperties(NamedTuple):
|
||||
allow_square_movement: bool = True
|
||||
allow_diagonal_movement: bool = False
|
||||
allow_no_op: bool = False
|
||||
|
||||
|
||||
class Entity():
|
||||
|
||||
@property
|
||||
def pos(self):
|
||||
return self._pos
|
||||
|
||||
def __init__(self, pos):
|
||||
self._pos = pos
|
||||
|
||||
|
||||
class AgentState:
|
||||
|
||||
def __init__(self, i: int, action: int):
|
||||
self.i = i
|
||||
self.action = action
|
||||
|
||||
self.collision_vector = None
|
||||
self.action_valid = None
|
||||
self.pos = None
|
||||
self.info = {}
|
||||
|
||||
@property
|
||||
def collisions(self):
|
||||
return np.argwhere(self.collision_vector != 0).flatten()
|
||||
|
||||
def update(self, **kwargs): # is this hacky?? o.0
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(self, key):
|
||||
self.__setattr__(key, value)
|
||||
else:
|
||||
raise AttributeError(f'"{key}" cannot be updated, this attr is not a part of {self.__class__.__name__}')
|
||||
|
||||
|
||||
class Register:
|
||||
|
||||
@property
|
||||
def n(self):
|
||||
return len(self)
|
||||
|
||||
def __init__(self):
|
||||
self._register = dict()
|
||||
|
||||
def __len__(self):
|
||||
return len(self._register)
|
||||
|
||||
def __add__(self, other: Union[str, List[str]]):
|
||||
other = other if isinstance(other, list) else [other]
|
||||
assert all([isinstance(x, str) for x in other]), f'All item names have to be of type {str}.'
|
||||
self._register.update({key+len(self._register): value for key, value in enumerate(other)})
|
||||
return self
|
||||
|
||||
def register_additional_items(self, other: Union[str, List[str]]):
|
||||
self_with_additional_items = self + other
|
||||
return self_with_additional_items
|
||||
|
||||
def keys(self):
|
||||
return self._register.keys()
|
||||
|
||||
def items(self):
|
||||
return self._register.items()
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self._register[item]
|
||||
|
||||
def by_name(self, item):
|
||||
return list(self._register.keys())[list(self._register.values()).index(item)]
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.__class__.__name__}({self._register})'
|
||||
|
||||
|
||||
class Actions(Register):
|
||||
|
||||
@property
|
||||
def movement_actions(self):
|
||||
return self._movement_actions
|
||||
|
||||
def __init__(self, movement_properties: MovementProperties):
|
||||
self.allow_no_op = movement_properties.allow_no_op
|
||||
self.allow_diagonal_movement = movement_properties.allow_diagonal_movement
|
||||
self.allow_square_movement = movement_properties.allow_square_movement
|
||||
# FIXME: There is a bug in helpers because there actions are ints. and the order matters.
|
||||
assert not(self.allow_square_movement is False and self.allow_diagonal_movement is True), \
|
||||
"There is a bug in helpers!!!"
|
||||
super(Actions, self).__init__()
|
||||
|
||||
if self.allow_square_movement:
|
||||
self + ['north', 'east', 'south', 'west']
|
||||
if self.allow_diagonal_movement:
|
||||
self + ['north-east', 'south-east', 'south-west', 'north-west']
|
||||
self._movement_actions = self._register.copy()
|
||||
if self.allow_no_op:
|
||||
self + 'no-op'
|
||||
|
||||
|
||||
class StateSlice(Register):
|
||||
|
||||
def __init__(self, n_agents: int):
|
||||
super(StateSlice, self).__init__()
|
||||
offset = 1
|
||||
self.register_additional_items(['level', *[f'agent#{i}' for i in range(offset, n_agents+offset)]])
|
||||
from environments.utility_classes import Actions, StateSlice, AgentState, MovementProperties
|
||||
|
||||
|
||||
# noinspection PyAttributeOutsideInit
|
||||
@ -148,9 +40,11 @@ class BaseFactory(gym.Env):
|
||||
def movement_actions(self):
|
||||
return self._actions.movement_actions
|
||||
|
||||
def __init__(self, level_name='simple', n_agents=1, max_steps=int(5e2), pomdp_radius: Union[None, int] = None,
|
||||
def __init__(self, level_name='simple', n_agents=1, max_steps=int(5e2), pomdp_radius: Union[None, int] = 0,
|
||||
movement_properties: MovementProperties = MovementProperties(),
|
||||
combin_agent_slices_in_obs: bool = False,
|
||||
omit_agent_slice_in_obs=False, **kwargs):
|
||||
assert combin_agent_slices_in_obs != omit_agent_slice_in_obs, 'Both options are exclusive'
|
||||
|
||||
self.movement_properties = movement_properties
|
||||
self.level_name = level_name
|
||||
@ -158,6 +52,7 @@ class BaseFactory(gym.Env):
|
||||
self.n_agents = n_agents
|
||||
self.max_steps = max_steps
|
||||
self.pomdp_radius = pomdp_radius
|
||||
self.combin_agent_slices_in_obs = combin_agent_slices_in_obs
|
||||
self.omit_agent_slice_in_obs = omit_agent_slice_in_obs
|
||||
|
||||
self.done_at_collision = False
|
||||
@ -185,7 +80,7 @@ class BaseFactory(gym.Env):
|
||||
raise NotImplementedError('Please register additional actions ')
|
||||
|
||||
def reset(self) -> (np.ndarray, int, bool, dict):
|
||||
self.steps = 0
|
||||
self._steps = 0
|
||||
self._agent_states = []
|
||||
# Agent placement ...
|
||||
agents = np.zeros((self.n_agents, *self._level.shape), dtype=np.int8)
|
||||
@ -202,17 +97,25 @@ class BaseFactory(gym.Env):
|
||||
# Returns State
|
||||
return None
|
||||
|
||||
def _return_state(self):
|
||||
def _get_observations(self) -> np.ndarray:
|
||||
if self.n_agents == 1:
|
||||
obs = self._build_per_agent_obs(0)
|
||||
elif self.n_agents >= 2:
|
||||
obs = np.stack([self._build_per_agent_obs(agent_i) for agent_i in range(self.n_agents)])
|
||||
return obs
|
||||
|
||||
def _build_per_agent_obs(self, agent_i: int) -> np.ndarray:
|
||||
if self.pomdp_radius:
|
||||
pos = self._agent_states[0].pos
|
||||
# pos = [agent_state.pos for agent_state in self.agent_states]
|
||||
# obs = [] ... list comprehension... pos per agent
|
||||
x0, x1 = max(0, pos[0] - self.pomdp_radius), pos[0] + self.pomdp_radius + 1
|
||||
y0, y1 = max(0, pos[1] - self.pomdp_radius), pos[1] + self.pomdp_radius + 1
|
||||
global_pos = self._agent_states[agent_i].pos
|
||||
x0, x1 = max(0, global_pos[0] - self.pomdp_radius), global_pos[0] + self.pomdp_radius + 1
|
||||
y0, y1 = max(0, global_pos[1] - self.pomdp_radius), global_pos[1] + self.pomdp_radius + 1
|
||||
obs = self._state[:, x0:x1, y0:y1]
|
||||
if obs.shape[1] != self.pomdp_radius * 2 + 1 or obs.shape[2] != self.pomdp_radius * 2 + 1:
|
||||
obs_padded = np.full((obs.shape[0], self.pomdp_radius * 2 + 1, self.pomdp_radius * 2 + 1), 1)
|
||||
a_pos = np.argwhere(obs[h.AGENT_START_IDX] == h.IS_OCCUPIED_CELL)[0]
|
||||
try:
|
||||
a_pos = np.argwhere(obs[h.AGENT_START_IDX + agent_i] == h.IS_OCCUPIED_CELL)[0]
|
||||
except IndexError:
|
||||
print('NO')
|
||||
obs_padded[:,
|
||||
abs(a_pos[0]-self.pomdp_radius):abs(a_pos[0]-self.pomdp_radius)+obs.shape[1],
|
||||
abs(a_pos[1]-self.pomdp_radius):abs(a_pos[1]-self.pomdp_radius)+obs.shape[2]] = obs
|
||||
@ -223,7 +126,13 @@ class BaseFactory(gym.Env):
|
||||
obs_new = obs[[key for key, val in self._state_slices.items() if 'agent' not in val]]
|
||||
return obs_new
|
||||
else:
|
||||
return obs
|
||||
if self.combin_agent_slices_in_obs:
|
||||
agent_obs = np.sum(obs[[key for key, val in self._state_slices.items() if 'agent' in val]],
|
||||
axis=0, keepdims=True)
|
||||
obs = np.concatenate((obs[:h.AGENT_START_IDX], agent_obs, obs[h.AGENT_START_IDX+self.n_agents:]))
|
||||
return obs
|
||||
else:
|
||||
return obs
|
||||
|
||||
def do_additional_actions(self, agent_i: int, action: int) -> ((int, int), bool):
|
||||
raise NotImplementedError
|
||||
@ -231,16 +140,16 @@ class BaseFactory(gym.Env):
|
||||
def step(self, actions):
|
||||
actions = [actions] if isinstance(actions, int) or np.isscalar(actions) else actions
|
||||
assert isinstance(actions, Iterable), f'"actions" has to be in [{int, list}]'
|
||||
self.steps += 1
|
||||
self._steps += 1
|
||||
done = False
|
||||
|
||||
# Move this in a seperate function?
|
||||
agent_states = list()
|
||||
for agent_i, action in enumerate(actions):
|
||||
agent_i_state = AgentState(agent_i, action)
|
||||
if self._is_moving_action(action):
|
||||
if self._actions.is_moving_action(action):
|
||||
pos, valid = self.move_or_colide(agent_i, action)
|
||||
elif self._is_no_op(action):
|
||||
elif self._actions.is_no_op(action):
|
||||
pos, valid = self.agent_i_position(agent_i), True
|
||||
else:
|
||||
pos, valid = self.do_additional_actions(agent_i, action)
|
||||
@ -256,24 +165,18 @@ class BaseFactory(gym.Env):
|
||||
self._agent_states = agent_states
|
||||
reward, info = self.calculate_reward(agent_states)
|
||||
|
||||
if self.steps >= self.max_steps:
|
||||
if self._steps >= self.max_steps:
|
||||
done = True
|
||||
|
||||
info.update(step_reward=reward, step=self.steps)
|
||||
info.update(step_reward=reward, step=self._steps)
|
||||
|
||||
return None, reward, done, info
|
||||
|
||||
def _is_moving_action(self, action):
|
||||
return action in self._actions.movement_actions
|
||||
|
||||
def _is_no_op(self, action):
|
||||
return self._actions[action] == 'no-op'
|
||||
|
||||
def check_all_collisions(self, agent_states: List[AgentState], collisions: int) -> np.ndarray:
|
||||
collision_vecs = np.zeros((len(agent_states), collisions)) # n_agents x n_slices
|
||||
for agent_state in agent_states:
|
||||
# Register only collisions of moving agents
|
||||
if self._is_moving_action(agent_state.action):
|
||||
if self._actions.is_moving_action(agent_state.action):
|
||||
collision_vecs[agent_state.i] = self.check_collisions(agent_state)
|
||||
return collision_vecs
|
||||
|
||||
|
Reference in New Issue
Block a user