mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-06-22 03:31:35 +02:00
new rules, new spawn logic, small fixes, default and narrow corridor debugged
This commit is contained in:
@ -1,3 +1,4 @@
|
||||
from itertools import islice
|
||||
from typing import List, Dict, Tuple
|
||||
|
||||
import numpy as np
|
||||
@ -59,14 +60,15 @@ class Gamestate(object):
|
||||
def moving_entites(self):
|
||||
return [y for x in self.entities for y in x if x.var_can_move]
|
||||
|
||||
def __init__(self, entities, agents_conf, rules: Dict[str, dict], env_seed=69, verbose=False):
|
||||
def __init__(self, entities, agents_conf, rules: List[Rule], lvl_shape, env_seed=69, verbose=False):
|
||||
self.lvl_shape = lvl_shape
|
||||
self.entities = entities
|
||||
self.curr_step = 0
|
||||
self.curr_actions = None
|
||||
self.agents_conf = agents_conf
|
||||
self.verbose = verbose
|
||||
self.rng = np.random.default_rng(env_seed)
|
||||
self.rules = StepRules(*(v['class'](**v['kwargs']) for v in rules.values()))
|
||||
self.rules = StepRules(*rules)
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.entities[item]
|
||||
@ -80,6 +82,13 @@ class Gamestate(object):
|
||||
def __repr__(self):
|
||||
return f'{self.__class__.__name__}({len(self.entities)} Entitites @ Step {self.curr_step})'
|
||||
|
||||
@property
|
||||
def random_free_position(self):
|
||||
return self.get_n_random_free_positions(1)[0]
|
||||
|
||||
def get_n_random_free_positions(self, n):
|
||||
return list(islice(self.entities.free_positions_generator, n))
|
||||
|
||||
def tick(self, actions) -> List[Result]:
|
||||
results = list()
|
||||
self.curr_step += 1
|
||||
@ -115,8 +124,7 @@ class Gamestate(object):
|
||||
return results
|
||||
|
||||
def get_all_pos_with_collisions(self) -> List[Tuple[(int, int)]]:
|
||||
positions = [pos for pos, entity_list_for_position in self.entities.pos_dict.items()
|
||||
if any([e.var_can_collide for e in entity_list_for_position])]
|
||||
positions = [pos for pos, entities in self.entities.pos_dict.items() if len(entities) >= 2 and (len([e for e in entities if e.var_can_collide]) >= 2)]
|
||||
return positions
|
||||
|
||||
def check_move_validity(self, moving_entity, position):
|
||||
|
Reference in New Issue
Block a user