mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-22 14:56:43 +02:00
299 lines
11 KiB
Python
299 lines
11 KiB
Python
import shutil
|
|
|
|
from collections import defaultdict
|
|
from itertools import chain
|
|
from os import PathLike
|
|
from pathlib import Path
|
|
from typing import Union, List
|
|
|
|
import gymnasium as gym
|
|
import numpy as np
|
|
|
|
from marl_factory_grid.utils.level_parser import LevelParser
|
|
from marl_factory_grid.utils.observation_builder import OBSBuilder
|
|
from marl_factory_grid.utils.config_parser import FactoryConfigParser
|
|
from marl_factory_grid.utils import helpers as h
|
|
import marl_factory_grid.environment.constants as c
|
|
from marl_factory_grid.utils.results import Result
|
|
|
|
from marl_factory_grid.utils.states import Gamestate
|
|
|
|
|
|
class Factory(gym.Env):
|
|
|
|
@property
|
|
def action_space(self):
|
|
"""
|
|
The action space defines the set of all possible actions that an agent can take in the environment.
|
|
|
|
:return: Action space
|
|
:rtype: gym.Space
|
|
"""
|
|
return self.state[c.AGENT].action_space
|
|
|
|
@property
|
|
def named_action_space(self):
|
|
"""
|
|
Returns the named action space for agents.
|
|
|
|
:return: Named action space
|
|
:rtype: dict[str, dict[str, list[int]]]
|
|
"""
|
|
return self.state[c.AGENT].named_action_space
|
|
|
|
@property
|
|
def observation_space(self):
|
|
"""
|
|
The observation space represents all the information that an agent can receive from the environment at a given
|
|
time step.
|
|
|
|
:return: Observation space.
|
|
:rtype: gym.Space
|
|
"""
|
|
return self.obs_builder.observation_space(self.state)
|
|
|
|
@property
|
|
def named_observation_space(self):
|
|
"""
|
|
Returns the named observation space for the environment.
|
|
|
|
:return: Named observation space.
|
|
:rtype: (dict, dict)
|
|
"""
|
|
return self.obs_builder.named_observation_space(self.state)
|
|
|
|
@property
|
|
def params(self) -> dict:
|
|
"""
|
|
FIXME LEGACY
|
|
|
|
|
|
:return:
|
|
"""
|
|
import yaml
|
|
config_path = Path(self._config_file)
|
|
config_dict = yaml.safe_load(config_path.open())
|
|
return config_dict
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
self.close()
|
|
|
|
def __init__(self, config_file: Union[str, PathLike], custom_modules_path: Union[None, PathLike] = None,
|
|
custom_level_path: Union[None, PathLike] = None):
|
|
"""
|
|
Initializes the marl-factory-grid as Gym environment.
|
|
|
|
:param config_file: Path to the configuration file.
|
|
:type config_file: Union[str, PathLike]
|
|
:param custom_modules_path: Path to custom modules directory. (Default: None)
|
|
:type custom_modules_path: Union[None, PathLike]
|
|
:param custom_level_path: Path to custom level file. (Default: None)
|
|
:type custom_level_path: Union[None, PathLike]
|
|
"""
|
|
self._config_file = config_file
|
|
self.conf = FactoryConfigParser(self._config_file, custom_modules_path)
|
|
# Attribute Assignment
|
|
if custom_level_path is not None:
|
|
self.level_filepath = Path(custom_level_path)
|
|
else:
|
|
self.level_filepath = Path(__file__).parent.parent / h.LEVELS_DIR / f'{self.conf.level_name}.txt'
|
|
|
|
parsed_entities = self.conf.load_entities()
|
|
self.map = LevelParser(self.level_filepath, parsed_entities, self.conf.pomdp_r)
|
|
|
|
# Init for later usage:
|
|
# noinspection PyTypeChecker
|
|
self.state: Gamestate = None
|
|
# noinspection PyTypeChecker
|
|
self.obs_builder: OBSBuilder = None
|
|
|
|
# expensive - don't use; unless required !
|
|
self._renderer = None
|
|
|
|
# Init entities
|
|
entities = self.map.do_init()
|
|
|
|
# Init rules
|
|
env_rules = self.conf.load_env_rules()
|
|
entity_rules = self.conf.load_entity_spawn_rules(entities)
|
|
env_rules.extend(entity_rules)
|
|
|
|
# Parse the agent conf
|
|
parsed_agents_conf = self.conf.parse_agents_conf()
|
|
self.state = Gamestate(entities, parsed_agents_conf, env_rules, self.map.level_shape,
|
|
self.conf.env_seed, self.conf.verbose)
|
|
|
|
# All is set up, trigger additional init (after agent entity spawn etc)
|
|
self.state.rules.do_all_init(self.state, self.map)
|
|
|
|
self.obs_builder = OBSBuilder(self.map.level_shape, self.state, self.map.pomdp_r)
|
|
|
|
def __getitem__(self, item):
|
|
return self.state.entities[item]
|
|
|
|
def reset(self) -> (dict, dict):
|
|
|
|
# Reset information the state holds
|
|
self.state.reset()
|
|
|
|
# Reset Information the GlobalEntity collection holds.
|
|
self.state.entities.reset()
|
|
|
|
# All is set up, trigger entity spawn with variable pos
|
|
self.state.rules.do_all_reset(self.state)
|
|
self.state.rules.do_all_post_spawn_reset(self.state)
|
|
|
|
# Build initial observations for all agents
|
|
self.obs_builder.reset(self.state)
|
|
return self.obs_builder.build_for_all(self.state)
|
|
|
|
def manual_step_init(self) -> List[Result]:
|
|
self.state.curr_step += 1
|
|
|
|
# Main Agent Step
|
|
pre_step_result = self.state.rules.tick_pre_step_all(self)
|
|
self.obs_builder.reset(self.state)
|
|
return pre_step_result
|
|
|
|
def manual_get_named_agent_obs(self, agent_name: str) -> (List[str], np.ndarray):
|
|
agent = self[c.AGENT][agent_name]
|
|
assert agent, f'"{agent_name}" could not be found. Check the spelling!'
|
|
return self.obs_builder.build_for_agent(agent, self.state)
|
|
|
|
def manual_get_agent_obs(self, agent_name: str) -> np.ndarray:
|
|
return self.manual_get_named_agent_obs(agent_name)[1]
|
|
|
|
def manual_agent_tick(self, agent_name: str, action: int) -> Result:
|
|
agent = self[c.AGENT][agent_name].clear_temp_state()
|
|
action = agent.actions[action]
|
|
action_result = action.do(agent, self)
|
|
agent.set_state(action_result)
|
|
return action_result
|
|
|
|
def manual_finalize_init(self):
|
|
results = list()
|
|
results.extend(self.state.rules.tick_step_all(self))
|
|
results.extend(self.state.rules.tick_post_step_all(self))
|
|
return results
|
|
|
|
# Finalize
|
|
def manual_step_finalize(self, tick_result) -> (float, bool, dict):
|
|
# Check Done Conditions
|
|
done_results = self.state.check_done()
|
|
reward, reward_info, done = self.summarize_step_results(tick_result, done_results)
|
|
|
|
info = reward_info
|
|
info.update(step_reward=sum(reward), step=self.state.curr_step)
|
|
return reward, done, info
|
|
|
|
def step(self, actions):
|
|
"""
|
|
Run one timestep of the environment's dynamics using the agent actions.
|
|
|
|
When the end of an episode is reached (``terminated or truncated``), it is necessary to call :meth:`reset` to
|
|
reset this environment's state for the next episode.
|
|
|
|
:param actions: An action or list of actions provided by the agent(s) to update the environment state.
|
|
:return: observation, reward, terminated, truncated, info, done
|
|
:rtype: tuple(list(np.ndarray), float, bool, bool, dict, bool)
|
|
"""
|
|
|
|
if not isinstance(actions, list):
|
|
actions = [int(actions)]
|
|
|
|
# --> Action
|
|
|
|
# Apply rules, do actions, tick the state, etc...
|
|
tick_result = self.state.tick(actions)
|
|
|
|
# Check Done Conditions
|
|
done_results = self.state.check_done()
|
|
|
|
# Finalize
|
|
reward, reward_info, done = self.summarize_step_results(tick_result, done_results)
|
|
|
|
info = dict(reward_info)
|
|
|
|
info.update(step_reward=sum(reward), step=self.state.curr_step)
|
|
|
|
obs = self.obs_builder.build_for_all(self.state)
|
|
return None, [x for x in obs.values()], reward, done, info
|
|
|
|
def summarize_step_results(self, tick_results: list, done_check_results: list) -> (int, dict, bool):
|
|
# Returns: Reward, Info
|
|
rewards = defaultdict(lambda: 0.0)
|
|
|
|
# Gather per agent environment rewards and
|
|
# Combine Info dicts into a global one
|
|
combined_info_dict = defaultdict(lambda: 0.0)
|
|
for result in chain(tick_results, done_check_results):
|
|
assert result, 'Something returned None...'
|
|
if result.reward is not None:
|
|
try:
|
|
rewards[result.entity.name] += result.reward
|
|
except AttributeError:
|
|
rewards['global'] += result.reward
|
|
infos = result.get_infos()
|
|
for info in infos:
|
|
assert isinstance(info.value, (float, int))
|
|
combined_info_dict[info.identifier] += info.value
|
|
|
|
# Check Done Rule Results
|
|
try:
|
|
done_reason = next(x for x in done_check_results if x.validity)
|
|
done = True
|
|
self.state.print(f'Env done, Reason: {done_reason.identifier}.')
|
|
except StopIteration:
|
|
done = False
|
|
|
|
if self.conf.individual_rewards:
|
|
global_rewards = rewards['global']
|
|
del rewards['global']
|
|
reward = [rewards[agent.name] for agent in self.state[c.AGENT]]
|
|
reward = [x + global_rewards for x in reward]
|
|
self.state.print(f"Individual rewards are {dict(rewards)}")
|
|
return reward, combined_info_dict, done
|
|
else:
|
|
reward = sum(rewards.values())
|
|
self.state.print(f"reward is {reward}")
|
|
return reward, combined_info_dict, done
|
|
|
|
# noinspection PyGlobalUndefined
|
|
def render(self, mode='human'):
|
|
if not self._renderer: # lazy init
|
|
from marl_factory_grid.utils.renderer import Renderer
|
|
global Renderer
|
|
self._renderer = Renderer(self.map.level_shape, view_radius=self.conf.pomdp_r, fps=10)
|
|
|
|
render_entities = self.state.entities.render()
|
|
if self.conf.pomdp_r:
|
|
for render_entity in render_entities:
|
|
if render_entity.name == c.AGENT:
|
|
render_entity.aux = self.obs_builder.curr_lightmaps[render_entity.real_name]
|
|
return self._renderer.render(render_entities)
|
|
|
|
def summarize_header(self):
|
|
header = {'rec_step': self.state.curr_step}
|
|
for entity_group in (x for x in self.state if x.name in ['Walls', 'DropOffLocations', 'ChargePods']):
|
|
header.update({f'rec{entity_group.name}': entity_group.summarize_states()})
|
|
return header
|
|
|
|
def summarize_state(self):
|
|
summary = {'step': self.state.curr_step}
|
|
|
|
# Todo: Protobuff Compatibility Section #######
|
|
# for entity_group in (x for x in self.state if x.name not in [c.WALLS, c.FLOORS]):
|
|
for entity_group in self.state:
|
|
summary.update({entity_group.name.lower(): entity_group.summarize_states()})
|
|
# TODO Section End ########
|
|
for key in list(summary.keys()):
|
|
if key not in ['step', 'walls', 'doors', 'agents', 'items', 'dirtPiles', 'batteries']:
|
|
del summary[key]
|
|
return summary
|
|
|
|
def save_params(self, filepath: Path):
|
|
# noinspection PyProtectedMember
|
|
filepath = Path(filepath)
|
|
filepath.parent.mkdir(parents=True, exist_ok=True)
|
|
shutil.copyfile(self._config_file, filepath)
|