environment code changes for RL settings

This commit is contained in:
Julian Schönberger
2024-05-02 10:54:46 +02:00
parent 50bcf5d995
commit 48d708bbcd
3 changed files with 22 additions and 7 deletions

View File

@ -65,8 +65,12 @@ def add_env_props(cfg):
_ = factory.reset()
# Agent Init
cfg['agent'].update(dict(observation_size=list(factory.observation_space[0].shape),
n_actions=factory.action_space[0].n))
if len(factory.state.moving_entites) == 1: # Single agent setting
observation_size = list(factory.observation_space.shape)
else: # Multi-agent setting
observation_size = list(factory.observation_space[0].shape)
cfg['agent'].update(dict(observation_size=observation_size, n_actions=factory.action_space[0].n))
return factory

View File

@ -188,7 +188,9 @@ class SpawnAgents(Rule):
positions = agent_conf['positions'].copy()
other = agent_conf['other'].copy()
if position := h.get_first(x for x in positions if x in empty_positions):
# Spawn agent on random position if multiple spawn points are provided
func = random.choice if len(positions) else h.get_first
if position := func([x for x in positions if x in empty_positions]):
assert state.check_pos_validity(position), 'smth went wrong....'
agents.add_item(Agent(actions, observations, position, str_ident=agent_name, **other))
elif positions:

View File

@ -1,3 +1,4 @@
import ast
from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.groups.collection import Collection
from marl_factory_grid.modules.clean_up.entitites import DirtPile
@ -74,11 +75,19 @@ class DirtPiles(Collection):
print("Exiting....")
exit()
coords_or_quantity = coords_or_quantity if coords_or_quantity else self.coords_or_quantity
n_new = int(abs(coords_or_quantity + (state.rng.uniform(-self.n_var, self.n_var))))
n_new = state.get_n_random_free_positions(n_new)
if isinstance(coords_or_quantity, int):
n_new = int(abs(coords_or_quantity + (state.rng.uniform(-self.n_var, self.n_var))))
n_new = state.get_n_random_free_positions(n_new)
else:
coords_or_quantity = ast.literal_eval(coords_or_quantity)
if isinstance(coords_or_quantity[0], int):
n_new = [coords_or_quantity]
else:
n_new = [pos for pos in coords_or_quantity]
amounts = [amount if amount else (self.initial_amount ) # removed rng amount
for _ in range(len(n_new))]
amounts = [amount if amount else (self.initial_amount + state.rng.uniform(-self.amount_var, self.amount_var))
for _ in range(coords_or_quantity)]
spawn_counter = 0
for idx, (pos, a) in enumerate(zip(n_new, amounts)):
if not self.global_amount > self.max_global_amount: