mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-07-15 08:42:40 +02:00
initial n steps
This commit is contained in:
@ -35,24 +35,17 @@ Entities:
|
|||||||
# We need a special spawn rule...
|
# We need a special spawn rule...
|
||||||
spawnrule:
|
spawnrule:
|
||||||
# ...which assigns the destinations per agent
|
# ...which assigns the destinations per agent
|
||||||
SpawnDestinationsPerAgent:
|
SpawnDestinationOnAgent: {}
|
||||||
# we use this parameter
|
|
||||||
coords_or_quantity:
|
|
||||||
# to enable and assign special positions per agent
|
|
||||||
Wolfgang: 1
|
|
||||||
Karl-Heinz: 1
|
|
||||||
Kevin: 1
|
|
||||||
Juergen: 1
|
|
||||||
Soeren: 1
|
|
||||||
Walter: 1
|
|
||||||
Siggi: 1
|
|
||||||
Dennis: 1
|
|
||||||
|
|
||||||
Rules:
|
Rules:
|
||||||
# Utilities
|
# Utilities
|
||||||
WatchCollisions:
|
WatchCollisions:
|
||||||
done_at_collisions: false
|
done_at_collisions: false
|
||||||
|
|
||||||
|
# Initial random walk
|
||||||
|
DoRandomInitialSteps:
|
||||||
|
random_steps: 10
|
||||||
|
|
||||||
# Done Conditions
|
# Done Conditions
|
||||||
DoneAtDestinationReach:
|
DoneAtDestinationReach:
|
||||||
condition: simultanious
|
condition: simultanious
|
||||||
|
@ -136,6 +136,7 @@ class Factory(gym.Env):
|
|||||||
|
|
||||||
# All is set up, trigger entity spawn with variable pos
|
# All is set up, trigger entity spawn with variable pos
|
||||||
self.state.rules.do_all_reset(self.state)
|
self.state.rules.do_all_reset(self.state)
|
||||||
|
self.state.rules.do_all_post_spawn_reset(self.state)
|
||||||
|
|
||||||
# Build initial observations for all agents
|
# Build initial observations for all agents
|
||||||
self.obs_builder.reset(self.state)
|
self.obs_builder.reset(self.state)
|
||||||
|
@ -4,15 +4,17 @@ from random import shuffle
|
|||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
from marl_factory_grid.environment.groups.objects import Objects
|
from marl_factory_grid.environment.groups.objects import Objects
|
||||||
from marl_factory_grid.utils.helpers import POS_MASK
|
from marl_factory_grid.utils.helpers import POS_MASK_8, POS_MASK_4
|
||||||
|
|
||||||
|
|
||||||
class Entities(Objects):
|
class Entities(Objects):
|
||||||
_entity = Objects
|
_entity = Objects
|
||||||
|
|
||||||
@staticmethod
|
def neighboring_positions(self, pos):
|
||||||
def neighboring_positions(pos):
|
return [tuple(x) for x in (POS_MASK_8 + pos).reshape(-1, 2) if tuple(x) in self._floor_positions]
|
||||||
return [tuple(x) for x in (POS_MASK + pos).reshape(-1, 2)]
|
|
||||||
|
def neighboring_4_positions(self, pos):
|
||||||
|
return [tuple(x) for x in (POS_MASK_4 + pos) if tuple(x) in self._floor_positions]
|
||||||
|
|
||||||
def get_entities_near_pos(self, pos):
|
def get_entities_near_pos(self, pos):
|
||||||
return [y for x in itemgetter(*self.neighboring_positions(pos))(self.pos_dict) for y in x]
|
return [y for x in itemgetter(*self.neighboring_positions(pos))(self.pos_dict) for y in x]
|
||||||
|
@ -1,7 +1,10 @@
|
|||||||
import abc
|
import abc
|
||||||
|
import random
|
||||||
from random import shuffle
|
from random import shuffle
|
||||||
from typing import List, Collection
|
from typing import List, Collection
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from marl_factory_grid.environment import rewards as r, constants as c
|
from marl_factory_grid.environment import rewards as r, constants as c
|
||||||
from marl_factory_grid.environment.entity.agent import Agent
|
from marl_factory_grid.environment.entity.agent import Agent
|
||||||
from marl_factory_grid.utils import helpers as h
|
from marl_factory_grid.utils import helpers as h
|
||||||
@ -37,6 +40,15 @@ class Rule(abc.ABC):
|
|||||||
TODO
|
TODO
|
||||||
|
|
||||||
|
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return []
|
||||||
|
|
||||||
|
def on_reset_post_spawn(self, state) -> List[TickResult]:
|
||||||
|
"""
|
||||||
|
TODO
|
||||||
|
|
||||||
|
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
return []
|
return []
|
||||||
@ -230,3 +242,33 @@ class WatchCollisions(Rule):
|
|||||||
if inter_entity_collision_detected or collision_in_step:
|
if inter_entity_collision_detected or collision_in_step:
|
||||||
return [DoneResult(validity=c.VALID, identifier=c.COLLISION, reward=self.reward_at_done)]
|
return [DoneResult(validity=c.VALID, identifier=c.COLLISION, reward=self.reward_at_done)]
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class DoRandomInitialSteps(Rule):
|
||||||
|
def __init__(self, random_steps: 10):
|
||||||
|
"""
|
||||||
|
Special rule which spawns destinations, that are bound to a single agent a fixed set of positions.
|
||||||
|
Useful for introducing specialists, etc. ..
|
||||||
|
|
||||||
|
!!! This rule does not introduce any reward or done condition.
|
||||||
|
|
||||||
|
:param random_steps: Number of random steps agents perform in an environment.
|
||||||
|
Useful in the `N-Puzzle` configuration.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.random_steps = random_steps
|
||||||
|
|
||||||
|
def on_reset_post_spawn(self, state):
|
||||||
|
state.print("Random Initial Steps initiated....")
|
||||||
|
for _ in range(self.random_steps):
|
||||||
|
# Find free positions
|
||||||
|
free_pos = state.random_free_position
|
||||||
|
neighbor_positions = state.entities.neighboring_4_positions(free_pos)
|
||||||
|
random.shuffle(neighbor_positions)
|
||||||
|
chosen_agent = h.get_first(state[c.AGENT].by_pos(neighbor_positions.pop()))
|
||||||
|
assert isinstance(chosen_agent, Agent)
|
||||||
|
valid = chosen_agent.move(free_pos, state)
|
||||||
|
valid_str = " not" if not valid else ""
|
||||||
|
state.print(f"Move {chosen_agent.name} from {chosen_agent.last_pos} "
|
||||||
|
f"to {chosen_agent.pos} was{valid_str} valid.")
|
||||||
|
pass
|
||||||
|
@ -105,10 +105,10 @@ class SpawnDestinationsPerAgent(Rule):
|
|||||||
|
|
||||||
!!! This rule does not introduce any reward or done condition.
|
!!! This rule does not introduce any reward or done condition.
|
||||||
|
|
||||||
:param coords_or_quantity: Please provide a dictionary with agent names as keys; and a list of possible
|
:param coords_or_quantity: Please provide a dictionary with agent names as keys; and a list of possible
|
||||||
destination coords as value. Example: {Wolfgang: [(0, 0), (1, 1), ...]}
|
destination coords as value. Example: {Wolfgang: [(0, 0), (1, 1), ...]}
|
||||||
"""
|
"""
|
||||||
super(Rule, self).__init__()
|
super().__init__()
|
||||||
self.per_agent_positions = dict()
|
self.per_agent_positions = dict()
|
||||||
for agent_name, value in coords_or_quantity.items():
|
for agent_name, value in coords_or_quantity.items():
|
||||||
if isinstance(value, int):
|
if isinstance(value, int):
|
||||||
@ -143,3 +143,25 @@ class SpawnDestinationsPerAgent(Rule):
|
|||||||
continue
|
continue
|
||||||
state[d.DESTINATION].add_item(destination)
|
state[d.DESTINATION].add_item(destination)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SpawnDestinationOnAgent(Rule):
|
||||||
|
def __init__(self):
|
||||||
|
"""
|
||||||
|
Special rule which spawns a single destination bound to a single agent just `below` him. Usefull for
|
||||||
|
the `N-Puzzle` configurations.
|
||||||
|
|
||||||
|
!!! This rule does not introduce any reward or done condition.
|
||||||
|
|
||||||
|
:param coords_or_quantity: Please provide a dictionary with agent names as keys; and a list of possible
|
||||||
|
destination coords as value. Example: {Wolfgang: [(0, 0), (1, 1), ...]}
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def on_reset(self, state: Gamestate):
|
||||||
|
state.print("Spawn Desitnations")
|
||||||
|
for agent in state[c.AGENT]:
|
||||||
|
destination = Destination(agent.pos, bind_to=agent)
|
||||||
|
state[d.DESTINATION].add_item(destination)
|
||||||
|
assert len(state[d.DESTINATION].by_pos(agent.pos)) == 1
|
||||||
|
pass
|
||||||
|
@ -27,9 +27,11 @@ IGNORED_DF_COLUMNS = ['Episode', 'Run', # For plotting, which values are ignore
|
|||||||
'train_step', 'step', 'index', 'dirt_amount', 'dirty_pos_count', 'terminal_observation',
|
'train_step', 'step', 'index', 'dirt_amount', 'dirty_pos_count', 'terminal_observation',
|
||||||
'episode']
|
'episode']
|
||||||
|
|
||||||
POS_MASK = np.asarray([[[-1, -1], [0, -1], [1, -1]],
|
POS_MASK_8 = np.asarray([[[-1, -1], [0, -1], [1, -1]],
|
||||||
[[-1, 0], [0, 0], [1, 0]],
|
[[-1, 0], [0, 0], [1, 0]],
|
||||||
[[-1, 1], [0, 1], [1, 1]]])
|
[[-1, 1], [0, 1], [1, 1]]])
|
||||||
|
|
||||||
|
POS_MASK_4 = np.asarray([[0, -1], [-1, 0], [1, 0], [-1, 1], [0, 1], [1, 1]])
|
||||||
|
|
||||||
MOVEMAP = defaultdict(lambda: (0, 0),
|
MOVEMAP = defaultdict(lambda: (0, 0),
|
||||||
{c.NORTH: (-1, 0), c.NORTHEAST: (-1, 1),
|
{c.NORTH: (-1, 0), c.NORTHEAST: (-1, 1),
|
||||||
|
@ -47,6 +47,12 @@ class StepRules:
|
|||||||
state.print(rule_reset_printline)
|
state.print(rule_reset_printline)
|
||||||
return c.VALID
|
return c.VALID
|
||||||
|
|
||||||
|
def do_all_post_spawn_reset(self, state):
|
||||||
|
for rule in self.rules:
|
||||||
|
if rule_reset_printline := rule.on_reset_post_spawn(state):
|
||||||
|
state.print(rule_reset_printline)
|
||||||
|
return c.VALID
|
||||||
|
|
||||||
def tick_step_all(self, state):
|
def tick_step_all(self, state):
|
||||||
results = list()
|
results = list()
|
||||||
for rule in self.rules:
|
for rule in self.rules:
|
||||||
|
@ -26,10 +26,10 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
if explain_config:
|
if explain_config:
|
||||||
ce = ConfigExplainer()
|
ce = ConfigExplainer()
|
||||||
ce.save_all(run_path / 'all_out.yaml')
|
ce.save_all(run_path / 'all_available_configs.yaml')
|
||||||
|
|
||||||
# Path to config File
|
# Path to config File
|
||||||
path = Path('marl_factory_grid/configs/clean_and_bring.yaml')
|
path = Path('marl_factory_grid/configs/eight_puzzle.yaml')
|
||||||
|
|
||||||
# Env Init
|
# Env Init
|
||||||
factory = Factory(path)
|
factory = Factory(path)
|
||||||
|
Reference in New Issue
Block a user