Merge branch 'main' into unit_testing

# Conflicts:
#	marl_factory_grid/environment/factory.py
#	marl_factory_grid/utils/config_parser.py
#	marl_factory_grid/utils/states.py
This commit is contained in:
Chanumask
2023-11-10 10:54:00 +01:00
97 changed files with 1088 additions and 1239 deletions

View File

@ -0,0 +1,3 @@
from . import helpers as h
from . import helpers
from .results import Result, DoneResult, ActionResult, TickResult

View File

@ -1,4 +1,5 @@
import ast
from os import PathLike
from pathlib import Path
from typing import Union, List
@ -9,18 +10,17 @@ from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.environment.tests import Test
from marl_factory_grid.utils.helpers import locate_and_import_class
DEFAULT_PATH = 'environment'
MODULE_PATH = 'modules'
from marl_factory_grid.environment.constants import DEFAULT_PATH, MODULE_PATH
from marl_factory_grid.environment import constants as c
class FactoryConfigParser(object):
default_entites = []
default_rules = ['MaxStepsReached', 'Collision']
default_rules = ['DoneAtMaxStepsReached', 'WatchCollision']
default_actions = [c.MOVE8, c.NOOP]
default_observations = [c.WALLS, c.AGENT]
def __init__(self, config_path, custom_modules_path: Union[None, PathLike] = None):
def __init__(self, config_path, custom_modules_path: Union[PathLike] = None):
self.config_path = Path(config_path)
self.custom_modules_path = Path(custom_modules_path) if custom_modules_path is not None else custom_modules_path
self.config = yaml.safe_load(self.config_path.open())
@ -44,6 +44,10 @@ class FactoryConfigParser(object):
def rules(self):
return self.config['Rules']
@property
def tests(self):
return self.config.get('Tests', [])
@property
def agents(self):
return self.config['Agents']
@ -56,10 +60,12 @@ class FactoryConfigParser(object):
return str(self.config)
def __getitem__(self, item):
return self.config[item]
try:
return self.config[item]
except KeyError:
print(f'The mandatory {item} section could not be found in your .config gile. Check Spelling!')
def load_entities(self):
# entites = Entities()
entity_classes = dict()
entities = []
if c.DEFAULTS in self.entities:
@ -67,28 +73,40 @@ class FactoryConfigParser(object):
entities.extend(x for x in self.entities if x != c.DEFAULTS)
for entity in entities:
e1 = e2 = e3 = None
try:
folder_path = Path(__file__).parent.parent / DEFAULT_PATH
entity_class = locate_and_import_class(entity, folder_path)
except AttributeError as e1:
except AttributeError as e:
e1 = e
try:
folder_path = Path(__file__).parent.parent / MODULE_PATH
entity_class = locate_and_import_class(entity, folder_path)
except AttributeError as e2:
try:
folder_path = self.custom_modules_path
entity_class = locate_and_import_class(entity, folder_path)
except AttributeError as e3:
ents = [y for x in [e1.argss[1], e2.argss[1], e3.argss[1]] for y in x]
print('### Error ### Error ### Error ### Error ### Error ###')
print()
print(f'Class "{entity}" was not found in "{folder_path.name}"')
print('Possible Entitys are:', str(ents))
print()
print('Goodbye')
print()
exit()
# raise AttributeError(e1.argss[0], e2.argss[0], e3.argss[0], 'Possible Entitys are:', str(ents))
module_path = Path(__file__).parent.parent / MODULE_PATH
entity_class = locate_and_import_class(entity, module_path)
except AttributeError as e:
e2 = e
if self.custom_modules_path:
try:
entity_class = locate_and_import_class(entity, self.custom_modules_path)
except AttributeError as e:
e3 = e
pass
if (e1 and e2) or e3:
ents = [y for x in [e1, e2, e3] if x is not None for y in x.args[1]]
print('##############################################################')
print('### Error ### Error ### Error ### Error ### Error ###')
print('##############################################################')
print(f'Class "{entity}" was not found in "{module_path.name}"')
print(f'Class "{entity}" was not found in "{folder_path.name}"')
print('##############################################################')
if self.custom_modules_path:
print(f'Class "{entity}" was not found in "{self.custom_modules_path}"')
print('Possible Entitys are:', str(ents))
print('##############################################################')
print('Goodbye')
print('##############################################################')
print('### Error ### Error ### Error ### Error ### Error ###')
print('##############################################################')
exit(-99999)
entity_kwargs = self.entities.get(entity, {})
entity_symbol = entity_class.symbol if hasattr(entity_class, 'symbol') else None
@ -126,7 +144,12 @@ class FactoryConfigParser(object):
observations.extend(self.default_observations)
observations.extend(x for x in self.agents[name]['Observations'] if x != c.DEFAULTS)
positions = [ast.literal_eval(x) for x in self.agents[name].get('Positions', [])]
parsed_agents_conf[name] = dict(actions=parsed_actions, observations=observations, positions=positions)
other_kwargs = {k: v for k, v in self.agents[name].items() if k not in
['Actions', 'Observations', 'Positions']}
parsed_agents_conf[name] = dict(
actions=parsed_actions, observations=observations, positions=positions, other=other_kwargs
)
return parsed_agents_conf
def load_env_rules(self) -> List[Rule]:
@ -137,28 +160,69 @@ class FactoryConfigParser(object):
rules.append({rule: {}})
return self._load_smth(rules, Rule)
pass
def load_env_tests(self) -> List[Test]:
def load_env_tests(self) -> List[Rule]:
return self._load_smth(self.tests, None) # Test
pass
def _load_smth(self, config, class_obj):
rules = list()
rules_names = list()
for rule in rules_names:
for rule in config:
e1 = e2 = e3 = None
try:
folder_path = (Path(__file__).parent.parent / DEFAULT_PATH)
rule_class = locate_and_import_class(rule, folder_path)
except AttributeError:
except AttributeError as e:
e1 = e
try:
folder_path = (Path(__file__).parent.parent / MODULE_PATH)
rule_class = locate_and_import_class(rule, folder_path)
except AttributeError:
rule_class = locate_and_import_class(rule, self.custom_modules_path)
# Fixme This check does not work!
# assert isinstance(rule_class, class_obj), f'{rule_class.__name__} is no valid "class_obj.__name__".'
rule_kwargs = config.get(rule, {})
rules.append(rule_class(**rule_kwargs))
module_path = (Path(__file__).parent.parent / MODULE_PATH)
rule_class = locate_and_import_class(rule, module_path)
except AttributeError as e:
e2 = e
if self.custom_modules_path:
try:
rule_class = locate_and_import_class(rule, self.custom_modules_path)
except AttributeError as e:
e3 = e
pass
if (e1 and e2) or e3:
ents = [y for x in [e1, e2, e3] if x is not None for y in x.args[1]]
print('### Error ### Error ### Error ### Error ### Error ###')
print('')
print(f'Class "{rule}" was not found in "{module_path.name}"')
print(f'Class "{rule}" was not found in "{folder_path.name}"')
if self.custom_modules_path:
print(f'Class "{rule}" was not found in "{self.custom_modules_path}"')
print('Possible Entitys are:', str(ents))
print('')
print('Goodbye')
print('')
exit(-99999)
if issubclass(rule_class, class_obj):
rule_kwargs = config.get(rule, {})
rules.append(rule_class(**(rule_kwargs or {})))
return rules
def load_entity_spawn_rules(self, entities) -> List[Rule]:
rules = list()
rules_dicts = list()
for e in entities:
try:
if spawn_rule := e.spawn_rule:
rules_dicts.append(spawn_rule)
except AttributeError:
pass
for rule_dict in rules_dicts:
for rule_name, rule_kwargs in rule_dict.items():
try:
folder_path = (Path(__file__).parent.parent / DEFAULT_PATH)
rule_class = locate_and_import_class(rule_name, folder_path)
except AttributeError:
try:
folder_path = (Path(__file__).parent.parent / MODULE_PATH)
rule_class = locate_and_import_class(rule_name, folder_path)
except AttributeError:
rule_class = locate_and_import_class(rule_name, self.custom_modules_path)
rules.append(rule_class(**rule_kwargs))
return rules

View File

@ -2,7 +2,7 @@ import importlib
from collections import defaultdict
from pathlib import PurePath, Path
from typing import Union, Dict, List
from typing import Union, Dict, List, Iterable, Callable
import numpy as np
from numpy.typing import ArrayLike
@ -61,8 +61,8 @@ class ObservationTranslator:
:param per_agent_named_obs_spaces: `Named observation space` one for each agent. Overloaded.
type per_agent_named_obs_spaces: Dict[str, dict]
:param placeholder_fill_value: Currently not fully implemented!!!
:type placeholder_fill_value: Union[int, str] = 'N')
:param placeholder_fill_value: Currently, not fully implemented!!!
:type placeholder_fill_value: Union[int, str] = 'N'
"""
if isinstance(placeholder_fill_value, str):
@ -222,7 +222,7 @@ def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''):
mod = importlib.import_module('.'.join(module_parts))
all_found_modules.extend([x for x in dir(mod) if (not(x.startswith('__') or len(x) <= 2) and x.istitle())
and x not in ['Entity', 'NamedTuple', 'List', 'Rule', 'Union',
'TickResult', 'ActionResult', 'Action', 'Agent', 'BoundEntityMixin',
'TickResult', 'ActionResult', 'Action', 'Agent',
'RenderEntity', 'TemplateRule', 'Objects', 'PositionMixin',
'IsBoundMixin', 'EnvObject', 'EnvObjects', 'Dict', 'Any'
]])
@ -240,7 +240,13 @@ def add_bound_name(name_str, bound_e):
def add_pos_name(name_str, bound_e):
if bound_e.var_has_position:
return f'{name_str}({bound_e.pos})'
return f'{name_str}@{bound_e.pos}'
return name_str
def get_first(iterable: Iterable, filter_by: Callable[[any], bool] = lambda _: True):
return next((x for x in iterable if filter_by(x)), None)
def get_first_index(iterable: Iterable, filter_by: Callable[[any], bool] = lambda _: True):
return next((idx for idx, x in enumerate(iterable) if filter_by(x)), None)

View File

@ -47,6 +47,7 @@ class LevelParser(object):
# All other
for es_name in self.e_p_dict:
e_class, e_kwargs = self.e_p_dict[es_name]['class'], self.e_p_dict[es_name]['kwargs']
e_kwargs = e_kwargs if e_kwargs else {}
if hasattr(e_class, 'symbol') and e_class.symbol is not None:
symbols = e_class.symbol

View File

@ -9,7 +9,7 @@ from marl_factory_grid.utils.helpers import IGNORED_DF_COLUMNS
import pandas as pd
from marl_factory_grid.utils.plotting.compare_runs import plot_single_run
from marl_factory_grid.utils.plotting.plot_single_runs import plot_single_run
class EnvMonitor(Wrapper):
@ -22,7 +22,6 @@ class EnvMonitor(Wrapper):
self._monitor_df = pd.DataFrame()
self._monitor_dict = dict()
def step(self, action):
obs_type, obs, reward, done, info = self.env.step(action)
self._read_info(info)

View File

@ -2,11 +2,9 @@ from os import PathLike
from pathlib import Path
from typing import Union, List
import yaml
from gymnasium import Wrapper
import numpy as np
import pandas as pd
from gymnasium import Wrapper
class EnvRecorder(Wrapper):
@ -106,7 +104,7 @@ class EnvRecorder(Wrapper):
out_dict = {'episodes': self._recorder_out_list}
out_dict.update(
{'n_episodes': self._curr_episode,
'metadata':dict(
'metadata': dict(
level_name=self.env.params['General']['level_name'],
verbose=False,
n_agents=len(self.env.params['Agents']),

View File

@ -1,17 +1,16 @@
import math
import re
from collections import defaultdict
from itertools import product
from typing import Dict, List
import numpy as np
from numba import njit
from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.entity.object import Object
from marl_factory_grid.environment.groups.utils import Combined
import marl_factory_grid.utils.helpers as h
from marl_factory_grid.utils.states import Gamestate
from marl_factory_grid.utils.utility_classes import Floor
from marl_factory_grid.utils.ray_caster import RayCaster
from marl_factory_grid.utils.states import Gamestate
from marl_factory_grid.utils import helpers as h
class OBSBuilder(object):
@ -77,11 +76,13 @@ class OBSBuilder(object):
def place_entity_in_observation(self, obs_array, agent, e):
x, y = (e.x - agent.x) + self.pomdp_r, (e.y - agent.y) + self.pomdp_r
try:
obs_array[x, y] += e.encoding
except IndexError:
# Seemded to be visible but is out of range
pass
if not min([y, x]) < 0:
try:
obs_array[x, y] += e.encoding
except IndexError:
# Seemded to be visible but is out of range
pass
pass
def build_for_agent(self, agent, state) -> (List[str], np.ndarray):
assert self._curr_env_step == state.curr_step, (
@ -121,18 +122,24 @@ class OBSBuilder(object):
e = self.all_obs[l_name]
except KeyError:
try:
# Look for bound entity names!
pattern = re.compile(f'{re.escape(l_name)}(.*){re.escape(agent.name)}')
name = next((x for x in self.all_obs if pattern.search(x)), None)
# Look for bound entity REPRs!
pattern = re.compile(f'{re.escape(l_name)}'
f'{re.escape("[")}(.*){re.escape("]")}'
f'{re.escape("(")}{re.escape(agent.name)}{re.escape(")")}')
name = next((key for key, val in self.all_obs.items()
if pattern.search(str(val)) and isinstance(val, Object)), None)
e = self.all_obs[name]
except KeyError:
try:
e = next(v for k, v in self.all_obs.items() if l_name in k and agent.name in k)
except StopIteration:
raise KeyError(
f'Check for spelling errors! \n '
f'No combination of "{l_name} and {agent.name}" could not be found in:\n '
f'{list(dict(self.all_obs).keys())}')
print(f'# Check for spelling errors!')
print(f'# No combination of "{l_name}" and "{agent.name}" could not be found in:')
print(f'# {list(dict(self.all_obs).keys())}')
print('#')
print('# exiting...')
print('#')
exit(-99999)
try:
positional = e.var_has_position
@ -161,31 +168,30 @@ class OBSBuilder(object):
try:
light_map = np.zeros(self.obs_shape)
visible_floor = self.ray_caster[agent.name].visible_entities(self._floortiles, reset_cache=False)
if self.pomdp_r:
for f in set(visible_floor):
self.place_entity_in_observation(light_map, agent, f)
else:
for f in set(visible_floor):
light_map[f.x, f.y] += f.encoding
for f in set(visible_floor):
self.place_entity_in_observation(light_map, agent, f)
# else:
# for f in set(visible_floor):
# light_map[f.x, f.y] += f.encoding
self.curr_lightmaps[agent.name] = light_map
except (KeyError, ValueError):
print()
pass
return obs, self.obs_layers[agent.name]
def _sort_and_name_observation_conf(self, agent):
'''
"""
Builds the useable observation scheme per agent from conf.yaml.
:param agent:
:return:
'''
"""
# Fixme: no asymetric shapes possible.
self.ray_caster[agent.name] = RayCaster(agent, min(self.obs_shape))
obs_layers = []
for obs_str in agent.observations:
if isinstance(obs_str, dict):
obs_str, vals = next(obs_str.items().__iter__())
obs_str, vals = h.get_first(obs_str.items())
else:
vals = None
if obs_str == c.SELF:
@ -214,129 +220,3 @@ class OBSBuilder(object):
obs_layers.append(obs_str)
self.obs_layers[agent.name] = obs_layers
self.curr_lightmaps[agent.name] = np.zeros(self.obs_shape)
class RayCaster:
def __init__(self, agent, pomdp_r, degs=360):
self.agent = agent
self.pomdp_r = pomdp_r
self.n_rays = (self.pomdp_r + 1) * 8
self.degs = degs
self.ray_targets = self.build_ray_targets()
self.obs_shape_cube = np.array([self.pomdp_r, self.pomdp_r])
self._cache_dict = {}
def __repr__(self):
return f'{self.__class__.__name__}({self.agent.name})'
def build_ray_targets(self):
north = np.array([0, -1]) * self.pomdp_r
thetas = [np.deg2rad(deg) for deg in np.linspace(-self.degs // 2, self.degs // 2, self.n_rays)[::-1]]
rot_M = [
[[math.cos(theta), -math.sin(theta)],
[math.sin(theta), math.cos(theta)]] for theta in thetas
]
rot_M = np.stack(rot_M, 0)
rot_M = np.unique(np.round(rot_M @ north), axis=0)
return rot_M.astype(int)
def ray_block_cache(self, key, callback):
if key not in self._cache_dict:
self._cache_dict[key] = callback()
return self._cache_dict[key]
def visible_entities(self, pos_dict, reset_cache=True):
visible = list()
if reset_cache:
self._cache_dict = {}
for ray in self.get_rays():
rx, ry = ray[0]
for x, y in ray:
cx, cy = x - rx, y - ry
entities_hit = pos_dict[(x, y)]
hits = self.ray_block_cache((x, y),
lambda: any(True for e in entities_hit if e.var_is_blocking_light)
)
diag_hits = all([
self.ray_block_cache(
key,
lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light) and bool(
pos_dict[key]))
for key in ((x, y - cy), (x - cx, y))
]) if (cx != 0 and cy != 0) else False
visible += entities_hit if not diag_hits else []
if hits or diag_hits:
break
rx, ry = x, y
return visible
def get_rays(self):
a_pos = self.agent.pos
outline = self.ray_targets + a_pos
return self.bresenham_loop(a_pos, outline)
# todo do this once and cache the points!
def get_fov_outline(self) -> np.ndarray:
return self.ray_targets + self.agent.pos
def get_square_outline(self):
agent = self.agent
x_coords = range(agent.x - self.pomdp_r, agent.x + self.pomdp_r + 1)
y_coords = range(agent.y - self.pomdp_r, agent.y + self.pomdp_r + 1)
outline = list(product(x_coords, [agent.y - self.pomdp_r, agent.y + self.pomdp_r])) \
+ list(product([agent.x - self.pomdp_r, agent.x + self.pomdp_r], y_coords))
return outline
@staticmethod
@njit
def bresenham_loop(a_pos, points):
results = []
for end in points:
x1, y1 = a_pos
x2, y2 = end
dx = x2 - x1
dy = y2 - y1
# Determine how steep the line is
is_steep = abs(dy) > abs(dx)
# Rotate line
if is_steep:
x1, y1 = y1, x1
x2, y2 = y2, x2
# Swap start and end points if necessary and store swap state
swapped = False
if x1 > x2:
x1, x2 = x2, x1
y1, y2 = y2, y1
swapped = True
# Recalculate differentials
dx = x2 - x1
dy = y2 - y1
# Calculate error
error = int(dx / 2.0)
ystep = 1 if y1 < y2 else -1
# Iterate over bounding box generating points between start and end
y = y1
points = []
for x in range(int(x1), int(x2) + 1):
coord = [y, x] if is_steep else [x, y]
points.append(coord)
error -= abs(dy)
if error < 0:
y += ystep
error += dx
# Reverse the list if the coordinates were swapped
if swapped:
points.reverse()
results.append(points)
return results

View File

@ -7,50 +7,11 @@ from typing import Union, List
import pandas as pd
from marl_factory_grid.utils.helpers import IGNORED_DF_COLUMNS
from marl_factory_grid.utils.plotting.plotting import prepare_plot
from marl_factory_grid.utils.plotting.plotting_utils import prepare_plot
MODEL_MAP = None
def plot_single_run(run_path: Union[str, PathLike], use_tex: bool = False, column_keys=None):
run_path = Path(run_path)
df_list = list()
if run_path.is_dir():
monitor_file = next(run_path.glob('*monitor*.pick'))
elif run_path.exists() and run_path.is_file():
monitor_file = run_path
else:
raise ValueError
with monitor_file.open('rb') as f:
monitor_df = pickle.load(f)
monitor_df = monitor_df.fillna(0)
df_list.append(monitor_df)
df = pd.concat(df_list, ignore_index=True)
df = df.fillna(0).rename(columns={'episode': 'Episode'}).sort_values(['Episode'])
if column_keys is not None:
columns = [col for col in column_keys if col in df.columns]
else:
columns = [col for col in df.columns if col not in IGNORED_DF_COLUMNS]
roll_n = 50
non_overlapp_window = df.groupby(['Episode']).rolling(roll_n, min_periods=1).mean()
df_melted = df[columns + ['Episode']].reset_index().melt(
id_vars=['Episode'], value_vars=columns, var_name="Measurement", value_name="Score"
)
if df_melted['Episode'].max() > 800:
skip_n = round(df_melted['Episode'].max() * 0.02)
df_melted = df_melted[df_melted['Episode'] % skip_n == 0]
prepare_plot(run_path.parent / f'{run_path.parent.name}_monitor_lineplot.png', df_melted, use_tex=use_tex)
print('Plotting done.')
def compare_seed_runs(run_path: Union[str, PathLike], use_tex: bool = False):
run_path = Path(run_path)
df_list = list()

View File

@ -0,0 +1,48 @@
import pickle
from os import PathLike
from pathlib import Path
from typing import Union
import pandas as pd
from marl_factory_grid.utils.helpers import IGNORED_DF_COLUMNS
from marl_factory_grid.utils.plotting.plotting_utils import prepare_plot
def plot_single_run(run_path: Union[str, PathLike], use_tex: bool = False, column_keys=None,
file_key: str ='monitor', file_ext: str ='pkl'):
run_path = Path(run_path)
df_list = list()
if run_path.is_dir():
monitor_file = next(run_path.glob(f'*{file_key}*.{file_ext}'))
elif run_path.exists() and run_path.is_file():
monitor_file = run_path
else:
raise ValueError
with monitor_file.open('rb') as f:
monitor_df = pickle.load(f)
monitor_df = monitor_df.fillna(0)
df_list.append(monitor_df)
df = pd.concat(df_list, ignore_index=True)
df = df.fillna(0).rename(columns={'episode': 'Episode'}).sort_values(['Episode'])
if column_keys is not None:
columns = [col for col in column_keys if col in df.columns]
else:
columns = [col for col in df.columns if col not in IGNORED_DF_COLUMNS]
# roll_n = 50
# non_overlapp_window = df.groupby(['Episode']).rolling(roll_n, min_periods=1).mean()
df_melted = df[columns + ['Episode']].reset_index().melt(
id_vars=['Episode'], value_vars=columns, var_name="Measurement", value_name="Score"
)
if df_melted['Episode'].max() > 800:
skip_n = round(df_melted['Episode'].max() * 0.02)
df_melted = df_melted[df_melted['Episode'] % skip_n == 0]
prepare_plot(run_path.parent / f'{run_path.parent.name}_monitor_lineplot.png', df_melted, use_tex=use_tex)
print('Plotting done.')

View File

@ -60,7 +60,7 @@ def prepare_center_double_column_legend(df, hue, style, hue_order):
print('Struggling to plot Figure using LaTeX - going back to normal.')
plt.close('all')
sns.set(rc={'text.usetex': False}, style='whitegrid')
fig = plt.figure(figsize=(10, 11))
_ = plt.figure(figsize=(10, 11))
lineplot = sns.lineplot(data=df, x='Episode', y='Score', hue=hue, style=style,
ci=95, palette=PALETTE, hue_order=hue_order, legend=False)
# plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0)

View File

@ -19,7 +19,7 @@ class RayCaster:
return f'{self.__class__.__name__}({self.agent.name})'
def build_ray_targets(self):
north = np.array([0, -1])*self.pomdp_r
north = np.array([0, -1]) * self.pomdp_r
thetas = [np.deg2rad(deg) for deg in np.linspace(-self.degs // 2, self.degs // 2, self.n_rays)[::-1]]
rot_M = [
[[math.cos(theta), -math.sin(theta)],
@ -39,8 +39,9 @@ class RayCaster:
if reset_cache:
self._cache_dict = dict()
for ray in self.get_rays():
for ray in self.get_rays(): # Do not check, just trust.
rx, ry = ray[0]
# self.ray_block_cache(ray[0], lambda: False) We do not do that, because of doors etc...
for x, y in ray:
cx, cy = x - rx, y - ry
@ -52,8 +53,9 @@ class RayCaster:
diag_hits = all([
self.ray_block_cache(
key,
lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light))
for key in ((x, y-cy), (x-cx, y))
# lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light)
lambda: any(True for e in pos_dict[key] if e.var_is_blocking_light))
for key in ((x, y - cy), (x - cx, y))
]) if (cx != 0 and cy != 0) else False
visible += entities_hit if not diag_hits else []
@ -75,8 +77,8 @@ class RayCaster:
agent = self.agent
x_coords = range(agent.x - self.pomdp_r, agent.x + self.pomdp_r + 1)
y_coords = range(agent.y - self.pomdp_r, agent.y + self.pomdp_r + 1)
outline = list(product(x_coords, [agent.y - self.pomdp_r, agent.y + self.pomdp_r])) \
+ list(product([agent.x - self.pomdp_r, agent.x + self.pomdp_r], y_coords))
outline = list(product(x_coords, [agent.y - self.pomdp_r, agent.y + self.pomdp_r]))
outline += list(product([agent.x - self.pomdp_r, agent.x + self.pomdp_r], y_coords))
return outline
@staticmethod

View File

@ -31,7 +31,7 @@ class Renderer:
def __init__(self, lvl_shape: Tuple[int, int] = (16, 16),
lvl_padded_shape: Union[Tuple[int, int], None] = None,
cell_size: int = 40, fps: int = 7,
cell_size: int = 40, fps: int = 7, factor: float = 0.9,
grid_lines: bool = True, view_radius: int = 2):
# TODO: Customn_assets paths
self.grid_h, self.grid_w = lvl_shape
@ -45,7 +45,7 @@ class Renderer:
self.screen = pygame.display.set_mode(self.screen_size)
self.clock = pygame.time.Clock()
assets = list(self.ASSETS.rglob('*.png'))
self.assets = {path.stem: self.load_asset(str(path), 1) for path in assets}
self.assets = {path.stem: self.load_asset(str(path), factor) for path in assets}
self.fill_bg()
now = time.time()
@ -110,22 +110,22 @@ class Renderer:
pygame.quit()
sys.exit()
self.fill_bg()
blits = deque()
for entity in [x for x in entities]:
bp = self.blit_params(entity)
blits.append(bp)
if entity.name.lower() == AGENT:
if self.view_radius > 0:
vis_rects = self.visibility_rects(bp, entity.aux)
blits.extendleft(vis_rects)
if entity.state != BLANK:
agent_state_blits = self.blit_params(
RenderEntity(entity.state, (entity.pos[0] + 0.12, entity.pos[1]), 0.48, SCALE)
)
textsurface = self.font.render(str(entity.id), False, (0, 0, 0))
text_blit = dict(source=textsurface, dest=(bp['dest'].center[0]-.07*self.cell_size,
bp['dest'].center[1]))
blits += [agent_state_blits, text_blit]
# First all others
blits = deque(self.blit_params(x) for x in entities if not x.name.lower() == AGENT)
# Then Agents, so that agents are rendered on top.
for agent in (x for x in entities if x.name.lower() == AGENT):
agent_blit = self.blit_params(agent)
if self.view_radius > 0:
vis_rects = self.visibility_rects(agent_blit, agent.aux)
blits.extendleft(vis_rects)
if agent.state != BLANK:
state_blit = self.blit_params(
RenderEntity(agent.state, (agent.pos[0] + 0.12, agent.pos[1]), 0.48, SCALE)
)
textsurface = self.font.render(str(agent.id), False, (0, 0, 0))
text_blit = dict(source=textsurface, dest=(agent_blit['dest'].center[0]-.07*self.cell_size,
agent_blit['dest'].center[1]))
blits += [agent_blit, state_blit, text_blit]
for blit in blits:
self.screen.blit(**blit)

View File

@ -1,9 +1,12 @@
from typing import Union
from dataclasses import dataclass
from marl_factory_grid.environment.entity.object import Object
TYPE_VALUE = 'value'
TYPE_REWARD = 'reward'
types = [TYPE_VALUE, TYPE_REWARD]
TYPES = [TYPE_VALUE, TYPE_REWARD]
@dataclass
class InfoObject:
@ -18,17 +21,21 @@ class Result:
validity: bool
reward: Union[float, None] = None
value: Union[float, None] = None
entity: None = None
entity: Object = None
def get_infos(self):
n = self.entity.name if self.entity is not None else "Global"
return [InfoObject(identifier=f'{n}_{self.identifier}_{t}',
val_type=t, value=self.__getattribute__(t)) for t in types
# Return multiple Info Dicts
return [InfoObject(identifier=f'{n}_{self.identifier}',
val_type=t, value=self.__getattribute__(t)) for t in TYPES
if self.__getattribute__(t) is not None]
def __repr__(self):
valid = "not " if not self.validity else ""
return f'{self.__class__.__name__}({self.identifier.capitalize()} {valid}valid: {self.reward})'
reward = f" | Reward: {self.reward}" if self.reward is not None else ""
value = f" | Value: {self.value}" if self.value is not None else ""
entity = f" | by: {self.entity.name}" if self.entity is not None else ""
return f'{self.__class__.__name__}({self.identifier.capitalize()} {valid}valid{reward}{value}{entity})'
@dataclass

View File

@ -1,9 +1,12 @@
from typing import List, Dict, Tuple
from itertools import islice
from typing import List, Tuple
import numpy as np
from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.entity.entity import Entity
from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.utils.results import Result, DoneResult
from marl_factory_grid.environment.tests import Test
from marl_factory_grid.utils.results import Result
@ -60,7 +63,8 @@ class Gamestate(object):
def moving_entites(self):
return [y for x in self.entities for y in x if x.var_can_move]
def __init__(self, entities, agents_conf, rules: [Rule], tests: [Test], env_seed=69, verbose=False):
def __init__(self, entities, agents_conf, rules: List[Rule], tests: [Test], lvl_shape, env_seed=69, verbose=False):
self.lvl_shape = lvl_shape
self.entities = entities
self.curr_step = 0
self.curr_actions = None
@ -82,7 +86,52 @@ class Gamestate(object):
def __repr__(self):
return f'{self.__class__.__name__}({len(self.entities)} Entitites @ Step {self.curr_step})'
def tick(self, actions) -> List[Result]:
@property
def random_free_position(self) -> (int, int):
"""
Returns a single **free** position (x, y), which is **free** for spawning or walking.
No Entity at this position posses *var_is_blocking_pos* or *var_can_collide*.
:return: Single **free** position.
"""
return self.get_n_random_free_positions(1)[0]
def get_n_random_free_positions(self, n) -> list[tuple[int, int]]:
"""
Returns a list of *n* **free** positions [(x, y), ... ], which are **free** for spawning or walking.
No Entity at this position posses *var_is_blocking_pos* or *var_can_collide*.
:return: List of n **free** position.
"""
return list(islice(self.entities.free_positions_generator, n))
@property
def random_position(self) -> (int, int):
"""
Returns a single available position (x, y), ignores all entity attributes.
:return: Single random position.
"""
return self.get_n_random_positions(1)[0]
def get_n_random_positions(self, n) -> list[tuple[int, int]]:
"""
Returns a list of *n* available positions [(x, y), ... ], ignores all entity attributes.
:return: List of n random positions.
"""
return list(islice(self.entities.floorlist, n))
def tick(self, actions) -> list[Result]:
"""
Performs a single **Gamestate Tick**by calling the inner rule hooks in sequential order.
- tick_pre_step_all: Things to do before the agents do their actions. Statechange, Moving, Spawning etc...
- agent tick: Agents do their actions.
- tick_step_all: Things to do after the agents did their actions. Statechange, Moving, Spawning etc...
- tick_post_step_all: Things to do at the very end of each step. Counting, Reward calculations etc...
:return: List of *Result*-objects.
"""
results = list()
test_results = list()
self.curr_step += 1
@ -112,11 +161,23 @@ class Gamestate(object):
return results
def print(self, string):
def print(self, string) -> None:
"""
When *verbose* is active, print stuff.
:param string: *String* to print.
:type string: str
:return: Nothing
"""
if self.verbose:
print(string)
def check_done(self):
def check_done(self) -> List[DoneResult]:
"""
Iterate all **Rules** that override tehe *on_ckeck_done* hook.
:return: List of Results
"""
results = list()
for rule in self.rules:
if on_check_done_result := rule.on_check_done(self):
@ -124,24 +185,47 @@ class Gamestate(object):
return results
def get_all_pos_with_collisions(self) -> List[Tuple[(int, int)]]:
positions = [pos for pos, entity_list_for_position in self.entities.pos_dict.items()
if any([e.var_can_collide for e in entity_list_for_position])]
"""
Returns a list positions [(x, y), ... ] on which collisions occur. This does not include agents,
that were unable to move because their target direction was blocked, also a form of collision.
:return: List of positions.
"""
positions = [pos for pos, entities in self.entities.pos_dict.items() if
len(entities) >= 2 and (len([e for e in entities if e.var_can_collide]) >= 2)
]
return positions
def check_move_validity(self, moving_entity, position):
if moving_entity.pos != position and not any(
entity.var_is_blocking_pos for entity in self.entities.pos_dict[position]) and not (
moving_entity.var_is_blocking_pos and self.entities.is_occupied(position)):
return True
else:
return False
def check_move_validity(self, moving_entity: Entity, target_position: (int, int)) -> bool:
"""
Whether it is safe to move to the target positions and moving entity does not introduce a blocking attribute,
when position is allready occupied.
def check_pos_validity(self, position):
if not any(entity.var_is_blocking_pos for entity in self.entities.pos_dict[position]):
return True
else:
return False
:param moving_entity: Entity
:param target_position: pos
:return: Safe to move to
"""
is_not_blocked = self.check_pos_validity(target_position)
will_not_block_others = moving_entity.var_is_blocking_pos and self.entities.is_occupied(target_position)
if moving_entity.pos != target_position and is_not_blocked and not will_not_block_others:
return c.VALID
else:
return c.NOT_VALID
def check_pos_validity(self, pos: (int, int)) -> bool:
"""
Check if *pos* is a valid position to move or spawn to.
:param pos: position to check
:return: Wheter pos is a valid target.
"""
if not any(e.var_is_blocking_pos for e in self.entities.pos_dict[pos]) and pos in self.entities.floorlist:
return c.VALID
else:
return c.NOT_VALID
class StepTests:
def __init__(self, *args):

View File

@ -28,7 +28,9 @@ class ConfigExplainer:
def explain_module(self, class_to_explain):
parameters = inspect.signature(class_to_explain).parameters
explained = {class_to_explain.__name__: {key: val.default for key, val in parameters.items() if key not in EXCLUDED}}
explained = {class_to_explain.__name__:
{key: val.default for key, val in parameters.items() if key not in EXCLUDED}
}
return explained
def _load_and_compare(self, compare_class, paths):
@ -135,4 +137,3 @@ if __name__ == '__main__':
ce.get_observations()
ce.get_assets()
all_conf = ce.get_all()
print()

View File

@ -52,3 +52,6 @@ class Floor:
def __hash__(self):
return hash(self.name)
def __repr__(self):
return f"Floor{self.pos}"