diff --git a/marl_factory_grid/environment/factory.py b/marl_factory_grid/environment/factory.py index 627351a..efea2b6 100644 --- a/marl_factory_grid/environment/factory.py +++ b/marl_factory_grid/environment/factory.py @@ -4,15 +4,17 @@ from collections import defaultdict from itertools import chain from os import PathLike from pathlib import Path -from typing import Union +from typing import Union, List import gymnasium as gym +import numpy as np from marl_factory_grid.utils.level_parser import LevelParser from marl_factory_grid.utils.observation_builder import OBSBuilder from marl_factory_grid.utils.config_parser import FactoryConfigParser from marl_factory_grid.utils import helpers as h import marl_factory_grid.environment.constants as c +from marl_factory_grid.utils.results import Result from marl_factory_grid.utils.states import Gamestate @@ -101,11 +103,52 @@ class Factory(gym.Env): self.obs_builder = OBSBuilder(self.map.level_shape, self.state, self.map.pomdp_r) return self.obs_builder.refresh_and_build_for_all(self.state) + def manual_step_init(self) -> List[Result]: + self.state.curr_step += 1 + + # Main Agent Step + pre_step_result = self.state.rules.tick_pre_step_all(self) + self.obs_builder.reset_struc_obs_block(self.state) + return pre_step_result + + def manual_get_named_agent_obs(self, agent_name: str) -> (List[str], np.ndarray): + agent = self[c.AGENT][agent_name] + assert agent, f'"{agent_name}" could not be found. Check the spelling!' + return self.obs_builder.build_for_agent(agent, self.state) + + def manual_get_agent_obs(self, agent_name: str) -> np.ndarray: + return self.manual_get_named_agent_obs(agent_name)[1] + + def manual_agent_tick(self, agent_name: str, action: int) -> Result: + agent = self[c.AGENT][agent_name].clear_temp_state() + action = agent.actions[action] + action_result = action.do(agent, self) + agent.set_state(action_result) + return action_result + + def manual_finalize_init(self): + results = list() + results.extend(self.state.rules.tick_step_all(self)) + results.extend(self.state.rules.tick_post_step_all(self)) + return results + + # Finalize + def manual_step_finalize(self, tick_result) -> (float, bool, dict): + # Check Done Conditions + done_results = self.state.check_done() + reward, reward_info, done = self.summarize_step_results(tick_result, done_results) + + info = reward_info + info.update(step_reward=sum(reward), step=self.state.curr_step) + return reward, done, info + def step(self, actions): if not isinstance(actions, list): actions = [int(actions)] + # --> Action + # Apply rules, do actions, tick the state, etc... tick_result = self.state.tick(actions) @@ -119,8 +162,7 @@ class Factory(gym.Env): info.update(step_reward=sum(reward), step=self.state.curr_step) - obs, reset_info = self.obs_builder.refresh_and_build_for_all(self.state) - info.update(reset_info) + obs = self.obs_builder.refresh_and_build_for_all(self.state) return None, [x for x in obs.values()], reward, done, info def summarize_step_results(self, tick_results: list, done_check_results: list) -> (int, dict, bool): diff --git a/marl_factory_grid/utils/observation_builder.py b/marl_factory_grid/utils/observation_builder.py index 4a4d4b2..eab9536 100644 --- a/marl_factory_grid/utils/observation_builder.py +++ b/marl_factory_grid/utils/observation_builder.py @@ -23,6 +23,7 @@ class OBSBuilder(object): return 0 def __init__(self, level_shape: np.size, state: Gamestate, pomdp_r: int): + self._curr_env_step = None self.all_obs = dict() self.light_blockers = defaultdict(lambda: False) self.positional = defaultdict(lambda: False) @@ -36,12 +37,16 @@ class OBSBuilder(object): self.obs_layers = dict() - self.build_structured_obs_block(state) + self.reset_struc_obs_block(state) self.curr_lightmaps = dict() - def build_structured_obs_block(self, state): + def reset_struc_obs_block(self, state): + self._curr_env_step = state.curr_step.copy() + # Construct an empty obs (array) for possible placeholders self.all_obs[c.PLACEHOLDER] = np.full(self.obs_shape, 0, dtype=float) + # Fill the all_obs-dict with all available entities self.all_obs.update({key: obj for key, obj in state.entities.obs_pairs}) + return True def observation_space(self, state): from gymnasium.spaces import Tuple, Box @@ -56,12 +61,11 @@ class OBSBuilder(object): return self.refresh_and_build_for_all(state) def refresh_and_build_for_all(self, state) -> (dict, dict): - self.build_structured_obs_block(state) - info = {} - return {agent.name: self.build_for_agent(agent, state)[0] for agent in state[c.AGENT]}, info + self.reset_struc_obs_block(state) + return {agent.name: self.build_for_agent(agent, state)[0] for agent in state[c.AGENT]} def refresh_and_build_named_for_all(self, state) -> Dict[str, Dict[str, np.ndarray]]: - self.build_structured_obs_block(state) + self.reset_struc_obs_block(state) named_obs_dict = {} for agent in state[c.AGENT]: obs, names = self.build_for_agent(agent, state) @@ -69,6 +73,9 @@ class OBSBuilder(object): return named_obs_dict def build_for_agent(self, agent, state) -> (List[str], np.ndarray): + assert self._curr_env_step == state.curr_step, ( + "The observation objekt has not been reset this state! Call 'reset_struc_obs_block(state)'" + ) try: agent_want_obs = self.obs_layers[agent.name] except KeyError: diff --git a/marl_factory_grid/utils/states.py b/marl_factory_grid/utils/states.py index d8ccc10..1b334d4 100644 --- a/marl_factory_grid/utils/states.py +++ b/marl_factory_grid/utils/states.py @@ -88,14 +88,17 @@ class Gamestate(object): # Main Agent Step results.extend(self.rules.tick_pre_step_all(self)) + for idx, action_int in enumerate(actions): agent = self[c.AGENT][idx].clear_temp_state() action = agent.actions[action_int] action_result = action.do(agent, self) results.append(action_result) agent.set_state(action_result) + results.extend(self.rules.tick_step_all(self)) results.extend(self.rules.tick_post_step_all(self)) + return results def print(self, string):