mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-23 07:16:44 +02:00
35 lines
1.3 KiB
Python
35 lines
1.3 KiB
Python
from typing import List
|
|
|
|
from marl_factory_grid.environment.rules import Rule
|
|
from marl_factory_grid.environment import constants as c
|
|
from marl_factory_grid.utils.results import TickResult
|
|
from marl_factory_grid.modules.items import constants as i
|
|
|
|
|
|
class RespawnItems(Rule):
|
|
|
|
def __init__(self, n_items: int = 5, respawn_freq: int = 15, n_locations: int = 5):
|
|
super().__init__()
|
|
self.spawn_frequency = respawn_freq
|
|
self._next_item_spawn = respawn_freq
|
|
self.n_items = n_items
|
|
self.n_locations = n_locations
|
|
|
|
def tick_step(self, state):
|
|
if not self._next_item_spawn:
|
|
state[i.ITEM].trigger_spawn(state, self.n_items, self.spawn_frequency)
|
|
else:
|
|
self._next_item_spawn = max(0, self._next_item_spawn - 1)
|
|
return []
|
|
|
|
def tick_post_step(self, state) -> List[TickResult]:
|
|
if not self._next_item_spawn:
|
|
if spawned_items := state[i.ITEM].trigger_spawn(state, self.n_items, self.spawn_frequency):
|
|
return [TickResult(self.name, validity=c.VALID, value=spawned_items.value)]
|
|
else:
|
|
return [TickResult(self.name, validity=c.NOT_VALID, value=0)]
|
|
else:
|
|
self._next_item_spawn = max(0, self._next_item_spawn-1)
|
|
return []
|
|
|