Merge branch 'main' into unit_testing

This commit is contained in:
Chanumask
2023-11-28 12:28:20 +01:00
21 changed files with 270 additions and 171 deletions

View File

@ -153,10 +153,12 @@ class FactoryConfigParser(object):
class_or_classes = locate_and_import_class(action, self.custom_modules_path)
try:
parsed_actions.extend(class_or_classes)
for actions_class in class_or_classes:
conf_kwargs[actions_class.__name__] = conf_kwargs[action]
except TypeError:
parsed_actions.append(class_or_classes)
parsed_actions = [x(**conf_kwargs.get(x, {})) for x in parsed_actions]
parsed_actions = [x(**conf_kwargs.get(x.__name__, {})) for x in parsed_actions]
# Observation
observations = list()

View File

@ -27,9 +27,11 @@ IGNORED_DF_COLUMNS = ['Episode', 'Run', # For plotting, which values are ignore
'train_step', 'step', 'index', 'dirt_amount', 'dirty_pos_count', 'terminal_observation',
'episode']
POS_MASK = np.asarray([[[-1, -1], [0, -1], [1, -1]],
[[-1, 0], [0, 0], [1, 0]],
[[-1, 1], [0, 1], [1, 1]]])
POS_MASK_8 = np.asarray([[[-1, -1], [0, -1], [1, -1]],
[[-1, 0], [0, 0], [1, 0]],
[[-1, 1], [0, 1], [1, 1]]])
POS_MASK_4 = np.asarray([[0, -1], [-1, 0], [1, 0], [-1, 1], [0, 1], [1, 1]])
MOVEMAP = defaultdict(lambda: (0, 0),
{c.NORTH: (-1, 0), c.NORTHEAST: (-1, 1),
@ -216,32 +218,6 @@ def is_move(action_name: str):
"""
return action_name in MOVEMAP.keys()
def asset_str(agent):
"""
FIXME @ romue
"""
# What does this abonimation do?
# if any([x is None for x in [cls._slices[j] for j in agent.collisions]]):
# print('error')
if step_result := agent.step_result:
action = step_result['action_name']
valid = step_result['action_valid']
col_names = [x.name for x in step_result['collisions']]
if any(c.AGENT in name for name in col_names):
return 'agent_collision', 'blank'
elif not valid or c.LEVEL in col_names or c.AGENT in col_names:
return c.AGENT, 'invalid'
elif valid and not is_move(action):
return c.AGENT, 'valid'
elif valid and is_move(action):
return c.AGENT, 'move'
else:
return c.AGENT, 'idle'
else:
return c.AGENT, 'idle'
def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''):
"""
Locate an object by name or dotted path.

View File

@ -51,7 +51,7 @@ class EnvMonitor(Wrapper):
pass
return
def save_run(self, filepath: Union[Path, str, None] = None, auto_plotting_keys=None):
def save_monitor(self, filepath: Union[Path, str, None] = None, auto_plotting_keys=None):
filepath = Path(filepath or self._filepath)
filepath.parent.mkdir(exist_ok=True, parents=True)
with filepath.open('wb') as f:

View File

@ -25,6 +25,12 @@ class EnvRecorder(Wrapper):
return self.env.reset()
def step(self, actions):
"""
Todo
:param actions:
:return:
"""
obs_type, obs, reward, done, info = self.env.step(actions)
if not self.episodes or self._curr_episode in self.episodes:
summary: dict = self.env.summarize_state()

View File

@ -2,9 +2,11 @@ from typing import Union
from dataclasses import dataclass
from marl_factory_grid.environment.entity.object import Object
import marl_factory_grid.environment.constants as c
TYPE_VALUE = 'value'
TYPE_REWARD = 'reward'
TYPES = [TYPE_VALUE, TYPE_REWARD]
@ -32,8 +34,9 @@ class Result:
"""
identifier: str
validity: bool
reward: Union[float, None] = None
value: Union[float, None] = None
reward: float | None = None
value: float | None = None
collision: bool | None = None
entity: Object = None
def get_infos(self):
@ -68,8 +71,17 @@ class ActionResult(Result):
super().__init__(*args, **kwargs)
self.action_introduced_collision = action_introduced_collision
pass
def __repr__(self):
sr = super().__repr__()
return sr + f" | {c.COLLISION}" if self.action_introduced_collision is not None else ""
def get_infos(self):
base_infos = super().get_infos()
if self.action_introduced_collision:
i = InfoObject(identifier=f'{self.entity.name}_{c.COLLISION}', val_type=TYPE_VALUE, value=1)
return base_infos + [i]
else:
return base_infos
@dataclass
class DoneResult(Result):

View File

@ -49,6 +49,12 @@ class StepRules:
state.print(rule_reset_printline)
return c.VALID
def do_all_post_spawn_reset(self, state):
for rule in self.rules:
if rule_reset_printline := rule.on_reset_post_spawn(state):
state.print(rule_reset_printline)
return c.VALID
def tick_step_all(self, state):
results = list()
for rule in self.rules:

View File

@ -14,8 +14,9 @@ ENTITIES = 'Objects'
OBSERVATIONS = 'Observations'
RULES = 'Rule'
TESTS = 'Tests'
EXCLUDED = ['identifier', 'args', 'kwargs', 'Move', 'Agent', 'GlobalPositions', 'Walls',
'TemplateRule', 'Entities', 'EnvObjects', 'Zones', ]
EXCLUDED = ['identifier', 'args', 'kwargs', 'Move', 'Agent', 'GlobalPositions', 'Walls', 'Gamestate', 'Path',
'Iterable', 'Move', 'Result', 'TemplateRule', 'Entities', 'EnvObjects', 'Zones', 'Collection',
'State', 'Object', 'default_valid_reward', 'default_fail_reward', 'size']
class ConfigExplainer:
@ -32,7 +33,9 @@ class ConfigExplainer:
:param custom_path: Path to your custom module folder.
"""
self.base_path = Path(__file__).parent.parent.resolve()
self.base_path = Path(__file__).parent.parent.resolve() /'environment'
self.modules_path = Path(__file__).parent.parent.resolve() / 'modules'
self.custom_path = Path(custom_path) if custom_path is not None else custom_path
self.searchspace = [ACTION, GENERAL, ENTITIES, OBSERVATIONS, RULES, TESTS]
@ -41,9 +44,16 @@ class ConfigExplainer:
"""
INTERNAL USE ONLY
"""
parameters = inspect.signature(class_to_explain).parameters
this_search = class_to_explain
parameters = dict(inspect.signature(class_to_explain).parameters)
while this_search.__bases__:
base_class = this_search.__bases__[0]
parameters.update(dict(inspect.signature(base_class).parameters))
this_search = base_class
explained = {class_to_explain.__name__:
{key: val.default for key, val in parameters.items() if key not in EXCLUDED}
{key: val.default if val.default != inspect._empty else '!' for key, val in parameters.items()
if key not in EXCLUDED}
}
return explained
@ -52,8 +62,10 @@ class ConfigExplainer:
INTERNAL USE ONLY
"""
entities_base_cls = locate_and_import_class(identifier, self.base_path)
module_paths = [x.resolve() for x in self.base_path.rglob('*.py') if x.is_file() and '__init__' not in x.name]
found_entities = self._load_and_compare(entities_base_cls, module_paths)
module_paths = [x.resolve() for x in self.modules_path.rglob('*.py') if x.is_file() and '__init__' not in x.name]
base_paths = [x.resolve() for x in self.base_path.rglob('*.py') if x.is_file() and '__init__' not in x.name]
found_entities = self._load_and_compare(entities_base_cls, base_paths)
found_entities.update(self._load_and_compare(entities_base_cls, module_paths))
if self.custom_path is not None:
module_paths = [x.resolve() for x in self.custom_path.rglob('*.py') if x.is_file()
and '__init__' not in x.name]
@ -91,16 +103,14 @@ class ConfigExplainer:
print(f'Example config {"for " + tag + " " if tag else " "}dumped')
print(f'See file: {filepath}')
def get_actions(self) -> list[str]:
def get_actions(self) -> dict[str]:
"""
Retrieve all actions from module folders.
:returns: A list of all available actions.
"""
actions = self._get_by_identifier(ACTION)
assert all(not x for x in actions.values()), 'Please only provide Names, no Mappings.'
actions = list(actions.keys())
actions.extend([c.MOVE8, c.MOVE4])
actions.update({c.MOVE8: {}, c.MOVE4: {}})
return actions
def get_all(self) -> dict[str]:
@ -125,6 +135,8 @@ class ConfigExplainer:
:returns: A list of all available entities.
"""
entities = self._get_by_identifier(ENTITIES)
for key in ['Combined', 'Agents', 'Inventory']:
del entities[key]
return entities
@staticmethod
@ -172,13 +184,20 @@ class ConfigExplainer:
except TypeError:
e = [key]
except AttributeError as err:
if self.custom_path is not None:
try:
e = locate_and_import_class(key, self.base_path)(level_shape=(0, 0), pomdp_r=0).obs_pairs
except TypeError:
e = [key]
try:
e = locate_and_import_class(key, self.modules_path)(level_shape=(0, 0), pomdp_r=0).obs_pairs
except TypeError:
e = [key]
except AttributeError as err2:
if self.custom_path is not None:
try:
e = locate_and_import_class(key, self.base_path)(level_shape=(0, 0), pomdp_r=0).obs_pairs
except TypeError:
e = [key]
else:
raise err
print(err.args)
print(err2.args)
exit(-9999)
names.extend(e)
return names