added changes from code submission branch and coin entity

This commit is contained in:
Chanumask
2024-09-06 11:01:42 +02:00
parent 33e40deecf
commit 5476f617c6
42 changed files with 1429 additions and 68 deletions

View File

@@ -81,7 +81,7 @@ class Factory(gym.Env):
def __init__(self, config_file: Union[str, PathLike], custom_modules_path: Union[None, PathLike] = None,
custom_level_path: Union[None, PathLike] = None):
"""
Initializes the marl-factory-grid as Gym environment.
Initializes the rl-factory-grid as Gym environment.
:param config_file: Path to the configuration file.
:type config_file: Union[str, PathLike]
@@ -271,15 +271,37 @@ class Factory(gym.Env):
if not self._renderer: # lazy init
from marl_factory_grid.utils.renderer import Renderer
global Renderer
self._renderer = Renderer(self.map.level_shape, view_radius=self.conf.pomdp_r, fps=10)
self._renderer = Renderer(self.map.level_shape, view_radius=self.conf.pomdp_r, fps=10)
render_entities = self.state.entities.render()
# Hide entities where certain conditions are met (e.g., amount <= 0 for DirtPiles)
render_entities = self.filter_entities(render_entities)
# Mask entities based on dynamic conditions instead of hardcoding level-specific logic
if self.conf['General']['level_name'] == 'two_rooms':
render_entities = self.mask_entities(render_entities)
if self.conf.pomdp_r:
for render_entity in render_entities:
if render_entity.name == c.AGENT:
render_entity.aux = self.obs_builder.curr_lightmaps[render_entity.real_name]
return self._renderer.render(render_entities, self._recorder)
def filter_entities(self, entities):
""" Generalized method to filter out entities that shouldn't be rendered. """
if 'DirtPiles' in self.state.entities.keys():
entities = [entity for entity in entities if not (entity.name == 'DirtPiles' and entity.amount <= 0)]
return entities
def mask_entities(self, entities):
""" Generalized method to mask entities based on dynamic conditions. """
for entity in entities:
if entity.name == 'CoinPiles':
entity.mask = 'Destinations'
entity.mask_value = 1
return entities
def set_recorder(self, recorder):
self._recorder = recorder
@@ -298,7 +320,7 @@ class Factory(gym.Env):
summary.update({entity_group.name.lower(): entity_group.summarize_states()})
# TODO Section End ########
for key in list(summary.keys()):
if key not in ['step', 'walls', 'doors', 'agents', 'items', 'dirtPiles', 'batteries']:
if key not in ['step', 'walls', 'doors', 'agents', 'items', 'dirtPiles', 'batteries', 'coinPiles']:
del summary[key]
return summary