mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-07-05 09:01:36 +02:00
environment code changes for RL settings
This commit is contained in:
@ -65,8 +65,12 @@ def add_env_props(cfg):
|
||||
_ = factory.reset()
|
||||
|
||||
# Agent Init
|
||||
cfg['agent'].update(dict(observation_size=list(factory.observation_space[0].shape),
|
||||
n_actions=factory.action_space[0].n))
|
||||
if len(factory.state.moving_entites) == 1: # Single agent setting
|
||||
observation_size = list(factory.observation_space.shape)
|
||||
else: # Multi-agent setting
|
||||
observation_size = list(factory.observation_space[0].shape)
|
||||
cfg['agent'].update(dict(observation_size=observation_size, n_actions=factory.action_space[0].n))
|
||||
|
||||
return factory
|
||||
|
||||
|
||||
|
@ -188,7 +188,9 @@ class SpawnAgents(Rule):
|
||||
positions = agent_conf['positions'].copy()
|
||||
other = agent_conf['other'].copy()
|
||||
|
||||
if position := h.get_first(x for x in positions if x in empty_positions):
|
||||
# Spawn agent on random position if multiple spawn points are provided
|
||||
func = random.choice if len(positions) else h.get_first
|
||||
if position := func([x for x in positions if x in empty_positions]):
|
||||
assert state.check_pos_validity(position), 'smth went wrong....'
|
||||
agents.add_item(Agent(actions, observations, position, str_ident=agent_name, **other))
|
||||
elif positions:
|
||||
|
@ -1,3 +1,4 @@
|
||||
import ast
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.environment.groups.collection import Collection
|
||||
from marl_factory_grid.modules.clean_up.entitites import DirtPile
|
||||
@ -74,11 +75,19 @@ class DirtPiles(Collection):
|
||||
print("Exiting....")
|
||||
exit()
|
||||
coords_or_quantity = coords_or_quantity if coords_or_quantity else self.coords_or_quantity
|
||||
n_new = int(abs(coords_or_quantity + (state.rng.uniform(-self.n_var, self.n_var))))
|
||||
n_new = state.get_n_random_free_positions(n_new)
|
||||
if isinstance(coords_or_quantity, int):
|
||||
n_new = int(abs(coords_or_quantity + (state.rng.uniform(-self.n_var, self.n_var))))
|
||||
n_new = state.get_n_random_free_positions(n_new)
|
||||
else:
|
||||
coords_or_quantity = ast.literal_eval(coords_or_quantity)
|
||||
if isinstance(coords_or_quantity[0], int):
|
||||
n_new = [coords_or_quantity]
|
||||
else:
|
||||
n_new = [pos for pos in coords_or_quantity]
|
||||
|
||||
amounts = [amount if amount else (self.initial_amount ) # removed rng amount
|
||||
for _ in range(len(n_new))]
|
||||
|
||||
amounts = [amount if amount else (self.initial_amount + state.rng.uniform(-self.amount_var, self.amount_var))
|
||||
for _ in range(coords_or_quantity)]
|
||||
spawn_counter = 0
|
||||
for idx, (pos, a) in enumerate(zip(n_new, amounts)):
|
||||
if not self.global_amount > self.max_global_amount:
|
||||
|
Reference in New Issue
Block a user