mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-06-18 18:52:52 +02:00
Machines
This commit is contained in:
@ -176,7 +176,7 @@ def one_hot_level(level, symbol: str):
|
||||
|
||||
grid = np.array(level)
|
||||
binary_grid = np.zeros(grid.shape, dtype=np.int8)
|
||||
binary_grid[grid == symbol] = c.VALUE_OCCUPIED_CELL
|
||||
binary_grid[grid == str(symbol)] = c.VALUE_OCCUPIED_CELL
|
||||
return binary_grid
|
||||
|
||||
|
||||
@ -222,18 +222,15 @@ def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''):
|
||||
for module_path in module_paths:
|
||||
module_parts = [x.replace('.py', '') for idx, x in enumerate(module_path.parts) if idx >= package_pos]
|
||||
mod = importlib.import_module('.'.join(module_parts))
|
||||
all_found_modules.extend([x for x in dir(mod) if not(x.startswith('__') or len(x) < 2 or x.isupper())
|
||||
and x not in ['Entity', 'NamedTuple', 'List', 'Rule', 'Union', 'random', 'Floor'
|
||||
'TickResult', 'ActionResult', 'Action', 'Agent', 'deque',
|
||||
'BoundEntityMixin', 'RenderEntity', 'TemplateRule', 'defaultdict',
|
||||
'is_move', 'Objects', 'PositionMixin', 'IsBoundMixin', 'EnvObject',
|
||||
'EnvObjects', 'Dict', 'locate_and_import_class', 'yaml', 'Any',
|
||||
'inspect']])
|
||||
all_found_modules.extend([x for x in dir(mod) if (not(x.startswith('__') or len(x) <= 2) and x.istitle())
|
||||
and x not in ['Entity', 'NamedTuple', 'List', 'Rule', 'Union', 'Floor'
|
||||
'TickResult', 'ActionResult', 'Action', 'Agent', 'BoundEntityMixin',
|
||||
'RenderEntity', 'TemplateRule', 'Objects', 'PositionMixin',
|
||||
'IsBoundMixin', 'EnvObject', 'EnvObjects', 'Dict', 'Any'
|
||||
]])
|
||||
try:
|
||||
model_class = mod.__getattribute__(class_name)
|
||||
return model_class
|
||||
except AttributeError:
|
||||
continue
|
||||
raise AttributeError(f'Class "{class_name}" was not found!!!"\n'
|
||||
f'Check the {folder_path.name} name.\n'
|
||||
f'Possible Options are:\n{set(all_found_modules)}')
|
||||
raise AttributeError(f'Class "{class_name}" was not found in "{folder_path.name}"', list(set(all_found_modules)))
|
||||
|
Reference in New Issue
Block a user