mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-23 07:16:44 +02:00

# Conflicts: # marl_factory_grid/modules/doors/groups.py # marl_factory_grid/utils/states.py
42 lines
1.3 KiB
Python
42 lines
1.3 KiB
Python
from typing import List
|
|
|
|
from marl_factory_grid.environment.rules import Rule
|
|
from marl_factory_grid.utils.results import TickResult, DoneResult
|
|
from marl_factory_grid.environment import constants as c
|
|
from . import constants as M
|
|
|
|
|
|
class MoveMaintainers(Rule):
|
|
|
|
def __init__(self):
|
|
"""
|
|
This rule is responsible for moving the maintainers at every step of the environment.
|
|
"""
|
|
super().__init__()
|
|
|
|
def tick_step(self, state) -> List[TickResult]:
|
|
move_results = []
|
|
for maintainer in state[M.MAINTAINERS]:
|
|
result = maintainer.tick(state)
|
|
move_results.append(result)
|
|
return move_results
|
|
|
|
|
|
class DoneAtMaintainerCollision(Rule):
|
|
|
|
def __init__(self):
|
|
"""
|
|
When active, this rule stops the environment after a maintainer reports a collision with another entity.
|
|
"""
|
|
super().__init__()
|
|
|
|
def on_check_done(self, state) -> List[DoneResult]:
|
|
agents = list(state[c.AGENT].values())
|
|
m_pos = state[M.MAINTAINERS].positions
|
|
done_results = []
|
|
for agent in agents:
|
|
if agent.pos in m_pos:
|
|
done_results.append(DoneResult(entity=agent, validity=c.VALID, identifier=self.name,
|
|
reward=M.MAINTAINER_COLLISION_REWARD))
|
|
return done_results
|