mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-12-06 15:40:37 +01:00
Maintainer and pos_dicts fixed. Are sets now.
This commit is contained in:
@@ -23,7 +23,7 @@ class Rule(abc.ABC):
|
||||
def on_init(self, state, lvl_map):
|
||||
return []
|
||||
|
||||
def on_reset(self):
|
||||
def on_reset(self, state) -> List[TickResult]:
|
||||
return []
|
||||
|
||||
def tick_pre_step(self, state) -> List[TickResult]:
|
||||
@@ -55,7 +55,7 @@ class SpawnEntity(Rule):
|
||||
self.collection = collection
|
||||
self.ignore_blocking = ignore_blocking
|
||||
|
||||
def on_init(self, state, lvl_map) -> [TickResult]:
|
||||
def on_reset(self, state) -> [TickResult]:
|
||||
results = self.collection.trigger_spawn(state, ignore_blocking=self.ignore_blocking)
|
||||
pos_str = f' on: {[x.pos for x in self.collection]}' if self.collection.var_has_position else ''
|
||||
state.print(f'Initial {self.collection.__class__.__name__} were spawned{pos_str}')
|
||||
@@ -68,8 +68,7 @@ class SpawnAgents(Rule):
|
||||
super().__init__()
|
||||
pass
|
||||
|
||||
def on_init(self, state, lvl_map):
|
||||
# agents = Agents(lvl_map.size)
|
||||
def on_reset(self, state):
|
||||
agents = state[c.AGENT]
|
||||
empty_positions = state.entities.empty_positions[:len(state.agents_conf)]
|
||||
for agent_name, agent_conf in state.agents_conf.items():
|
||||
@@ -101,9 +100,6 @@ class DoneAtMaxStepsReached(Rule):
|
||||
super().__init__()
|
||||
self.max_steps = max_steps
|
||||
|
||||
def on_init(self, state, lvl_map):
|
||||
pass
|
||||
|
||||
def on_check_done(self, state):
|
||||
if self.max_steps <= state.curr_step:
|
||||
return [DoneResult(validity=c.VALID, identifier=self.name)]
|
||||
@@ -115,7 +111,7 @@ class AssignGlobalPositions(Rule):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def on_init(self, state, lvl_map):
|
||||
def on_reset(self, state, lvl_map):
|
||||
from marl_factory_grid.environment.entity.util import GlobalPosition
|
||||
for agent in state[c.AGENT]:
|
||||
gp = GlobalPosition(agent, lvl_map.level_shape)
|
||||
|
||||
Reference in New Issue
Block a user