mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-07-05 09:01:36 +02:00
environment code changes for RL settings
This commit is contained in:
@ -65,8 +65,12 @@ def add_env_props(cfg):
|
|||||||
_ = factory.reset()
|
_ = factory.reset()
|
||||||
|
|
||||||
# Agent Init
|
# Agent Init
|
||||||
cfg['agent'].update(dict(observation_size=list(factory.observation_space[0].shape),
|
if len(factory.state.moving_entites) == 1: # Single agent setting
|
||||||
n_actions=factory.action_space[0].n))
|
observation_size = list(factory.observation_space.shape)
|
||||||
|
else: # Multi-agent setting
|
||||||
|
observation_size = list(factory.observation_space[0].shape)
|
||||||
|
cfg['agent'].update(dict(observation_size=observation_size, n_actions=factory.action_space[0].n))
|
||||||
|
|
||||||
return factory
|
return factory
|
||||||
|
|
||||||
|
|
||||||
|
@ -188,7 +188,9 @@ class SpawnAgents(Rule):
|
|||||||
positions = agent_conf['positions'].copy()
|
positions = agent_conf['positions'].copy()
|
||||||
other = agent_conf['other'].copy()
|
other = agent_conf['other'].copy()
|
||||||
|
|
||||||
if position := h.get_first(x for x in positions if x in empty_positions):
|
# Spawn agent on random position if multiple spawn points are provided
|
||||||
|
func = random.choice if len(positions) else h.get_first
|
||||||
|
if position := func([x for x in positions if x in empty_positions]):
|
||||||
assert state.check_pos_validity(position), 'smth went wrong....'
|
assert state.check_pos_validity(position), 'smth went wrong....'
|
||||||
agents.add_item(Agent(actions, observations, position, str_ident=agent_name, **other))
|
agents.add_item(Agent(actions, observations, position, str_ident=agent_name, **other))
|
||||||
elif positions:
|
elif positions:
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import ast
|
||||||
from marl_factory_grid.environment import constants as c
|
from marl_factory_grid.environment import constants as c
|
||||||
from marl_factory_grid.environment.groups.collection import Collection
|
from marl_factory_grid.environment.groups.collection import Collection
|
||||||
from marl_factory_grid.modules.clean_up.entitites import DirtPile
|
from marl_factory_grid.modules.clean_up.entitites import DirtPile
|
||||||
@ -74,11 +75,19 @@ class DirtPiles(Collection):
|
|||||||
print("Exiting....")
|
print("Exiting....")
|
||||||
exit()
|
exit()
|
||||||
coords_or_quantity = coords_or_quantity if coords_or_quantity else self.coords_or_quantity
|
coords_or_quantity = coords_or_quantity if coords_or_quantity else self.coords_or_quantity
|
||||||
|
if isinstance(coords_or_quantity, int):
|
||||||
n_new = int(abs(coords_or_quantity + (state.rng.uniform(-self.n_var, self.n_var))))
|
n_new = int(abs(coords_or_quantity + (state.rng.uniform(-self.n_var, self.n_var))))
|
||||||
n_new = state.get_n_random_free_positions(n_new)
|
n_new = state.get_n_random_free_positions(n_new)
|
||||||
|
else:
|
||||||
|
coords_or_quantity = ast.literal_eval(coords_or_quantity)
|
||||||
|
if isinstance(coords_or_quantity[0], int):
|
||||||
|
n_new = [coords_or_quantity]
|
||||||
|
else:
|
||||||
|
n_new = [pos for pos in coords_or_quantity]
|
||||||
|
|
||||||
|
amounts = [amount if amount else (self.initial_amount ) # removed rng amount
|
||||||
|
for _ in range(len(n_new))]
|
||||||
|
|
||||||
amounts = [amount if amount else (self.initial_amount + state.rng.uniform(-self.amount_var, self.amount_var))
|
|
||||||
for _ in range(coords_or_quantity)]
|
|
||||||
spawn_counter = 0
|
spawn_counter = 0
|
||||||
for idx, (pos, a) in enumerate(zip(n_new, amounts)):
|
for idx, (pos, a) in enumerate(zip(n_new, amounts)):
|
||||||
if not self.global_amount > self.max_global_amount:
|
if not self.global_amount > self.max_global_amount:
|
||||||
|
Reference in New Issue
Block a user