Merge branch 'main' into 'unit_testing'

# Conflicts:
#   marl_factory_grid/algorithms/static/TSP_dirt_agent.py
#   marl_factory_grid/utils/config_parser.py
This commit is contained in:
Friedrich, Joel
2024-03-18 16:23:44 +01:00
98 changed files with 2608 additions and 554 deletions

View File

@ -1,3 +1,11 @@
from . import helpers as h
from . import helpers
from .results import Result, DoneResult, ActionResult, TickResult
"""
Utils
=====
Todo
"""

View File

@ -22,6 +22,12 @@ class FactoryConfigParser(object):
default_observations = [c.WALLS, c.AGENT]
def __init__(self, config_path, custom_modules_path: Union[PathLike] = None):
"""
This class parses the factory env config file.
:param config_path: Path to where the 'config.yml' is.
:param custom_modules_path: Additional search path for custom modules, levels, entities, etc..
"""
self.config_path = Path(config_path)
self.custom_modules_path = Path(custom_modules_path) if custom_modules_path is not None else custom_modules_path
self.config = yaml.safe_load(self.config_path.open())
@ -40,7 +46,6 @@ class FactoryConfigParser(object):
self._n_abbr_dict = defaultdict(lambda: 'th', {1: 'st', 2: 'nd', 3: 'rd'})
return self._n_abbr_dict[n]
@property
def agent_actions(self):
return self._get_sub_list('Agents', "Actions")
@ -129,13 +134,25 @@ class FactoryConfigParser(object):
# Actions
conf_actions = self.agents[name]['Actions']
actions = list()
# Actions:
# Allowed
# - Noop
# - Move8
# ----
# Noop:
# South:
# reward_fail: 0.5
# ----
# Forbidden
# - South:
# reward_fail: 0.5
if isinstance(conf_actions, dict):
conf_kwargs = conf_actions.copy()
conf_actions = list(conf_actions.keys())
elif isinstance(conf_actions, list):
conf_kwargs = {}
if isinstance(conf_actions, dict):
if any(isinstance(x, dict) for x in conf_actions):
raise ValueError
pass
for action in conf_actions:
@ -152,11 +169,10 @@ class FactoryConfigParser(object):
except AttributeError:
class_or_classes = locate_and_import_class(action, self.custom_modules_path)
try:
# print(action)
# Handle Lists of Actions (e.g., Move8, Move4, Default)
parsed_actions.extend(class_or_classes)
# print(parsed_actions)
for actions_class in class_or_classes:
# break
conf_kwargs[actions_class.__name__] = conf_kwargs.get(action, {})
except TypeError:
parsed_actions.append(class_or_classes)
@ -174,7 +190,7 @@ class FactoryConfigParser(object):
['Actions', 'Observations', 'Positions', 'Clones']}
parsed_agents_conf[name] = dict(
actions=parsed_actions, observations=observations, positions=positions, other=other_kwargs
)
)
clones = self.agents[name].get('Clones', 0)
if clones:

View File

@ -10,14 +10,14 @@ from marl_factory_grid.environment import constants as c
"""
This file is used for:
1. string based definition
Use a class like `Constants`, to define attributes, which then reveal strings.
These can be used for naming convention along the environments as well as keys for mappings such as dicts etc.
When defining new envs, use class inheritance.
2. utility function definition
There are static utility functions which are not bound to a specific environment.
In this file they are defined to be used across the entire package.
1. string based definition
Use a class like `Constants`, to define attributes, which then reveal strings.
These can be used for naming convention along the environments as well as keys for mappings such as dicts etc.
When defining new envs, use class inheritance.
2. utility function definition
There are static utility functions which are not bound to a specific environment.
In this file they are defined to be used across the entire package.
"""
LEVELS_DIR = 'levels' # for use in studies and experiments
@ -54,15 +54,9 @@ class ObservationTranslator:
A string _identifier based approach is used.
Currently, it is not possible to mix different obs shapes.
:param this_named_observation_space: `Named observation space` of the joined environment.
:type this_named_observation_space: Dict[str, dict]
:param per_agent_named_obs_spaces: `Named observation space` one for each agent. Overloaded.
type per_agent_named_obs_spaces: Dict[str, dict]
:param placeholder_fill_value: Currently, not fully implemented!!!
:type placeholder_fill_value: Union[int, str] = 'N'
"""
if isinstance(placeholder_fill_value, str):

View File

@ -16,7 +16,10 @@ class LevelParser(object):
@property
def pomdp_d(self):
"""
Internal Usage
Calculates the effective diameter of the POMDP observation space.
:return: The calculated effective diameter.
:rtype: int
"""
return self.pomdp_r * 2 + 1

View File

@ -0,0 +1,7 @@
"""
logging
=======
Todo
"""

View File

@ -17,6 +17,9 @@ class EnvMonitor(Wrapper):
ext = 'png'
def __init__(self, env, filepath: Union[str, PathLike] = None):
"""
EnvMonitor is a wrapper for Gymnasium environments that monitors and logs key information during interactions.
"""
super(EnvMonitor, self).__init__(env)
self._filepath = filepath
self._monitor_df = pd.DataFrame()
@ -52,6 +55,14 @@ class EnvMonitor(Wrapper):
return
def save_monitor(self, filepath: Union[Path, str, None] = None, auto_plotting_keys=None):
"""
Saves the monitoring data to a file and optionally generates plots.
:param filepath: The path to save the monitoring data file.
:type filepath: Union[Path, str, None]
:param auto_plotting_keys: Keys to use for automatic plot generation.
:type auto_plotting_keys: Any
"""
filepath = Path(filepath or self._filepath)
filepath.parent.mkdir(exist_ok=True, parents=True)
with filepath.open('wb') as f:

View File

@ -11,6 +11,16 @@ class EnvRecorder(Wrapper):
def __init__(self, env, filepath: Union[str, PathLike] = None,
episodes: Union[List[int], None] = None):
"""
EnvRecorder is a wrapper for OpenAI Gym environments that records state summaries during interactions.
:param env: The environment to record.
:type env: gym.Env
:param filepath: The path to save the recording data file.
:type filepath: Union[str, PathLike]
:param episodes: A list of episode numbers to record. If None, records all episodes.
:type episodes: Union[List[int], None]
"""
super(EnvRecorder, self).__init__(env)
self.filepath = filepath
self.episodes = episodes
@ -19,6 +29,9 @@ class EnvRecorder(Wrapper):
self._recorder_out_list = list()
def reset(self):
"""
Overrides the reset method to reset the environment and recording lists.
"""
self._curr_ep_recorder = list()
self._recorder_out_list = list()
self._curr_episode += 1
@ -26,10 +39,12 @@ class EnvRecorder(Wrapper):
def step(self, actions):
"""
Todo
Overrides the step method to record state summaries during each step.
:param actions:
:return:
:param actions: The actions taken in the environment.
:type actions: Any
:return: The observation, reward, done flag, and additional information.
:rtype: Tuple
"""
obs_type, obs, reward, done, info = self.env.step(actions)
if not self.episodes or self._curr_episode in self.episodes:
@ -55,6 +70,18 @@ class EnvRecorder(Wrapper):
save_occupation_map=False,
save_trajectory_map=False,
):
"""
Saves the recorded data to a file.
:param filepath: The path to save the recording data file.
:type filepath: Union[Path, str, None]
:param only_deltas: If True, saves only the differences between consecutive episodes.
:type only_deltas: bool
:param save_occupation_map: If True, saves an occupation map as a heatmap.
:type save_occupation_map: bool
:param save_trajectory_map: If True, saves a trajectory map.
:type save_trajectory_map: bool
"""
self._finalize()
filepath = Path(filepath or self.filepath)
filepath.parent.mkdir(exist_ok=True, parents=True)
@ -73,7 +100,6 @@ class EnvRecorder(Wrapper):
n_dests=0,
dwell_time=0,
spawn_frequency=0,
spawn_in_other_zone=False,
spawn_mode=''
)
rewards_dest = dict(

View File

@ -19,10 +19,10 @@ class OBSBuilder(object):
@property
def pomdp_d(self):
"""
TODO
Calculates the effective diameter of the POMDP observation space.
:return:
:return: The calculated effective diameter.
:rtype: int
"""
if self.pomdp_r:
return (self.pomdp_r * 2) + 1
@ -31,10 +31,17 @@ class OBSBuilder(object):
def __init__(self, level_shape: np.size, state: Gamestate, pomdp_r: int):
"""
TODO
OBSBuilder
==========
The OBSBuilder class is responsible for constructing observations in the environment.
:return:
:param level_shape: The shape of the level or environment.
:type level_shape: np.size
:param state: The current game state.
:type state: marl_factory_grid.environment.state.Gamestate
:param pomdp_r: The POMDP radius, influencing the size of the observation space.
:type pomdp_r: int
"""
self.all_obs = dict()
self.ray_caster = dict()
@ -52,6 +59,9 @@ class OBSBuilder(object):
self.reset(state)
def reset(self, state):
"""
Resets temporary information and constructs an empty observation array with possible placeholders.
"""
# Reset temporary information
self.curr_lightmaps = dict()
# Construct an empty obs (array) for possible placeholders
@ -61,6 +71,11 @@ class OBSBuilder(object):
return True
def observation_space(self, state):
"""
Returns the observation space for a single agent or a tuple of spaces for multiple agents.
:returns: The observation space for the agent(s).
:rtype: gym.Space|Tuple
"""
from gymnasium.spaces import Tuple, Box
self.reset(state)
obsn = self.build_for_all(state)
@ -71,13 +86,29 @@ class OBSBuilder(object):
return space
def named_observation_space(self, state):
"""
:returns: A dictionary of named observation spaces for all agents.
:rtype: dict
"""
self.reset(state)
return self.build_for_all(state)
def build_for_all(self, state) -> (dict, dict):
"""
Builds observations for all agents in the environment.
:returns: A dictionary of observations for all agents.
:rtype: dict
"""
return {agent.name: self.build_for_agent(agent, state)[0] for agent in state[c.AGENT]}
def build_named_for_all(self, state) -> Dict[str, Dict[str, np.ndarray]]:
"""
Builds named observations for all agents in the environment.
:returns: A dictionary containing named observations for all agents.
:rtype: dict
"""
named_obs_dict = {}
for agent in state[c.AGENT]:
obs, names = self.build_for_agent(agent, state)
@ -85,6 +116,16 @@ class OBSBuilder(object):
return named_obs_dict
def place_entity_in_observation(self, obs_array, agent, e):
"""
Places the encoding of an entity in the observation array relative to the agent's position.
:param obs_array: The observation array.
:type obs_array: np.ndarray
:param agent: the associated agent
:type agent: Agent
:param e: The entity to be placed in the observation.
:type e: Entity
"""
x, y = (e.x - agent.x) + self.pomdp_r, (e.y - agent.y) + self.pomdp_r
if not min([y, x]) < 0:
try:
@ -95,6 +136,12 @@ class OBSBuilder(object):
pass
def build_for_agent(self, agent, state) -> (List[str], np.ndarray):
"""
Builds observations for a specific agent.
:returns: A tuple containing a list of observation names and the corresponding observation array
:rtype: Tuple[List[str], np.ndarray]
"""
try:
agent_want_obs = self.obs_layers[agent.name]
except KeyError:
@ -190,8 +237,8 @@ class OBSBuilder(object):
def _sort_and_name_observation_conf(self, agent):
"""
Builds the useable observation scheme per agent from conf.yaml.
:param agent:
:return:
:param agent: The agent for whom the observation scheme is built.
"""
# Fixme: no asymetric shapes possible.
self.ray_caster[agent.name] = RayCaster(agent, min(self.obs_shape))

View File

@ -0,0 +1,7 @@
"""
PLotting
========
Todo
"""

View File

@ -13,6 +13,16 @@ MODEL_MAP = None
def compare_seed_runs(run_path: Union[str, PathLike], use_tex: bool = False):
"""
Compare multiple runs with different seeds by generating a line plot that shows the evolution of scores (step rewards)
across episodes.
:param run_path: The path to the directory containing the monitor files for each run.
:type run_path: Union[str, PathLike]
:param use_tex: A boolean indicating whether to use TeX formatting in the plot.
:type use_tex: bool
"""
run_path = Path(run_path)
df_list = list()
for run, monitor_file in enumerate(run_path.rglob('monitor*.pick')):
@ -23,7 +33,7 @@ def compare_seed_runs(run_path: Union[str, PathLike], use_tex: bool = False):
monitor_df = monitor_df.fillna(0)
df_list.append(monitor_df)
df = pd.concat(df_list, ignore_index=True)
df = pd.concat(df_list, ignore_index=True)
df = df.fillna(0).rename(columns={'episode': 'Episode', 'run': 'Run'}).sort_values(['Run', 'Episode'])
columns = [col for col in df.columns if col not in IGNORED_DF_COLUMNS]
@ -49,6 +59,19 @@ def compare_seed_runs(run_path: Union[str, PathLike], use_tex: bool = False):
def compare_model_runs(run_path: Path, run_identifier: Union[str, int], parameter: Union[str, List[str]],
use_tex: bool = False):
"""
Compares multiple model runs based on specified parameters by generating a line plot showing the evolution of scores (step rewards)
across episodes.
:param run_path: The path to the directory containing the monitor files for each model run.
:type run_path: Path
:param run_identifier: A string or integer identifying the runs to compare.
:type run_identifier: Union[str, int]
:param parameter: A single parameter or a list of parameters to compare.
:type parameter: Union[str, List[str]]
:param use_tex: A boolean indicating whether to use TeX formatting in the plot.
:type use_tex: bool
"""
run_path = Path(run_path)
df_list = list()
parameter = [parameter] if isinstance(parameter, str) else parameter
@ -89,6 +112,20 @@ def compare_model_runs(run_path: Path, run_identifier: Union[str, int], paramete
def compare_all_parameter_runs(run_root_path: Path, parameter: Union[str, List[str]],
param_names: Union[List[str], None] = None, str_to_ignore='', use_tex: bool = False):
"""
Compares model runs across different parameter settings by generating a line plot showing the evolution of scores across episodes.
:param run_root_path: The root path to the directory containing the monitor files for all model runs.
:type run_root_path: Path
:param parameter: The parameter(s) to compare across different runs.
:type parameter: Union[str, List[str]]
:param param_names: A list of custom names for the parameters to be used as labels in the plot. If None, default names will be assigned.
:type param_names: Union[List[str], None]
:param str_to_ignore: A string pattern to ignore in parameter names.
:type str_to_ignore: str
:param use_tex: A boolean indicating whether to use TeX formatting in the plot.
:type use_tex: bool
"""
run_root_path = Path(run_root_path)
df_list = list()
parameter = [parameter] if isinstance(parameter, str) else parameter

View File

@ -10,7 +10,21 @@ from marl_factory_grid.utils.plotting.plotting_utils import prepare_plot
def plot_single_run(run_path: Union[str, PathLike], use_tex: bool = False, column_keys=None,
file_key: str ='monitor', file_ext: str ='pkl'):
file_key: str = 'monitor', file_ext: str = 'pkl'):
"""
Plots the Epoch score (step reward) over a single run based on monitoring data stored in a file.
:param run_path: The path to the directory containing monitoring data or directly to the monitoring file.
:type run_path: Union[str, PathLike]
:param use_tex: Flag indicating whether to use TeX for plotting.
:type use_tex: bool, optional
:param column_keys: Specific columns to include in the plot. If None, includes all columns except ignored ones.
:type column_keys: list or None, optional
:param file_key: The keyword to identify the monitoring file.
:type file_key: str, optional
:param file_ext: The extension of the monitoring file.
:type file_ext: str, optional
"""
run_path = Path(run_path)
df_list = list()
if run_path.is_dir():
@ -26,7 +40,7 @@ def plot_single_run(run_path: Union[str, PathLike], use_tex: bool = False, colum
monitor_df = monitor_df.fillna(0)
df_list.append(monitor_df)
df = pd.concat(df_list, ignore_index=True)
df = pd.concat(df_list, ignore_index=True)
df = df.fillna(0).rename(columns={'episode': 'Episode'}).sort_values(['Episode'])
if column_keys is not None:
columns = [col for col in column_keys if col in df.columns]

View File

@ -1,7 +1,6 @@
import seaborn as sns
import matplotlib as mpl
from matplotlib import pyplot as plt
PALETTE = 10 * (
"#377eb8",
"#4daf4a",
@ -20,6 +19,14 @@ PALETTE = 10 * (
def plot(filepath, ext='png'):
"""
Saves the current plot to a file and displays it.
:param filepath: The path to save the plot file.
:type filepath: str
:param ext: The file extension of the saved plot. Default is 'png'.
:type ext: str
"""
plt.tight_layout()
figure = plt.gcf()
ax = plt.gca()
@ -35,6 +42,20 @@ def plot(filepath, ext='png'):
def prepare_tex(df, hue, style, hue_order):
"""
Prepares a line plot for rendering in LaTeX.
:param df: The DataFrame containing the data to be plotted.
:type df: pandas.DataFrame
:param hue: Grouping variable that will produce lines with different colors.
:type hue: str
:param style: Grouping variable that will produce lines with different styles.
:type style: str
:param hue_order: Order for the levels of the hue variable in the plot.
:type hue_order: list
:return: The prepared line plot.
:rtype: matplotlib.axes._subplots.AxesSubplot
"""
sns.set(rc={'text.usetex': True}, style='whitegrid')
lineplot = sns.lineplot(data=df, x='Episode', y='Score', ci=95, palette=PALETTE,
hue_order=hue_order, hue=hue, style=style)
@ -45,6 +66,20 @@ def prepare_tex(df, hue, style, hue_order):
def prepare_plt(df, hue, style, hue_order):
"""
Prepares a line plot using matplotlib.
:param df: The DataFrame containing the data to be plotted.
:type df: pandas.DataFrame
:param hue: Grouping variable that will produce lines with different colors.
:type hue: str
:param style: Grouping variable that will produce lines with different styles.
:type style: str
:param hue_order: Order for the levels of the hue variable in the plot.
:type hue_order: list
:return: The prepared line plot.
:rtype: matplotlib.axes._subplots.AxesSubplot
"""
print('Struggling to plot Figure using LaTeX - going back to normal.')
plt.close('all')
sns.set(rc={'text.usetex': False}, style='whitegrid')
@ -57,6 +92,20 @@ def prepare_plt(df, hue, style, hue_order):
def prepare_center_double_column_legend(df, hue, style, hue_order):
"""
Prepares a line plot with a legend centered at the bottom and spread across two columns.
:param df: The DataFrame containing the data to be plotted.
:type df: pandas.DataFrame
:param hue: Grouping variable that will produce lines with different colors.
:type hue: str
:param style: Grouping variable that will produce lines with different styles.
:type style: str
:param hue_order: Order for the levels of the hue variable in the plot.
:type hue_order: list
:return: The prepared line plot.
:rtype: matplotlib.axes._subplots.AxesSubplot
"""
print('Struggling to plot Figure using LaTeX - going back to normal.')
plt.close('all')
sns.set(rc={'text.usetex': False}, style='whitegrid')
@ -70,6 +119,23 @@ def prepare_center_double_column_legend(df, hue, style, hue_order):
def prepare_plot(filepath, results_df, ext='png', hue='Measurement', style=None, use_tex=False):
"""
Prepares a line plot for visualization. Based on the use tex parameter calls the prepare_tex or prepare_plot
function accordingly, followed by the plot function to save the plot.
:param filepath: The file path where the plot will be saved.
:type filepath: str
:param results_df: The DataFrame containing the data to be plotted.
:type results_df: pandas.DataFrame
:param ext: The file extension of the saved plot (default is 'png').
:type ext: str
:param hue: The variable to determine the color of the lines in the plot.
:type hue: str
:param style: The variable to determine the style of the lines in the plot (default is None).
:type style: str or None
:param use_tex: Whether to use LaTeX for text rendering (default is False).
:type use_tex: bool
"""
df = results_df.copy()
df[hue] = df[hue].str.replace('_', '-')
hue_order = sorted(list(df[hue].unique()))

View File

@ -8,10 +8,17 @@ from numba import njit
class RayCaster:
def __init__(self, agent, pomdp_r, degs=360):
"""
TODO
The RayCaster class enables agents in the environment to simulate field-of-view visibility,
providing methods for calculating visible entities and outlining the field of view based on
Bresenham's algorithm.
:return:
:param agent: The agent for which the RayCaster is initialized.
:type agent: Agent
:param pomdp_r: The range of the partially observable Markov decision process (POMDP).
:type pomdp_r: int
:param degs: The degrees of the field of view (FOV). Defaults to 360.
:type degs: int
:return: None
"""
self.agent = agent
self.pomdp_r = pomdp_r
@ -25,6 +32,12 @@ class RayCaster:
return f'{self.__class__.__name__}({self.agent.name})'
def build_ray_targets(self):
"""
Builds the targets for the rays based on the field of view (FOV).
:return: The targets for the rays.
:rtype: np.ndarray
"""
north = np.array([0, -1]) * self.pomdp_r
thetas = [np.deg2rad(deg) for deg in np.linspace(-self.degs // 2, self.degs // 2, self.n_rays)[::-1]]
rot_M = [
@ -36,11 +49,31 @@ class RayCaster:
return rot_M.astype(int)
def ray_block_cache(self, key, callback):
"""
Retrieves or caches a value in the cache dictionary.
:param key: The key for the cache dictionary.
:type key: any
:param callback: The callback function to obtain the value if not present in the cache.
:type callback: callable
:return: The cached or newly computed value.
:rtype: any
"""
if key not in self._cache_dict:
self._cache_dict[key] = callback()
return self._cache_dict[key]
def visible_entities(self, pos_dict, reset_cache=True):
"""
Returns a list of visible entities based on the agent's field of view.
:param pos_dict: The dictionary containing positions of entities.
:type pos_dict: dict
:param reset_cache: Flag to reset the cache. Defaults to True.
:type reset_cache: bool
:return: A list of visible entities.
:rtype: list
"""
visible = list()
if reset_cache:
self._cache_dict = dict()
@ -71,15 +104,33 @@ class RayCaster:
return visible
def get_rays(self):
"""
Gets the rays for the agent.
:return: The rays for the agent.
:rtype: list
"""
a_pos = self.agent.pos
outline = self.ray_targets + a_pos
return self.bresenham_loop(a_pos, outline)
# todo do this once and cache the points!
def get_fov_outline(self) -> np.ndarray:
"""
Gets the field of view (FOV) outline.
:return: The FOV outline.
:rtype: np.ndarray
"""
return self.ray_targets + self.agent.pos
def get_square_outline(self):
"""
Gets the square outline for the agent.
:return: The square outline.
:rtype: list
"""
agent = self.agent
x_coords = range(agent.x - self.pomdp_r, agent.x + self.pomdp_r + 1)
y_coords = range(agent.y - self.pomdp_r, agent.y + self.pomdp_r + 1)
@ -90,6 +141,16 @@ class RayCaster:
@staticmethod
@njit
def bresenham_loop(a_pos, points):
"""
Applies Bresenham's algorithm to calculate the points between two positions.
:param a_pos: The starting position.
:type a_pos: list
:param points: The ending positions.
:type points: list
:return: The list of points between the starting and ending positions.
:rtype: list
"""
results = []
for end in points:
x1, y1 = a_pos

View File

@ -34,12 +34,26 @@ class Renderer:
cell_size: int = 40, fps: int = 7, factor: float = 0.9,
grid_lines: bool = True, view_radius: int = 2):
"""
TODO
The Renderer class initializes and manages the rendering environment for the simulation,
providing methods for preparing entities for display, loading assets, calculating visibility rectangles and
rendering the entities on the screen with specified parameters.
:return:
:param lvl_shape: Tuple representing the shape of the level.
:type lvl_shape: Tuple[int, int]
:param lvl_padded_shape: Optional Tuple representing the padded shape of the level.
:type lvl_padded_shape: Union[Tuple[int, int], None]
:param cell_size: Size of each cell in pixels.
:type cell_size: int
:param fps: Frames per second for rendering.
:type fps: int
:param factor: Factor for resizing assets.
:type factor: float
:param grid_lines: Boolean indicating whether to display grid lines.
:type grid_lines: bool
:param view_radius: Radius for agent's field of view.
:type view_radius: int
"""
# TODO: Customn_assets paths
# TODO: Custom_assets paths
self.grid_h, self.grid_w = lvl_shape
self.lvl_padded_shape = lvl_padded_shape if lvl_padded_shape is not None else lvl_shape
self.cell_size = cell_size
@ -60,6 +74,9 @@ class Renderer:
print('Loading System font with pygame.font.Font took', time.time() - now)
def fill_bg(self):
"""
Fills the background of the screen with the specified BG color.
"""
self.screen.fill(Renderer.BG_COLOR)
if self.grid_lines:
w, h = self.screen_size
@ -69,6 +86,16 @@ class Renderer:
pygame.draw.rect(self.screen, Renderer.WHITE, rect, 1)
def blit_params(self, entity):
"""
Prepares parameters for blitting an entity on the screen. Blitting refers to the process of combining or copying
rectangular blocks of pixels from one part of a graphical buffer to another and is often used to efficiently
update the display by copying pre-drawn or cached images onto the screen.
:param entity: The entity to be blitted.
:type entity: Entity
:return: Dictionary containing source and destination information for blitting.
:rtype: dict
"""
offset_r, offset_c = (self.lvl_padded_shape[0] - self.grid_h) // 2, \
(self.lvl_padded_shape[1] - self.grid_w) // 2
@ -90,12 +117,31 @@ class Renderer:
return dict(source=img, dest=rect)
def load_asset(self, path, factor=1.0):
"""
Loads and resizes an asset from the specified path.
:param path: Path to the asset.
:type path: str
:param factor: Resizing factor for the asset.
:type factor: float
:return: Resized asset.
"""
s = int(factor*self.cell_size)
asset = pygame.image.load(path).convert_alpha()
asset = pygame.transform.smoothscale(asset, (s, s))
return asset
def visibility_rects(self, bp, view):
"""
Calculates the visibility rectangles for an agent.
:param bp: Blit parameters for the agent.
:type bp: dict
:param view: Agent's field of view.
:type view: np.ndarray
:return: List of visibility rectangles.
:rtype: List[dict]
"""
rects = []
for i, j in product(range(-self.view_radius, self.view_radius+1),
range(-self.view_radius, self.view_radius+1)):
@ -111,6 +157,14 @@ class Renderer:
return rects
def render(self, entities):
"""
Renders the entities on the screen.
:param entities: List of entities to be rendered.
:type entities: List[Entity]
:return: Transposed RGB observation array.
:rtype: np.ndarray
"""
for event in pygame.event.get():
if event.type == pygame.QUIT:
pygame.quit()

View File

@ -15,10 +15,12 @@ from marl_factory_grid.utils.results import Result
class StepRules:
def __init__(self, *args):
"""
TODO
Manages a collection of rules to be applied at each step of the environment.
The StepRules class allows you to organize and apply custom rules during the simulation, ensuring that the
corresponding hooks for all rules are called at the appropriate times.
:return:
:param args: Optional Rule objects to initialize the StepRules with.
"""
if args:
self.rules = list(args)
@ -92,10 +94,18 @@ class Gamestate(object):
def __init__(self, entities, agents_conf, rules: List[Rule], tests: [Test], lvl_shape, env_seed=69, verbose=False):
"""
TODO
The `Gamestate` class represents the state of the game environment.
:return:
:param lvl_shape: The shape of the game level.
:type lvl_shape: tuple
:param entities: The entities present in the environment.
:type entities: Entities
:param agents_conf: Agent configurations for the environment.
:type agents_conf: Any
:param verbose: Controls verbosity in the environment.
:type verbose: bool
:param rules: Organizes and applies custom rules during the simulation.
:type rules: StepRules
"""
self.lvl_shape = lvl_shape
self.entities = entities
@ -162,7 +172,7 @@ class Gamestate(object):
def tick(self, actions) -> list[Result]:
"""
Performs a single **Gamestate Tick**by calling the inner rule hooks in sequential order.
Performs a single **Gamestate Tick** by calling the inner rule hooks in sequential order.
- tick_pre_step_all: Things to do before the agents do their actions. Statechange, Moving, Spawning etc...
- agent tick: Agents do their actions.
- tick_step_all: Things to do after the agents did their actions. Statechange, Moving, Spawning etc...

View File

@ -15,7 +15,7 @@ OBSERVATIONS = 'Observations'
RULES = 'Rule'
TESTS = 'Tests'
EXCLUDED = ['identifier', 'args', 'kwargs', 'Move', 'Agent', 'GlobalPositions', 'Walls', 'Gamestate', 'Path',
'Iterable', 'Move', 'Result', 'TemplateRule', 'Entities', 'EnvObjects', 'Zones', 'Collection',
'Iterable', 'Move', 'Result', 'TemplateRule', 'Entities', 'EnvObjects', 'Collection',
'State', 'Object', 'default_valid_reward', 'default_fail_reward', 'size']

View File

@ -6,7 +6,10 @@ import numpy as np
class MarlFrameStack(gym.ObservationWrapper):
"""todo @romue404"""
"""
todo @romue404
"""
def __init__(self, env):
super().__init__(env)