mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-12-06 15:40:37 +01:00
added changes from code submission branch and coin entity
This commit is contained in:
@@ -81,7 +81,7 @@ class Factory(gym.Env):
|
||||
def __init__(self, config_file: Union[str, PathLike], custom_modules_path: Union[None, PathLike] = None,
|
||||
custom_level_path: Union[None, PathLike] = None):
|
||||
"""
|
||||
Initializes the marl-factory-grid as Gym environment.
|
||||
Initializes the rl-factory-grid as Gym environment.
|
||||
|
||||
:param config_file: Path to the configuration file.
|
||||
:type config_file: Union[str, PathLike]
|
||||
@@ -271,15 +271,37 @@ class Factory(gym.Env):
|
||||
if not self._renderer: # lazy init
|
||||
from marl_factory_grid.utils.renderer import Renderer
|
||||
global Renderer
|
||||
self._renderer = Renderer(self.map.level_shape, view_radius=self.conf.pomdp_r, fps=10)
|
||||
self._renderer = Renderer(self.map.level_shape, view_radius=self.conf.pomdp_r, fps=10)
|
||||
|
||||
render_entities = self.state.entities.render()
|
||||
|
||||
# Hide entities where certain conditions are met (e.g., amount <= 0 for DirtPiles)
|
||||
render_entities = self.filter_entities(render_entities)
|
||||
|
||||
# Mask entities based on dynamic conditions instead of hardcoding level-specific logic
|
||||
if self.conf['General']['level_name'] == 'two_rooms':
|
||||
render_entities = self.mask_entities(render_entities)
|
||||
|
||||
if self.conf.pomdp_r:
|
||||
for render_entity in render_entities:
|
||||
if render_entity.name == c.AGENT:
|
||||
render_entity.aux = self.obs_builder.curr_lightmaps[render_entity.real_name]
|
||||
return self._renderer.render(render_entities, self._recorder)
|
||||
|
||||
def filter_entities(self, entities):
|
||||
""" Generalized method to filter out entities that shouldn't be rendered. """
|
||||
if 'DirtPiles' in self.state.entities.keys():
|
||||
entities = [entity for entity in entities if not (entity.name == 'DirtPiles' and entity.amount <= 0)]
|
||||
return entities
|
||||
|
||||
def mask_entities(self, entities):
|
||||
""" Generalized method to mask entities based on dynamic conditions. """
|
||||
for entity in entities:
|
||||
if entity.name == 'CoinPiles':
|
||||
entity.mask = 'Destinations'
|
||||
entity.mask_value = 1
|
||||
return entities
|
||||
|
||||
def set_recorder(self, recorder):
|
||||
self._recorder = recorder
|
||||
|
||||
@@ -298,7 +320,7 @@ class Factory(gym.Env):
|
||||
summary.update({entity_group.name.lower(): entity_group.summarize_states()})
|
||||
# TODO Section End ########
|
||||
for key in list(summary.keys()):
|
||||
if key not in ['step', 'walls', 'doors', 'agents', 'items', 'dirtPiles', 'batteries']:
|
||||
if key not in ['step', 'walls', 'doors', 'agents', 'items', 'dirtPiles', 'batteries', 'coinPiles']:
|
||||
del summary[key]
|
||||
return summary
|
||||
|
||||
|
||||
Reference in New Issue
Block a user