mirror of
				https://github.com/illiumst/marl-factory-grid.git
				synced 2025-10-31 04:37:25 +01:00 
			
		
		
		
	initial n steps
This commit is contained in:
		| @@ -35,24 +35,17 @@ Entities: | ||||
|     # We need a special spawn rule... | ||||
|     spawnrule: | ||||
|       # ...which assigns the destinations per agent | ||||
|       SpawnDestinationsPerAgent: | ||||
|         # we use this parameter | ||||
|         coords_or_quantity: | ||||
|           # to enable and assign special positions per agent | ||||
|           Wolfgang: 1 | ||||
|           Karl-Heinz: 1 | ||||
|           Kevin: 1 | ||||
|           Juergen: 1 | ||||
|           Soeren: 1 | ||||
|           Walter: 1 | ||||
|           Siggi: 1 | ||||
|           Dennis: 1 | ||||
|       SpawnDestinationOnAgent: {} | ||||
|  | ||||
| Rules: | ||||
|   # Utilities | ||||
|   WatchCollisions: | ||||
|     done_at_collisions: false | ||||
|  | ||||
|   # Initial random walk | ||||
|   DoRandomInitialSteps: | ||||
|     random_steps: 10 | ||||
|  | ||||
|   # Done Conditions | ||||
|   DoneAtDestinationReach: | ||||
|     condition: simultanious | ||||
|   | ||||
| @@ -136,6 +136,7 @@ class Factory(gym.Env): | ||||
|  | ||||
|         # All is set up, trigger entity spawn with variable pos | ||||
|         self.state.rules.do_all_reset(self.state) | ||||
|         self.state.rules.do_all_post_spawn_reset(self.state) | ||||
|  | ||||
|         # Build initial observations for all agents | ||||
|         self.obs_builder.reset(self.state) | ||||
|   | ||||
| @@ -4,15 +4,17 @@ from random import shuffle | ||||
| from typing import Dict | ||||
|  | ||||
| from marl_factory_grid.environment.groups.objects import Objects | ||||
| from marl_factory_grid.utils.helpers import POS_MASK | ||||
| from marl_factory_grid.utils.helpers import POS_MASK_8, POS_MASK_4 | ||||
|  | ||||
|  | ||||
| class Entities(Objects): | ||||
|     _entity = Objects | ||||
|  | ||||
|     @staticmethod | ||||
|     def neighboring_positions(pos): | ||||
|         return [tuple(x) for x in (POS_MASK + pos).reshape(-1, 2)] | ||||
|     def neighboring_positions(self, pos): | ||||
|         return [tuple(x) for x in (POS_MASK_8 + pos).reshape(-1, 2) if tuple(x) in self._floor_positions] | ||||
|  | ||||
|     def neighboring_4_positions(self, pos): | ||||
|         return [tuple(x) for x in (POS_MASK_4 + pos) if tuple(x) in self._floor_positions] | ||||
|  | ||||
|     def get_entities_near_pos(self, pos): | ||||
|         return [y for x in itemgetter(*self.neighboring_positions(pos))(self.pos_dict) for y in x] | ||||
|   | ||||
| @@ -1,7 +1,10 @@ | ||||
| import abc | ||||
| import random | ||||
| from random import shuffle | ||||
| from typing import List, Collection | ||||
|  | ||||
| import numpy as np | ||||
|  | ||||
| from marl_factory_grid.environment import rewards as r, constants as c | ||||
| from marl_factory_grid.environment.entity.agent import Agent | ||||
| from marl_factory_grid.utils import helpers as h | ||||
| @@ -37,6 +40,15 @@ class Rule(abc.ABC): | ||||
|         TODO | ||||
|  | ||||
|  | ||||
|         :return: | ||||
|         """ | ||||
|         return [] | ||||
|  | ||||
|     def on_reset_post_spawn(self, state) -> List[TickResult]: | ||||
|         """ | ||||
|         TODO | ||||
|  | ||||
|  | ||||
|         :return: | ||||
|         """ | ||||
|         return [] | ||||
| @@ -230,3 +242,33 @@ class WatchCollisions(Rule): | ||||
|             if inter_entity_collision_detected or collision_in_step: | ||||
|                 return [DoneResult(validity=c.VALID, identifier=c.COLLISION, reward=self.reward_at_done)] | ||||
|         return [] | ||||
|  | ||||
|  | ||||
| class DoRandomInitialSteps(Rule): | ||||
|     def __init__(self, random_steps: 10): | ||||
|         """ | ||||
|         Special rule which spawns destinations, that are bound to a single agent a fixed set of positions. | ||||
|         Useful for introducing specialists, etc. .. | ||||
|  | ||||
|         !!! This rule does not introduce any reward or done condition. | ||||
|  | ||||
|         :param random_steps:  Number of random steps agents perform in an environment. | ||||
|                                 Useful in the `N-Puzzle` configuration. | ||||
|         """ | ||||
|         super().__init__() | ||||
|         self.random_steps = random_steps | ||||
|  | ||||
|     def on_reset_post_spawn(self, state): | ||||
|         state.print("Random Initial Steps initiated....") | ||||
|         for _ in range(self.random_steps): | ||||
|             # Find free positions | ||||
|             free_pos = state.random_free_position | ||||
|             neighbor_positions = state.entities.neighboring_4_positions(free_pos) | ||||
|             random.shuffle(neighbor_positions) | ||||
|             chosen_agent = h.get_first(state[c.AGENT].by_pos(neighbor_positions.pop())) | ||||
|             assert isinstance(chosen_agent, Agent) | ||||
|             valid = chosen_agent.move(free_pos, state) | ||||
|             valid_str = " not" if not valid else "" | ||||
|             state.print(f"Move {chosen_agent.name} from {chosen_agent.last_pos} " | ||||
|                         f"to {chosen_agent.pos} was{valid_str} valid.") | ||||
|         pass | ||||
|   | ||||
| @@ -105,10 +105,10 @@ class SpawnDestinationsPerAgent(Rule): | ||||
|  | ||||
|         !!! This rule does not introduce any reward or done condition. | ||||
|  | ||||
|         :param coords_or_quantity: Please provide a dictionary with agent names as keys; and a list of possible | ||||
|                                      destination coords as value. Example: {Wolfgang: [(0, 0), (1, 1), ...]} | ||||
|         :param coords_or_quantity:  Please provide a dictionary with agent names as keys; and a list of possible | ||||
|                                         destination coords as value. Example: {Wolfgang: [(0, 0), (1, 1), ...]} | ||||
|         """ | ||||
|         super(Rule, self).__init__() | ||||
|         super().__init__() | ||||
|         self.per_agent_positions = dict() | ||||
|         for agent_name, value in coords_or_quantity.items(): | ||||
|             if isinstance(value, int): | ||||
| @@ -143,3 +143,25 @@ class SpawnDestinationsPerAgent(Rule): | ||||
|                     continue | ||||
|             state[d.DESTINATION].add_item(destination) | ||||
|         pass | ||||
|  | ||||
|  | ||||
| class SpawnDestinationOnAgent(Rule): | ||||
|     def __init__(self): | ||||
|         """ | ||||
|         Special rule which spawns a single destination bound to a single agent just `below` him. Usefull for | ||||
|         the `N-Puzzle` configurations. | ||||
|  | ||||
|         !!! This rule does not introduce any reward or done condition. | ||||
|  | ||||
|         :param coords_or_quantity:  Please provide a dictionary with agent names as keys; and a list of possible | ||||
|                                         destination coords as value. Example: {Wolfgang: [(0, 0), (1, 1), ...]} | ||||
|         """ | ||||
|         super().__init__() | ||||
|  | ||||
|     def on_reset(self, state: Gamestate): | ||||
|         state.print("Spawn Desitnations") | ||||
|         for agent in state[c.AGENT]: | ||||
|             destination = Destination(agent.pos, bind_to=agent) | ||||
|             state[d.DESTINATION].add_item(destination) | ||||
|             assert len(state[d.DESTINATION].by_pos(agent.pos)) == 1 | ||||
|         pass | ||||
|   | ||||
| @@ -27,9 +27,11 @@ IGNORED_DF_COLUMNS = ['Episode', 'Run',  # For plotting, which values are ignore | ||||
|                       'train_step', 'step', 'index', 'dirt_amount', 'dirty_pos_count', 'terminal_observation', | ||||
|                       'episode'] | ||||
|  | ||||
| POS_MASK = np.asarray([[[-1, -1], [0, -1], [1, -1]], | ||||
|                        [[-1, 0], [0, 0], [1, 0]], | ||||
|                        [[-1, 1], [0, 1], [1, 1]]]) | ||||
| POS_MASK_8 = np.asarray([[[-1, -1], [0, -1], [1, -1]], | ||||
|                          [[-1, 0],  [0, 0],  [1, 0]], | ||||
|                          [[-1, 1],  [0, 1],  [1, 1]]]) | ||||
|  | ||||
| POS_MASK_4 = np.asarray([[0, -1], [-1, 0], [1, 0], [-1, 1], [0, 1], [1, 1]]) | ||||
|  | ||||
| MOVEMAP = defaultdict(lambda: (0, 0), | ||||
|                       {c.NORTH: (-1, 0), c.NORTHEAST: (-1, 1), | ||||
|   | ||||
| @@ -47,6 +47,12 @@ class StepRules: | ||||
|                 state.print(rule_reset_printline) | ||||
|         return c.VALID | ||||
|  | ||||
|     def do_all_post_spawn_reset(self, state): | ||||
|         for rule in self.rules: | ||||
|             if rule_reset_printline := rule.on_reset_post_spawn(state): | ||||
|                 state.print(rule_reset_printline) | ||||
|         return c.VALID | ||||
|  | ||||
|     def tick_step_all(self, state): | ||||
|         results = list() | ||||
|         for rule in self.rules: | ||||
|   | ||||
| @@ -26,10 +26,10 @@ if __name__ == '__main__': | ||||
|  | ||||
|     if explain_config: | ||||
|         ce = ConfigExplainer() | ||||
|         ce.save_all(run_path / 'all_out.yaml') | ||||
|         ce.save_all(run_path / 'all_available_configs.yaml') | ||||
|  | ||||
|     # Path to config File | ||||
|     path = Path('marl_factory_grid/configs/clean_and_bring.yaml') | ||||
|     path = Path('marl_factory_grid/configs/eight_puzzle.yaml') | ||||
|  | ||||
|     # Env Init | ||||
|     factory = Factory(path) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Steffen Illium
					Steffen Illium