mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-11-02 21:47:25 +01:00
checked rule returns
This commit is contained in:
@@ -60,7 +60,7 @@ class BatteryDecharge(Rule):
|
||||
|
||||
batteries.by_entity(agent).decharge(energy_consumption)
|
||||
|
||||
results.append(TickResult(self.name, entity=agent, validity=c.VALID))
|
||||
results.append(TickResult(self.name, entity=agent, validity=c.VALID, value=energy_consumption))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ class DoneOnAllDirtCleaned(Rule):
|
||||
def on_check_done(self, state) -> [DoneResult]:
|
||||
if len(state[d.DIRT]) == 0 and state.curr_step:
|
||||
return [DoneResult(validity=c.VALID, identifier=self.name, reward=self.reward)]
|
||||
return [DoneResult(validity=c.NOT_VALID, identifier=self.name)]
|
||||
return []
|
||||
|
||||
|
||||
class RespawnDirt(Rule):
|
||||
@@ -81,5 +81,6 @@ class EntitiesSmearDirtOnMove(Rule):
|
||||
old_pos_dirt = next(iter(old_pos_dirt))
|
||||
if smeared_dirt := round(old_pos_dirt.amount * self.smear_ratio, 2):
|
||||
if state[d.DIRT].spawn(entity.pos, amount=smeared_dirt):
|
||||
results.append(TickResult(identifier=self.name, entity=entity, validity=c.VALID))
|
||||
results.append(TickResult(identifier=self.name, entity=entity,
|
||||
validity=c.VALID, value=smeared_dirt))
|
||||
return results
|
||||
|
||||
@@ -90,6 +90,8 @@ class DoneAtDestinationReach(DestinationReachReward):
|
||||
pass
|
||||
else:
|
||||
dest.unmark_as_reached()
|
||||
return [DoneResult(f'{dest}_unmarked_as_reached',
|
||||
validity=c.NOT_VALID, entity=dest)]
|
||||
else:
|
||||
pass
|
||||
else:
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from typing import List
|
||||
|
||||
import marl_factory_grid.modules.maintenance.constants
|
||||
from marl_factory_grid.environment.rules import Rule
|
||||
from marl_factory_grid.utils.results import TickResult, DoneResult
|
||||
from marl_factory_grid.environment import constants as c
|
||||
@@ -31,5 +30,5 @@ class DoneAtMaintainerCollision(Rule):
|
||||
for agent in agents:
|
||||
if agent.pos in m_pos:
|
||||
done_results.append(DoneResult(entity=agent, validity=c.VALID, identifier=self.name,
|
||||
reward=marl_factory_grid.modules.maintenance.constants.MAINTAINER_COLLISION_REWARD))
|
||||
reward=M.MAINTAINER_COLLISION_REWARD))
|
||||
return done_results
|
||||
|
||||
@@ -43,9 +43,6 @@ class AgentSingleZonePlacement(Rule):
|
||||
agent.move(state[z.ZONES][z_idxs.pop()].random_pos, state)
|
||||
return []
|
||||
|
||||
def tick_step(self, state):
|
||||
return []
|
||||
|
||||
|
||||
class IndividualDestinationZonePlacement(Rule):
|
||||
|
||||
|
||||
Reference in New Issue
Block a user