diff --git a/marl_factory_grid/environment/factory.py b/marl_factory_grid/environment/factory.py index c7f1af1..94caac6 100644 --- a/marl_factory_grid/environment/factory.py +++ b/marl_factory_grid/environment/factory.py @@ -100,6 +100,7 @@ class Factory(gym.Env): parsed_entities = self.conf.load_entities() self.map = LevelParser(self.level_filepath, parsed_entities, self.conf.pomdp_r) + self.levels_that_require_masking = ['two_rooms'] # Init for later usage: # noinspection PyTypeChecker @@ -279,7 +280,7 @@ class Factory(gym.Env): render_entities = self.filter_entities(render_entities) # Mask entities based on dynamic conditions instead of hardcoding level-specific logic - if self.conf['General']['level_name'] == 'two_rooms': + if self.conf['General']['level_name'] in self.levels_that_require_masking: render_entities = self.mask_entities(render_entities) if self.conf.pomdp_r: