mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-22 14:56:43 +02:00
103 lines
4.0 KiB
Python
103 lines
4.0 KiB
Python
from os import PathLike
|
|
from pathlib import Path
|
|
from typing import Dict
|
|
|
|
import numpy as np
|
|
|
|
from marl_factory_grid.environment.groups.agents import Agents
|
|
from marl_factory_grid.environment.groups.global_entities import Entities
|
|
from marl_factory_grid.environment.groups.walls import Walls
|
|
from marl_factory_grid.utils import helpers as h
|
|
from marl_factory_grid.environment import constants as c
|
|
|
|
|
|
class LevelParser(object):
|
|
|
|
@property
|
|
def pomdp_d(self):
|
|
"""
|
|
Calculates the effective diameter of the POMDP observation space.
|
|
|
|
:return: The calculated effective diameter.
|
|
:rtype: int
|
|
"""
|
|
return self.pomdp_r * 2 + 1
|
|
|
|
def __init__(self, level_file_path: PathLike, entity_parse_dict: Dict[Entities, dict], pomdp_r=0):
|
|
"""
|
|
Parses a level file and creates the initial state of the environment.
|
|
|
|
:param level_file_path: Path to the level file.
|
|
:type level_file_path: PathLike
|
|
|
|
:param entity_parse_dict: Dictionary specifying how to parse different entities.
|
|
:type entity_parse_dict: Dict[Entities, dict]
|
|
|
|
:param pomdp_r: The POMDP radius. Defaults to 0.
|
|
:type pomdp_r: int
|
|
"""
|
|
self.pomdp_r = pomdp_r
|
|
self.e_p_dict = entity_parse_dict
|
|
self._parsed_level = h.parse_level(Path(level_file_path))
|
|
level_array = h.one_hot_level(self._parsed_level, c.SYMBOL_WALL)
|
|
self.level_shape = level_array.shape
|
|
self.size = self.pomdp_r ** 2 if self.pomdp_r else np.prod(self.level_shape)
|
|
|
|
def get_coordinates_for_symbol(self, symbol, negate=False) -> np.ndarray:
|
|
"""
|
|
Get the coordinates for a given symbol in the parsed level.
|
|
|
|
:param symbol: The symbol to search for.
|
|
:param negate: If True, get coordinates not matching the symbol. Defaults to False.
|
|
|
|
:return: Array of coordinates.
|
|
:rtype: np.ndarray
|
|
"""
|
|
level_array = h.one_hot_level(self._parsed_level, symbol)
|
|
if negate:
|
|
return np.argwhere(level_array != c.VALUE_OCCUPIED_CELL)
|
|
else:
|
|
return np.argwhere(level_array == c.VALUE_OCCUPIED_CELL)
|
|
|
|
def do_init(self) -> Entities:
|
|
"""
|
|
Initialize the environment map state by creating entities such as Walls, Agents or Machines according to the
|
|
entity parse dict.
|
|
|
|
:return: A dict of all parsed entities with their positions.
|
|
:rtype: Entities
|
|
"""
|
|
# Global Entities
|
|
list_of_all_positions = ([tuple(f) for f in self.get_coordinates_for_symbol(c.SYMBOL_WALL, negate=True)])
|
|
entities = Entities(list_of_all_positions)
|
|
|
|
# Walls
|
|
walls = Walls.from_coordinates(self.get_coordinates_for_symbol(c.SYMBOL_WALL), self.size)
|
|
entities.add_items({c.WALLS: walls})
|
|
|
|
# Agents
|
|
entities.add_items({c.AGENT: Agents(self.size)})
|
|
|
|
# All other
|
|
for es_name in self.e_p_dict:
|
|
e_class, e_kwargs = self.e_p_dict[es_name]['class'], self.e_p_dict[es_name]['kwargs']
|
|
e_kwargs = e_kwargs if e_kwargs else {}
|
|
|
|
if hasattr(e_class, 'symbol') and e_class.symbol is not None:
|
|
symbols = e_class.symbol
|
|
if isinstance(symbols, (str, int, float)):
|
|
symbols = [symbols]
|
|
for symbol in symbols:
|
|
level_array = h.one_hot_level(self._parsed_level, symbol=symbol)
|
|
if np.any(level_array):
|
|
# TODO: Get rid of this!
|
|
e = e_class.from_coordinates(np.argwhere(level_array == c.VALUE_OCCUPIED_CELL).tolist(),
|
|
self.size, entity_kwargs=e_kwargs)
|
|
else:
|
|
raise ValueError(f'No {e_class} (Symbol: {e_class.symbol}) could be found!\n'
|
|
f'Check your level file!')
|
|
else:
|
|
e = e_class(self.size, **e_kwargs)
|
|
entities.add_items({e.name: e})
|
|
return entities
|