mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-07-11 15:32:42 +02:00
major redesign ob observations and entittes
This commit is contained in:
120
environment/utils/config_parser.py
Normal file
120
environment/utils/config_parser.py
Normal file
@ -0,0 +1,120 @@
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
|
||||
from environment.groups.global_entities import Entities
|
||||
from environment.groups.agents import Agents
|
||||
from environment.entity.agent import Agent
|
||||
from environment.utils.helpers import locate_and_import_class
|
||||
from environment import constants as c
|
||||
|
||||
|
||||
DEFAULT_PATH = 'environment'
|
||||
MODULE_PATH = 'modules'
|
||||
|
||||
|
||||
class FactoryConfigParser(object):
|
||||
|
||||
default_entites = []
|
||||
default_rules = ['MaxStepsReached', 'Collision']
|
||||
default_actions = [c.MOVE8, c.NOOP]
|
||||
default_observations = [c.WALLS, c.AGENTS]
|
||||
|
||||
def __init__(self, config_path):
|
||||
self.config_path = Path(config_path)
|
||||
self.config = yaml.safe_load(config_path.open())
|
||||
self.do_record = False
|
||||
|
||||
def __getattr__(self, item):
|
||||
return self['General'][item]
|
||||
|
||||
def _get_sub_list(self, primary_key: str, sub_key: str):
|
||||
return [{key: [s for k, v in val.items() if k == sub_key for s in v] for key, val in x.items()
|
||||
} for x in self.config[primary_key]]
|
||||
|
||||
@property
|
||||
def agent_actions(self):
|
||||
return self._get_sub_list('Agents', "Actions")
|
||||
|
||||
@property
|
||||
def agent_observations(self):
|
||||
return self._get_sub_list('Agents', "Observations")
|
||||
|
||||
@property
|
||||
def rules(self):
|
||||
return self.config['Rules']
|
||||
|
||||
@property
|
||||
def agents(self):
|
||||
return self.config['Agents']
|
||||
|
||||
@property
|
||||
def entities(self):
|
||||
return self.config['Entities']
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.config)
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.config[item]
|
||||
|
||||
def load_entities(self):
|
||||
# entites = Entities()
|
||||
entity_classes = dict()
|
||||
entities = []
|
||||
if c.DEFAULTS in self.entities:
|
||||
entities.extend(self.default_entites)
|
||||
entities.extend(x for x in self.entities if x != c.DEFAULTS)
|
||||
|
||||
for entity in entities:
|
||||
folder_path = MODULE_PATH if entity not in self.default_entites else DEFAULT_PATH
|
||||
entity_class = locate_and_import_class(entity, folder_path)
|
||||
entity_kwargs = self.entities.get(entity, {})
|
||||
entity_symbol = entity_class.symbol if hasattr(entity_class, 'symbol') else None
|
||||
entity_classes.update({entity: {'class': entity_class, 'kwargs': entity_kwargs, 'symbol': entity_symbol}})
|
||||
return entity_classes
|
||||
|
||||
def load_agents(self, size, free_tiles):
|
||||
agents = Agents(size)
|
||||
base_env_actions = self.default_actions.copy() + [c.MOVE4]
|
||||
for name in self.agents:
|
||||
# Actions
|
||||
actions = list()
|
||||
if c.DEFAULTS in self.agents[name]['Actions']:
|
||||
actions.extend(self.default_actions)
|
||||
actions.extend(x for x in self.agents[name]['Actions'] if x != c.DEFAULTS)
|
||||
parsed_actions = list()
|
||||
for action in actions:
|
||||
folder_path = MODULE_PATH if action not in base_env_actions else DEFAULT_PATH
|
||||
class_or_classes = locate_and_import_class(action, folder_path)
|
||||
try:
|
||||
parsed_actions.extend(class_or_classes)
|
||||
except TypeError:
|
||||
parsed_actions.append(class_or_classes)
|
||||
parsed_actions = [x() for x in parsed_actions]
|
||||
|
||||
# Observation
|
||||
observations = list()
|
||||
if c.DEFAULTS in self.agents[name]['Observations']:
|
||||
observations.extend(self.default_observations)
|
||||
observations.extend(x for x in self.agents[name]['Observations'] if x != c.DEFAULTS)
|
||||
agent = Agent(parsed_actions, observations, free_tiles.pop(), str_ident=name)
|
||||
agents.add_item(agent)
|
||||
return agents
|
||||
|
||||
def load_rules(self):
|
||||
# entites = Entities()
|
||||
rules_classes = dict()
|
||||
rules = []
|
||||
if c.DEFAULTS in self.rules:
|
||||
for rule in self.default_rules:
|
||||
if rule not in rules:
|
||||
rules.append(rule)
|
||||
rules.extend(x for x in self.rules if x != c.DEFAULTS)
|
||||
|
||||
for rule in rules:
|
||||
folder_path = MODULE_PATH if rule not in self.default_rules else DEFAULT_PATH
|
||||
rule_class = locate_and_import_class(rule, folder_path)
|
||||
rule_kwargs = self.rules.get(rule, {})
|
||||
rules_classes.update({rule: {'class': rule_class, 'kwargs': rule_kwargs}})
|
||||
return rules_classes
|
Reference in New Issue
Block a user