mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-10-22 10:06:52 +02:00
Machines
This commit is contained in:
@@ -17,7 +17,7 @@ class Rule(abc.ABC):
|
||||
def __repr__(self):
|
||||
return f'{self.name}'
|
||||
|
||||
def on_init(self, state):
|
||||
def on_init(self, state, lvl_map):
|
||||
return []
|
||||
|
||||
def on_reset(self):
|
||||
@@ -42,7 +42,7 @@ class MaxStepsReached(Rule):
|
||||
super().__init__()
|
||||
self.max_steps = max_steps
|
||||
|
||||
def on_init(self, state):
|
||||
def on_init(self, state, lvl_map):
|
||||
pass
|
||||
|
||||
def on_check_done(self, state):
|
||||
@@ -51,6 +51,20 @@ class MaxStepsReached(Rule):
|
||||
return [DoneResult(validity=c.NOT_VALID, identifier=self.name, reward=0)]
|
||||
|
||||
|
||||
class AssignGlobalPositions(Rule):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def on_init(self, state, lvl_map):
|
||||
from marl_factory_grid.environment.entity.util import GlobalPosition
|
||||
for agent in state[c.AGENT]:
|
||||
gp = GlobalPosition(lvl_map.level_shape)
|
||||
gp.bind_to(agent)
|
||||
state[c.GLOBALPOSITIONS].add_item(gp)
|
||||
return []
|
||||
|
||||
|
||||
class Collision(Rule):
|
||||
|
||||
def __init__(self, done_at_collisions: bool = False):
|
||||
|
Reference in New Issue
Block a user