69 lines
2.8 KiB
Python

from os import PathLike
from pathlib import Path
from typing import Dict
import numpy as np
from marl_factory_grid.environment.groups.agents import Agents
from marl_factory_grid.environment.groups.global_entities import Entities
from marl_factory_grid.environment.groups.walls import Walls
from marl_factory_grid.utils import helpers as h
from marl_factory_grid.environment import constants as c
class LevelParser(object):
@property
def pomdp_d(self):
return self.pomdp_r * 2 + 1
def __init__(self, level_file_path: PathLike, entity_parse_dict: Dict[Entities, dict], pomdp_r=0):
self.pomdp_r = pomdp_r
self.e_p_dict = entity_parse_dict
self._parsed_level = h.parse_level(Path(level_file_path))
level_array = h.one_hot_level(self._parsed_level, c.SYMBOL_WALL)
self.level_shape = level_array.shape
self.size = self.pomdp_r ** 2 if self.pomdp_r else np.prod(self.level_shape)
def get_coordinates_for_symbol(self, symbol, negate=False):
level_array = h.one_hot_level(self._parsed_level, symbol)
if negate:
return np.argwhere(level_array != c.VALUE_OCCUPIED_CELL)
else:
return np.argwhere(level_array == c.VALUE_OCCUPIED_CELL)
def do_init(self):
# Global Entities
list_of_all_positions = ([tuple(f) for f in self.get_coordinates_for_symbol(c.SYMBOL_WALL, negate=True)])
entities = Entities(list_of_all_positions)
# Walls
walls = Walls.from_coordinates(self.get_coordinates_for_symbol(c.SYMBOL_WALL), self.size)
entities.add_items({c.WALLS: walls})
# Agents
entities.add_items({c.AGENT: Agents(self.size)})
# All other
for es_name in self.e_p_dict:
e_class, e_kwargs = self.e_p_dict[es_name]['class'], self.e_p_dict[es_name]['kwargs']
e_kwargs = e_kwargs if e_kwargs else {}
if hasattr(e_class, 'symbol') and e_class.symbol is not None:
symbols = e_class.symbol
if isinstance(symbols, (str, int, float)):
symbols = [symbols]
for symbol in symbols:
level_array = h.one_hot_level(self._parsed_level, symbol=symbol)
if np.any(level_array):
# TODO: Get rid of this!
e = e_class.from_coordinates(np.argwhere(level_array == c.VALUE_OCCUPIED_CELL).tolist(),
self.size, entity_kwargs=e_kwargs)
else:
raise ValueError(f'No {e_class} (Symbol: {e_class.symbol}) could be found!\n'
f'Check your level file!')
else:
e = e_class(self.size, **e_kwargs)
entities.add_items({e.name: e})
return entities