mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-07-05 00:51:35 +02:00
All relevant functional code for A2C Dirt Quadrant setting with small changes to the environment + Different configs for single agent and multiagent settings
This commit is contained in:
@ -5,6 +5,7 @@ from typing import List, Collection
|
||||
|
||||
import numpy as np
|
||||
|
||||
import marl_factory_grid
|
||||
from marl_factory_grid.environment import rewards as r, constants as c
|
||||
from marl_factory_grid.environment.entity.agent import Agent
|
||||
from marl_factory_grid.utils import helpers as h
|
||||
@ -180,6 +181,11 @@ class SpawnAgents(Rule):
|
||||
pass
|
||||
|
||||
def on_reset(self, state):
|
||||
spawn_rule = None
|
||||
for rule in state.rules.rules:
|
||||
if isinstance(rule, marl_factory_grid.environment.rules.AgentSpawnRule):
|
||||
spawn_rule = rule.spawn_rule
|
||||
|
||||
agents = state[c.AGENT]
|
||||
for agent_name, agent_conf in state.agents_conf.items():
|
||||
empty_positions = state.entities.empty_positions
|
||||
@ -187,10 +193,9 @@ class SpawnAgents(Rule):
|
||||
observations = agent_conf['observations'].copy()
|
||||
positions = agent_conf['positions'].copy()
|
||||
other = agent_conf['other'].copy()
|
||||
positions_pointer = agent_conf['pos_pointer']
|
||||
|
||||
# Spawn agent on random position if multiple spawn points are provided
|
||||
func = random.choice if len(positions) else h.get_first
|
||||
if position := func([x for x in positions if x in empty_positions]):
|
||||
if position := self._get_position(spawn_rule, positions, empty_positions, positions_pointer):
|
||||
assert state.check_pos_validity(position), 'smth went wrong....'
|
||||
agents.add_item(Agent(actions, observations, position, str_ident=agent_name, **other))
|
||||
elif positions:
|
||||
@ -200,6 +205,20 @@ class SpawnAgents(Rule):
|
||||
agents.add_item(Agent(actions, observations, empty_positions.pop(), str_ident=agent_name, **other))
|
||||
return []
|
||||
|
||||
def _get_position(self, spawn_rule, positions, empty_positions, positions_pointer):
|
||||
if spawn_rule and spawn_rule == "random":
|
||||
position = random.choice(([x for x in positions if x in empty_positions]))
|
||||
elif spawn_rule and spawn_rule == "order":
|
||||
position = ([x for x in positions if x in empty_positions])[positions_pointer]
|
||||
else:
|
||||
position = h.get_first([x for x in positions if x in empty_positions])
|
||||
|
||||
return position
|
||||
|
||||
class AgentSpawnRule(Rule):
|
||||
def __init__(self, spawn_rule):
|
||||
self.spawn_rule = spawn_rule
|
||||
super().__init__()
|
||||
|
||||
class DoneAtMaxStepsReached(Rule):
|
||||
|
||||
|
Reference in New Issue
Block a user