Destinations implemented and debugged
This commit is contained in:
@ -64,7 +64,7 @@ class BaseFactory(gym.Env):
|
||||
def __init__(self, level_name='simple', n_agents=1, max_steps=int(5e2),
|
||||
mv_prop: MovementProperties = MovementProperties(),
|
||||
obs_prop: ObservationProperties = ObservationProperties(),
|
||||
parse_doors=False, done_at_collision=False,
|
||||
parse_doors=False, done_at_collision=False, inject_agents: Union[None, List] = None,
|
||||
verbose=False, doors_have_area=True, env_seed=time.time_ns(), individual_rewards=False,
|
||||
**kwargs):
|
||||
|
||||
@ -98,6 +98,7 @@ class BaseFactory(gym.Env):
|
||||
self.done_at_collision = done_at_collision
|
||||
self._record_episodes = False
|
||||
self.parse_doors = parse_doors
|
||||
self._injected_agents = inject_agents or []
|
||||
self.doors_have_area = doors_have_area
|
||||
self.individual_rewards = individual_rewards
|
||||
|
||||
@ -108,8 +109,10 @@ class BaseFactory(gym.Env):
|
||||
return self._entities[item]
|
||||
|
||||
def _base_init_env(self):
|
||||
|
||||
# All entities
|
||||
# Objects
|
||||
entities = {}
|
||||
self._entities = Entities()
|
||||
# Level
|
||||
level_filepath = Path(__file__).parent.parent / h.LEVELS_DIR / f'{self.level_name}.txt'
|
||||
parsed_level = h.parse_level(level_filepath)
|
||||
@ -121,14 +124,14 @@ class BaseFactory(gym.Env):
|
||||
np.argwhere(level_array == c.OCCUPIED_CELL.value),
|
||||
self._level_shape
|
||||
)
|
||||
entities.update({c.WALLS: walls})
|
||||
self._entities.register_additional_items({c.WALLS: walls})
|
||||
|
||||
# Floor
|
||||
floor = FloorTiles.from_argwhere_coordinates(
|
||||
np.argwhere(level_array == c.FREE_CELL.value),
|
||||
self._level_shape
|
||||
)
|
||||
entities.update({c.FLOOR: floor})
|
||||
self._entities.register_additional_items({c.FLOOR: floor})
|
||||
|
||||
# NOPOS
|
||||
self._NO_POS_TILE = Tile(c.NO_POS.value)
|
||||
@ -141,7 +144,7 @@ class BaseFactory(gym.Env):
|
||||
doors = Doors.from_tiles(door_tiles, self._level_shape,
|
||||
entity_kwargs=dict(context=floor)
|
||||
)
|
||||
entities.update({c.DOORS: doors})
|
||||
self._entities.register_additional_items({c.DOORS: doors})
|
||||
|
||||
# Actions
|
||||
self._actions = Actions(self.mv_prop, can_use_doors=self.parse_doors)
|
||||
@ -149,12 +152,22 @@ class BaseFactory(gym.Env):
|
||||
self._actions.register_additional_items(additional_actions)
|
||||
|
||||
# Agents
|
||||
agents = Agents.from_tiles(floor.empty_tiles[:self.n_agents], self._level_shape,
|
||||
individual_slices=self.obs_prop.render_agents == a_obs.SEPERATE,
|
||||
hide_from_obs_builder=self.obs_prop.render_agents == a_obs.LEVEL,
|
||||
is_observable=self.obs_prop.render_agents != a_obs.NOT
|
||||
)
|
||||
entities.update({c.AGENT: agents})
|
||||
agents_to_spawn = self.n_agents-len(self._injected_agents)
|
||||
agents_kwargs = dict(level_shape=self._level_shape,
|
||||
individual_slices=self.obs_prop.render_agents == a_obs.SEPERATE,
|
||||
hide_from_obs_builder=self.obs_prop.render_agents == a_obs.LEVEL,
|
||||
is_observable=self.obs_prop.render_agents != a_obs.NOT)
|
||||
if agents_to_spawn:
|
||||
agents = Agents.from_tiles(floor.empty_tiles[:agents_to_spawn], **agents_kwargs)
|
||||
else:
|
||||
agents = Agents(**agents_kwargs)
|
||||
if self._injected_agents:
|
||||
initialized_injections = list()
|
||||
for i, injection in enumerate(self._injected_agents):
|
||||
agents.register_item(injection(self, floor.empty_tiles[agents_to_spawn+i+1], static_problem=False))
|
||||
initialized_injections.append(agents[-1])
|
||||
self._initialized_injections = initialized_injections
|
||||
self._entities.register_additional_items({c.AGENT: agents})
|
||||
|
||||
if self.obs_prop.additional_agent_placeholder is not None:
|
||||
# TODO: Make this accept Lists for multiple placeholders
|
||||
@ -165,11 +178,7 @@ class BaseFactory(gym.Env):
|
||||
fill_value=self.obs_prop.additional_agent_placeholder)
|
||||
)
|
||||
|
||||
entities.update({c.AGENT_PLACEHOLDER: placeholder})
|
||||
|
||||
# All entities
|
||||
self._entities = Entities()
|
||||
self._entities.register_additional_items(entities)
|
||||
self._entities.register_additional_items({c.AGENT_PLACEHOLDER: placeholder})
|
||||
|
||||
# Additional Entitites from SubEnvs
|
||||
if additional_entities := self.additional_entities:
|
||||
@ -182,6 +191,7 @@ class BaseFactory(gym.Env):
|
||||
arrays = self._entities.obs_arrays
|
||||
|
||||
obs_cube_z = sum([a.shape[0] if not self[key].is_per_agent else 1 for key, a in arrays.items()])
|
||||
obs_cube_z += 1 if self.obs_prop.show_global_position_info else 0
|
||||
self._obs_cube = np.zeros((obs_cube_z, *self._level_shape), dtype=np.float32)
|
||||
|
||||
def reset(self) -> (np.ndarray, int, bool, dict):
|
||||
@ -279,7 +289,7 @@ class BaseFactory(gym.Env):
|
||||
if self.n_agents == 1:
|
||||
obs = self._build_per_agent_obs(self[c.AGENT][0], state_array_dict)
|
||||
elif self.n_agents >= 2:
|
||||
obs = np.stack(self._build_per_agent_obs(agent, state_array_dict) for agent in self[c.AGENT])
|
||||
obs = np.stack([self._build_per_agent_obs(agent, state_array_dict) for agent in self[c.AGENT]])
|
||||
else:
|
||||
raise ValueError('n_agents cannot be smaller than 1!!')
|
||||
return obs
|
||||
@ -384,6 +394,7 @@ class BaseFactory(gym.Env):
|
||||
|
||||
if self.obs_prop.pomdp_r:
|
||||
oobs = self._do_pomdp_obs_cutout(agent, other_agent_obs)[0]
|
||||
# noinspection PyUnresolvedReferences
|
||||
mask = (oobs != c.SHADOWED_CELL.value).astype(int)
|
||||
obs[0] += oobs * mask
|
||||
|
||||
@ -497,7 +508,7 @@ class BaseFactory(gym.Env):
|
||||
if self._actions.is_moving_action(agent.temp_action):
|
||||
if agent.temp_valid:
|
||||
# info_dict.update(movement=1)
|
||||
per_agent_reward -= 0.01
|
||||
per_agent_reward -= 0.001
|
||||
pass
|
||||
else:
|
||||
per_agent_reward -= 0.05
|
||||
@ -553,6 +564,7 @@ class BaseFactory(gym.Env):
|
||||
self.print(f"reward is {reward}")
|
||||
return reward, combined_info_dict
|
||||
|
||||
# noinspection PyGlobalUndefined
|
||||
def render(self, mode='human'):
|
||||
if not self._renderer: # lazy init
|
||||
from environments.factory.base.renderer import Renderer, RenderEntity
|
||||
@ -560,6 +572,7 @@ class BaseFactory(gym.Env):
|
||||
height, width = self._obs_cube.shape[1:]
|
||||
self._renderer = Renderer(width, height, view_radius=self._pomdp_r, fps=5)
|
||||
|
||||
# noinspection PyUnboundLocalVariable
|
||||
walls = [RenderEntity('wall', wall.pos) for wall in self[c.WALLS]]
|
||||
|
||||
agents = []
|
||||
@ -582,6 +595,12 @@ class BaseFactory(gym.Env):
|
||||
with filepath.open('w') as f:
|
||||
simplejson.dump(d, f, indent=4, namedtuple_as_object=True)
|
||||
|
||||
def get_injected_agents(self) -> list:
|
||||
if hasattr(self, '_initialized_injections'):
|
||||
return self._initialized_injections
|
||||
else:
|
||||
return []
|
||||
|
||||
def _summarize_state(self):
|
||||
summary = {f'{REC_TAC}step': self._steps}
|
||||
|
||||
@ -621,9 +640,15 @@ class BaseFactory(gym.Env):
|
||||
def additional_obs_build(self) -> List[np.ndarray]:
|
||||
return []
|
||||
|
||||
@abc.abstractmethod
|
||||
def additional_per_agent_obs_build(self, agent) -> List[np.ndarray]:
|
||||
return []
|
||||
additional_per_agent_obs = []
|
||||
if self.obs_prop.show_global_position_info:
|
||||
pos_array = np.zeros(self.observation_space.shape[1:])
|
||||
for xy in range(1):
|
||||
pos_array[0, xy] = agent.pos[xy] / self._level_shape[xy]
|
||||
additional_per_agent_obs.append(pos_array)
|
||||
|
||||
return additional_per_agent_obs
|
||||
|
||||
@abc.abstractmethod
|
||||
def do_additional_reset(self) -> None:
|
||||
|
@ -50,6 +50,8 @@ class Register:
|
||||
|
||||
def __getitem__(self, item):
|
||||
if isinstance(item, (int, np.int64, np.int32)):
|
||||
if item < 0:
|
||||
item = len(self._register) - abs(item)
|
||||
try:
|
||||
return next(v for i, v in enumerate(self._register.values()) if i == item)
|
||||
except StopIteration:
|
||||
@ -147,10 +149,10 @@ class MovingEntityObjectRegister(EntityObjectRegister, ABC):
|
||||
if self.individual_slices:
|
||||
self._array = np.delete(self._array, idx, axis=0)
|
||||
|
||||
def delete_item(self, item):
|
||||
self.delete_item_by_name(item.name)
|
||||
def delete_entity(self, item):
|
||||
self.delete_entity_by_name(item.name)
|
||||
|
||||
def delete_item_by_name(self, name):
|
||||
def delete_entity_by_name(self, name):
|
||||
del self[name]
|
||||
|
||||
|
||||
@ -320,8 +322,11 @@ class Agents(MovingEntityObjectRegister):
|
||||
def positions(self):
|
||||
return [agent.pos for agent in self]
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self._register[self[key].name] = value
|
||||
def replace_agent(self, key, agent):
|
||||
old_agent = self[key]
|
||||
self[key].tile.leave(self[key])
|
||||
agent._name = old_agent.name
|
||||
self._register[agent.name] = agent
|
||||
|
||||
|
||||
class Doors(EntityObjectRegister):
|
||||
|
Reference in New Issue
Block a user