mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-07-06 01:21:36 +02:00
Merge branch 'main' into unit_testing
# Conflicts: # marl_factory_grid/environment/factory.py # marl_factory_grid/utils/config_parser.py # marl_factory_grid/utils/states.py
This commit is contained in:
@ -0,0 +1,3 @@
|
||||
from . import helpers as h
|
||||
from . import helpers
|
||||
from .results import Result, DoneResult, ActionResult, TickResult
|
||||
|
@ -1,4 +1,5 @@
|
||||
import ast
|
||||
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
from typing import Union, List
|
||||
@ -9,18 +10,17 @@ from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.environment.rules import Rule
|
||||
from marl_factory_grid.environment.tests import Test
|
||||
from marl_factory_grid.utils.helpers import locate_and_import_class
|
||||
|
||||
DEFAULT_PATH = 'environment'
|
||||
MODULE_PATH = 'modules'
|
||||
from marl_factory_grid.environment.constants import DEFAULT_PATH, MODULE_PATH
|
||||
from marl_factory_grid.environment import constants as c
|
||||
|
||||
|
||||
class FactoryConfigParser(object):
|
||||
default_entites = []
|
||||
default_rules = ['MaxStepsReached', 'Collision']
|
||||
default_rules = ['DoneAtMaxStepsReached', 'WatchCollision']
|
||||
default_actions = [c.MOVE8, c.NOOP]
|
||||
default_observations = [c.WALLS, c.AGENT]
|
||||
|
||||
def __init__(self, config_path, custom_modules_path: Union[None, PathLike] = None):
|
||||
def __init__(self, config_path, custom_modules_path: Union[PathLike] = None):
|
||||
self.config_path = Path(config_path)
|
||||
self.custom_modules_path = Path(custom_modules_path) if custom_modules_path is not None else custom_modules_path
|
||||
self.config = yaml.safe_load(self.config_path.open())
|
||||
@ -44,6 +44,10 @@ class FactoryConfigParser(object):
|
||||
def rules(self):
|
||||
return self.config['Rules']
|
||||
|
||||
@property
|
||||
def tests(self):
|
||||
return self.config.get('Tests', [])
|
||||
|
||||
@property
|
||||
def agents(self):
|
||||
return self.config['Agents']
|
||||
@ -56,10 +60,12 @@ class FactoryConfigParser(object):
|
||||
return str(self.config)
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.config[item]
|
||||
try:
|
||||
return self.config[item]
|
||||
except KeyError:
|
||||
print(f'The mandatory {item} section could not be found in your .config gile. Check Spelling!')
|
||||
|
||||
def load_entities(self):
|
||||
# entites = Entities()
|
||||
entity_classes = dict()
|
||||
entities = []
|
||||
if c.DEFAULTS in self.entities:
|
||||
@ -67,28 +73,40 @@ class FactoryConfigParser(object):
|
||||
entities.extend(x for x in self.entities if x != c.DEFAULTS)
|
||||
|
||||
for entity in entities:
|
||||
e1 = e2 = e3 = None
|
||||
try:
|
||||
folder_path = Path(__file__).parent.parent / DEFAULT_PATH
|
||||
entity_class = locate_and_import_class(entity, folder_path)
|
||||
except AttributeError as e1:
|
||||
except AttributeError as e:
|
||||
e1 = e
|
||||
try:
|
||||
folder_path = Path(__file__).parent.parent / MODULE_PATH
|
||||
entity_class = locate_and_import_class(entity, folder_path)
|
||||
except AttributeError as e2:
|
||||
try:
|
||||
folder_path = self.custom_modules_path
|
||||
entity_class = locate_and_import_class(entity, folder_path)
|
||||
except AttributeError as e3:
|
||||
ents = [y for x in [e1.argss[1], e2.argss[1], e3.argss[1]] for y in x]
|
||||
print('### Error ### Error ### Error ### Error ### Error ###')
|
||||
print()
|
||||
print(f'Class "{entity}" was not found in "{folder_path.name}"')
|
||||
print('Possible Entitys are:', str(ents))
|
||||
print()
|
||||
print('Goodbye')
|
||||
print()
|
||||
exit()
|
||||
# raise AttributeError(e1.argss[0], e2.argss[0], e3.argss[0], 'Possible Entitys are:', str(ents))
|
||||
module_path = Path(__file__).parent.parent / MODULE_PATH
|
||||
entity_class = locate_and_import_class(entity, module_path)
|
||||
except AttributeError as e:
|
||||
e2 = e
|
||||
if self.custom_modules_path:
|
||||
try:
|
||||
entity_class = locate_and_import_class(entity, self.custom_modules_path)
|
||||
except AttributeError as e:
|
||||
e3 = e
|
||||
pass
|
||||
if (e1 and e2) or e3:
|
||||
ents = [y for x in [e1, e2, e3] if x is not None for y in x.args[1]]
|
||||
print('##############################################################')
|
||||
print('### Error ### Error ### Error ### Error ### Error ###')
|
||||
print('##############################################################')
|
||||
print(f'Class "{entity}" was not found in "{module_path.name}"')
|
||||
print(f'Class "{entity}" was not found in "{folder_path.name}"')
|
||||
print('##############################################################')
|
||||
if self.custom_modules_path:
|
||||
print(f'Class "{entity}" was not found in "{self.custom_modules_path}"')
|
||||
print('Possible Entitys are:', str(ents))
|
||||
print('##############################################################')
|
||||
print('Goodbye')
|
||||
print('##############################################################')
|
||||
print('### Error ### Error ### Error ### Error ### Error ###')
|
||||
print('##############################################################')
|
||||
exit(-99999)
|
||||
|
||||
entity_kwargs = self.entities.get(entity, {})
|
||||
entity_symbol = entity_class.symbol if hasattr(entity_class, 'symbol') else None
|
||||
@ -126,7 +144,12 @@ class FactoryConfigParser(object):
|
||||
observations.extend(self.default_observations)
|
||||
observations.extend(x for x in self.agents[name]['Observations'] if x != c.DEFAULTS)
|
||||
positions = [ast.literal_eval(x) for x in self.agents[name].get('Positions', [])]
|
||||
parsed_agents_conf[name] = dict(actions=parsed_actions, observations=observations, positions=positions)
|
||||
other_kwargs = {k: v for k, v in self.agents[name].items() if k not in
|
||||
['Actions', 'Observations', 'Positions']}
|
||||
parsed_agents_conf[name] = dict(
|
||||
actions=parsed_actions, observations=observations, positions=positions, other=other_kwargs
|
||||
)
|
||||
|
||||
return parsed_agents_conf
|
||||
|
||||
def load_env_rules(self) -> List[Rule]:
|
||||
@ -137,28 +160,69 @@ class FactoryConfigParser(object):
|
||||
rules.append({rule: {}})
|
||||
|
||||
return self._load_smth(rules, Rule)
|
||||
pass
|
||||
|
||||
def load_env_tests(self) -> List[Test]:
|
||||
def load_env_tests(self) -> List[Rule]:
|
||||
return self._load_smth(self.tests, None) # Test
|
||||
pass
|
||||
|
||||
def _load_smth(self, config, class_obj):
|
||||
rules = list()
|
||||
rules_names = list()
|
||||
|
||||
for rule in rules_names:
|
||||
for rule in config:
|
||||
e1 = e2 = e3 = None
|
||||
try:
|
||||
folder_path = (Path(__file__).parent.parent / DEFAULT_PATH)
|
||||
rule_class = locate_and_import_class(rule, folder_path)
|
||||
except AttributeError:
|
||||
except AttributeError as e:
|
||||
e1 = e
|
||||
try:
|
||||
folder_path = (Path(__file__).parent.parent / MODULE_PATH)
|
||||
rule_class = locate_and_import_class(rule, folder_path)
|
||||
except AttributeError:
|
||||
rule_class = locate_and_import_class(rule, self.custom_modules_path)
|
||||
# Fixme This check does not work!
|
||||
# assert isinstance(rule_class, class_obj), f'{rule_class.__name__} is no valid "class_obj.__name__".'
|
||||
rule_kwargs = config.get(rule, {})
|
||||
rules.append(rule_class(**rule_kwargs))
|
||||
module_path = (Path(__file__).parent.parent / MODULE_PATH)
|
||||
rule_class = locate_and_import_class(rule, module_path)
|
||||
except AttributeError as e:
|
||||
e2 = e
|
||||
if self.custom_modules_path:
|
||||
try:
|
||||
rule_class = locate_and_import_class(rule, self.custom_modules_path)
|
||||
except AttributeError as e:
|
||||
e3 = e
|
||||
pass
|
||||
if (e1 and e2) or e3:
|
||||
ents = [y for x in [e1, e2, e3] if x is not None for y in x.args[1]]
|
||||
print('### Error ### Error ### Error ### Error ### Error ###')
|
||||
print('')
|
||||
print(f'Class "{rule}" was not found in "{module_path.name}"')
|
||||
print(f'Class "{rule}" was not found in "{folder_path.name}"')
|
||||
if self.custom_modules_path:
|
||||
print(f'Class "{rule}" was not found in "{self.custom_modules_path}"')
|
||||
print('Possible Entitys are:', str(ents))
|
||||
print('')
|
||||
print('Goodbye')
|
||||
print('')
|
||||
exit(-99999)
|
||||
|
||||
if issubclass(rule_class, class_obj):
|
||||
rule_kwargs = config.get(rule, {})
|
||||
rules.append(rule_class(**(rule_kwargs or {})))
|
||||
return rules
|
||||
|
||||
def load_entity_spawn_rules(self, entities) -> List[Rule]:
|
||||
rules = list()
|
||||
rules_dicts = list()
|
||||
for e in entities:
|
||||
try:
|
||||
if spawn_rule := e.spawn_rule:
|
||||
rules_dicts.append(spawn_rule)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
for rule_dict in rules_dicts:
|
||||
for rule_name, rule_kwargs in rule_dict.items():
|
||||
try:
|
||||
folder_path = (Path(__file__).parent.parent / DEFAULT_PATH)
|
||||
rule_class = locate_and_import_class(rule_name, folder_path)
|
||||
except AttributeError:
|
||||
try:
|
||||
folder_path = (Path(__file__).parent.parent / MODULE_PATH)
|
||||
rule_class = locate_and_import_class(rule_name, folder_path)
|
||||
except AttributeError:
|
||||
rule_class = locate_and_import_class(rule_name, self.custom_modules_path)
|
||||
rules.append(rule_class(**rule_kwargs))
|
||||
return rules
|
||||
|
@ -2,7 +2,7 @@ import importlib
|
||||
|
||||
from collections import defaultdict
|
||||
from pathlib import PurePath, Path
|
||||
from typing import Union, Dict, List
|
||||
from typing import Union, Dict, List, Iterable, Callable
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import ArrayLike
|
||||
@ -61,8 +61,8 @@ class ObservationTranslator:
|
||||
:param per_agent_named_obs_spaces: `Named observation space` one for each agent. Overloaded.
|
||||
type per_agent_named_obs_spaces: Dict[str, dict]
|
||||
|
||||
:param placeholder_fill_value: Currently not fully implemented!!!
|
||||
:type placeholder_fill_value: Union[int, str] = 'N')
|
||||
:param placeholder_fill_value: Currently, not fully implemented!!!
|
||||
:type placeholder_fill_value: Union[int, str] = 'N'
|
||||
"""
|
||||
|
||||
if isinstance(placeholder_fill_value, str):
|
||||
@ -222,7 +222,7 @@ def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''):
|
||||
mod = importlib.import_module('.'.join(module_parts))
|
||||
all_found_modules.extend([x for x in dir(mod) if (not(x.startswith('__') or len(x) <= 2) and x.istitle())
|
||||
and x not in ['Entity', 'NamedTuple', 'List', 'Rule', 'Union',
|
||||
'TickResult', 'ActionResult', 'Action', 'Agent', 'BoundEntityMixin',
|
||||
'TickResult', 'ActionResult', 'Action', 'Agent',
|
||||
'RenderEntity', 'TemplateRule', 'Objects', 'PositionMixin',
|
||||
'IsBoundMixin', 'EnvObject', 'EnvObjects', 'Dict', 'Any'
|
||||
]])
|
||||
@ -240,7 +240,13 @@ def add_bound_name(name_str, bound_e):
|
||||
|
||||
def add_pos_name(name_str, bound_e):
|
||||
if bound_e.var_has_position:
|
||||
return f'{name_str}({bound_e.pos})'
|
||||
return f'{name_str}@{bound_e.pos}'
|
||||
return name_str
|
||||
|
||||
|
||||
def get_first(iterable: Iterable, filter_by: Callable[[any], bool] = lambda _: True):
|
||||
return next((x for x in iterable if filter_by(x)), None)
|
||||
|
||||
|
||||
def get_first_index(iterable: Iterable, filter_by: Callable[[any], bool] = lambda _: True):
|
||||
return next((idx for idx, x in enumerate(iterable) if filter_by(x)), None)
|
||||
|
@ -47,6 +47,7 @@ class LevelParser(object):
|
||||
# All other
|
||||
for es_name in self.e_p_dict:
|
||||
e_class, e_kwargs = self.e_p_dict[es_name]['class'], self.e_p_dict[es_name]['kwargs']
|
||||
e_kwargs = e_kwargs if e_kwargs else {}
|
||||
|
||||
if hasattr(e_class, 'symbol') and e_class.symbol is not None:
|
||||
symbols = e_class.symbol
|
||||
|
@ -9,7 +9,7 @@ from marl_factory_grid.utils.helpers import IGNORED_DF_COLUMNS
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from marl_factory_grid.utils.plotting.compare_runs import plot_single_run
|
||||
from marl_factory_grid.utils.plotting.plot_single_runs import plot_single_run
|
||||
|
||||
|
||||
class EnvMonitor(Wrapper):
|
||||
@ -22,7 +22,6 @@ class EnvMonitor(Wrapper):
|
||||
self._monitor_df = pd.DataFrame()
|
||||
self._monitor_dict = dict()
|
||||
|
||||
|
||||
def step(self, action):
|
||||
obs_type, obs, reward, done, info = self.env.step(action)
|
||||
self._read_info(info)
|
||||
|
@ -2,11 +2,9 @@ from os import PathLike
|
||||
from pathlib import Path
|
||||
from typing import Union, List
|
||||
|
||||
import yaml
|
||||
from gymnasium import Wrapper
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from gymnasium import Wrapper
|
||||
|
||||
|
||||
class EnvRecorder(Wrapper):
|
||||
@ -106,7 +104,7 @@ class EnvRecorder(Wrapper):
|
||||
out_dict = {'episodes': self._recorder_out_list}
|
||||
out_dict.update(
|
||||
{'n_episodes': self._curr_episode,
|
||||
'metadata':dict(
|
||||
'metadata': dict(
|
||||
level_name=self.env.params['General']['level_name'],
|
||||
verbose=False,
|
||||
n_agents=len(self.env.params['Agents']),
|
||||
|
@ -1,17 +1,16 @@
|
||||
import math
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from itertools import product
|
||||
from typing import Dict, List
|
||||
|
||||
import numpy as np
|
||||
from numba import njit
|
||||
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.environment.entity.object import Object
|
||||
from marl_factory_grid.environment.groups.utils import Combined
|
||||
import marl_factory_grid.utils.helpers as h
|
||||
from marl_factory_grid.utils.states import Gamestate
|
||||
from marl_factory_grid.utils.utility_classes import Floor
|
||||
from marl_factory_grid.utils.ray_caster import RayCaster
|
||||
from marl_factory_grid.utils.states import Gamestate
|
||||
from marl_factory_grid.utils import helpers as h
|
||||
|
||||
|
||||
class OBSBuilder(object):
|
||||
@ -77,11 +76,13 @@ class OBSBuilder(object):
|
||||
|
||||
def place_entity_in_observation(self, obs_array, agent, e):
|
||||
x, y = (e.x - agent.x) + self.pomdp_r, (e.y - agent.y) + self.pomdp_r
|
||||
try:
|
||||
obs_array[x, y] += e.encoding
|
||||
except IndexError:
|
||||
# Seemded to be visible but is out of range
|
||||
pass
|
||||
if not min([y, x]) < 0:
|
||||
try:
|
||||
obs_array[x, y] += e.encoding
|
||||
except IndexError:
|
||||
# Seemded to be visible but is out of range
|
||||
pass
|
||||
pass
|
||||
|
||||
def build_for_agent(self, agent, state) -> (List[str], np.ndarray):
|
||||
assert self._curr_env_step == state.curr_step, (
|
||||
@ -121,18 +122,24 @@ class OBSBuilder(object):
|
||||
e = self.all_obs[l_name]
|
||||
except KeyError:
|
||||
try:
|
||||
# Look for bound entity names!
|
||||
pattern = re.compile(f'{re.escape(l_name)}(.*){re.escape(agent.name)}')
|
||||
name = next((x for x in self.all_obs if pattern.search(x)), None)
|
||||
# Look for bound entity REPRs!
|
||||
pattern = re.compile(f'{re.escape(l_name)}'
|
||||
f'{re.escape("[")}(.*){re.escape("]")}'
|
||||
f'{re.escape("(")}{re.escape(agent.name)}{re.escape(")")}')
|
||||
name = next((key for key, val in self.all_obs.items()
|
||||
if pattern.search(str(val)) and isinstance(val, Object)), None)
|
||||
e = self.all_obs[name]
|
||||
except KeyError:
|
||||
try:
|
||||
e = next(v for k, v in self.all_obs.items() if l_name in k and agent.name in k)
|
||||
except StopIteration:
|
||||
raise KeyError(
|
||||
f'Check for spelling errors! \n '
|
||||
f'No combination of "{l_name} and {agent.name}" could not be found in:\n '
|
||||
f'{list(dict(self.all_obs).keys())}')
|
||||
print(f'# Check for spelling errors!')
|
||||
print(f'# No combination of "{l_name}" and "{agent.name}" could not be found in:')
|
||||
print(f'# {list(dict(self.all_obs).keys())}')
|
||||
print('#')
|
||||
print('# exiting...')
|
||||
print('#')
|
||||
exit(-99999)
|
||||
|
||||
try:
|
||||
positional = e.var_has_position
|
||||
@ -161,31 +168,30 @@ class OBSBuilder(object):
|
||||
try:
|
||||
light_map = np.zeros(self.obs_shape)
|
||||
visible_floor = self.ray_caster[agent.name].visible_entities(self._floortiles, reset_cache=False)
|
||||
if self.pomdp_r:
|
||||
for f in set(visible_floor):
|
||||
self.place_entity_in_observation(light_map, agent, f)
|
||||
else:
|
||||
for f in set(visible_floor):
|
||||
light_map[f.x, f.y] += f.encoding
|
||||
|
||||
for f in set(visible_floor):
|
||||
self.place_entity_in_observation(light_map, agent, f)
|
||||
# else:
|
||||
# for f in set(visible_floor):
|
||||
# light_map[f.x, f.y] += f.encoding
|
||||
self.curr_lightmaps[agent.name] = light_map
|
||||
except (KeyError, ValueError):
|
||||
print()
|
||||
pass
|
||||
return obs, self.obs_layers[agent.name]
|
||||
|
||||
def _sort_and_name_observation_conf(self, agent):
|
||||
'''
|
||||
"""
|
||||
Builds the useable observation scheme per agent from conf.yaml.
|
||||
:param agent:
|
||||
:return:
|
||||
'''
|
||||
"""
|
||||
# Fixme: no asymetric shapes possible.
|
||||
self.ray_caster[agent.name] = RayCaster(agent, min(self.obs_shape))
|
||||
obs_layers = []
|
||||
|
||||
for obs_str in agent.observations:
|
||||
if isinstance(obs_str, dict):
|
||||
obs_str, vals = next(obs_str.items().__iter__())
|
||||
obs_str, vals = h.get_first(obs_str.items())
|
||||
else:
|
||||
vals = None
|
||||
if obs_str == c.SELF:
|
||||
@ -214,129 +220,3 @@ class OBSBuilder(object):
|
||||
obs_layers.append(obs_str)
|
||||
self.obs_layers[agent.name] = obs_layers
|
||||
self.curr_lightmaps[agent.name] = np.zeros(self.obs_shape)
|
||||
|
||||
|
||||
class RayCaster:
|
||||
def __init__(self, agent, pomdp_r, degs=360):
|
||||
self.agent = agent
|
||||
self.pomdp_r = pomdp_r
|
||||
self.n_rays = (self.pomdp_r + 1) * 8
|
||||
self.degs = degs
|
||||
self.ray_targets = self.build_ray_targets()
|
||||
self.obs_shape_cube = np.array([self.pomdp_r, self.pomdp_r])
|
||||
self._cache_dict = {}
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.__class__.__name__}({self.agent.name})'
|
||||
|
||||
def build_ray_targets(self):
|
||||
north = np.array([0, -1]) * self.pomdp_r
|
||||
thetas = [np.deg2rad(deg) for deg in np.linspace(-self.degs // 2, self.degs // 2, self.n_rays)[::-1]]
|
||||
rot_M = [
|
||||
[[math.cos(theta), -math.sin(theta)],
|
||||
[math.sin(theta), math.cos(theta)]] for theta in thetas
|
||||
]
|
||||
rot_M = np.stack(rot_M, 0)
|
||||
rot_M = np.unique(np.round(rot_M @ north), axis=0)
|
||||
return rot_M.astype(int)
|
||||
|
||||
def ray_block_cache(self, key, callback):
|
||||
if key not in self._cache_dict:
|
||||
self._cache_dict[key] = callback()
|
||||
return self._cache_dict[key]
|
||||
|
||||
def visible_entities(self, pos_dict, reset_cache=True):
|
||||
visible = list()
|
||||
if reset_cache:
|
||||
self._cache_dict = {}
|
||||
|
||||
for ray in self.get_rays():
|
||||
rx, ry = ray[0]
|
||||
for x, y in ray:
|
||||
cx, cy = x - rx, y - ry
|
||||
|
||||
entities_hit = pos_dict[(x, y)]
|
||||
hits = self.ray_block_cache((x, y),
|
||||
lambda: any(True for e in entities_hit if e.var_is_blocking_light)
|
||||
)
|
||||
|
||||
diag_hits = all([
|
||||
self.ray_block_cache(
|
||||
key,
|
||||
lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light) and bool(
|
||||
pos_dict[key]))
|
||||
for key in ((x, y - cy), (x - cx, y))
|
||||
]) if (cx != 0 and cy != 0) else False
|
||||
|
||||
visible += entities_hit if not diag_hits else []
|
||||
if hits or diag_hits:
|
||||
break
|
||||
rx, ry = x, y
|
||||
return visible
|
||||
|
||||
def get_rays(self):
|
||||
a_pos = self.agent.pos
|
||||
outline = self.ray_targets + a_pos
|
||||
return self.bresenham_loop(a_pos, outline)
|
||||
|
||||
# todo do this once and cache the points!
|
||||
def get_fov_outline(self) -> np.ndarray:
|
||||
return self.ray_targets + self.agent.pos
|
||||
|
||||
def get_square_outline(self):
|
||||
agent = self.agent
|
||||
x_coords = range(agent.x - self.pomdp_r, agent.x + self.pomdp_r + 1)
|
||||
y_coords = range(agent.y - self.pomdp_r, agent.y + self.pomdp_r + 1)
|
||||
outline = list(product(x_coords, [agent.y - self.pomdp_r, agent.y + self.pomdp_r])) \
|
||||
+ list(product([agent.x - self.pomdp_r, agent.x + self.pomdp_r], y_coords))
|
||||
return outline
|
||||
|
||||
@staticmethod
|
||||
@njit
|
||||
def bresenham_loop(a_pos, points):
|
||||
results = []
|
||||
for end in points:
|
||||
x1, y1 = a_pos
|
||||
x2, y2 = end
|
||||
dx = x2 - x1
|
||||
dy = y2 - y1
|
||||
|
||||
# Determine how steep the line is
|
||||
is_steep = abs(dy) > abs(dx)
|
||||
|
||||
# Rotate line
|
||||
if is_steep:
|
||||
x1, y1 = y1, x1
|
||||
x2, y2 = y2, x2
|
||||
|
||||
# Swap start and end points if necessary and store swap state
|
||||
swapped = False
|
||||
if x1 > x2:
|
||||
x1, x2 = x2, x1
|
||||
y1, y2 = y2, y1
|
||||
swapped = True
|
||||
|
||||
# Recalculate differentials
|
||||
dx = x2 - x1
|
||||
dy = y2 - y1
|
||||
|
||||
# Calculate error
|
||||
error = int(dx / 2.0)
|
||||
ystep = 1 if y1 < y2 else -1
|
||||
|
||||
# Iterate over bounding box generating points between start and end
|
||||
y = y1
|
||||
points = []
|
||||
for x in range(int(x1), int(x2) + 1):
|
||||
coord = [y, x] if is_steep else [x, y]
|
||||
points.append(coord)
|
||||
error -= abs(dy)
|
||||
if error < 0:
|
||||
y += ystep
|
||||
error += dx
|
||||
|
||||
# Reverse the list if the coordinates were swapped
|
||||
if swapped:
|
||||
points.reverse()
|
||||
results.append(points)
|
||||
return results
|
||||
|
@ -7,50 +7,11 @@ from typing import Union, List
|
||||
import pandas as pd
|
||||
|
||||
from marl_factory_grid.utils.helpers import IGNORED_DF_COLUMNS
|
||||
from marl_factory_grid.utils.plotting.plotting import prepare_plot
|
||||
from marl_factory_grid.utils.plotting.plotting_utils import prepare_plot
|
||||
|
||||
MODEL_MAP = None
|
||||
|
||||
|
||||
def plot_single_run(run_path: Union[str, PathLike], use_tex: bool = False, column_keys=None):
|
||||
run_path = Path(run_path)
|
||||
df_list = list()
|
||||
if run_path.is_dir():
|
||||
monitor_file = next(run_path.glob('*monitor*.pick'))
|
||||
elif run_path.exists() and run_path.is_file():
|
||||
monitor_file = run_path
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
with monitor_file.open('rb') as f:
|
||||
monitor_df = pickle.load(f)
|
||||
|
||||
monitor_df = monitor_df.fillna(0)
|
||||
df_list.append(monitor_df)
|
||||
|
||||
df = pd.concat(df_list, ignore_index=True)
|
||||
df = df.fillna(0).rename(columns={'episode': 'Episode'}).sort_values(['Episode'])
|
||||
if column_keys is not None:
|
||||
columns = [col for col in column_keys if col in df.columns]
|
||||
else:
|
||||
columns = [col for col in df.columns if col not in IGNORED_DF_COLUMNS]
|
||||
|
||||
roll_n = 50
|
||||
|
||||
non_overlapp_window = df.groupby(['Episode']).rolling(roll_n, min_periods=1).mean()
|
||||
|
||||
df_melted = df[columns + ['Episode']].reset_index().melt(
|
||||
id_vars=['Episode'], value_vars=columns, var_name="Measurement", value_name="Score"
|
||||
)
|
||||
|
||||
if df_melted['Episode'].max() > 800:
|
||||
skip_n = round(df_melted['Episode'].max() * 0.02)
|
||||
df_melted = df_melted[df_melted['Episode'] % skip_n == 0]
|
||||
|
||||
prepare_plot(run_path.parent / f'{run_path.parent.name}_monitor_lineplot.png', df_melted, use_tex=use_tex)
|
||||
print('Plotting done.')
|
||||
|
||||
|
||||
def compare_seed_runs(run_path: Union[str, PathLike], use_tex: bool = False):
|
||||
run_path = Path(run_path)
|
||||
df_list = list()
|
48
marl_factory_grid/utils/plotting/plot_single_runs.py
Normal file
48
marl_factory_grid/utils/plotting/plot_single_runs.py
Normal file
@ -0,0 +1,48 @@
|
||||
import pickle
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from marl_factory_grid.utils.helpers import IGNORED_DF_COLUMNS
|
||||
from marl_factory_grid.utils.plotting.plotting_utils import prepare_plot
|
||||
|
||||
|
||||
def plot_single_run(run_path: Union[str, PathLike], use_tex: bool = False, column_keys=None,
|
||||
file_key: str ='monitor', file_ext: str ='pkl'):
|
||||
run_path = Path(run_path)
|
||||
df_list = list()
|
||||
if run_path.is_dir():
|
||||
monitor_file = next(run_path.glob(f'*{file_key}*.{file_ext}'))
|
||||
elif run_path.exists() and run_path.is_file():
|
||||
monitor_file = run_path
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
with monitor_file.open('rb') as f:
|
||||
monitor_df = pickle.load(f)
|
||||
|
||||
monitor_df = monitor_df.fillna(0)
|
||||
df_list.append(monitor_df)
|
||||
|
||||
df = pd.concat(df_list, ignore_index=True)
|
||||
df = df.fillna(0).rename(columns={'episode': 'Episode'}).sort_values(['Episode'])
|
||||
if column_keys is not None:
|
||||
columns = [col for col in column_keys if col in df.columns]
|
||||
else:
|
||||
columns = [col for col in df.columns if col not in IGNORED_DF_COLUMNS]
|
||||
|
||||
# roll_n = 50
|
||||
# non_overlapp_window = df.groupby(['Episode']).rolling(roll_n, min_periods=1).mean()
|
||||
|
||||
df_melted = df[columns + ['Episode']].reset_index().melt(
|
||||
id_vars=['Episode'], value_vars=columns, var_name="Measurement", value_name="Score"
|
||||
)
|
||||
|
||||
if df_melted['Episode'].max() > 800:
|
||||
skip_n = round(df_melted['Episode'].max() * 0.02)
|
||||
df_melted = df_melted[df_melted['Episode'] % skip_n == 0]
|
||||
|
||||
prepare_plot(run_path.parent / f'{run_path.parent.name}_monitor_lineplot.png', df_melted, use_tex=use_tex)
|
||||
print('Plotting done.')
|
@ -60,7 +60,7 @@ def prepare_center_double_column_legend(df, hue, style, hue_order):
|
||||
print('Struggling to plot Figure using LaTeX - going back to normal.')
|
||||
plt.close('all')
|
||||
sns.set(rc={'text.usetex': False}, style='whitegrid')
|
||||
fig = plt.figure(figsize=(10, 11))
|
||||
_ = plt.figure(figsize=(10, 11))
|
||||
lineplot = sns.lineplot(data=df, x='Episode', y='Score', hue=hue, style=style,
|
||||
ci=95, palette=PALETTE, hue_order=hue_order, legend=False)
|
||||
# plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0)
|
@ -19,7 +19,7 @@ class RayCaster:
|
||||
return f'{self.__class__.__name__}({self.agent.name})'
|
||||
|
||||
def build_ray_targets(self):
|
||||
north = np.array([0, -1])*self.pomdp_r
|
||||
north = np.array([0, -1]) * self.pomdp_r
|
||||
thetas = [np.deg2rad(deg) for deg in np.linspace(-self.degs // 2, self.degs // 2, self.n_rays)[::-1]]
|
||||
rot_M = [
|
||||
[[math.cos(theta), -math.sin(theta)],
|
||||
@ -39,8 +39,9 @@ class RayCaster:
|
||||
if reset_cache:
|
||||
self._cache_dict = dict()
|
||||
|
||||
for ray in self.get_rays():
|
||||
for ray in self.get_rays(): # Do not check, just trust.
|
||||
rx, ry = ray[0]
|
||||
# self.ray_block_cache(ray[0], lambda: False) We do not do that, because of doors etc...
|
||||
for x, y in ray:
|
||||
cx, cy = x - rx, y - ry
|
||||
|
||||
@ -52,8 +53,9 @@ class RayCaster:
|
||||
diag_hits = all([
|
||||
self.ray_block_cache(
|
||||
key,
|
||||
lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light))
|
||||
for key in ((x, y-cy), (x-cx, y))
|
||||
# lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light)
|
||||
lambda: any(True for e in pos_dict[key] if e.var_is_blocking_light))
|
||||
for key in ((x, y - cy), (x - cx, y))
|
||||
]) if (cx != 0 and cy != 0) else False
|
||||
|
||||
visible += entities_hit if not diag_hits else []
|
||||
@ -75,8 +77,8 @@ class RayCaster:
|
||||
agent = self.agent
|
||||
x_coords = range(agent.x - self.pomdp_r, agent.x + self.pomdp_r + 1)
|
||||
y_coords = range(agent.y - self.pomdp_r, agent.y + self.pomdp_r + 1)
|
||||
outline = list(product(x_coords, [agent.y - self.pomdp_r, agent.y + self.pomdp_r])) \
|
||||
+ list(product([agent.x - self.pomdp_r, agent.x + self.pomdp_r], y_coords))
|
||||
outline = list(product(x_coords, [agent.y - self.pomdp_r, agent.y + self.pomdp_r]))
|
||||
outline += list(product([agent.x - self.pomdp_r, agent.x + self.pomdp_r], y_coords))
|
||||
return outline
|
||||
|
||||
@staticmethod
|
||||
|
@ -31,7 +31,7 @@ class Renderer:
|
||||
|
||||
def __init__(self, lvl_shape: Tuple[int, int] = (16, 16),
|
||||
lvl_padded_shape: Union[Tuple[int, int], None] = None,
|
||||
cell_size: int = 40, fps: int = 7,
|
||||
cell_size: int = 40, fps: int = 7, factor: float = 0.9,
|
||||
grid_lines: bool = True, view_radius: int = 2):
|
||||
# TODO: Customn_assets paths
|
||||
self.grid_h, self.grid_w = lvl_shape
|
||||
@ -45,7 +45,7 @@ class Renderer:
|
||||
self.screen = pygame.display.set_mode(self.screen_size)
|
||||
self.clock = pygame.time.Clock()
|
||||
assets = list(self.ASSETS.rglob('*.png'))
|
||||
self.assets = {path.stem: self.load_asset(str(path), 1) for path in assets}
|
||||
self.assets = {path.stem: self.load_asset(str(path), factor) for path in assets}
|
||||
self.fill_bg()
|
||||
|
||||
now = time.time()
|
||||
@ -110,22 +110,22 @@ class Renderer:
|
||||
pygame.quit()
|
||||
sys.exit()
|
||||
self.fill_bg()
|
||||
blits = deque()
|
||||
for entity in [x for x in entities]:
|
||||
bp = self.blit_params(entity)
|
||||
blits.append(bp)
|
||||
if entity.name.lower() == AGENT:
|
||||
if self.view_radius > 0:
|
||||
vis_rects = self.visibility_rects(bp, entity.aux)
|
||||
blits.extendleft(vis_rects)
|
||||
if entity.state != BLANK:
|
||||
agent_state_blits = self.blit_params(
|
||||
RenderEntity(entity.state, (entity.pos[0] + 0.12, entity.pos[1]), 0.48, SCALE)
|
||||
)
|
||||
textsurface = self.font.render(str(entity.id), False, (0, 0, 0))
|
||||
text_blit = dict(source=textsurface, dest=(bp['dest'].center[0]-.07*self.cell_size,
|
||||
bp['dest'].center[1]))
|
||||
blits += [agent_state_blits, text_blit]
|
||||
# First all others
|
||||
blits = deque(self.blit_params(x) for x in entities if not x.name.lower() == AGENT)
|
||||
# Then Agents, so that agents are rendered on top.
|
||||
for agent in (x for x in entities if x.name.lower() == AGENT):
|
||||
agent_blit = self.blit_params(agent)
|
||||
if self.view_radius > 0:
|
||||
vis_rects = self.visibility_rects(agent_blit, agent.aux)
|
||||
blits.extendleft(vis_rects)
|
||||
if agent.state != BLANK:
|
||||
state_blit = self.blit_params(
|
||||
RenderEntity(agent.state, (agent.pos[0] + 0.12, agent.pos[1]), 0.48, SCALE)
|
||||
)
|
||||
textsurface = self.font.render(str(agent.id), False, (0, 0, 0))
|
||||
text_blit = dict(source=textsurface, dest=(agent_blit['dest'].center[0]-.07*self.cell_size,
|
||||
agent_blit['dest'].center[1]))
|
||||
blits += [agent_blit, state_blit, text_blit]
|
||||
|
||||
for blit in blits:
|
||||
self.screen.blit(**blit)
|
||||
|
@ -1,9 +1,12 @@
|
||||
from typing import Union
|
||||
from dataclasses import dataclass
|
||||
|
||||
from marl_factory_grid.environment.entity.object import Object
|
||||
|
||||
TYPE_VALUE = 'value'
|
||||
TYPE_REWARD = 'reward'
|
||||
types = [TYPE_VALUE, TYPE_REWARD]
|
||||
TYPES = [TYPE_VALUE, TYPE_REWARD]
|
||||
|
||||
|
||||
@dataclass
|
||||
class InfoObject:
|
||||
@ -18,17 +21,21 @@ class Result:
|
||||
validity: bool
|
||||
reward: Union[float, None] = None
|
||||
value: Union[float, None] = None
|
||||
entity: None = None
|
||||
entity: Object = None
|
||||
|
||||
def get_infos(self):
|
||||
n = self.entity.name if self.entity is not None else "Global"
|
||||
return [InfoObject(identifier=f'{n}_{self.identifier}_{t}',
|
||||
val_type=t, value=self.__getattribute__(t)) for t in types
|
||||
# Return multiple Info Dicts
|
||||
return [InfoObject(identifier=f'{n}_{self.identifier}',
|
||||
val_type=t, value=self.__getattribute__(t)) for t in TYPES
|
||||
if self.__getattribute__(t) is not None]
|
||||
|
||||
def __repr__(self):
|
||||
valid = "not " if not self.validity else ""
|
||||
return f'{self.__class__.__name__}({self.identifier.capitalize()} {valid}valid: {self.reward})'
|
||||
reward = f" | Reward: {self.reward}" if self.reward is not None else ""
|
||||
value = f" | Value: {self.value}" if self.value is not None else ""
|
||||
entity = f" | by: {self.entity.name}" if self.entity is not None else ""
|
||||
return f'{self.__class__.__name__}({self.identifier.capitalize()} {valid}valid{reward}{value}{entity})'
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -1,9 +1,12 @@
|
||||
from typing import List, Dict, Tuple
|
||||
from itertools import islice
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.environment.entity.entity import Entity
|
||||
from marl_factory_grid.environment.rules import Rule
|
||||
from marl_factory_grid.utils.results import Result, DoneResult
|
||||
from marl_factory_grid.environment.tests import Test
|
||||
from marl_factory_grid.utils.results import Result
|
||||
|
||||
@ -60,7 +63,8 @@ class Gamestate(object):
|
||||
def moving_entites(self):
|
||||
return [y for x in self.entities for y in x if x.var_can_move]
|
||||
|
||||
def __init__(self, entities, agents_conf, rules: [Rule], tests: [Test], env_seed=69, verbose=False):
|
||||
def __init__(self, entities, agents_conf, rules: List[Rule], tests: [Test], lvl_shape, env_seed=69, verbose=False):
|
||||
self.lvl_shape = lvl_shape
|
||||
self.entities = entities
|
||||
self.curr_step = 0
|
||||
self.curr_actions = None
|
||||
@ -82,7 +86,52 @@ class Gamestate(object):
|
||||
def __repr__(self):
|
||||
return f'{self.__class__.__name__}({len(self.entities)} Entitites @ Step {self.curr_step})'
|
||||
|
||||
def tick(self, actions) -> List[Result]:
|
||||
@property
|
||||
def random_free_position(self) -> (int, int):
|
||||
"""
|
||||
Returns a single **free** position (x, y), which is **free** for spawning or walking.
|
||||
No Entity at this position posses *var_is_blocking_pos* or *var_can_collide*.
|
||||
|
||||
:return: Single **free** position.
|
||||
"""
|
||||
return self.get_n_random_free_positions(1)[0]
|
||||
|
||||
def get_n_random_free_positions(self, n) -> list[tuple[int, int]]:
|
||||
"""
|
||||
Returns a list of *n* **free** positions [(x, y), ... ], which are **free** for spawning or walking.
|
||||
No Entity at this position posses *var_is_blocking_pos* or *var_can_collide*.
|
||||
|
||||
:return: List of n **free** position.
|
||||
"""
|
||||
return list(islice(self.entities.free_positions_generator, n))
|
||||
|
||||
@property
|
||||
def random_position(self) -> (int, int):
|
||||
"""
|
||||
Returns a single available position (x, y), ignores all entity attributes.
|
||||
|
||||
:return: Single random position.
|
||||
"""
|
||||
return self.get_n_random_positions(1)[0]
|
||||
|
||||
def get_n_random_positions(self, n) -> list[tuple[int, int]]:
|
||||
"""
|
||||
Returns a list of *n* available positions [(x, y), ... ], ignores all entity attributes.
|
||||
|
||||
:return: List of n random positions.
|
||||
"""
|
||||
return list(islice(self.entities.floorlist, n))
|
||||
|
||||
def tick(self, actions) -> list[Result]:
|
||||
"""
|
||||
Performs a single **Gamestate Tick**by calling the inner rule hooks in sequential order.
|
||||
- tick_pre_step_all: Things to do before the agents do their actions. Statechange, Moving, Spawning etc...
|
||||
- agent tick: Agents do their actions.
|
||||
- tick_step_all: Things to do after the agents did their actions. Statechange, Moving, Spawning etc...
|
||||
- tick_post_step_all: Things to do at the very end of each step. Counting, Reward calculations etc...
|
||||
|
||||
:return: List of *Result*-objects.
|
||||
"""
|
||||
results = list()
|
||||
test_results = list()
|
||||
self.curr_step += 1
|
||||
@ -112,11 +161,23 @@ class Gamestate(object):
|
||||
|
||||
return results
|
||||
|
||||
def print(self, string):
|
||||
def print(self, string) -> None:
|
||||
"""
|
||||
When *verbose* is active, print stuff.
|
||||
|
||||
:param string: *String* to print.
|
||||
:type string: str
|
||||
:return: Nothing
|
||||
"""
|
||||
if self.verbose:
|
||||
print(string)
|
||||
|
||||
def check_done(self):
|
||||
def check_done(self) -> List[DoneResult]:
|
||||
"""
|
||||
Iterate all **Rules** that override tehe *on_ckeck_done* hook.
|
||||
|
||||
:return: List of Results
|
||||
"""
|
||||
results = list()
|
||||
for rule in self.rules:
|
||||
if on_check_done_result := rule.on_check_done(self):
|
||||
@ -124,24 +185,47 @@ class Gamestate(object):
|
||||
return results
|
||||
|
||||
def get_all_pos_with_collisions(self) -> List[Tuple[(int, int)]]:
|
||||
positions = [pos for pos, entity_list_for_position in self.entities.pos_dict.items()
|
||||
if any([e.var_can_collide for e in entity_list_for_position])]
|
||||
"""
|
||||
Returns a list positions [(x, y), ... ] on which collisions occur. This does not include agents,
|
||||
that were unable to move because their target direction was blocked, also a form of collision.
|
||||
|
||||
:return: List of positions.
|
||||
"""
|
||||
positions = [pos for pos, entities in self.entities.pos_dict.items() if
|
||||
len(entities) >= 2 and (len([e for e in entities if e.var_can_collide]) >= 2)
|
||||
]
|
||||
return positions
|
||||
|
||||
def check_move_validity(self, moving_entity, position):
|
||||
if moving_entity.pos != position and not any(
|
||||
entity.var_is_blocking_pos for entity in self.entities.pos_dict[position]) and not (
|
||||
moving_entity.var_is_blocking_pos and self.entities.is_occupied(position)):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
def check_move_validity(self, moving_entity: Entity, target_position: (int, int)) -> bool:
|
||||
"""
|
||||
Whether it is safe to move to the target positions and moving entity does not introduce a blocking attribute,
|
||||
when position is allready occupied.
|
||||
|
||||
def check_pos_validity(self, position):
|
||||
if not any(entity.var_is_blocking_pos for entity in self.entities.pos_dict[position]):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
:param moving_entity: Entity
|
||||
:param target_position: pos
|
||||
:return: Safe to move to
|
||||
"""
|
||||
|
||||
is_not_blocked = self.check_pos_validity(target_position)
|
||||
will_not_block_others = moving_entity.var_is_blocking_pos and self.entities.is_occupied(target_position)
|
||||
|
||||
if moving_entity.pos != target_position and is_not_blocked and not will_not_block_others:
|
||||
return c.VALID
|
||||
else:
|
||||
return c.NOT_VALID
|
||||
|
||||
def check_pos_validity(self, pos: (int, int)) -> bool:
|
||||
"""
|
||||
Check if *pos* is a valid position to move or spawn to.
|
||||
|
||||
:param pos: position to check
|
||||
:return: Wheter pos is a valid target.
|
||||
"""
|
||||
|
||||
if not any(e.var_is_blocking_pos for e in self.entities.pos_dict[pos]) and pos in self.entities.floorlist:
|
||||
return c.VALID
|
||||
else:
|
||||
return c.NOT_VALID
|
||||
|
||||
class StepTests:
|
||||
def __init__(self, *args):
|
||||
|
@ -28,7 +28,9 @@ class ConfigExplainer:
|
||||
|
||||
def explain_module(self, class_to_explain):
|
||||
parameters = inspect.signature(class_to_explain).parameters
|
||||
explained = {class_to_explain.__name__: {key: val.default for key, val in parameters.items() if key not in EXCLUDED}}
|
||||
explained = {class_to_explain.__name__:
|
||||
{key: val.default for key, val in parameters.items() if key not in EXCLUDED}
|
||||
}
|
||||
return explained
|
||||
|
||||
def _load_and_compare(self, compare_class, paths):
|
||||
@ -135,4 +137,3 @@ if __name__ == '__main__':
|
||||
ce.get_observations()
|
||||
ce.get_assets()
|
||||
all_conf = ce.get_all()
|
||||
print()
|
||||
|
@ -52,3 +52,6 @@ class Floor:
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.name)
|
||||
|
||||
def __repr__(self):
|
||||
return f"Floor{self.pos}"
|
||||
|
Reference in New Issue
Block a user