new rules, new spawn logic, small fixes, default and narrow corridor debugged

This commit is contained in:
Steffen Illium
2023-11-09 17:50:20 +01:00
parent 9b9c6e0385
commit 06a5130b25
67 changed files with 768 additions and 921 deletions

View File

@@ -0,0 +1,3 @@
from . import helpers as h
from . import helpers
from .results import Result, DoneResult, ActionResult, TickResult

View File

@@ -1,28 +1,24 @@
import ast
from collections import defaultdict
from os import PathLike
from pathlib import Path
from typing import Union
from typing import Union, List
import yaml
from marl_factory_grid.environment.groups.agents import Agents
from marl_factory_grid.environment.entity.agent import Agent
from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.utils.helpers import locate_and_import_class
from marl_factory_grid.environment.constants import DEFAULT_PATH, MODULE_PATH
from marl_factory_grid.environment import constants as c
DEFAULT_PATH = 'environment'
MODULE_PATH = 'modules'
class FactoryConfigParser(object):
default_entites = []
default_rules = ['MaxStepsReached', 'Collision']
default_rules = ['DoneAtMaxStepsReached', 'WatchCollision']
default_actions = [c.MOVE8, c.NOOP]
default_observations = [c.WALLS, c.AGENT]
def __init__(self, config_path, custom_modules_path: Union[None, PathLike] = None):
def __init__(self, config_path, custom_modules_path: Union[PathLike] = None):
self.config_path = Path(config_path)
self.custom_modules_path = Path(custom_modules_path) if custom_modules_path is not None else custom_modules_path
self.config = yaml.safe_load(self.config_path.open())
@@ -46,6 +42,10 @@ class FactoryConfigParser(object):
def rules(self):
return self.config['Rules']
@property
def tests(self):
return self.config.get('Tests', [])
@property
def agents(self):
return self.config['Agents']
@@ -61,7 +61,6 @@ class FactoryConfigParser(object):
return self.config[item]
def load_entities(self):
# entites = Entities()
entity_classes = dict()
entities = []
if c.DEFAULTS in self.entities:
@@ -69,28 +68,40 @@ class FactoryConfigParser(object):
entities.extend(x for x in self.entities if x != c.DEFAULTS)
for entity in entities:
e1 = e2 = e3 = None
try:
folder_path = Path(__file__).parent.parent / DEFAULT_PATH
entity_class = locate_and_import_class(entity, folder_path)
except AttributeError as e1:
except AttributeError as e:
e1 = e
try:
folder_path = Path(__file__).parent.parent / MODULE_PATH
entity_class = locate_and_import_class(entity, folder_path)
except AttributeError as e2:
try:
folder_path = self.custom_modules_path
entity_class = locate_and_import_class(entity, folder_path)
except AttributeError as e3:
ents = [y for x in [e1.argss[1], e2.argss[1], e3.argss[1]] for y in x]
print('### Error ### Error ### Error ### Error ### Error ###')
print()
print(f'Class "{entity}" was not found in "{folder_path.name}"')
print('Possible Entitys are:', str(ents))
print()
print('Goodbye')
print()
exit()
# raise AttributeError(e1.argss[0], e2.argss[0], e3.argss[0], 'Possible Entitys are:', str(ents))
module_path = Path(__file__).parent.parent / MODULE_PATH
entity_class = locate_and_import_class(entity, module_path)
except AttributeError as e:
e2 = e
if self.custom_modules_path:
try:
entity_class = locate_and_import_class(entity, self.custom_modules_path)
except AttributeError as e:
e3 = e
pass
if (e1 and e2) or e3:
ents = [y for x in [e1, e2, e3] if x is not None for y in x.args[1]]
print('##############################################################')
print('### Error ### Error ### Error ### Error ### Error ###')
print('##############################################################')
print(f'Class "{entity}" was not found in "{module_path.name}"')
print(f'Class "{entity}" was not found in "{folder_path.name}"')
print('##############################################################')
if self.custom_modules_path:
print(f'Class "{entity}" was not found in "{self.custom_modules_path}"')
print('Possible Entitys are:', str(ents))
print('##############################################################')
print('Goodbye')
print('##############################################################')
print('### Error ### Error ### Error ### Error ### Error ###')
print('##############################################################')
exit(-99999)
entity_kwargs = self.entities.get(entity, {})
entity_symbol = entity_class.symbol if hasattr(entity_class, 'symbol') else None
@@ -128,31 +139,86 @@ class FactoryConfigParser(object):
observations.extend(self.default_observations)
observations.extend(x for x in self.agents[name]['Observations'] if x != c.DEFAULTS)
positions = [ast.literal_eval(x) for x in self.agents[name].get('Positions', [])]
parsed_agents_conf[name] = dict(actions=parsed_actions, observations=observations, positions=positions)
other_kwargs = {k: v for k, v in self.agents[name].items() if k not in
['Actions', 'Observations', 'Positions']}
parsed_agents_conf[name] = dict(
actions=parsed_actions, observations=observations, positions=positions, other=other_kwargs
)
return parsed_agents_conf
def load_rules(self):
# entites = Entities()
rules_classes = dict()
rules = []
def load_env_rules(self) -> List[Rule]:
rules = self.rules.copy()
if c.DEFAULTS in self.rules:
for rule in self.default_rules:
if rule not in rules:
rules.append(rule)
rules.extend(x for x in self.rules if x != c.DEFAULTS)
rules.append({rule: {}})
for rule in rules:
return self._load_smth(rules, Rule)
def load_env_tests(self) -> List[Rule]:
return self._load_smth(self.tests, None) # Test
def _load_smth(self, config, class_obj):
rules = list()
rules_names = list()
for rule in config:
e1 = e2 = e3 = None
try:
folder_path = (Path(__file__).parent.parent / DEFAULT_PATH)
rule_class = locate_and_import_class(rule, folder_path)
except AttributeError:
except AttributeError as e:
e1 = e
try:
folder_path = (Path(__file__).parent.parent / MODULE_PATH)
rule_class = locate_and_import_class(rule, folder_path)
module_path = (Path(__file__).parent.parent / MODULE_PATH)
rule_class = locate_and_import_class(rule, module_path)
except AttributeError as e:
e2 = e
if self.custom_modules_path:
try:
rule_class = locate_and_import_class(rule, self.custom_modules_path)
except AttributeError as e:
e3 = e
pass
if (e1 and e2) or e3:
ents = [y for x in [e1, e2, e3] if x is not None for y in x.args[1]]
print('### Error ### Error ### Error ### Error ### Error ###')
print('')
print(f'Class "{rule}" was not found in "{module_path.name}"')
print(f'Class "{rule}" was not found in "{folder_path.name}"')
if self.custom_modules_path:
print(f'Class "{rule}" was not found in "{self.custom_modules_path}"')
print('Possible Entitys are:', str(ents))
print('')
print('Goodbye')
print('')
exit(-99999)
if issubclass(rule_class, class_obj):
rule_kwargs = config.get(rule, {})
rules.append(rule_class(**(rule_kwargs or {})))
return rules
def load_entity_spawn_rules(self, entities) -> List[Rule]:
rules = list()
rules_dicts = list()
for e in entities:
try:
if spawn_rule := e.spawn_rule:
rules_dicts.append(spawn_rule)
except AttributeError:
pass
for rule_dict in rules_dicts:
for rule_name, rule_kwargs in rule_dict.items():
try:
folder_path = (Path(__file__).parent.parent / DEFAULT_PATH)
rule_class = locate_and_import_class(rule_name, folder_path)
except AttributeError:
rule_class = locate_and_import_class(rule, self.custom_modules_path)
# Fixme This check does not work!
# assert isinstance(rule_class, Rule), f'{rule_class.__name__} is no valid "Rule".'
rule_kwargs = self.rules.get(rule, {})
rules_classes.update({rule: {'class': rule_class, 'kwargs': rule_kwargs}})
return rules_classes
try:
folder_path = (Path(__file__).parent.parent / MODULE_PATH)
rule_class = locate_and_import_class(rule_name, folder_path)
except AttributeError:
rule_class = locate_and_import_class(rule_name, self.custom_modules_path)
rules.append(rule_class(**rule_kwargs))
return rules

View File

@@ -2,7 +2,7 @@ import importlib
from collections import defaultdict
from pathlib import PurePath, Path
from typing import Union, Dict, List
from typing import Union, Dict, List, Iterable, Callable
import numpy as np
from numpy.typing import ArrayLike
@@ -222,7 +222,7 @@ def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''):
mod = importlib.import_module('.'.join(module_parts))
all_found_modules.extend([x for x in dir(mod) if (not(x.startswith('__') or len(x) <= 2) and x.istitle())
and x not in ['Entity', 'NamedTuple', 'List', 'Rule', 'Union',
'TickResult', 'ActionResult', 'Action', 'Agent', 'BoundEntityMixin',
'TickResult', 'ActionResult', 'Action', 'Agent',
'RenderEntity', 'TemplateRule', 'Objects', 'PositionMixin',
'IsBoundMixin', 'EnvObject', 'EnvObjects', 'Dict', 'Any'
]])
@@ -240,7 +240,13 @@ def add_bound_name(name_str, bound_e):
def add_pos_name(name_str, bound_e):
if bound_e.var_has_position:
return f'{name_str}({bound_e.pos})'
return f'{name_str}@{bound_e.pos}'
return name_str
def get_first(iterable: Iterable, filter_by: Callable[[any], bool] = lambda _: True):
return next((x for x in iterable if filter_by(x)), None)
def get_first_index(iterable: Iterable, filter_by: Callable[[any], bool] = lambda _: True):
return next((idx for idx, x in enumerate(iterable) if filter_by(x)), None)

View File

@@ -47,6 +47,7 @@ class LevelParser(object):
# All other
for es_name in self.e_p_dict:
e_class, e_kwargs = self.e_p_dict[es_name]['class'], self.e_p_dict[es_name]['kwargs']
e_kwargs = e_kwargs if e_kwargs else {}
if hasattr(e_class, 'symbol') and e_class.symbol is not None:
symbols = e_class.symbol

View File

@@ -1,17 +1,17 @@
import math
import re
from collections import defaultdict
from itertools import product
from typing import Dict, List
import numpy as np
from numba import njit
from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.entity.object import _Object
from marl_factory_grid.environment.groups.utils import Combined
import marl_factory_grid.utils.helpers as h
from marl_factory_grid.utils.states import Gamestate
from marl_factory_grid.utils.utility_classes import Floor
from marl_factory_grid.utils.ray_caster import RayCaster
from marl_factory_grid.utils.states import Gamestate
from marl_factory_grid.utils import helpers as h
class OBSBuilder(object):
@@ -77,11 +77,13 @@ class OBSBuilder(object):
def place_entity_in_observation(self, obs_array, agent, e):
x, y = (e.x - agent.x) + self.pomdp_r, (e.y - agent.y) + self.pomdp_r
try:
obs_array[x, y] += e.encoding
except IndexError:
# Seemded to be visible but is out of range
pass
if not min([y, x]) < 0:
try:
obs_array[x, y] += e.encoding
except IndexError:
# Seemded to be visible but is out of range
pass
pass
def build_for_agent(self, agent, state) -> (List[str], np.ndarray):
assert self._curr_env_step == state.curr_step, (
@@ -121,18 +123,24 @@ class OBSBuilder(object):
e = self.all_obs[l_name]
except KeyError:
try:
# Look for bound entity names!
pattern = re.compile(f'{re.escape(l_name)}(.*){re.escape(agent.name)}')
name = next((x for x in self.all_obs if pattern.search(x)), None)
# Look for bound entity REPRs!
pattern = re.compile(f'{re.escape(l_name)}'
f'{re.escape("[")}(.*){re.escape("]")}'
f'{re.escape("(")}{re.escape(agent.name)}{re.escape(")")}')
name = next((key for key, val in self.all_obs.items()
if pattern.search(str(val)) and isinstance(val, _Object)), None)
e = self.all_obs[name]
except KeyError:
try:
e = next(v for k, v in self.all_obs.items() if l_name in k and agent.name in k)
except StopIteration:
raise KeyError(
f'Check for spelling errors! \n '
f'No combination of "{l_name} and {agent.name}" could not be found in:\n '
f'{list(dict(self.all_obs).keys())}')
print(f'# Check for spelling errors!')
print(f'# No combination of "{l_name}" and "{agent.name}" could not be found in:')
print(f'# {list(dict(self.all_obs).keys())}')
print('#')
print('# exiting...')
print('#')
exit(-99999)
try:
positional = e.var_has_position
@@ -161,15 +169,14 @@ class OBSBuilder(object):
try:
light_map = np.zeros(self.obs_shape)
visible_floor = self.ray_caster[agent.name].visible_entities(self._floortiles, reset_cache=False)
if self.pomdp_r:
for f in set(visible_floor):
self.place_entity_in_observation(light_map, agent, f)
else:
for f in set(visible_floor):
light_map[f.x, f.y] += f.encoding
for f in set(visible_floor):
self.place_entity_in_observation(light_map, agent, f)
# else:
# for f in set(visible_floor):
# light_map[f.x, f.y] += f.encoding
self.curr_lightmaps[agent.name] = light_map
except (KeyError, ValueError):
print()
pass
return obs, self.obs_layers[agent.name]
@@ -185,7 +192,7 @@ class OBSBuilder(object):
for obs_str in agent.observations:
if isinstance(obs_str, dict):
obs_str, vals = next(obs_str.items().__iter__())
obs_str, vals = h.get_first(obs_str.items())
else:
vals = None
if obs_str == c.SELF:
@@ -214,129 +221,3 @@ class OBSBuilder(object):
obs_layers.append(obs_str)
self.obs_layers[agent.name] = obs_layers
self.curr_lightmaps[agent.name] = np.zeros(self.obs_shape)
class RayCaster:
def __init__(self, agent, pomdp_r, degs=360):
self.agent = agent
self.pomdp_r = pomdp_r
self.n_rays = (self.pomdp_r + 1) * 8
self.degs = degs
self.ray_targets = self.build_ray_targets()
self.obs_shape_cube = np.array([self.pomdp_r, self.pomdp_r])
self._cache_dict = {}
def __repr__(self):
return f'{self.__class__.__name__}({self.agent.name})'
def build_ray_targets(self):
north = np.array([0, -1]) * self.pomdp_r
thetas = [np.deg2rad(deg) for deg in np.linspace(-self.degs // 2, self.degs // 2, self.n_rays)[::-1]]
rot_M = [
[[math.cos(theta), -math.sin(theta)],
[math.sin(theta), math.cos(theta)]] for theta in thetas
]
rot_M = np.stack(rot_M, 0)
rot_M = np.unique(np.round(rot_M @ north), axis=0)
return rot_M.astype(int)
def ray_block_cache(self, key, callback):
if key not in self._cache_dict:
self._cache_dict[key] = callback()
return self._cache_dict[key]
def visible_entities(self, pos_dict, reset_cache=True):
visible = list()
if reset_cache:
self._cache_dict = {}
for ray in self.get_rays():
rx, ry = ray[0]
for x, y in ray:
cx, cy = x - rx, y - ry
entities_hit = pos_dict[(x, y)]
hits = self.ray_block_cache((x, y),
lambda: any(True for e in entities_hit if e.var_is_blocking_light)
)
diag_hits = all([
self.ray_block_cache(
key,
lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light) and bool(
pos_dict[key]))
for key in ((x, y - cy), (x - cx, y))
]) if (cx != 0 and cy != 0) else False
visible += entities_hit if not diag_hits else []
if hits or diag_hits:
break
rx, ry = x, y
return visible
def get_rays(self):
a_pos = self.agent.pos
outline = self.ray_targets + a_pos
return self.bresenham_loop(a_pos, outline)
# todo do this once and cache the points!
def get_fov_outline(self) -> np.ndarray:
return self.ray_targets + self.agent.pos
def get_square_outline(self):
agent = self.agent
x_coords = range(agent.x - self.pomdp_r, agent.x + self.pomdp_r + 1)
y_coords = range(agent.y - self.pomdp_r, agent.y + self.pomdp_r + 1)
outline = list(product(x_coords, [agent.y - self.pomdp_r, agent.y + self.pomdp_r])) \
+ list(product([agent.x - self.pomdp_r, agent.x + self.pomdp_r], y_coords))
return outline
@staticmethod
@njit
def bresenham_loop(a_pos, points):
results = []
for end in points:
x1, y1 = a_pos
x2, y2 = end
dx = x2 - x1
dy = y2 - y1
# Determine how steep the line is
is_steep = abs(dy) > abs(dx)
# Rotate line
if is_steep:
x1, y1 = y1, x1
x2, y2 = y2, x2
# Swap start and end points if necessary and store swap state
swapped = False
if x1 > x2:
x1, x2 = x2, x1
y1, y2 = y2, y1
swapped = True
# Recalculate differentials
dx = x2 - x1
dy = y2 - y1
# Calculate error
error = int(dx / 2.0)
ystep = 1 if y1 < y2 else -1
# Iterate over bounding box generating points between start and end
y = y1
points = []
for x in range(int(x1), int(x2) + 1):
coord = [y, x] if is_steep else [x, y]
points.append(coord)
error -= abs(dy)
if error < 0:
y += ystep
error += dx
# Reverse the list if the coordinates were swapped
if swapped:
points.reverse()
results.append(points)
return results

View File

@@ -39,8 +39,9 @@ class RayCaster:
if reset_cache:
self._cache_dict = dict()
for ray in self.get_rays():
for ray in self.get_rays(): # Do not check, just trust.
rx, ry = ray[0]
# self.ray_block_cache(ray[0], lambda: False) We do not do that, because of doors etc...
for x, y in ray:
cx, cy = x - rx, y - ry
@@ -52,7 +53,8 @@ class RayCaster:
diag_hits = all([
self.ray_block_cache(
key,
lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light))
lambda: any(True for e in pos_dict[key] if e.var_is_blocking_light))
# lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light))
for key in ((x, y-cy), (x-cx, y))
]) if (cx != 0 and cy != 0) else False

View File

@@ -31,7 +31,7 @@ class Renderer:
def __init__(self, lvl_shape: Tuple[int, int] = (16, 16),
lvl_padded_shape: Union[Tuple[int, int], None] = None,
cell_size: int = 40, fps: int = 7,
cell_size: int = 40, fps: int = 7, factor: float = 0.9,
grid_lines: bool = True, view_radius: int = 2):
# TODO: Customn_assets paths
self.grid_h, self.grid_w = lvl_shape
@@ -45,7 +45,7 @@ class Renderer:
self.screen = pygame.display.set_mode(self.screen_size)
self.clock = pygame.time.Clock()
assets = list(self.ASSETS.rglob('*.png'))
self.assets = {path.stem: self.load_asset(str(path), 1) for path in assets}
self.assets = {path.stem: self.load_asset(str(path), factor) for path in assets}
self.fill_bg()
now = time.time()
@@ -110,22 +110,22 @@ class Renderer:
pygame.quit()
sys.exit()
self.fill_bg()
blits = deque()
for entity in [x for x in entities]:
bp = self.blit_params(entity)
blits.append(bp)
if entity.name.lower() == AGENT:
if self.view_radius > 0:
vis_rects = self.visibility_rects(bp, entity.aux)
blits.extendleft(vis_rects)
if entity.state != BLANK:
agent_state_blits = self.blit_params(
RenderEntity(entity.state, (entity.pos[0] + 0.12, entity.pos[1]), 0.48, SCALE)
)
textsurface = self.font.render(str(entity.id), False, (0, 0, 0))
text_blit = dict(source=textsurface, dest=(bp['dest'].center[0]-.07*self.cell_size,
bp['dest'].center[1]))
blits += [agent_state_blits, text_blit]
# First all others
blits = deque(self.blit_params(x) for x in entities if not x.name.lower() == AGENT)
# Then Agents, so that agents are rendered on top.
for agent in (x for x in entities if x.name.lower() == AGENT):
agent_blit = self.blit_params(agent)
if self.view_radius > 0:
vis_rects = self.visibility_rects(agent_blit, agent.aux)
blits.extendleft(vis_rects)
if agent.state != BLANK:
state_blit = self.blit_params(
RenderEntity(agent.state, (agent.pos[0] + 0.12, agent.pos[1]), 0.48, SCALE)
)
textsurface = self.font.render(str(agent.id), False, (0, 0, 0))
text_blit = dict(source=textsurface, dest=(agent_blit['dest'].center[0]-.07*self.cell_size,
agent_blit['dest'].center[1]))
blits += [agent_blit, state_blit, text_blit]
for blit in blits:
self.screen.blit(**blit)

View File

@@ -28,7 +28,10 @@ class Result:
def __repr__(self):
valid = "not " if not self.validity else ""
return f'{self.__class__.__name__}({self.identifier.capitalize()} {valid}valid: {self.reward})'
reward = f" | Reward: {self.reward}" if self.reward is not None else ""
value = f" | Value: {self.value}" if self.value is not None else ""
entity = f" | by: {self.entity.name}" if self.entity is not None else ""
return f'{self.__class__.__name__}({self.identifier.capitalize()} {valid}valid{reward}{value})'
@dataclass

View File

@@ -1,3 +1,4 @@
from itertools import islice
from typing import List, Dict, Tuple
import numpy as np
@@ -59,14 +60,15 @@ class Gamestate(object):
def moving_entites(self):
return [y for x in self.entities for y in x if x.var_can_move]
def __init__(self, entities, agents_conf, rules: Dict[str, dict], env_seed=69, verbose=False):
def __init__(self, entities, agents_conf, rules: List[Rule], lvl_shape, env_seed=69, verbose=False):
self.lvl_shape = lvl_shape
self.entities = entities
self.curr_step = 0
self.curr_actions = None
self.agents_conf = agents_conf
self.verbose = verbose
self.rng = np.random.default_rng(env_seed)
self.rules = StepRules(*(v['class'](**v['kwargs']) for v in rules.values()))
self.rules = StepRules(*rules)
def __getitem__(self, item):
return self.entities[item]
@@ -80,6 +82,13 @@ class Gamestate(object):
def __repr__(self):
return f'{self.__class__.__name__}({len(self.entities)} Entitites @ Step {self.curr_step})'
@property
def random_free_position(self):
return self.get_n_random_free_positions(1)[0]
def get_n_random_free_positions(self, n):
return list(islice(self.entities.free_positions_generator, n))
def tick(self, actions) -> List[Result]:
results = list()
self.curr_step += 1
@@ -115,8 +124,7 @@ class Gamestate(object):
return results
def get_all_pos_with_collisions(self) -> List[Tuple[(int, int)]]:
positions = [pos for pos, entity_list_for_position in self.entities.pos_dict.items()
if any([e.var_can_collide for e in entity_list_for_position])]
positions = [pos for pos, entities in self.entities.pos_dict.items() if len(entities) >= 2 and (len([e for e in entities if e.var_can_collide]) >= 2)]
return positions
def check_move_validity(self, moving_entity, position):

View File

@@ -135,4 +135,3 @@ if __name__ == '__main__':
ce.get_observations()
ce.get_assets()
all_conf = ce.get_all()
print()

View File

@@ -52,3 +52,6 @@ class Floor:
def __hash__(self):
return hash(self.name)
def __repr__(self):
return f"Floor{self.pos}"