From 48d708bbcd0ceefb64fac91213a6b768d1db2874 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Julian=20Sch=C3=B6nberger?= Date: Thu, 2 May 2024 10:54:46 +0200 Subject: [PATCH] environment code changes for RL settings --- marl_factory_grid/algorithms/utils.py | 8 ++++++-- marl_factory_grid/environment/rules.py | 4 +++- marl_factory_grid/modules/clean_up/groups.py | 17 +++++++++++++---- 3 files changed, 22 insertions(+), 7 deletions(-) diff --git a/marl_factory_grid/algorithms/utils.py b/marl_factory_grid/algorithms/utils.py index b472cba..562a95d 100644 --- a/marl_factory_grid/algorithms/utils.py +++ b/marl_factory_grid/algorithms/utils.py @@ -65,8 +65,12 @@ def add_env_props(cfg): _ = factory.reset() # Agent Init - cfg['agent'].update(dict(observation_size=list(factory.observation_space[0].shape), - n_actions=factory.action_space[0].n)) + if len(factory.state.moving_entites) == 1: # Single agent setting + observation_size = list(factory.observation_space.shape) + else: # Multi-agent setting + observation_size = list(factory.observation_space[0].shape) + cfg['agent'].update(dict(observation_size=observation_size, n_actions=factory.action_space[0].n)) + return factory diff --git a/marl_factory_grid/environment/rules.py b/marl_factory_grid/environment/rules.py index 873339e..c3669f1 100644 --- a/marl_factory_grid/environment/rules.py +++ b/marl_factory_grid/environment/rules.py @@ -188,7 +188,9 @@ class SpawnAgents(Rule): positions = agent_conf['positions'].copy() other = agent_conf['other'].copy() - if position := h.get_first(x for x in positions if x in empty_positions): + # Spawn agent on random position if multiple spawn points are provided + func = random.choice if len(positions) else h.get_first + if position := func([x for x in positions if x in empty_positions]): assert state.check_pos_validity(position), 'smth went wrong....' agents.add_item(Agent(actions, observations, position, str_ident=agent_name, **other)) elif positions: diff --git a/marl_factory_grid/modules/clean_up/groups.py b/marl_factory_grid/modules/clean_up/groups.py index de108a1..8a99439 100644 --- a/marl_factory_grid/modules/clean_up/groups.py +++ b/marl_factory_grid/modules/clean_up/groups.py @@ -1,3 +1,4 @@ +import ast from marl_factory_grid.environment import constants as c from marl_factory_grid.environment.groups.collection import Collection from marl_factory_grid.modules.clean_up.entitites import DirtPile @@ -74,11 +75,19 @@ class DirtPiles(Collection): print("Exiting....") exit() coords_or_quantity = coords_or_quantity if coords_or_quantity else self.coords_or_quantity - n_new = int(abs(coords_or_quantity + (state.rng.uniform(-self.n_var, self.n_var)))) - n_new = state.get_n_random_free_positions(n_new) + if isinstance(coords_or_quantity, int): + n_new = int(abs(coords_or_quantity + (state.rng.uniform(-self.n_var, self.n_var)))) + n_new = state.get_n_random_free_positions(n_new) + else: + coords_or_quantity = ast.literal_eval(coords_or_quantity) + if isinstance(coords_or_quantity[0], int): + n_new = [coords_or_quantity] + else: + n_new = [pos for pos in coords_or_quantity] + + amounts = [amount if amount else (self.initial_amount ) # removed rng amount + for _ in range(len(n_new))] - amounts = [amount if amount else (self.initial_amount + state.rng.uniform(-self.amount_var, self.amount_var)) - for _ in range(coords_or_quantity)] spawn_counter = 0 for idx, (pos, a) in enumerate(zip(n_new, amounts)): if not self.global_amount > self.max_global_amount: