299 lines
11 KiB
Python

import shutil
from collections import defaultdict
from itertools import chain
from os import PathLike
from pathlib import Path
from typing import Union, List
import gymnasium as gym
import numpy as np
from marl_factory_grid.utils.level_parser import LevelParser
from marl_factory_grid.utils.observation_builder import OBSBuilder
from marl_factory_grid.utils.config_parser import FactoryConfigParser
from marl_factory_grid.utils import helpers as h
import marl_factory_grid.environment.constants as c
from marl_factory_grid.utils.results import Result
from marl_factory_grid.utils.states import Gamestate
class Factory(gym.Env):
@property
def action_space(self):
"""
The action space defines the set of all possible actions that an agent can take in the environment.
:return: Action space
:rtype: gym.Space
"""
return self.state[c.AGENT].action_space
@property
def named_action_space(self):
"""
Returns the named action space for agents.
:return: Named action space
:rtype: dict[str, dict[str, list[int]]]
"""
return self.state[c.AGENT].named_action_space
@property
def observation_space(self):
"""
The observation space represents all the information that an agent can receive from the environment at a given
time step.
:return: Observation space.
:rtype: gym.Space
"""
return self.obs_builder.observation_space(self.state)
@property
def named_observation_space(self):
"""
Returns the named observation space for the environment.
:return: Named observation space.
:rtype: (dict, dict)
"""
return self.obs_builder.named_observation_space(self.state)
@property
def params(self) -> dict:
"""
FIXME LEGACY
:return:
"""
import yaml
config_path = Path(self._config_file)
config_dict = yaml.safe_load(config_path.open())
return config_dict
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
def __init__(self, config_file: Union[str, PathLike], custom_modules_path: Union[None, PathLike] = None,
custom_level_path: Union[None, PathLike] = None):
"""
Initializes the marl-factory-grid as Gym environment.
:param config_file: Path to the configuration file.
:type config_file: Union[str, PathLike]
:param custom_modules_path: Path to custom modules directory. (Default: None)
:type custom_modules_path: Union[None, PathLike]
:param custom_level_path: Path to custom level file. (Default: None)
:type custom_level_path: Union[None, PathLike]
"""
self._config_file = config_file
self.conf = FactoryConfigParser(self._config_file, custom_modules_path)
# Attribute Assignment
if custom_level_path is not None:
self.level_filepath = Path(custom_level_path)
else:
self.level_filepath = Path(__file__).parent.parent / h.LEVELS_DIR / f'{self.conf.level_name}.txt'
parsed_entities = self.conf.load_entities()
self.map = LevelParser(self.level_filepath, parsed_entities, self.conf.pomdp_r)
# Init for later usage:
# noinspection PyTypeChecker
self.state: Gamestate = None
# noinspection PyTypeChecker
self.obs_builder: OBSBuilder = None
# expensive - don't use; unless required !
self._renderer = None
# Init entities
entities = self.map.do_init()
# Init rules
env_rules = self.conf.load_env_rules()
entity_rules = self.conf.load_entity_spawn_rules(entities)
env_rules.extend(entity_rules)
# Parse the agent conf
parsed_agents_conf = self.conf.parse_agents_conf()
self.state = Gamestate(entities, parsed_agents_conf, env_rules, self.map.level_shape,
self.conf.env_seed, self.conf.verbose)
# All is set up, trigger additional init (after agent entity spawn etc)
self.state.rules.do_all_init(self.state, self.map)
self.obs_builder = OBSBuilder(self.map.level_shape, self.state, self.map.pomdp_r)
def __getitem__(self, item):
return self.state.entities[item]
def reset(self) -> (dict, dict):
# Reset information the state holds
self.state.reset()
# Reset Information the GlobalEntity collection holds.
self.state.entities.reset()
# All is set up, trigger entity spawn with variable pos
self.state.rules.do_all_reset(self.state)
self.state.rules.do_all_post_spawn_reset(self.state)
# Build initial observations for all agents
self.obs_builder.reset(self.state)
return self.obs_builder.build_for_all(self.state)
def manual_step_init(self) -> List[Result]:
self.state.curr_step += 1
# Main Agent Step
pre_step_result = self.state.rules.tick_pre_step_all(self)
self.obs_builder.reset(self.state)
return pre_step_result
def manual_get_named_agent_obs(self, agent_name: str) -> (List[str], np.ndarray):
agent = self[c.AGENT][agent_name]
assert agent, f'"{agent_name}" could not be found. Check the spelling!'
return self.obs_builder.build_for_agent(agent, self.state)
def manual_get_agent_obs(self, agent_name: str) -> np.ndarray:
return self.manual_get_named_agent_obs(agent_name)[1]
def manual_agent_tick(self, agent_name: str, action: int) -> Result:
agent = self[c.AGENT][agent_name].clear_temp_state()
action = agent.actions[action]
action_result = action.do(agent, self)
agent.set_state(action_result)
return action_result
def manual_finalize_init(self):
results = list()
results.extend(self.state.rules.tick_step_all(self))
results.extend(self.state.rules.tick_post_step_all(self))
return results
# Finalize
def manual_step_finalize(self, tick_result) -> (float, bool, dict):
# Check Done Conditions
done_results = self.state.check_done()
reward, reward_info, done = self.summarize_step_results(tick_result, done_results)
info = reward_info
info.update(step_reward=sum(reward), step=self.state.curr_step)
return reward, done, info
def step(self, actions):
"""
Run one timestep of the environment's dynamics using the agent actions.
When the end of an episode is reached (``terminated or truncated``), it is necessary to call :meth:`reset` to
reset this environment's state for the next episode.
:param actions: An action or list of actions provided by the agent(s) to update the environment state.
:return: observation, reward, terminated, truncated, info, done
:rtype: tuple(list(np.ndarray), float, bool, bool, dict, bool)
"""
if not isinstance(actions, list):
actions = [int(actions)]
# --> Action
# Apply rules, do actions, tick the state, etc...
tick_result = self.state.tick(actions)
# Check Done Conditions
done_results = self.state.check_done()
# Finalize
reward, reward_info, done = self.summarize_step_results(tick_result, done_results)
info = dict(reward_info)
info.update(step_reward=sum(reward), step=self.state.curr_step)
obs = self.obs_builder.build_for_all(self.state)
return None, [x for x in obs.values()], reward, done, info
def summarize_step_results(self, tick_results: list, done_check_results: list) -> (int, dict, bool):
# Returns: Reward, Info
rewards = defaultdict(lambda: 0.0)
# Gather per agent environment rewards and
# Combine Info dicts into a global one
combined_info_dict = defaultdict(lambda: 0.0)
for result in chain(tick_results, done_check_results):
assert result, 'Something returned None...'
if result.reward is not None:
try:
rewards[result.entity.name] += result.reward
except AttributeError:
rewards['global'] += result.reward
infos = result.get_infos()
for info in infos:
assert isinstance(info.value, (float, int))
combined_info_dict[info.identifier] += info.value
# Check Done Rule Results
try:
done_reason = next(x for x in done_check_results if x.validity)
done = True
self.state.print(f'Env done, Reason: {done_reason.identifier}.')
except StopIteration:
done = False
if self.conf.individual_rewards:
global_rewards = rewards['global']
del rewards['global']
reward = [rewards[agent.name] for agent in self.state[c.AGENT]]
reward = [x + global_rewards for x in reward]
self.state.print(f"Individual rewards are {dict(rewards)}")
return reward, combined_info_dict, done
else:
reward = sum(rewards.values())
self.state.print(f"reward is {reward}")
return reward, combined_info_dict, done
# noinspection PyGlobalUndefined
def render(self, mode='human'):
if not self._renderer: # lazy init
from marl_factory_grid.utils.renderer import Renderer
global Renderer
self._renderer = Renderer(self.map.level_shape, view_radius=self.conf.pomdp_r, fps=10)
render_entities = self.state.entities.render()
if self.conf.pomdp_r:
for render_entity in render_entities:
if render_entity.name == c.AGENT:
render_entity.aux = self.obs_builder.curr_lightmaps[render_entity.real_name]
return self._renderer.render(render_entities)
def summarize_header(self):
header = {'rec_step': self.state.curr_step}
for entity_group in (x for x in self.state if x.name in ['Walls', 'DropOffLocations', 'ChargePods']):
header.update({f'rec{entity_group.name}': entity_group.summarize_states()})
return header
def summarize_state(self):
summary = {'step': self.state.curr_step}
# Todo: Protobuff Compatibility Section #######
# for entity_group in (x for x in self.state if x.name not in [c.WALLS, c.FLOORS]):
for entity_group in self.state:
summary.update({entity_group.name.lower(): entity_group.summarize_states()})
# TODO Section End ########
for key in list(summary.keys()):
if key not in ['step', 'walls', 'doors', 'agents', 'items', 'dirtPiles', 'batteries']:
del summary[key]
return summary
def save_params(self, filepath: Path):
# noinspection PyProtectedMember
filepath = Path(filepath)
filepath.parent.mkdir(parents=True, exist_ok=True)
shutil.copyfile(self._config_file, filepath)