Adjustments and Documentation, recording and new environments, refactoring

This commit is contained in:
Steffen Illium
2022-08-04 14:57:48 +02:00
parent e7461d7dcf
commit 6a24e7b518
41 changed files with 1660 additions and 760 deletions

View File

@@ -180,6 +180,7 @@ class BaseFactory(gym.Env):
self._entities.add_additional_items({c.DOORS: doors})
# Actions
# TODO: Move this to Agent init, so that agents can have individual action sets.
self._actions = Actions(self.mv_prop, can_use_doors=self.parse_doors)
if additional_actions := self.actions_hook:
self._actions.add_additional_items(additional_actions)
@@ -308,7 +309,8 @@ class BaseFactory(gym.Env):
info.update(self._summarize_state())
# Post step Hook for later use
info.update(self.post_step_hook())
for post_step_info in self.post_step_hook():
info.update(post_step_info)
obs, _ = self._build_observations()
@@ -367,14 +369,16 @@ class BaseFactory(gym.Env):
agent_obs = global_agent_obs.copy()
agent_obs[(0, *agent.pos)] -= agent.encoding
else:
agent_obs = global_agent_obs
agent_obs = global_agent_obs.copy()
else:
# agent_obs == None!!!!!
agent_obs = global_agent_obs
# Build Level Observations
if self.obs_prop.render_agents == a_obs.LEVEL:
assert agent_obs is not None
lvl_obs = lvl_obs.copy()
lvl_obs += global_agent_obs
lvl_obs += agent_obs
obs_dict[c.WALLS] = lvl_obs
if self.obs_prop.render_agents in [a_obs.SEPERATE, a_obs.COMBINED] and agent_obs is not None:
@@ -600,7 +604,9 @@ class BaseFactory(gym.Env):
for reward in agent.step_result['rewards']:
combined_info_dict.update(reward['info'])
# Combine Info dicts into a global one
combined_info_dict = dict(combined_info_dict)
combined_info_dict.update(info)
global_reward_sum = sum(global_env_rewards)
@@ -616,9 +622,11 @@ class BaseFactory(gym.Env):
def start_recording(self):
self._record_episodes = True
return self._record_episodes
def stop_recording(self):
self._record_episodes = False
return not self._record_episodes
# noinspection PyGlobalUndefined
def render(self, mode='human'):
@@ -719,12 +727,12 @@ class BaseFactory(gym.Env):
return {}
@abc.abstractmethod
def per_agent_reward_hook(self, agent: Agent) -> Dict[str, dict]:
return {}
def per_agent_reward_hook(self, agent: Agent) -> List[dict]:
return []
@abc.abstractmethod
def post_step_hook(self) -> dict:
return {}
def post_step_hook(self) -> List[dict]:
return []
@abc.abstractmethod
def per_agent_raw_observations_hook(self, agent) -> Dict[str, np.typing.ArrayLike]:

View File

@@ -119,7 +119,6 @@ class Entity(EnvObject):
def __repr__(self):
return super(Entity, self).__repr__() + f'(@{self.pos})'
# With Position in Env
# TODO: Missing Documentation

View File

@@ -117,7 +117,7 @@ class EnvObjectCollection(ObjectCollection):
return self._array
def summarize_states(self, n_steps=None):
return [val.summarize_state(n_steps=n_steps) for val in self.values()]
return [entity.summarize_state(n_steps=n_steps) for entity in self.values()]
def notify_change_to_free(self, env_object: EnvObject):
self._array_change_notifyer(env_object, value=c.FREE_CELL)