2023-12-29 10:37:39 +01:00

271 lines
9.5 KiB
Python

from itertools import islice
from typing import List, Tuple
import numpy as np
from marl_factory_grid.algorithms.static.utils import points_to_graph
from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.entity.entity import Entity
from marl_factory_grid.environment.rules import Rule, SpawnAgents
from marl_factory_grid.utils.results import Result, DoneResult
class StepRules:
def __init__(self, *args):
"""
Manages a collection of rules to be applied at each step of the environment.
The StepRules class allows you to organize and apply custom rules during the simulation, ensuring that the
corresponding hooks for all rules are called at the appropriate times.
:param args: Optional Rule objects to initialize the StepRules with.
"""
if args:
self.rules = list(args)
else:
self.rules = list()
def __repr__(self):
return f'Rules{[x.name for x in self]}'
def __iter__(self):
return iter(self.rules)
def append(self, item):
assert isinstance(item, Rule)
self.rules.append(item)
return True
def do_all_init(self, state, lvl_map):
for rule in self.rules:
if rule_init_printline := rule.on_init(state, lvl_map):
state.print(rule_init_printline)
return c.VALID
def do_all_reset(self, state):
SpawnAgents().on_reset(state)
for rule in self.rules:
if rule_reset_printline := rule.on_reset(state):
state.print(rule_reset_printline)
return c.VALID
def do_all_post_spawn_reset(self, state):
for rule in self.rules:
if rule_reset_printline := rule.on_reset_post_spawn(state):
state.print(rule_reset_printline)
return c.VALID
def tick_step_all(self, state):
results = list()
for rule in self.rules:
if tick_step_result := rule.tick_step(state):
results.extend(tick_step_result)
return results
def tick_pre_step_all(self, state):
results = list()
for rule in self.rules:
if tick_pre_step_result := rule.tick_pre_step(state):
results.extend(tick_pre_step_result)
return results
def tick_post_step_all(self, state):
results = list()
for rule in self.rules:
if tick_post_step_result := rule.tick_post_step(state):
results.extend(tick_post_step_result)
return results
class Gamestate(object):
@property
def floortile_graph(self):
if not self._floortile_graph:
self.print("Generating Floorgraph....")
self._floortile_graph = points_to_graph(self.entities.floorlist)
return self._floortile_graph
@property
def moving_entites(self):
return [y for x in self.entities for y in x if x.var_can_move]
def __init__(self, entities, agents_conf, rules: List[Rule], lvl_shape, env_seed=69, verbose=False):
"""
The `Gamestate` class represents the state of the game environment.
:param lvl_shape: The shape of the game level.
:type lvl_shape: tuple
:param entities: The entities present in the environment.
:type entities: Entities
:param agents_conf: Agent configurations for the environment.
:type agents_conf: Any
:param verbose: Controls verbosity in the environment.
:type verbose: bool
:param rules: Organizes and applies custom rules during the simulation.
:type rules: StepRules
"""
self.lvl_shape = lvl_shape
self.entities = entities
self.curr_step = 0
self.curr_actions = None
self.agents_conf = agents_conf
self.verbose = verbose
self.rng = np.random.default_rng(env_seed)
self.rules = StepRules(*rules)
self._floortile_graph = None
def reset(self):
self.curr_step = 0
self.curr_actions = None
def __getitem__(self, item):
return self.entities[item]
def __iter__(self):
return iter(e for e in self.entities.values())
def __contains__(self, item):
return item in self.entities
def __repr__(self):
return f'{self.__class__.__name__}({len(self.entities)} Entitites @ Step {self.curr_step})'
@property
def random_free_position(self) -> (int, int):
"""
Returns a single **free** position (x, y), which is **free** for spawning or walking.
No Entity at this position posses *var_is_blocking_pos* or *var_can_collide*.
:return: Single **free** position.
"""
return self.get_n_random_free_positions(1)[0]
def get_n_random_free_positions(self, n) -> list[tuple[int, int]]:
"""
Returns a list of *n* **free** positions [(x, y), ... ], which are **free** for spawning or walking.
No Entity at this position posses *var_is_blocking_pos* or *var_can_collide*.
:return: List of n **free** position.
"""
return list(islice(self.entities.free_positions_generator, n))
@property
def random_position(self) -> (int, int):
"""
Returns a single available position (x, y), ignores all entity attributes.
:return: Single random position.
"""
return self.get_n_random_positions(1)[0]
def get_n_random_positions(self, n) -> list[tuple[int, int]]:
"""
Returns a list of *n* available positions [(x, y), ... ], ignores all entity attributes.
:return: List of n random positions.
"""
return list(islice(self.entities.floorlist, n))
def tick(self, actions) -> list[Result]:
"""
Performs a single **Gamestate Tick** by calling the inner rule hooks in sequential order.
- tick_pre_step_all: Things to do before the agents do their actions. Statechange, Moving, Spawning etc...
- agent tick: Agents do their actions.
- tick_step_all: Things to do after the agents did their actions. Statechange, Moving, Spawning etc...
- tick_post_step_all: Things to do at the very end of each step. Counting, Reward calculations etc...
:return: List of *Result*-objects.
"""
results = list()
self.curr_step += 1
for entity in self.entities.iter_entities():
entity.clear_temp_state()
# Main Agent Step
results.extend(self.rules.tick_pre_step_all(self))
for idx, action_int in enumerate(actions):
agent = self[c.AGENT][idx].clear_temp_state()
if not agent.var_is_paralyzed:
action = agent.actions[action_int]
action_result = action.do(agent, self)
results.append(action_result)
agent.set_state(action_result)
else:
self.print(f"{agent.name} is paralied because of: {agent.paralyze_reasons}")
continue
results.extend(self.rules.tick_step_all(self))
results.extend(self.rules.tick_post_step_all(self))
return results
def print(self, string) -> None:
"""
When *verbose* is active, print stuff.
:param string: *String* to print.
:type string: str
:return: Nothing
"""
if self.verbose:
print(string)
def check_done(self) -> List[DoneResult]:
"""
Iterate all **Rules** that override tehe *on_ckeck_done* hook.
:return: List of Results
"""
results = list()
for rule in self.rules:
if on_check_done_result := rule.on_check_done(self):
results.extend(on_check_done_result)
return results
def get_collision_positions(self) -> List[Tuple[(int, int)]]:
"""
Returns a list positions [(x, y), ... ] on which collisions occur. This does not include agents,
that were unable to move because their target direction was blocked, also a form of collision.
:return: List of positions.
"""
positions = [pos for pos, entities in self.entities.pos_dict.items() if
len(entities) >= 2 and (len([e for e in entities if e.var_can_collide]) >= 2)
]
return positions
def check_move_validity(self, moving_entity: Entity, target_position: (int, int)) -> bool:
"""
Whether it is safe to move to the target positions and moving entity does not introduce a blocking attribute,
when position is allready occupied.
!!! Will still report true even though, there could be an enity, which var_can_collide == true !!!
:param moving_entity: Entity
:param target_position: pos
:return: Safe to move to
"""
is_not_blocked = self.check_pos_validity(target_position)
will_not_block_others = moving_entity.var_is_blocking_pos and self.entities.is_occupied(target_position)
if moving_entity.pos != target_position and is_not_blocked and not will_not_block_others:
return c.VALID
else:
return c.NOT_VALID
def check_pos_validity(self, pos: (int, int)) -> bool:
"""
Check if *pos* is a valid position to move or spawn to.
:param pos: position to check
:return: Wheter pos is a valid target.
"""
if not any(e.var_is_blocking_pos for e in self.entities.pos_dict[pos]) and pos in self.entities.floorlist:
return c.VALID
else:
return c.NOT_VALID