mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-06-22 11:41:34 +02:00
Machines
This commit is contained in:
@ -18,11 +18,11 @@ class FactoryConfigParser(object):
|
||||
default_entites = []
|
||||
default_rules = ['MaxStepsReached', 'Collision']
|
||||
default_actions = [c.MOVE8, c.NOOP]
|
||||
default_observations = [c.WALLS, c.AGENTS]
|
||||
default_observations = [c.WALLS, c.AGENT]
|
||||
|
||||
def __init__(self, config_path, custom_modules_path: Union[None, PathLike] = None):
|
||||
self.config_path = Path(config_path)
|
||||
self.custom_modules_path = Path(config_path) if custom_modules_path is not None else custom_modules_path
|
||||
self.custom_modules_path = Path(custom_modules_path) if custom_modules_path is not None else custom_modules_path
|
||||
self.config = yaml.safe_load(self.config_path.open())
|
||||
self.do_record = False
|
||||
|
||||
@ -69,12 +69,20 @@ class FactoryConfigParser(object):
|
||||
|
||||
for entity in entities:
|
||||
try:
|
||||
folder_path = MODULE_PATH if entity not in self.default_entites else DEFAULT_PATH
|
||||
folder_path = (Path(__file__) / '..' / '..' / '..' / folder_path)
|
||||
entity_class = locate_and_import_class(entity, folder_path)
|
||||
except AttributeError:
|
||||
folder_path = self.custom_modules_path
|
||||
folder_path = Path(__file__).parent.parent / DEFAULT_PATH
|
||||
entity_class = locate_and_import_class(entity, folder_path)
|
||||
except AttributeError as e1:
|
||||
try:
|
||||
folder_path = Path(__file__).parent.parent / MODULE_PATH
|
||||
entity_class = locate_and_import_class(entity, folder_path)
|
||||
except AttributeError as e2:
|
||||
try:
|
||||
folder_path = self.custom_modules_path
|
||||
entity_class = locate_and_import_class(entity, folder_path)
|
||||
except AttributeError as e3:
|
||||
ents = [y for x in [e1.argss[1], e2.argss[1], e3.argss[1]] for y in x]
|
||||
raise AttributeError(e1.argss[0], e2.argss[0], e3.argss[0], 'Possible Entitys are>:', str(ents))
|
||||
|
||||
entity_kwargs = self.entities.get(entity, {})
|
||||
entity_symbol = entity_class.symbol if hasattr(entity_class, 'symbol') else None
|
||||
entity_classes.update({entity: {'class': entity_class, 'kwargs': entity_kwargs, 'symbol': entity_symbol}})
|
||||
@ -92,7 +100,7 @@ class FactoryConfigParser(object):
|
||||
parsed_actions = list()
|
||||
for action in actions:
|
||||
folder_path = MODULE_PATH if action not in base_env_actions else DEFAULT_PATH
|
||||
folder_path = (Path(__file__) / '..' / '..' / '..' / folder_path)
|
||||
folder_path = Path(__file__).parent.parent / folder_path
|
||||
try:
|
||||
class_or_classes = locate_and_import_class(action, folder_path)
|
||||
except AttributeError:
|
||||
@ -124,12 +132,15 @@ class FactoryConfigParser(object):
|
||||
rules.extend(x for x in self.rules if x != c.DEFAULTS)
|
||||
|
||||
for rule in rules:
|
||||
folder_path = MODULE_PATH if rule not in self.default_rules else DEFAULT_PATH
|
||||
folder_path = (Path(__file__) / '..' / '..' / '..' / folder_path)
|
||||
try:
|
||||
folder_path = (Path(__file__).parent.parent / DEFAULT_PATH)
|
||||
rule_class = locate_and_import_class(rule, folder_path)
|
||||
except AttributeError:
|
||||
rule_class = locate_and_import_class(rule, self.custom_modules_path)
|
||||
try:
|
||||
folder_path = (Path(__file__).parent.parent / MODULE_PATH)
|
||||
rule_class = locate_and_import_class(rule, folder_path)
|
||||
except AttributeError:
|
||||
rule_class = locate_and_import_class(rule, self.custom_modules_path)
|
||||
rule_kwargs = self.rules.get(rule, {})
|
||||
rules_classes.update({rule: {'class': rule_class, 'kwargs': rule_kwargs}})
|
||||
return rules_classes
|
||||
|
@ -176,7 +176,7 @@ def one_hot_level(level, symbol: str):
|
||||
|
||||
grid = np.array(level)
|
||||
binary_grid = np.zeros(grid.shape, dtype=np.int8)
|
||||
binary_grid[grid == symbol] = c.VALUE_OCCUPIED_CELL
|
||||
binary_grid[grid == str(symbol)] = c.VALUE_OCCUPIED_CELL
|
||||
return binary_grid
|
||||
|
||||
|
||||
@ -222,18 +222,15 @@ def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''):
|
||||
for module_path in module_paths:
|
||||
module_parts = [x.replace('.py', '') for idx, x in enumerate(module_path.parts) if idx >= package_pos]
|
||||
mod = importlib.import_module('.'.join(module_parts))
|
||||
all_found_modules.extend([x for x in dir(mod) if not(x.startswith('__') or len(x) < 2 or x.isupper())
|
||||
and x not in ['Entity', 'NamedTuple', 'List', 'Rule', 'Union', 'random', 'Floor'
|
||||
'TickResult', 'ActionResult', 'Action', 'Agent', 'deque',
|
||||
'BoundEntityMixin', 'RenderEntity', 'TemplateRule', 'defaultdict',
|
||||
'is_move', 'Objects', 'PositionMixin', 'IsBoundMixin', 'EnvObject',
|
||||
'EnvObjects', 'Dict', 'locate_and_import_class', 'yaml', 'Any',
|
||||
'inspect']])
|
||||
all_found_modules.extend([x for x in dir(mod) if (not(x.startswith('__') or len(x) <= 2) and x.istitle())
|
||||
and x not in ['Entity', 'NamedTuple', 'List', 'Rule', 'Union', 'Floor'
|
||||
'TickResult', 'ActionResult', 'Action', 'Agent', 'BoundEntityMixin',
|
||||
'RenderEntity', 'TemplateRule', 'Objects', 'PositionMixin',
|
||||
'IsBoundMixin', 'EnvObject', 'EnvObjects', 'Dict', 'Any'
|
||||
]])
|
||||
try:
|
||||
model_class = mod.__getattribute__(class_name)
|
||||
return model_class
|
||||
except AttributeError:
|
||||
continue
|
||||
raise AttributeError(f'Class "{class_name}" was not found!!!"\n'
|
||||
f'Check the {folder_path.name} name.\n'
|
||||
f'Possible Options are:\n{set(all_found_modules)}')
|
||||
raise AttributeError(f'Class "{class_name}" was not found in "{folder_path.name}"', list(set(all_found_modules)))
|
||||
|
@ -24,31 +24,40 @@ class LevelParser(object):
|
||||
self.level_shape = level_array.shape
|
||||
self.size = self.pomdp_r**2 if self.pomdp_r else np.prod(self.level_shape)
|
||||
|
||||
def get_coordinates_for_symbol(self, symbol, negate=False):
|
||||
level_array = h.one_hot_level(self._parsed_level, symbol)
|
||||
if negate:
|
||||
return np.argwhere(level_array != c.VALUE_OCCUPIED_CELL)
|
||||
else:
|
||||
return np.argwhere(level_array == c.VALUE_OCCUPIED_CELL)
|
||||
|
||||
def do_init(self):
|
||||
entities = Entities()
|
||||
# Walls
|
||||
level_array = h.one_hot_level(self._parsed_level, c.SYMBOL_WALL)
|
||||
|
||||
walls = Walls.from_coordinates(np.argwhere(level_array == c.VALUE_OCCUPIED_CELL), self.size)
|
||||
walls = Walls.from_coordinates(self.get_coordinates_for_symbol(c.SYMBOL_WALL), self.size)
|
||||
entities.add_items({c.WALL: walls})
|
||||
|
||||
# Floor
|
||||
floor = Floors.from_coordinates(np.argwhere(level_array == c.VALUE_FREE_CELL), self.size)
|
||||
floor = Floors.from_coordinates(self.get_coordinates_for_symbol(c.SYMBOL_WALL, negate=True), self.size)
|
||||
entities.add_items({c.FLOOR: floor})
|
||||
|
||||
# All other
|
||||
for es_name in self.e_p_dict:
|
||||
e_class, e_kwargs = self.e_p_dict[es_name]['class'], self.e_p_dict[es_name]['kwargs']
|
||||
|
||||
if hasattr(e_class, 'symbol'):
|
||||
level_array = h.one_hot_level(self._parsed_level, symbol=e_class.symbol)
|
||||
if np.any(level_array):
|
||||
e = e_class.from_coordinates(np.argwhere(level_array == c.VALUE_OCCUPIED_CELL).tolist(),
|
||||
entities[c.FLOOR], self.size, entity_kwargs=e_kwargs
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'No {e_class} (Symbol: {e_class.symbol}) could be found!\n'
|
||||
f'Check your level file!')
|
||||
if hasattr(e_class, 'symbol') and e_class.symbol is not None:
|
||||
symbols = e_class.symbol
|
||||
if isinstance(symbols, (str, int, float)):
|
||||
symbols = [symbols]
|
||||
for symbol in symbols:
|
||||
level_array = h.one_hot_level(self._parsed_level, symbol=symbol)
|
||||
if np.any(level_array):
|
||||
e = e_class.from_coordinates(np.argwhere(level_array == c.VALUE_OCCUPIED_CELL).tolist(),
|
||||
entities[c.FLOOR], self.size, entity_kwargs=e_kwargs
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'No {e_class} (Symbol: {e_class.symbol}) could be found!\n'
|
||||
f'Check your level file!')
|
||||
else:
|
||||
e = e_class(self.size, **e_kwargs)
|
||||
entities.add_items({e.name: e})
|
||||
|
@ -6,11 +6,10 @@ from typing import Dict, List
|
||||
import numpy as np
|
||||
from numba import njit
|
||||
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.environment.groups.utils import Combined
|
||||
from marl_factory_grid.utils.states import Gamestate
|
||||
|
||||
from marl_factory_grid.environment import constants as c
|
||||
|
||||
|
||||
class OBSBuilder(object):
|
||||
|
||||
@ -111,10 +110,10 @@ class OBSBuilder(object):
|
||||
e = next(x for x in self.all_obs if l_name in x and agent.name in x)
|
||||
except StopIteration:
|
||||
raise KeyError(
|
||||
f'Check typing!\n{l_name} could not be found in:\n{dict(self.all_obs).keys()}')
|
||||
f'Check typing! {l_name} could not be found in: {list(dict(self.all_obs).keys())}')
|
||||
|
||||
try:
|
||||
positional = e.has_position
|
||||
positional = e.var_has_position
|
||||
except AttributeError:
|
||||
positional = False
|
||||
if positional:
|
||||
@ -172,7 +171,7 @@ class OBSBuilder(object):
|
||||
obs_layers.append(combined.name)
|
||||
elif obs_str == c.OTHERS:
|
||||
obs_layers.extend([x for x in self.all_obs if x != agent.name and x.startswith(f'{c.AGENT}[')])
|
||||
elif obs_str == c.AGENTS:
|
||||
elif obs_str == c.AGENT:
|
||||
obs_layers.extend([x for x in self.all_obs if x.startswith(f'{c.AGENT}[')])
|
||||
else:
|
||||
obs_layers.append(obs_str)
|
||||
@ -222,7 +221,7 @@ class RayCaster:
|
||||
entities_hit = entities.pos_dict[(x, y)]
|
||||
hits = self.ray_block_cache(cache_blocking,
|
||||
(x, y),
|
||||
lambda: any(e.is_blocking_light for e in entities_hit),
|
||||
lambda: any(e.var_is_blocking_light for e in entities_hit),
|
||||
entities)
|
||||
|
||||
try:
|
||||
@ -237,8 +236,8 @@ class RayCaster:
|
||||
self.ray_block_cache(
|
||||
cache_blocking,
|
||||
key,
|
||||
# lambda: all(False for e in entities.pos_dict[key] if not e.is_blocking_light),
|
||||
lambda: any(e.is_blocking_light for e in entities.pos_dict[key]),
|
||||
# lambda: all(False for e in entities.pos_dict[key] if not e.var_is_blocking_light),
|
||||
lambda: any(e.var_is_blocking_light for e in entities.pos_dict[key]),
|
||||
entities)
|
||||
for key in ((x, y-cy), (x-cx, y))
|
||||
]) if (cx != 0 and cy != 0) else False
|
||||
|
@ -27,13 +27,13 @@ class Renderer:
|
||||
BG_COLOR = (178, 190, 195) # (99, 110, 114)
|
||||
WHITE = (223, 230, 233) # (200, 200, 200)
|
||||
AGENT_VIEW_COLOR = (9, 132, 227)
|
||||
ASSETS = Path(__file__).parent.parent / 'assets'
|
||||
MODULE_ASSETS = Path(__file__).parent.parent.parent / 'modules'
|
||||
ASSETS = Path(__file__).parent.parent
|
||||
|
||||
def __init__(self, lvl_shape: Tuple[int, int] = (16, 16),
|
||||
lvl_padded_shape: Union[Tuple[int, int], None] = None,
|
||||
cell_size: int = 40, fps: int = 7,
|
||||
grid_lines: bool = True, view_radius: int = 2):
|
||||
# TODO: Customn_assets paths
|
||||
self.grid_h, self.grid_w = lvl_shape
|
||||
self.lvl_padded_shape = lvl_padded_shape if lvl_padded_shape is not None else lvl_shape
|
||||
self.cell_size = cell_size
|
||||
@ -44,7 +44,7 @@ class Renderer:
|
||||
self.screen_size = (self.grid_w*cell_size, self.grid_h*cell_size)
|
||||
self.screen = pygame.display.set_mode(self.screen_size)
|
||||
self.clock = pygame.time.Clock()
|
||||
assets = list(self.ASSETS.rglob('*.png')) + list(self.MODULE_ASSETS.rglob('*.png'))
|
||||
assets = list(self.ASSETS.rglob('*.png'))
|
||||
self.assets = {path.stem: self.load_asset(str(path), 1) for path in assets}
|
||||
self.fill_bg()
|
||||
|
||||
|
@ -1,8 +1,6 @@
|
||||
from typing import Union
|
||||
from dataclasses import dataclass
|
||||
|
||||
from marl_factory_grid.environment.entity.entity import Entity
|
||||
|
||||
TYPE_VALUE = 'value'
|
||||
TYPE_REWARD = 'reward'
|
||||
types = [TYPE_VALUE, TYPE_REWARD]
|
||||
@ -20,7 +18,7 @@ class Result:
|
||||
validity: bool
|
||||
reward: Union[float, None] = None
|
||||
value: Union[float, None] = None
|
||||
entity: Union[Entity, None] = None
|
||||
entity: None = None
|
||||
|
||||
def get_infos(self):
|
||||
n = self.entity.name if self.entity is not None else "Global"
|
||||
|
@ -2,10 +2,11 @@ from typing import List, Dict
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.environment.entity.wall_floor import Floor
|
||||
from marl_factory_grid.environment.rules import Rule
|
||||
from marl_factory_grid.utils.results import Result
|
||||
from marl_factory_grid.environment import constants as c
|
||||
|
||||
|
||||
class StepRules:
|
||||
@ -26,9 +27,9 @@ class StepRules:
|
||||
self.rules.append(item)
|
||||
return True
|
||||
|
||||
def do_all_init(self, state):
|
||||
def do_all_init(self, state, lvl_map):
|
||||
for rule in self.rules:
|
||||
if rule_init_printline := rule.on_init(state):
|
||||
if rule_init_printline := rule.on_init(state, lvl_map):
|
||||
state.print(rule_init_printline)
|
||||
return c.VALID
|
||||
|
||||
@ -58,7 +59,7 @@ class Gamestate(object):
|
||||
|
||||
@property
|
||||
def moving_entites(self):
|
||||
return [y for x in self.entities for y in x if x.can_move]
|
||||
return [y for x in self.entities for y in x if x.var_can_move]
|
||||
|
||||
def __init__(self, entitites, rules: Dict[str, dict], env_seed=69, verbose=False):
|
||||
self.entities = entitites
|
||||
@ -107,6 +108,6 @@ class Gamestate(object):
|
||||
|
||||
def get_all_tiles_with_collisions(self) -> List[Floor]:
|
||||
tiles = [self[c.FLOOR].by_pos(pos) for pos, e in self.entities.pos_dict.items()
|
||||
if sum([x.can_collide for x in e]) > 1]
|
||||
if sum([x.var_can_collide for x in e]) > 1]
|
||||
# tiles = [x for x in self[c.FLOOR] if len(x.guests_that_can_collide) > 1]
|
||||
return tiles
|
||||
|
@ -15,7 +15,7 @@ ENTITIES = 'Objects'
|
||||
OBSERVATIONS = 'Observations'
|
||||
RULES = 'Rule'
|
||||
ASSETS = 'Assets'
|
||||
EXCLUDED = ['identifier', 'args', 'kwargs', 'Move',
|
||||
EXCLUDED = ['identifier', 'args', 'kwargs', 'Move', 'Floor', 'Agent', 'GlobalPositions', 'Walls',
|
||||
'TemplateRule', 'Entities', 'EnvObjects', 'Zones', ]
|
||||
|
||||
|
||||
|
@ -1,21 +1,6 @@
|
||||
import gymnasium as gym
|
||||
|
||||
|
||||
class EnvCombiner(object):
|
||||
|
||||
def __init__(self, *envs_cls):
|
||||
self._env_dict = {env_cls.__name__: env_cls for env_cls in envs_cls}
|
||||
|
||||
@staticmethod
|
||||
def combine_cls(name, *envs_cls):
|
||||
return type(name, envs_cls, {})
|
||||
|
||||
def build(self):
|
||||
name = f'{"".join([x.lower().replace("factory").capitalize() for x in self._env_dict.keys()])}Factory'
|
||||
|
||||
return self.combine_cls(name, tuple(self._env_dict.values()))
|
||||
|
||||
|
||||
class MarlFrameStack(gym.ObservationWrapper):
|
||||
"""todo @romue404"""
|
||||
def __init__(self, env):
|
||||
|
Reference in New Issue
Block a user