mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-11-02 13:37:27 +01:00
Multiple Fixes:
- Config Explainer - Rewards - Destination Reach Condition - Additional Step Callback
This commit is contained in:
@@ -151,10 +151,12 @@ class FactoryConfigParser(object):
|
||||
class_or_classes = locate_and_import_class(action, self.custom_modules_path)
|
||||
try:
|
||||
parsed_actions.extend(class_or_classes)
|
||||
for actions_class in class_or_classes:
|
||||
conf_kwargs[actions_class.__name__] = conf_kwargs[action]
|
||||
except TypeError:
|
||||
parsed_actions.append(class_or_classes)
|
||||
|
||||
parsed_actions = [x(**conf_kwargs.get(x, {})) for x in parsed_actions]
|
||||
parsed_actions = [x(**conf_kwargs.get(x.__name__, {})) for x in parsed_actions]
|
||||
|
||||
# Observation
|
||||
observations = list()
|
||||
|
||||
@@ -218,32 +218,6 @@ def is_move(action_name: str):
|
||||
"""
|
||||
return action_name in MOVEMAP.keys()
|
||||
|
||||
|
||||
def asset_str(agent):
|
||||
"""
|
||||
FIXME @ romue
|
||||
"""
|
||||
# What does this abonimation do?
|
||||
# if any([x is None for x in [cls._slices[j] for j in agent.collisions]]):
|
||||
# print('error')
|
||||
if step_result := agent.step_result:
|
||||
action = step_result['action_name']
|
||||
valid = step_result['action_valid']
|
||||
col_names = [x.name for x in step_result['collisions']]
|
||||
if any(c.AGENT in name for name in col_names):
|
||||
return 'agent_collision', 'blank'
|
||||
elif not valid or c.LEVEL in col_names or c.AGENT in col_names:
|
||||
return c.AGENT, 'invalid'
|
||||
elif valid and not is_move(action):
|
||||
return c.AGENT, 'valid'
|
||||
elif valid and is_move(action):
|
||||
return c.AGENT, 'move'
|
||||
else:
|
||||
return c.AGENT, 'idle'
|
||||
else:
|
||||
return c.AGENT, 'idle'
|
||||
|
||||
|
||||
def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''):
|
||||
"""
|
||||
Locate an object by name or dotted path.
|
||||
|
||||
@@ -51,7 +51,7 @@ class EnvMonitor(Wrapper):
|
||||
pass
|
||||
return
|
||||
|
||||
def save_run(self, filepath: Union[Path, str, None] = None, auto_plotting_keys=None):
|
||||
def save_monitor(self, filepath: Union[Path, str, None] = None, auto_plotting_keys=None):
|
||||
filepath = Path(filepath or self._filepath)
|
||||
filepath.parent.mkdir(exist_ok=True, parents=True)
|
||||
with filepath.open('wb') as f:
|
||||
|
||||
@@ -25,6 +25,12 @@ class EnvRecorder(Wrapper):
|
||||
return self.env.reset()
|
||||
|
||||
def step(self, actions):
|
||||
"""
|
||||
Todo
|
||||
|
||||
:param actions:
|
||||
:return:
|
||||
"""
|
||||
obs_type, obs, reward, done, info = self.env.step(actions)
|
||||
if not self.episodes or self._curr_episode in self.episodes:
|
||||
summary: dict = self.env.summarize_state()
|
||||
|
||||
@@ -14,8 +14,9 @@ ENTITIES = 'Objects'
|
||||
OBSERVATIONS = 'Observations'
|
||||
RULES = 'Rule'
|
||||
TESTS = 'Tests'
|
||||
EXCLUDED = ['identifier', 'args', 'kwargs', 'Move', 'Agent', 'GlobalPositions', 'Walls',
|
||||
'TemplateRule', 'Entities', 'EnvObjects', 'Zones', ]
|
||||
EXCLUDED = ['identifier', 'args', 'kwargs', 'Move', 'Agent', 'GlobalPositions', 'Walls', 'Gamestate', 'Path',
|
||||
'Iterable', 'Move', 'Result', 'TemplateRule', 'Entities', 'EnvObjects', 'Zones', 'Collection',
|
||||
'State', 'Object', 'default_valid_reward', 'default_fail_reward', 'size']
|
||||
|
||||
|
||||
class ConfigExplainer:
|
||||
@@ -32,7 +33,9 @@ class ConfigExplainer:
|
||||
|
||||
:param custom_path: Path to your custom module folder.
|
||||
"""
|
||||
self.base_path = Path(__file__).parent.parent.resolve()
|
||||
|
||||
self.base_path = Path(__file__).parent.parent.resolve() /'environment'
|
||||
self.modules_path = Path(__file__).parent.parent.resolve() / 'modules'
|
||||
self.custom_path = Path(custom_path) if custom_path is not None else custom_path
|
||||
self.searchspace = [ACTION, GENERAL, ENTITIES, OBSERVATIONS, RULES, TESTS]
|
||||
|
||||
@@ -41,7 +44,13 @@ class ConfigExplainer:
|
||||
"""
|
||||
INTERNAL USE ONLY
|
||||
"""
|
||||
parameters = inspect.signature(class_to_explain).parameters
|
||||
this_search = class_to_explain
|
||||
parameters = dict(inspect.signature(class_to_explain).parameters)
|
||||
while this_search.__bases__:
|
||||
base_class = this_search.__bases__[0]
|
||||
parameters.update(dict(inspect.signature(base_class).parameters))
|
||||
this_search = base_class
|
||||
|
||||
explained = {class_to_explain.__name__:
|
||||
{key: val.default for key, val in parameters.items() if key not in EXCLUDED}
|
||||
}
|
||||
@@ -52,8 +61,10 @@ class ConfigExplainer:
|
||||
INTERNAL USE ONLY
|
||||
"""
|
||||
entities_base_cls = locate_and_import_class(identifier, self.base_path)
|
||||
module_paths = [x.resolve() for x in self.base_path.rglob('*.py') if x.is_file() and '__init__' not in x.name]
|
||||
found_entities = self._load_and_compare(entities_base_cls, module_paths)
|
||||
module_paths = [x.resolve() for x in self.modules_path.rglob('*.py') if x.is_file() and '__init__' not in x.name]
|
||||
base_paths = [x.resolve() for x in self.base_path.rglob('*.py') if x.is_file() and '__init__' not in x.name]
|
||||
found_entities = self._load_and_compare(entities_base_cls, base_paths)
|
||||
found_entities.update(self._load_and_compare(entities_base_cls, module_paths))
|
||||
if self.custom_path is not None:
|
||||
module_paths = [x.resolve() for x in self.custom_path.rglob('*.py') if x.is_file()
|
||||
and '__init__' not in x.name]
|
||||
@@ -91,16 +102,14 @@ class ConfigExplainer:
|
||||
print(f'Example config {"for " + tag + " " if tag else " "}dumped')
|
||||
print(f'See file: {filepath}')
|
||||
|
||||
def get_actions(self) -> list[str]:
|
||||
def get_actions(self) -> dict[str]:
|
||||
"""
|
||||
Retrieve all actions from module folders.
|
||||
|
||||
:returns: A list of all available actions.
|
||||
"""
|
||||
actions = self._get_by_identifier(ACTION)
|
||||
assert all(not x for x in actions.values()), 'Please only provide Names, no Mappings.'
|
||||
actions = list(actions.keys())
|
||||
actions.extend([c.MOVE8, c.MOVE4])
|
||||
actions.update({c.MOVE8: {}, c.MOVE4: {}})
|
||||
return actions
|
||||
|
||||
def get_all(self) -> dict[str]:
|
||||
@@ -172,13 +181,20 @@ class ConfigExplainer:
|
||||
except TypeError:
|
||||
e = [key]
|
||||
except AttributeError as err:
|
||||
if self.custom_path is not None:
|
||||
try:
|
||||
e = locate_and_import_class(key, self.base_path)(level_shape=(0, 0), pomdp_r=0).obs_pairs
|
||||
except TypeError:
|
||||
e = [key]
|
||||
try:
|
||||
e = locate_and_import_class(key, self.modules_path)(level_shape=(0, 0), pomdp_r=0).obs_pairs
|
||||
except TypeError:
|
||||
e = [key]
|
||||
except AttributeError as err2:
|
||||
if self.custom_path is not None:
|
||||
try:
|
||||
e = locate_and_import_class(key, self.base_path)(level_shape=(0, 0), pomdp_r=0).obs_pairs
|
||||
except TypeError:
|
||||
e = [key]
|
||||
else:
|
||||
raise err
|
||||
print(err.args)
|
||||
print(err2.args)
|
||||
exit(-9999)
|
||||
names.extend(e)
|
||||
return names
|
||||
|
||||
|
||||
Reference in New Issue
Block a user