Manual step and ticks for AOMAS

This commit is contained in:
Steffen Illium 2023-10-10 12:07:30 +02:00
parent 04af996232
commit e64fa84ef1
3 changed files with 61 additions and 9 deletions

View File

@ -4,15 +4,17 @@ from collections import defaultdict
from itertools import chain
from os import PathLike
from pathlib import Path
from typing import Union
from typing import Union, List
import gymnasium as gym
import numpy as np
from marl_factory_grid.utils.level_parser import LevelParser
from marl_factory_grid.utils.observation_builder import OBSBuilder
from marl_factory_grid.utils.config_parser import FactoryConfigParser
from marl_factory_grid.utils import helpers as h
import marl_factory_grid.environment.constants as c
from marl_factory_grid.utils.results import Result
from marl_factory_grid.utils.states import Gamestate
@ -101,11 +103,52 @@ class Factory(gym.Env):
self.obs_builder = OBSBuilder(self.map.level_shape, self.state, self.map.pomdp_r)
return self.obs_builder.refresh_and_build_for_all(self.state)
def manual_step_init(self) -> List[Result]:
self.state.curr_step += 1
# Main Agent Step
pre_step_result = self.state.rules.tick_pre_step_all(self)
self.obs_builder.reset_struc_obs_block(self.state)
return pre_step_result
def manual_get_named_agent_obs(self, agent_name: str) -> (List[str], np.ndarray):
agent = self[c.AGENT][agent_name]
assert agent, f'"{agent_name}" could not be found. Check the spelling!'
return self.obs_builder.build_for_agent(agent, self.state)
def manual_get_agent_obs(self, agent_name: str) -> np.ndarray:
return self.manual_get_named_agent_obs(agent_name)[1]
def manual_agent_tick(self, agent_name: str, action: int) -> Result:
agent = self[c.AGENT][agent_name].clear_temp_state()
action = agent.actions[action]
action_result = action.do(agent, self)
agent.set_state(action_result)
return action_result
def manual_finalize_init(self):
results = list()
results.extend(self.state.rules.tick_step_all(self))
results.extend(self.state.rules.tick_post_step_all(self))
return results
# Finalize
def manual_step_finalize(self, tick_result) -> (float, bool, dict):
# Check Done Conditions
done_results = self.state.check_done()
reward, reward_info, done = self.summarize_step_results(tick_result, done_results)
info = reward_info
info.update(step_reward=sum(reward), step=self.state.curr_step)
return reward, done, info
def step(self, actions):
if not isinstance(actions, list):
actions = [int(actions)]
# --> Action
# Apply rules, do actions, tick the state, etc...
tick_result = self.state.tick(actions)
@ -119,8 +162,7 @@ class Factory(gym.Env):
info.update(step_reward=sum(reward), step=self.state.curr_step)
obs, reset_info = self.obs_builder.refresh_and_build_for_all(self.state)
info.update(reset_info)
obs = self.obs_builder.refresh_and_build_for_all(self.state)
return None, [x for x in obs.values()], reward, done, info
def summarize_step_results(self, tick_results: list, done_check_results: list) -> (int, dict, bool):

View File

@ -23,6 +23,7 @@ class OBSBuilder(object):
return 0
def __init__(self, level_shape: np.size, state: Gamestate, pomdp_r: int):
self._curr_env_step = None
self.all_obs = dict()
self.light_blockers = defaultdict(lambda: False)
self.positional = defaultdict(lambda: False)
@ -36,12 +37,16 @@ class OBSBuilder(object):
self.obs_layers = dict()
self.build_structured_obs_block(state)
self.reset_struc_obs_block(state)
self.curr_lightmaps = dict()
def build_structured_obs_block(self, state):
def reset_struc_obs_block(self, state):
self._curr_env_step = state.curr_step.copy()
# Construct an empty obs (array) for possible placeholders
self.all_obs[c.PLACEHOLDER] = np.full(self.obs_shape, 0, dtype=float)
# Fill the all_obs-dict with all available entities
self.all_obs.update({key: obj for key, obj in state.entities.obs_pairs})
return True
def observation_space(self, state):
from gymnasium.spaces import Tuple, Box
@ -56,12 +61,11 @@ class OBSBuilder(object):
return self.refresh_and_build_for_all(state)
def refresh_and_build_for_all(self, state) -> (dict, dict):
self.build_structured_obs_block(state)
info = {}
return {agent.name: self.build_for_agent(agent, state)[0] for agent in state[c.AGENT]}, info
self.reset_struc_obs_block(state)
return {agent.name: self.build_for_agent(agent, state)[0] for agent in state[c.AGENT]}
def refresh_and_build_named_for_all(self, state) -> Dict[str, Dict[str, np.ndarray]]:
self.build_structured_obs_block(state)
self.reset_struc_obs_block(state)
named_obs_dict = {}
for agent in state[c.AGENT]:
obs, names = self.build_for_agent(agent, state)
@ -69,6 +73,9 @@ class OBSBuilder(object):
return named_obs_dict
def build_for_agent(self, agent, state) -> (List[str], np.ndarray):
assert self._curr_env_step == state.curr_step, (
"The observation objekt has not been reset this state! Call 'reset_struc_obs_block(state)'"
)
try:
agent_want_obs = self.obs_layers[agent.name]
except KeyError:

View File

@ -88,14 +88,17 @@ class Gamestate(object):
# Main Agent Step
results.extend(self.rules.tick_pre_step_all(self))
for idx, action_int in enumerate(actions):
agent = self[c.AGENT][idx].clear_temp_state()
action = agent.actions[action_int]
action_result = action.do(agent, self)
results.append(action_result)
agent.set_state(action_result)
results.extend(self.rules.tick_step_all(self))
results.extend(self.rules.tick_post_step_all(self))
return results
def print(self, string):