2023-10-16 11:06:51 +02:00

206 lines
7.7 KiB
Python

import shutil
from collections import defaultdict
from itertools import chain
from os import PathLike
from pathlib import Path
from typing import Union
import gymnasium as gym
from marl_factory_grid.utils.level_parser import LevelParser
from marl_factory_grid.utils.observation_builder import OBSBuilder
from marl_factory_grid.utils.config_parser import FactoryConfigParser
from marl_factory_grid.utils import helpers as h
import marl_factory_grid.environment.constants as c
from marl_factory_grid.utils.states import Gamestate
class Factory(gym.Env):
@property
def action_space(self):
return self.state[c.AGENT].action_space
@property
def named_action_space(self):
return self.state[c.AGENT].named_action_space
@property
def observation_space(self):
return self.obs_builder.observation_space(self.state)
@property
def named_observation_space(self):
return self.obs_builder.named_observation_space(self.state)
@property
def params(self) -> dict:
import yaml
config_path = Path(self._config_file)
config_dict = yaml.safe_load(config_path.open())
return config_dict
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
def __init__(self, config_file: Union[str, PathLike], custom_modules_path: Union[None, PathLike] = None,
custom_level_path: Union[None, PathLike] = None):
self._config_file = config_file
self.conf = FactoryConfigParser(self._config_file, custom_modules_path)
# Attribute Assignment
if custom_level_path is not None:
self.level_filepath = Path(custom_level_path)
else:
self.level_filepath = Path(__file__).parent.parent / h.LEVELS_DIR / f'{self.conf.level_name}.txt'
self._renderer = None # expensive - don't use; unless required !
parsed_entities = self.conf.load_entities()
self.map = LevelParser(self.level_filepath, parsed_entities, self.conf.pomdp_r)
# Init for later usage:
self.state: Gamestate
self.map: LevelParser
self.obs_builder: OBSBuilder
# reset env to initial state, preparing env for new episode.
# returns tuple where the first dict contains initial observation for each agent in the env
self.reset()
def __getitem__(self, item):
return self.state.entities[item]
def reset(self) -> (dict, dict):
if hasattr(self, 'state'):
for entity_group in self.state.entities:
try:
entity_group[0].reset_uid()
except (AttributeError, TypeError):
pass
self.state = None
# Init entities
entities = self.map.do_init()
# Init rules
rules = self.conf.load_rules()
# Init agents
# noinspection PyAttributeOutsideInit
self.state = Gamestate(entities, rules, self.conf.env_seed) # get_all_tiles_with_collisions
agents = self.conf.load_agents(self.map.size, self.state.entities.floorlist)
self.state.entities.add_item({c.AGENT: agents})
# All is set up, trigger additional init (after agent entity spawn etc)
self.state.rules.do_all_init(self.state, self.map)
# Build initial observations for all agents
# noinspection PyAttributeOutsideInit
self.obs_builder = OBSBuilder(self.map.level_shape, self.state, self.map.pomdp_r)
return self.obs_builder.refresh_and_build_for_all(self.state)
def step(self, actions):
if not isinstance(actions, list):
actions = [int(actions)]
# Apply rules, do actions, tick the state, etc...
tick_result = self.state.tick(actions)
# Check Done Conditions
done_results = self.state.check_done()
# Finalize
reward, reward_info, done = self.summarize_step_results(tick_result, done_results)
info = reward_info
info.update(step_reward=sum(reward), step=self.state.curr_step)
obs, reset_info = self.obs_builder.refresh_and_build_for_all(self.state)
info.update(reset_info)
return None, [x for x in obs.values()], reward, done, info
def summarize_step_results(self, tick_results: list, done_check_results: list) -> (int, dict, bool):
# Returns: Reward, Info
rewards = defaultdict(lambda: 0.0)
# Gather per agent environment rewards and
# Combine Info dicts into a global one
combined_info_dict = defaultdict(lambda: 0.0)
for result in chain(tick_results, done_check_results):
if result.reward is not None:
try:
rewards[result.entity.name] += result.reward
except AttributeError:
rewards['global'] += result.reward
infos = result.get_infos()
for info in infos:
assert isinstance(info.value, (float, int))
combined_info_dict[info.identifier] += info.value
# Check Done Rule Results
try:
done_reason = next(x for x in done_check_results if x.validity)
done = True
self.state.print(f'Env done, Reason: {done_reason.identifier}.')
except StopIteration:
done = False
if self.conf.individual_rewards:
global_rewards = rewards['global']
del rewards['global']
reward = [rewards[agent.name] for agent in self.state[c.AGENT]]
reward = [x + global_rewards for x in reward]
self.state.print(f"rewards are {rewards}")
return reward, combined_info_dict, done
else:
reward = sum(rewards.values())
self.state.print(f"reward is {reward}")
return reward, combined_info_dict, done
# noinspection PyGlobalUndefined
def render(self, mode='human'):
if not self._renderer: # lazy init
from marl_factory_grid.utils.renderer import Renderer
global Renderer
self._renderer = Renderer(self.map.level_shape, view_radius=self.conf.pomdp_r, fps=10)
render_entities = self.state.entities.render()
if self.conf.pomdp_r:
for render_entity in render_entities:
if render_entity.name == c.AGENT:
render_entity.aux = self.obs_builder.curr_lightmaps[render_entity.real_name]
return self._renderer.render(render_entities)
def summarize_header(self):
header = {'rec_step': self.state.curr_step}
for entity_group in (x for x in self.state if x.name in ['Walls', 'Floors', 'DropOffLocations', 'ChargePods']):
header.update({f'rec{entity_group.name}': entity_group.summarize_states()})
return header
def summarize_state(self):
summary = {'step': self.state.curr_step}
# Todo: Protobuff Compatibility Section #######
# for entity_group in (x for x in self.state if x.name not in [c.WALLS, c.FLOORS]):
for entity_group in (x for x in self.state if x.name not in [c.FLOORS]):
summary.update({entity_group.name.lower(): entity_group.summarize_states()})
# TODO Section End ########
for key in list(summary.keys()):
if key not in ['step', 'walls', 'doors', 'agents', 'items', 'dirtPiles', 'batteries']:
del summary[key]
return summary
def print(self, string):
if self.conf.verbose:
print(string)
def save_params(self, filepath: Path):
# noinspection PyProtectedMember
filepath = Path(filepath)
filepath.parent.mkdir(parents=True, exist_ok=True)
shutil.copyfile(self._config_file, filepath)