Agents not smear Dirt
This commit is contained in:
@ -10,7 +10,7 @@ from gym.wrappers import FrameStack
|
||||
|
||||
from environments.helpers import Constants as c, Constants
|
||||
from environments import helpers as h
|
||||
from environments.factory.base.objects import Slice, Agent, Tile, Action, MoveableEntity
|
||||
from environments.factory.base.objects import Slice, Agent, Tile, Action
|
||||
from environments.factory.base.registers import StateSlices, Actions, Entities, Agents, Doors, FloorTiles
|
||||
from environments.utility_classes import MovementProperties
|
||||
|
||||
@ -85,9 +85,6 @@ class BaseFactory(gym.Env):
|
||||
movement_properties: MovementProperties = MovementProperties(), parse_doors=False,
|
||||
combin_agent_slices_in_obs: bool = False, frames_to_stack=0, record_episodes=False,
|
||||
omit_agent_slice_in_obs=False, done_at_collision=False, **kwargs):
|
||||
assert (combin_agent_slices_in_obs != omit_agent_slice_in_obs) or \
|
||||
(not combin_agent_slices_in_obs and not omit_agent_slice_in_obs), \
|
||||
'Both options are exclusive'
|
||||
assert frames_to_stack != 1 and frames_to_stack >= 0, "'frames_to_stack' cannot be negative or 1."
|
||||
|
||||
# Attribute Assignment
|
||||
@ -125,7 +122,7 @@ class BaseFactory(gym.Env):
|
||||
|
||||
# Doors
|
||||
parsed_doors = h.one_hot_level(parsed_level, c.DOOR)
|
||||
doors = [Slice(c.DOORS.value, parsed_doors)] if parsed_doors.any() and self.parse_doors else []
|
||||
doors = [Slice(c.DOORS.name, parsed_doors)] if parsed_doors.any() and self.parse_doors else []
|
||||
|
||||
# Agents
|
||||
agents = []
|
||||
@ -283,15 +280,17 @@ class BaseFactory(gym.Env):
|
||||
obs = self._padded_obs_cube[:, x0:x1, y0:y1]
|
||||
else:
|
||||
obs = self._obs_cube
|
||||
if self.omit_agent_slice_in_obs:
|
||||
obs_new = obs[[key for key, val in self._slices.items() if c.AGENT.value not in val]]
|
||||
return obs_new
|
||||
|
||||
if self.combin_agent_slices_in_obs and self.n_agents >= 1:
|
||||
agent_obs = np.sum(obs[[key for key, slice in self._slices.items() if c.AGENT.name in slice.name and
|
||||
(not self.omit_agent_slice_in_obs and slice.name != agent.name)]],
|
||||
axis=0, keepdims=True)
|
||||
obs = np.concatenate((obs[:first_agent_slice], agent_obs, obs[first_agent_slice+self.n_agents:]))
|
||||
return obs
|
||||
else:
|
||||
if self.combin_agent_slices_in_obs:
|
||||
agent_obs = np.sum(obs[[key for key, slice in self._slices.items() if c.AGENT.name in slice.name]],
|
||||
axis=0, keepdims=True)
|
||||
obs = np.concatenate((obs[:first_agent_slice], agent_obs, obs[first_agent_slice+self.n_agents:]))
|
||||
return obs
|
||||
if self.omit_agent_slice_in_obs:
|
||||
obs_new = obs[[key for key, val in self._slices.items() if c.AGENT.value not in val.name]]
|
||||
return obs_new
|
||||
else:
|
||||
return obs
|
||||
|
||||
|
@ -196,7 +196,7 @@ class Door(Entity):
|
||||
def encoding(self):
|
||||
return 1 if self.is_closed else -1
|
||||
|
||||
def __init__(self, *args, context, closed_on_init=True, auto_close_interval=500):
|
||||
def __init__(self, *args, context, closed_on_init=True, auto_close_interval=10):
|
||||
super(Door, self).__init__(*args)
|
||||
self._state = c.IS_CLOSED_DOOR
|
||||
self.auto_close_interval = auto_close_interval
|
||||
|
Reference in New Issue
Block a user