This commit is contained in:
Steffen Illium
2023-07-06 12:01:25 +02:00
parent dc134d71e0
commit 836495a884
72 changed files with 742 additions and 298 deletions

View File

@ -18,11 +18,11 @@ class FactoryConfigParser(object):
default_entites = []
default_rules = ['MaxStepsReached', 'Collision']
default_actions = [c.MOVE8, c.NOOP]
default_observations = [c.WALLS, c.AGENTS]
default_observations = [c.WALLS, c.AGENT]
def __init__(self, config_path, custom_modules_path: Union[None, PathLike] = None):
self.config_path = Path(config_path)
self.custom_modules_path = Path(config_path) if custom_modules_path is not None else custom_modules_path
self.custom_modules_path = Path(custom_modules_path) if custom_modules_path is not None else custom_modules_path
self.config = yaml.safe_load(self.config_path.open())
self.do_record = False
@ -69,12 +69,20 @@ class FactoryConfigParser(object):
for entity in entities:
try:
folder_path = MODULE_PATH if entity not in self.default_entites else DEFAULT_PATH
folder_path = (Path(__file__) / '..' / '..' / '..' / folder_path)
entity_class = locate_and_import_class(entity, folder_path)
except AttributeError:
folder_path = self.custom_modules_path
folder_path = Path(__file__).parent.parent / DEFAULT_PATH
entity_class = locate_and_import_class(entity, folder_path)
except AttributeError as e1:
try:
folder_path = Path(__file__).parent.parent / MODULE_PATH
entity_class = locate_and_import_class(entity, folder_path)
except AttributeError as e2:
try:
folder_path = self.custom_modules_path
entity_class = locate_and_import_class(entity, folder_path)
except AttributeError as e3:
ents = [y for x in [e1.argss[1], e2.argss[1], e3.argss[1]] for y in x]
raise AttributeError(e1.argss[0], e2.argss[0], e3.argss[0], 'Possible Entitys are>:', str(ents))
entity_kwargs = self.entities.get(entity, {})
entity_symbol = entity_class.symbol if hasattr(entity_class, 'symbol') else None
entity_classes.update({entity: {'class': entity_class, 'kwargs': entity_kwargs, 'symbol': entity_symbol}})
@ -92,7 +100,7 @@ class FactoryConfigParser(object):
parsed_actions = list()
for action in actions:
folder_path = MODULE_PATH if action not in base_env_actions else DEFAULT_PATH
folder_path = (Path(__file__) / '..' / '..' / '..' / folder_path)
folder_path = Path(__file__).parent.parent / folder_path
try:
class_or_classes = locate_and_import_class(action, folder_path)
except AttributeError:
@ -124,12 +132,15 @@ class FactoryConfigParser(object):
rules.extend(x for x in self.rules if x != c.DEFAULTS)
for rule in rules:
folder_path = MODULE_PATH if rule not in self.default_rules else DEFAULT_PATH
folder_path = (Path(__file__) / '..' / '..' / '..' / folder_path)
try:
folder_path = (Path(__file__).parent.parent / DEFAULT_PATH)
rule_class = locate_and_import_class(rule, folder_path)
except AttributeError:
rule_class = locate_and_import_class(rule, self.custom_modules_path)
try:
folder_path = (Path(__file__).parent.parent / MODULE_PATH)
rule_class = locate_and_import_class(rule, folder_path)
except AttributeError:
rule_class = locate_and_import_class(rule, self.custom_modules_path)
rule_kwargs = self.rules.get(rule, {})
rules_classes.update({rule: {'class': rule_class, 'kwargs': rule_kwargs}})
return rules_classes

View File

@ -176,7 +176,7 @@ def one_hot_level(level, symbol: str):
grid = np.array(level)
binary_grid = np.zeros(grid.shape, dtype=np.int8)
binary_grid[grid == symbol] = c.VALUE_OCCUPIED_CELL
binary_grid[grid == str(symbol)] = c.VALUE_OCCUPIED_CELL
return binary_grid
@ -222,18 +222,15 @@ def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''):
for module_path in module_paths:
module_parts = [x.replace('.py', '') for idx, x in enumerate(module_path.parts) if idx >= package_pos]
mod = importlib.import_module('.'.join(module_parts))
all_found_modules.extend([x for x in dir(mod) if not(x.startswith('__') or len(x) < 2 or x.isupper())
and x not in ['Entity', 'NamedTuple', 'List', 'Rule', 'Union', 'random', 'Floor'
'TickResult', 'ActionResult', 'Action', 'Agent', 'deque',
'BoundEntityMixin', 'RenderEntity', 'TemplateRule', 'defaultdict',
'is_move', 'Objects', 'PositionMixin', 'IsBoundMixin', 'EnvObject',
'EnvObjects', 'Dict', 'locate_and_import_class', 'yaml', 'Any',
'inspect']])
all_found_modules.extend([x for x in dir(mod) if (not(x.startswith('__') or len(x) <= 2) and x.istitle())
and x not in ['Entity', 'NamedTuple', 'List', 'Rule', 'Union', 'Floor'
'TickResult', 'ActionResult', 'Action', 'Agent', 'BoundEntityMixin',
'RenderEntity', 'TemplateRule', 'Objects', 'PositionMixin',
'IsBoundMixin', 'EnvObject', 'EnvObjects', 'Dict', 'Any'
]])
try:
model_class = mod.__getattribute__(class_name)
return model_class
except AttributeError:
continue
raise AttributeError(f'Class "{class_name}" was not found!!!"\n'
f'Check the {folder_path.name} name.\n'
f'Possible Options are:\n{set(all_found_modules)}')
raise AttributeError(f'Class "{class_name}" was not found in "{folder_path.name}"', list(set(all_found_modules)))

View File

@ -24,31 +24,40 @@ class LevelParser(object):
self.level_shape = level_array.shape
self.size = self.pomdp_r**2 if self.pomdp_r else np.prod(self.level_shape)
def get_coordinates_for_symbol(self, symbol, negate=False):
level_array = h.one_hot_level(self._parsed_level, symbol)
if negate:
return np.argwhere(level_array != c.VALUE_OCCUPIED_CELL)
else:
return np.argwhere(level_array == c.VALUE_OCCUPIED_CELL)
def do_init(self):
entities = Entities()
# Walls
level_array = h.one_hot_level(self._parsed_level, c.SYMBOL_WALL)
walls = Walls.from_coordinates(np.argwhere(level_array == c.VALUE_OCCUPIED_CELL), self.size)
walls = Walls.from_coordinates(self.get_coordinates_for_symbol(c.SYMBOL_WALL), self.size)
entities.add_items({c.WALL: walls})
# Floor
floor = Floors.from_coordinates(np.argwhere(level_array == c.VALUE_FREE_CELL), self.size)
floor = Floors.from_coordinates(self.get_coordinates_for_symbol(c.SYMBOL_WALL, negate=True), self.size)
entities.add_items({c.FLOOR: floor})
# All other
for es_name in self.e_p_dict:
e_class, e_kwargs = self.e_p_dict[es_name]['class'], self.e_p_dict[es_name]['kwargs']
if hasattr(e_class, 'symbol'):
level_array = h.one_hot_level(self._parsed_level, symbol=e_class.symbol)
if np.any(level_array):
e = e_class.from_coordinates(np.argwhere(level_array == c.VALUE_OCCUPIED_CELL).tolist(),
entities[c.FLOOR], self.size, entity_kwargs=e_kwargs
)
else:
raise ValueError(f'No {e_class} (Symbol: {e_class.symbol}) could be found!\n'
f'Check your level file!')
if hasattr(e_class, 'symbol') and e_class.symbol is not None:
symbols = e_class.symbol
if isinstance(symbols, (str, int, float)):
symbols = [symbols]
for symbol in symbols:
level_array = h.one_hot_level(self._parsed_level, symbol=symbol)
if np.any(level_array):
e = e_class.from_coordinates(np.argwhere(level_array == c.VALUE_OCCUPIED_CELL).tolist(),
entities[c.FLOOR], self.size, entity_kwargs=e_kwargs
)
else:
raise ValueError(f'No {e_class} (Symbol: {e_class.symbol}) could be found!\n'
f'Check your level file!')
else:
e = e_class(self.size, **e_kwargs)
entities.add_items({e.name: e})

View File

@ -6,11 +6,10 @@ from typing import Dict, List
import numpy as np
from numba import njit
from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.groups.utils import Combined
from marl_factory_grid.utils.states import Gamestate
from marl_factory_grid.environment import constants as c
class OBSBuilder(object):
@ -111,10 +110,10 @@ class OBSBuilder(object):
e = next(x for x in self.all_obs if l_name in x and agent.name in x)
except StopIteration:
raise KeyError(
f'Check typing!\n{l_name} could not be found in:\n{dict(self.all_obs).keys()}')
f'Check typing! {l_name} could not be found in: {list(dict(self.all_obs).keys())}')
try:
positional = e.has_position
positional = e.var_has_position
except AttributeError:
positional = False
if positional:
@ -172,7 +171,7 @@ class OBSBuilder(object):
obs_layers.append(combined.name)
elif obs_str == c.OTHERS:
obs_layers.extend([x for x in self.all_obs if x != agent.name and x.startswith(f'{c.AGENT}[')])
elif obs_str == c.AGENTS:
elif obs_str == c.AGENT:
obs_layers.extend([x for x in self.all_obs if x.startswith(f'{c.AGENT}[')])
else:
obs_layers.append(obs_str)
@ -222,7 +221,7 @@ class RayCaster:
entities_hit = entities.pos_dict[(x, y)]
hits = self.ray_block_cache(cache_blocking,
(x, y),
lambda: any(e.is_blocking_light for e in entities_hit),
lambda: any(e.var_is_blocking_light for e in entities_hit),
entities)
try:
@ -237,8 +236,8 @@ class RayCaster:
self.ray_block_cache(
cache_blocking,
key,
# lambda: all(False for e in entities.pos_dict[key] if not e.is_blocking_light),
lambda: any(e.is_blocking_light for e in entities.pos_dict[key]),
# lambda: all(False for e in entities.pos_dict[key] if not e.var_is_blocking_light),
lambda: any(e.var_is_blocking_light for e in entities.pos_dict[key]),
entities)
for key in ((x, y-cy), (x-cx, y))
]) if (cx != 0 and cy != 0) else False

View File

@ -27,13 +27,13 @@ class Renderer:
BG_COLOR = (178, 190, 195) # (99, 110, 114)
WHITE = (223, 230, 233) # (200, 200, 200)
AGENT_VIEW_COLOR = (9, 132, 227)
ASSETS = Path(__file__).parent.parent / 'assets'
MODULE_ASSETS = Path(__file__).parent.parent.parent / 'modules'
ASSETS = Path(__file__).parent.parent
def __init__(self, lvl_shape: Tuple[int, int] = (16, 16),
lvl_padded_shape: Union[Tuple[int, int], None] = None,
cell_size: int = 40, fps: int = 7,
grid_lines: bool = True, view_radius: int = 2):
# TODO: Customn_assets paths
self.grid_h, self.grid_w = lvl_shape
self.lvl_padded_shape = lvl_padded_shape if lvl_padded_shape is not None else lvl_shape
self.cell_size = cell_size
@ -44,7 +44,7 @@ class Renderer:
self.screen_size = (self.grid_w*cell_size, self.grid_h*cell_size)
self.screen = pygame.display.set_mode(self.screen_size)
self.clock = pygame.time.Clock()
assets = list(self.ASSETS.rglob('*.png')) + list(self.MODULE_ASSETS.rglob('*.png'))
assets = list(self.ASSETS.rglob('*.png'))
self.assets = {path.stem: self.load_asset(str(path), 1) for path in assets}
self.fill_bg()

View File

@ -1,8 +1,6 @@
from typing import Union
from dataclasses import dataclass
from marl_factory_grid.environment.entity.entity import Entity
TYPE_VALUE = 'value'
TYPE_REWARD = 'reward'
types = [TYPE_VALUE, TYPE_REWARD]
@ -20,7 +18,7 @@ class Result:
validity: bool
reward: Union[float, None] = None
value: Union[float, None] = None
entity: Union[Entity, None] = None
entity: None = None
def get_infos(self):
n = self.entity.name if self.entity is not None else "Global"

View File

@ -2,10 +2,11 @@ from typing import List, Dict
import numpy as np
from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.entity.wall_floor import Floor
from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.utils.results import Result
from marl_factory_grid.environment import constants as c
class StepRules:
@ -26,9 +27,9 @@ class StepRules:
self.rules.append(item)
return True
def do_all_init(self, state):
def do_all_init(self, state, lvl_map):
for rule in self.rules:
if rule_init_printline := rule.on_init(state):
if rule_init_printline := rule.on_init(state, lvl_map):
state.print(rule_init_printline)
return c.VALID
@ -58,7 +59,7 @@ class Gamestate(object):
@property
def moving_entites(self):
return [y for x in self.entities for y in x if x.can_move]
return [y for x in self.entities for y in x if x.var_can_move]
def __init__(self, entitites, rules: Dict[str, dict], env_seed=69, verbose=False):
self.entities = entitites
@ -107,6 +108,6 @@ class Gamestate(object):
def get_all_tiles_with_collisions(self) -> List[Floor]:
tiles = [self[c.FLOOR].by_pos(pos) for pos, e in self.entities.pos_dict.items()
if sum([x.can_collide for x in e]) > 1]
if sum([x.var_can_collide for x in e]) > 1]
# tiles = [x for x in self[c.FLOOR] if len(x.guests_that_can_collide) > 1]
return tiles

View File

@ -15,7 +15,7 @@ ENTITIES = 'Objects'
OBSERVATIONS = 'Observations'
RULES = 'Rule'
ASSETS = 'Assets'
EXCLUDED = ['identifier', 'args', 'kwargs', 'Move',
EXCLUDED = ['identifier', 'args', 'kwargs', 'Move', 'Floor', 'Agent', 'GlobalPositions', 'Walls',
'TemplateRule', 'Entities', 'EnvObjects', 'Zones', ]

View File

@ -1,21 +1,6 @@
import gymnasium as gym
class EnvCombiner(object):
def __init__(self, *envs_cls):
self._env_dict = {env_cls.__name__: env_cls for env_cls in envs_cls}
@staticmethod
def combine_cls(name, *envs_cls):
return type(name, envs_cls, {})
def build(self):
name = f'{"".join([x.lower().replace("factory").capitalize() for x in self._env_dict.keys()])}Factory'
return self.combine_cls(name, tuple(self._env_dict.values()))
class MarlFrameStack(gym.ObservationWrapper):
"""todo @romue404"""
def __init__(self, env):