mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-11-02 21:47:25 +01:00
fix mismatching signatures of spawn
This commit is contained in:
@@ -1,9 +1,10 @@
|
||||
from typing import Union, List, Tuple
|
||||
|
||||
from marl_factory_grid.environment.groups.collection import Collection
|
||||
from marl_factory_grid.modules.batteries.entitites import Pod, Battery
|
||||
|
||||
|
||||
class Batteries(Collection):
|
||||
|
||||
_entity = Battery
|
||||
|
||||
@property
|
||||
@@ -33,9 +34,14 @@ class Batteries(Collection):
|
||||
batteries = [self._entity(initial_charge_level, agent) for _, agent in enumerate(agents)]
|
||||
self.add_items(batteries)
|
||||
|
||||
# def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args): hat keine pos
|
||||
# agents = entity_args[0]
|
||||
# initial_charge_level = entity_args[1]
|
||||
# batteries = [self._entity(initial_charge_level, agent) for _, agent in enumerate(agents)]
|
||||
# self.add_items(batteries)
|
||||
|
||||
|
||||
class ChargePods(Collection):
|
||||
|
||||
_entity = Pod
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
||||
@@ -49,7 +49,7 @@ class BatteryDecharge(Rule):
|
||||
self.per_action_costs = per_action_costs
|
||||
self.initial_charge = initial_charge
|
||||
|
||||
def on_init(self, state, lvl_map):
|
||||
def on_init(self, state, lvl_map): # on reset?
|
||||
assert len(state[c.AGENT]), "There are no agents, did you already spawn them?"
|
||||
state[b.BATTERIES].spawn(state[c.AGENT], self.initial_charge)
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Union, List, Tuple
|
||||
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.utils.results import Result
|
||||
from marl_factory_grid.environment.groups.collection import Collection
|
||||
@@ -36,9 +38,10 @@ class DirtPiles(Collection):
|
||||
self.max_global_amount = max_global_amount
|
||||
self.max_local_amount = max_local_amount
|
||||
|
||||
def spawn(self, then_dirty_positions, amount_s) -> Result:
|
||||
def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args):
|
||||
amount_s = entity_args[0]
|
||||
spawn_counter = 0
|
||||
for idx, pos in enumerate(then_dirty_positions):
|
||||
for idx, pos in enumerate(coords_or_quantity):
|
||||
if not self.amount > self.max_global_amount:
|
||||
amount = amount_s[idx] if isinstance(amount_s, list) else amount_s
|
||||
if dirt := self.by_pos(pos):
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Union, List, Tuple
|
||||
|
||||
from marl_factory_grid.environment.groups.collection import Collection
|
||||
|
||||
from .entitites import Machine
|
||||
@@ -21,3 +23,4 @@ class Machines(Collection):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(Machines, self).__init__(*args, **kwargs)
|
||||
|
||||
|
||||
@@ -13,8 +13,7 @@ class MachineRule(Rule):
|
||||
self.n_machines = n_machines
|
||||
|
||||
def on_init(self, state, lvl_map):
|
||||
# TODO Move to spawn!!!
|
||||
state[m.MACHINES].add_items(Machine(pos) for pos in state.entities.empty_positions())
|
||||
state[m.MACHINES].spawn(state.entities.empty_positions())
|
||||
|
||||
def tick_pre_step(self, state) -> List[TickResult]:
|
||||
pass
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Union, List, Tuple
|
||||
|
||||
from marl_factory_grid.environment.groups.collection import Collection
|
||||
from .entities import Maintainer
|
||||
from ..machines import constants as mc
|
||||
@@ -27,5 +29,6 @@ class Maintainers(Collection):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def spawn(self, position, state: Gamestate):
|
||||
self.add_items([self._entity(state, mc.MACHINES, MachineAction(), pos) for pos in position])
|
||||
def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args):
|
||||
state = entity_args[0]
|
||||
self.add_items([self._entity(state, mc.MACHINES, MachineAction(), pos) for pos in coords_or_quantity])
|
||||
|
||||
@@ -14,7 +14,6 @@ class MaintenanceRule(Rule):
|
||||
self.n_maintainer = n_maintainer
|
||||
|
||||
def on_init(self, state: Gamestate, lvl_map):
|
||||
# Move to spawn? : #TODO
|
||||
state[M.MAINTAINERS].spawn(state.entities.empty_positions[:self.n_maintainer], state)
|
||||
pass
|
||||
|
||||
|
||||
Reference in New Issue
Block a user