mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-06-24 04:11:36 +02:00
new rules, new spawn logic, small fixes, default and narrow corridor debugged
This commit is contained in:
@ -1,4 +1,4 @@
|
||||
from .actions import BtryCharge
|
||||
from .entitites import Pod, Battery
|
||||
from .entitites import ChargePod, Battery
|
||||
from .groups import ChargePods, Batteries
|
||||
from .rules import DoneAtBatteryDischarge, BatteryDecharge
|
||||
|
@ -6,6 +6,7 @@ from marl_factory_grid.utils.results import ActionResult
|
||||
|
||||
from marl_factory_grid.modules.batteries import constants as b
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.utils import helpers as h
|
||||
|
||||
|
||||
class BtryCharge(Action):
|
||||
@ -14,8 +15,8 @@ class BtryCharge(Action):
|
||||
super().__init__(b.ACTION_CHARGE)
|
||||
|
||||
def do(self, entity, state) -> Union[None, ActionResult]:
|
||||
if charge_pod := state[b.CHARGE_PODS].by_pos(entity.pos):
|
||||
valid = charge_pod.charge_battery(state[b.BATTERIES].by_entity(entity))
|
||||
if charge_pod := h.get_first(state[b.CHARGE_PODS].by_pos(entity.pos)):
|
||||
valid = h.get_first(charge_pod.charge_battery(state[b.BATTERIES].by_entity(entity)))
|
||||
if valid:
|
||||
state.print(f'{entity.name} just charged batteries at {charge_pod.name}.')
|
||||
else:
|
||||
|
BIN
marl_factory_grid/modules/batteries/chargepods.png
Normal file
BIN
marl_factory_grid/modules/batteries/chargepods.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 7.9 KiB |
@ -50,7 +50,7 @@ class Battery(_Object):
|
||||
return summary
|
||||
|
||||
|
||||
class Pod(Entity):
|
||||
class ChargePod(Entity):
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
@ -58,7 +58,7 @@ class Pod(Entity):
|
||||
|
||||
def __init__(self, *args, charge_rate: float = 0.4,
|
||||
multi_charge: bool = False, **kwargs):
|
||||
super(Pod, self).__init__(*args, **kwargs)
|
||||
super(ChargePod, self).__init__(*args, **kwargs)
|
||||
self.charge_rate = charge_rate
|
||||
self.multi_charge = multi_charge
|
||||
|
||||
|
@ -1,52 +1,36 @@
|
||||
from typing import Union, List, Tuple
|
||||
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.environment.groups.collection import Collection
|
||||
from marl_factory_grid.modules.batteries.entitites import Pod, Battery
|
||||
from marl_factory_grid.modules.batteries.entitites import ChargePod, Battery
|
||||
from marl_factory_grid.utils.results import Result
|
||||
|
||||
|
||||
class Batteries(Collection):
|
||||
_entity = Battery
|
||||
|
||||
@property
|
||||
def var_is_blocking_light(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def var_can_collide(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def var_can_move(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def var_has_position(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def var_can_be_bound(self):
|
||||
return True
|
||||
var_has_position = False
|
||||
var_can_be_bound = True
|
||||
|
||||
@property
|
||||
def obs_tag(self):
|
||||
return self.__class__.__name__
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(Batteries, self).__init__(*args, **kwargs)
|
||||
def __init__(self, size, initial_charge_level: float=1.0, *args, **kwargs):
|
||||
super(Batteries, self).__init__(size, *args, **kwargs)
|
||||
self.initial_charge_level = initial_charge_level
|
||||
|
||||
def spawn(self, agents, initial_charge_level):
|
||||
batteries = [self._entity(initial_charge_level, agent) for _, agent in enumerate(agents)]
|
||||
def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], agents, *entity_args, **entity_kwargs):
|
||||
batteries = [self._entity(self.initial_charge_level, agent) for _, agent in enumerate(agents)]
|
||||
self.add_items(batteries)
|
||||
|
||||
# def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args): hat keine pos
|
||||
# agents = entity_args[0]
|
||||
# initial_charge_level = entity_args[1]
|
||||
# batteries = [self._entity(initial_charge_level, agent) for _, agent in enumerate(agents)]
|
||||
# self.add_items(batteries)
|
||||
def trigger_spawn(self, state, *entity_args, coords_or_quantity=None, **entity_kwargs):
|
||||
self.spawn(0, state[c.AGENT])
|
||||
return Result(identifier=f'{self.name}_spawn', validity=c.VALID, value=len(self))
|
||||
|
||||
|
||||
class ChargePods(Collection):
|
||||
_entity = Pod
|
||||
_entity = ChargePod
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(ChargePods, self).__init__(*args, **kwargs)
|
||||
|
@ -49,10 +49,6 @@ class BatteryDecharge(Rule):
|
||||
self.per_action_costs = per_action_costs
|
||||
self.initial_charge = initial_charge
|
||||
|
||||
def on_init(self, state, lvl_map): # on reset?
|
||||
assert len(state[c.AGENT]), "There are no agents, did you already spawn them?"
|
||||
state[b.BATTERIES].spawn(state[c.AGENT], self.initial_charge)
|
||||
|
||||
def tick_step(self, state) -> List[TickResult]:
|
||||
# Decharge
|
||||
batteries = state[b.BATTERIES]
|
||||
@ -66,7 +62,7 @@ class BatteryDecharge(Rule):
|
||||
|
||||
batteries.by_entity(agent).decharge(energy_consumption)
|
||||
|
||||
results.append(TickResult(self.name, reward=0, entity=agent, validity=c.VALID))
|
||||
results.append(TickResult(self.name, entity=agent, validity=c.VALID))
|
||||
|
||||
return results
|
||||
|
||||
@ -82,13 +78,13 @@ class BatteryDecharge(Rule):
|
||||
if self.paralyze_agents_on_discharge:
|
||||
btry.bound_entity.paralyze(self.name)
|
||||
results.append(
|
||||
TickResult("Paralyzed", entity=btry.bound_entity, reward=0, validity=c.VALID)
|
||||
TickResult("Paralyzed", entity=btry.bound_entity, validity=c.VALID)
|
||||
)
|
||||
state.print(f'{btry.bound_entity.name} has just been paralyzed!')
|
||||
if btry.bound_entity.var_is_paralyzed and not btry.is_discharged:
|
||||
btry.bound_entity.de_paralyze(self.name)
|
||||
results.append(
|
||||
TickResult("De-Paralyzed", entity=btry.bound_entity, reward=0, validity=c.VALID)
|
||||
TickResult("De-Paralyzed", entity=btry.bound_entity, validity=c.VALID)
|
||||
)
|
||||
state.print(f'{btry.bound_entity.name} has just been de-paralyzed!')
|
||||
return results
|
||||
@ -132,7 +128,7 @@ class DoneAtBatteryDischarge(BatteryDecharge):
|
||||
if any_discharged or all_discharged:
|
||||
return [DoneResult(self.name, validity=c.VALID, reward=self.reward_discharge_done)]
|
||||
else:
|
||||
return [DoneResult(self.name, validity=c.NOT_VALID, reward=0)]
|
||||
return [DoneResult(self.name, validity=c.NOT_VALID)]
|
||||
|
||||
|
||||
class SpawnChargePods(Rule):
|
||||
@ -155,7 +151,7 @@ class SpawnChargePods(Rule):
|
||||
|
||||
def on_init(self, state, lvl_map):
|
||||
pod_collection = state[b.CHARGE_PODS]
|
||||
empty_positions = state.entities.empty_positions()
|
||||
empty_positions = state.entities.empty_positions
|
||||
pods = pod_collection.from_coordinates(empty_positions, entity_kwargs=dict(
|
||||
multi_charge=self.multi_charge, charge_rate=self.charge_rate)
|
||||
)
|
||||
|
Reference in New Issue
Block a user