Adjustments and Documentation, recording and new environments, refactoring
This commit is contained in:
@@ -180,6 +180,7 @@ class BaseFactory(gym.Env):
|
||||
self._entities.add_additional_items({c.DOORS: doors})
|
||||
|
||||
# Actions
|
||||
# TODO: Move this to Agent init, so that agents can have individual action sets.
|
||||
self._actions = Actions(self.mv_prop, can_use_doors=self.parse_doors)
|
||||
if additional_actions := self.actions_hook:
|
||||
self._actions.add_additional_items(additional_actions)
|
||||
@@ -308,7 +309,8 @@ class BaseFactory(gym.Env):
|
||||
info.update(self._summarize_state())
|
||||
|
||||
# Post step Hook for later use
|
||||
info.update(self.post_step_hook())
|
||||
for post_step_info in self.post_step_hook():
|
||||
info.update(post_step_info)
|
||||
|
||||
obs, _ = self._build_observations()
|
||||
|
||||
@@ -367,14 +369,16 @@ class BaseFactory(gym.Env):
|
||||
agent_obs = global_agent_obs.copy()
|
||||
agent_obs[(0, *agent.pos)] -= agent.encoding
|
||||
else:
|
||||
agent_obs = global_agent_obs
|
||||
agent_obs = global_agent_obs.copy()
|
||||
else:
|
||||
# agent_obs == None!!!!!
|
||||
agent_obs = global_agent_obs
|
||||
|
||||
# Build Level Observations
|
||||
if self.obs_prop.render_agents == a_obs.LEVEL:
|
||||
assert agent_obs is not None
|
||||
lvl_obs = lvl_obs.copy()
|
||||
lvl_obs += global_agent_obs
|
||||
lvl_obs += agent_obs
|
||||
|
||||
obs_dict[c.WALLS] = lvl_obs
|
||||
if self.obs_prop.render_agents in [a_obs.SEPERATE, a_obs.COMBINED] and agent_obs is not None:
|
||||
@@ -600,7 +604,9 @@ class BaseFactory(gym.Env):
|
||||
for reward in agent.step_result['rewards']:
|
||||
combined_info_dict.update(reward['info'])
|
||||
|
||||
# Combine Info dicts into a global one
|
||||
combined_info_dict = dict(combined_info_dict)
|
||||
|
||||
combined_info_dict.update(info)
|
||||
|
||||
global_reward_sum = sum(global_env_rewards)
|
||||
@@ -616,9 +622,11 @@ class BaseFactory(gym.Env):
|
||||
|
||||
def start_recording(self):
|
||||
self._record_episodes = True
|
||||
return self._record_episodes
|
||||
|
||||
def stop_recording(self):
|
||||
self._record_episodes = False
|
||||
return not self._record_episodes
|
||||
|
||||
# noinspection PyGlobalUndefined
|
||||
def render(self, mode='human'):
|
||||
@@ -719,12 +727,12 @@ class BaseFactory(gym.Env):
|
||||
return {}
|
||||
|
||||
@abc.abstractmethod
|
||||
def per_agent_reward_hook(self, agent: Agent) -> Dict[str, dict]:
|
||||
return {}
|
||||
def per_agent_reward_hook(self, agent: Agent) -> List[dict]:
|
||||
return []
|
||||
|
||||
@abc.abstractmethod
|
||||
def post_step_hook(self) -> dict:
|
||||
return {}
|
||||
def post_step_hook(self) -> List[dict]:
|
||||
return []
|
||||
|
||||
@abc.abstractmethod
|
||||
def per_agent_raw_observations_hook(self, agent) -> Dict[str, np.typing.ArrayLike]:
|
||||
|
||||
@@ -119,7 +119,6 @@ class Entity(EnvObject):
|
||||
|
||||
def __repr__(self):
|
||||
return super(Entity, self).__repr__() + f'(@{self.pos})'
|
||||
# With Position in Env
|
||||
|
||||
|
||||
# TODO: Missing Documentation
|
||||
|
||||
@@ -117,7 +117,7 @@ class EnvObjectCollection(ObjectCollection):
|
||||
return self._array
|
||||
|
||||
def summarize_states(self, n_steps=None):
|
||||
return [val.summarize_state(n_steps=n_steps) for val in self.values()]
|
||||
return [entity.summarize_state(n_steps=n_steps) for entity in self.values()]
|
||||
|
||||
def notify_change_to_free(self, env_object: EnvObject):
|
||||
self._array_change_notifyer(env_object, value=c.FREE_CELL)
|
||||
|
||||
Reference in New Issue
Block a user