diff --git a/marl_factory_grid/modules/batteries/actions.py b/marl_factory_grid/modules/batteries/actions.py index 85e843b..fe69180 100644 --- a/marl_factory_grid/modules/batteries/actions.py +++ b/marl_factory_grid/modules/batteries/actions.py @@ -15,7 +15,7 @@ class Charge(Action): def do(self, entity, state) -> Union[None, ActionResult]: if charge_pod := h.get_first(state[b.CHARGE_PODS].by_pos(entity.pos)): - valid = h.get_first(charge_pod.charge_battery(state[b.BATTERIES].by_entity(entity))) + valid = h.get_first(charge_pod.charge_battery(entity, state)) if valid: state.print(f'{entity.name} just charged batteries at {charge_pod.name}.') else: diff --git a/marl_factory_grid/modules/batteries/entitites.py b/marl_factory_grid/modules/batteries/entitites.py index 7675fe9..e9006b9 100644 --- a/marl_factory_grid/modules/batteries/entitites.py +++ b/marl_factory_grid/modules/batteries/entitites.py @@ -1,4 +1,5 @@ from marl_factory_grid.environment import constants as c +from marl_factory_grid.environment.entity.agent import Agent from marl_factory_grid.environment.entity.entity import Entity from marl_factory_grid.environment.entity.object import Object from marl_factory_grid.modules.batteries import constants as b @@ -62,11 +63,11 @@ class ChargePod(Entity): self.charge_rate = charge_rate self.multi_charge = multi_charge - def charge_battery(self, battery: Battery): - if battery.charge_level == 1.0: + def charge_battery(self, entity, state): + battery = state[b.BATTERIES].by_entity(entity) + if battery.charge_level >= 1.0: return c.NOT_VALID - if sum(1 for key, val in self.state.entities.pos_dict[self.pos] for guest in val if - 'agent' in guest.name.lower()) > 1: + if len([x for x in state[c.AGENT].by_pos(entity.pos)]) > 1: return c.NOT_VALID valid = battery.do_charge_action(self.charge_rate) return valid diff --git a/marl_factory_grid/modules/maintenance/entities.py b/marl_factory_grid/modules/maintenance/entities.py index 459d70f..759fdda 100644 --- a/marl_factory_grid/modules/maintenance/entities.py +++ b/marl_factory_grid/modules/maintenance/entities.py @@ -26,10 +26,9 @@ class Maintainer(Entity): self._last_serviced = 'None' def tick(self, state): - self.clear_temp_state if found_objective := h.get_first(state[self.objective].by_pos(self.pos)): if found_objective.name != self._last_serviced: - self.action.do(self, state) + result = self.action.do(self, state) self._last_serviced = found_objective.name else: action = self.get_move_action(state)