mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-09-13 22:44:00 +02:00
new rules, new spawn logic, small fixes, default and narrow corridor debugged
This commit is contained in:
@@ -1,3 +1,2 @@
|
||||
from .entitites import Machine
|
||||
from .groups import Machines
|
||||
from .rules import MachineRule
|
||||
|
@@ -5,6 +5,7 @@ from marl_factory_grid.utils.results import ActionResult
|
||||
|
||||
from marl_factory_grid.modules.machines import constants as m, rewards as r
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.utils import helpers as h
|
||||
|
||||
|
||||
class MachineAction(Action):
|
||||
@@ -13,13 +14,10 @@ class MachineAction(Action):
|
||||
super().__init__(m.MACHINE_ACTION)
|
||||
|
||||
def do(self, entity, state) -> Union[None, ActionResult]:
|
||||
if machine := state[m.MACHINES].by_pos(entity.pos):
|
||||
if machine := h.get_first(state[m.MACHINES].by_pos(entity.pos)):
|
||||
if valid := machine.maintain():
|
||||
return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=r.MAINTAIN_VALID)
|
||||
else:
|
||||
return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=r.MAINTAIN_FAIL)
|
||||
else:
|
||||
return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=r.MAINTAIN_FAIL)
|
||||
|
||||
|
||||
|
||||
|
@@ -8,22 +8,6 @@ from . import constants as m
|
||||
|
||||
class Machine(Entity):
|
||||
|
||||
@property
|
||||
def var_can_collide(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def var_can_move(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def var_is_blocking_light(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def var_has_position(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
return self._encodings[self.status]
|
||||
@@ -46,12 +30,12 @@ class Machine(Entity):
|
||||
else:
|
||||
return c.NOT_VALID
|
||||
|
||||
def tick(self):
|
||||
def tick(self, state):
|
||||
# if self.status == m.STATE_MAINTAIN and any([c.AGENT in x.name for x in self.tile.guests]):
|
||||
if self.status == m.STATE_MAINTAIN and any([c.AGENT in x.name for x in self.state.entities.pos_dict[self.pos]]):
|
||||
return TickResult(identifier=self.name, validity=c.VALID, reward=0, entity=self)
|
||||
if self.status == m.STATE_MAINTAIN and any([c.AGENT in x.name for x in state.entities.pos_dict[self.pos]]):
|
||||
return TickResult(identifier=self.name, validity=c.VALID, entity=self)
|
||||
# elif self.status == m.STATE_MAINTAIN and not any([c.AGENT in x.name for x in self.tile.guests]):
|
||||
elif self.status == m.STATE_MAINTAIN and not any([c.AGENT in x.name for x in self.state.entities.pos_dict[self.pos]]):
|
||||
elif self.status == m.STATE_MAINTAIN and not any([c.AGENT in x.name for x in state.entities.pos_dict[self.pos]]):
|
||||
self.status = m.STATE_WORK
|
||||
self.reset_counter()
|
||||
return None
|
||||
|
@@ -1,28 +0,0 @@
|
||||
from typing import List
|
||||
from marl_factory_grid.environment.rules import Rule
|
||||
from marl_factory_grid.utils.results import TickResult, DoneResult
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.modules.machines import constants as m
|
||||
from marl_factory_grid.modules.machines.entitites import Machine
|
||||
|
||||
|
||||
class MachineRule(Rule):
|
||||
|
||||
def __init__(self, n_machines: int = 2):
|
||||
super(MachineRule, self).__init__()
|
||||
self.n_machines = n_machines
|
||||
|
||||
def on_init(self, state, lvl_map):
|
||||
state[m.MACHINES].spawn(state.entities.empty_positions())
|
||||
|
||||
def tick_pre_step(self, state) -> List[TickResult]:
|
||||
pass
|
||||
|
||||
def tick_step(self, state) -> List[TickResult]:
|
||||
pass
|
||||
|
||||
def tick_post_step(self, state) -> List[TickResult]:
|
||||
pass
|
||||
|
||||
def on_check_done(self, state) -> List[DoneResult]:
|
||||
pass
|
||||
|
Reference in New Issue
Block a user