from itertools import islice from typing import List, Tuple import numpy as np from marl_factory_grid.algorithms.static.utils import points_to_graph from marl_factory_grid.environment import constants as c from marl_factory_grid.environment.entity.entity import Entity from marl_factory_grid.environment.rules import Rule, SpawnAgents from marl_factory_grid.utils.results import Result, DoneResult class StepRules: def __init__(self, *args): """ Manages a collection of rules to be applied at each step of the environment. The StepRules class allows you to organize and apply custom rules during the simulation, ensuring that the corresponding hooks for all rules are called at the appropriate times. :param args: Optional Rule objects to initialize the StepRules with. """ if args: self.rules = list(args) else: self.rules = list() def __repr__(self): return f'Rules{[x.name for x in self]}' def __iter__(self): return iter(self.rules) def append(self, item): assert isinstance(item, Rule) self.rules.append(item) return True def do_all_init(self, state, lvl_map): for rule in self.rules: if rule_init_printline := rule.on_init(state, lvl_map): state.print(rule_init_printline) return c.VALID def do_all_reset(self, state): SpawnAgents().on_reset(state) for rule in self.rules: if rule_reset_printline := rule.on_reset(state): state.print(rule_reset_printline) return c.VALID def do_all_post_spawn_reset(self, state): for rule in self.rules: if rule_reset_printline := rule.on_reset_post_spawn(state): state.print(rule_reset_printline) return c.VALID def tick_step_all(self, state): results = list() for rule in self.rules: if tick_step_result := rule.tick_step(state): results.extend(tick_step_result) return results def tick_pre_step_all(self, state): results = list() for rule in self.rules: if tick_pre_step_result := rule.tick_pre_step(state): results.extend(tick_pre_step_result) return results def tick_post_step_all(self, state): results = list() for rule in self.rules: if tick_post_step_result := rule.tick_post_step(state): results.extend(tick_post_step_result) return results class Gamestate(object): @property def floortile_graph(self): if not self._floortile_graph: self.print("Generating Floorgraph....") self._floortile_graph = points_to_graph(self.entities.floorlist) return self._floortile_graph @property def moving_entites(self): return [y for x in self.entities for y in x if x.var_can_move] def __init__(self, entities, agents_conf, rules: List[Rule], lvl_shape, env_seed=69, verbose=False): """ The `Gamestate` class represents the state of the game environment. :param lvl_shape: The shape of the game level. :type lvl_shape: tuple :param entities: The entities present in the environment. :type entities: Entities :param agents_conf: Agent configurations for the environment. :type agents_conf: Any :param verbose: Controls verbosity in the environment. :type verbose: bool :param rules: Organizes and applies custom rules during the simulation. :type rules: StepRules """ self.lvl_shape = lvl_shape self.entities = entities self.curr_step = 0 self.curr_actions = None self.agents_conf = agents_conf self.verbose = verbose self.rng = np.random.default_rng(env_seed) self.rules = StepRules(*rules) self._floortile_graph = None def reset(self): self.curr_step = 0 self.curr_actions = None def __getitem__(self, item): return self.entities[item] def __iter__(self): return iter(e for e in self.entities.values()) def __contains__(self, item): return item in self.entities def __repr__(self): return f'{self.__class__.__name__}({len(self.entities)} Entitites @ Step {self.curr_step})' @property def random_free_position(self) -> (int, int): """ Returns a single **free** position (x, y), which is **free** for spawning or walking. No Entity at this position posses *var_is_blocking_pos* or *var_can_collide*. :return: Single **free** position. """ return self.get_n_random_free_positions(1)[0] def get_n_random_free_positions(self, n) -> list[tuple[int, int]]: """ Returns a list of *n* **free** positions [(x, y), ... ], which are **free** for spawning or walking. No Entity at this position posses *var_is_blocking_pos* or *var_can_collide*. :return: List of n **free** position. """ return list(islice(self.entities.free_positions_generator, n)) @property def random_position(self) -> (int, int): """ Returns a single available position (x, y), ignores all entity attributes. :return: Single random position. """ return self.get_n_random_positions(1)[0] def get_n_random_positions(self, n) -> list[tuple[int, int]]: """ Returns a list of *n* available positions [(x, y), ... ], ignores all entity attributes. :return: List of n random positions. """ return list(islice(self.entities.floorlist, n)) def tick(self, actions) -> list[Result]: """ Performs a single **Gamestate Tick** by calling the inner rule hooks in sequential order. - tick_pre_step_all: Things to do before the agents do their actions. Statechange, Moving, Spawning etc... - agent tick: Agents do their actions. - tick_step_all: Things to do after the agents did their actions. Statechange, Moving, Spawning etc... - tick_post_step_all: Things to do at the very end of each step. Counting, Reward calculations etc... :return: List of *Result*-objects. """ results = list() self.curr_step += 1 for entity in self.entities.iter_entities(): entity.clear_temp_state() # Main Agent Step results.extend(self.rules.tick_pre_step_all(self)) for idx, action_int in enumerate(actions): agent = self[c.AGENT][idx].clear_temp_state() if not agent.var_is_paralyzed: action = agent.actions[action_int] action_result = action.do(agent, self) results.append(action_result) agent.set_state(action_result) else: self.print(f"{agent.name} is paralied because of: {agent.paralyze_reasons}") continue results.extend(self.rules.tick_step_all(self)) results.extend(self.rules.tick_post_step_all(self)) return results def print(self, string) -> None: """ When *verbose* is active, print stuff. :param string: *String* to print. :type string: str :return: Nothing """ if self.verbose: print(string) def check_done(self) -> List[DoneResult]: """ Iterate all **Rules** that override tehe *on_ckeck_done* hook. :return: List of Results """ results = list() for rule in self.rules: if on_check_done_result := rule.on_check_done(self): results.extend(on_check_done_result) return results def get_collision_positions(self) -> List[Tuple[(int, int)]]: """ Returns a list positions [(x, y), ... ] on which collisions occur. This does not include agents, that were unable to move because their target direction was blocked, also a form of collision. :return: List of positions. """ positions = [pos for pos, entities in self.entities.pos_dict.items() if len(entities) >= 2 and (len([e for e in entities if e.var_can_collide]) >= 2) ] return positions def check_move_validity(self, moving_entity: Entity, target_position: (int, int)) -> bool: """ Whether it is safe to move to the target positions and moving entity does not introduce a blocking attribute, when position is allready occupied. !!! Will still report true even though, there could be an enity, which var_can_collide == true !!! :param moving_entity: Entity :param target_position: pos :return: Safe to move to """ is_not_blocked = self.check_pos_validity(target_position) will_not_block_others = moving_entity.var_is_blocking_pos and self.entities.is_occupied(target_position) if moving_entity.pos != target_position and is_not_blocked and not will_not_block_others: return c.VALID else: return c.NOT_VALID def check_pos_validity(self, pos: (int, int)) -> bool: """ Check if *pos* is a valid position to move or spawn to. :param pos: position to check :return: Wheter pos is a valid target. """ if not any(e.var_is_blocking_pos for e in self.entities.pos_dict[pos]) and pos in self.entities.floorlist: return c.VALID else: return c.NOT_VALID