alles was ich hab
This commit is contained in:
@ -113,15 +113,19 @@ class BaseFactory(gym.Env):
|
||||
def movement_actions(self):
|
||||
return self._actions.movement_actions
|
||||
|
||||
def __init__(self, level='simple', n_agents=1, max_steps=int(5e2), pomdp_radius: Union[None, int] = None, **kwargs):
|
||||
def __init__(self, level='simple', n_agents=1, max_steps=int(5e2), pomdp_radius: Union[None, int] = None,
|
||||
allow_square_movement=True, allow_diagonal_movement=True, allow_no_op=True, **kwargs):
|
||||
self.allow_no_op = allow_no_op
|
||||
self.allow_diagonal_movement = allow_diagonal_movement
|
||||
self.allow_square_movement = allow_square_movement
|
||||
self.n_agents = n_agents
|
||||
self.max_steps = max_steps
|
||||
self.pomdp_radius = pomdp_radius
|
||||
|
||||
self.done_at_collision = False
|
||||
_actions = Actions(allow_square_movement=kwargs.get('allow_square_movement', True),
|
||||
allow_diagonal_movement=kwargs.get('allow_diagonal_movement', True),
|
||||
allow_no_op=kwargs.get('allow_no_op', True))
|
||||
_actions = Actions(allow_square_movement=self.allow_square_movement,
|
||||
allow_diagonal_movement=self.allow_diagonal_movement,
|
||||
allow_no_op=allow_no_op)
|
||||
self._actions = _actions + self.additional_actions
|
||||
|
||||
self._level = h.one_hot_level(
|
||||
@ -165,12 +169,16 @@ class BaseFactory(gym.Env):
|
||||
pos = self._agent_states[0].pos
|
||||
# pos = [agent_state.pos for agent_state in self.agent_states]
|
||||
# obs = [] ... list comprehension... pos per agent
|
||||
npad = [(0, 0)] + [(self.pomdp_radius, self.pomdp_radius)] * (self._state.ndim - 1)
|
||||
x_roll = self.pomdp_radius-pos[0]
|
||||
y_roll = self.pomdp_radius-pos[1]
|
||||
padded_state = np.pad(self._state, pad_width=npad, mode='constant', constant_values=0)
|
||||
padded_state = np.roll(np.roll(padded_state, x_roll, axis=1), y_roll, axis=2)
|
||||
obs = padded_state[:, :self.pomdp_radius * 2 + 1, :self.pomdp_radius * 2 + 1]
|
||||
x0, x1 = max(0, pos[0] - self.pomdp_radius), pos[0] + self.pomdp_radius + 1
|
||||
y0, y1 = max(0, pos[1] - self.pomdp_radius), pos[1] + self.pomdp_radius + 1
|
||||
obs = self._state[:, x0:x1, y0:y1]
|
||||
if obs.shape[1] != self.pomdp_radius ** 2 + 1 or obs.shape[2] != self.pomdp_radius ** 2 + 1:
|
||||
obs_padded = np.full((obs.shape[0], self.pomdp_radius ** 2 + 1, self.pomdp_radius ** 2 + 1), 1)
|
||||
a_pos = np.argwhere(obs[h.AGENT_START_IDX] == h.IS_OCCUPIED_CELL)[0]
|
||||
obs_padded[:,
|
||||
abs(a_pos[0]-self.pomdp_radius):abs(a_pos[0]-self.pomdp_radius)+obs.shape[1],
|
||||
abs(a_pos[1]-self.pomdp_radius):abs(a_pos[1]-self.pomdp_radius)+obs.shape[2]] = obs
|
||||
obs = obs_padded
|
||||
else:
|
||||
obs = self._state
|
||||
return obs
|
||||
@ -211,8 +219,7 @@ class BaseFactory(gym.Env):
|
||||
|
||||
info.update(step_reward=reward, step=self.steps)
|
||||
|
||||
obs = self._return_state()
|
||||
return obs, reward, done, info
|
||||
return None, reward, done, info
|
||||
|
||||
def _is_moving_action(self, action):
|
||||
return action in self._actions.movement_actions
|
||||
|
Reference in New Issue
Block a user