Maintainer and pos_dicts fixed. Are sets now.

This commit is contained in:
Steffen Illium
2023-11-10 14:16:48 +01:00
parent 6711a0976b
commit 9b289591ba
22 changed files with 100 additions and 131 deletions

View File

@@ -23,7 +23,7 @@ class Rule(abc.ABC):
def on_init(self, state, lvl_map):
return []
def on_reset(self):
def on_reset(self, state) -> List[TickResult]:
return []
def tick_pre_step(self, state) -> List[TickResult]:
@@ -55,7 +55,7 @@ class SpawnEntity(Rule):
self.collection = collection
self.ignore_blocking = ignore_blocking
def on_init(self, state, lvl_map) -> [TickResult]:
def on_reset(self, state) -> [TickResult]:
results = self.collection.trigger_spawn(state, ignore_blocking=self.ignore_blocking)
pos_str = f' on: {[x.pos for x in self.collection]}' if self.collection.var_has_position else ''
state.print(f'Initial {self.collection.__class__.__name__} were spawned{pos_str}')
@@ -68,8 +68,7 @@ class SpawnAgents(Rule):
super().__init__()
pass
def on_init(self, state, lvl_map):
# agents = Agents(lvl_map.size)
def on_reset(self, state):
agents = state[c.AGENT]
empty_positions = state.entities.empty_positions[:len(state.agents_conf)]
for agent_name, agent_conf in state.agents_conf.items():
@@ -101,9 +100,6 @@ class DoneAtMaxStepsReached(Rule):
super().__init__()
self.max_steps = max_steps
def on_init(self, state, lvl_map):
pass
def on_check_done(self, state):
if self.max_steps <= state.curr_step:
return [DoneResult(validity=c.VALID, identifier=self.name)]
@@ -115,7 +111,7 @@ class AssignGlobalPositions(Rule):
def __init__(self):
super().__init__()
def on_init(self, state, lvl_map):
def on_reset(self, state, lvl_map):
from marl_factory_grid.environment.entity.util import GlobalPosition
for agent in state[c.AGENT]:
gp = GlobalPosition(agent, lvl_map.level_shape)