new rules, new spawn logic, small fixes, default and narrow corridor debugged

This commit is contained in:
Steffen Illium
2023-11-09 17:50:20 +01:00
parent 9b9c6e0385
commit 06a5130b25
67 changed files with 768 additions and 921 deletions

View File

@ -1,48 +1,35 @@
from random import shuffle
import networkx as nx
import numpy as np
from ...algorithms.static.utils import points_to_graph
from ...environment import constants as c
from ...environment.actions import Action, ALL_BASEACTIONS
from ...environment.entity.entity import Entity
from ..doors import constants as do
from ..maintenance import constants as mi
from ...utils.helpers import MOVEMAP
from ...utils.utility_classes import RenderEntity
from ...utils.states import Gamestate
from ...utils import helpers as h
from ...utils.utility_classes import RenderEntity, Floor
from ..doors import DoorUse
class Maintainer(Entity):
@property
def var_can_collide(self):
return True
@property
def var_can_move(self):
return False
@property
def var_is_blocking_light(self):
return False
@property
def var_has_position(self):
return True
def __init__(self, state: Gamestate, objective: str, action: Action, *args, **kwargs):
def __init__(self, objective: str, action: Action, *args, **kwargs):
super().__init__(*args, **kwargs)
self.action = action
self.actions = [x() for x in ALL_BASEACTIONS]
self.actions = [x() for x in ALL_BASEACTIONS] + [DoorUse()]
self.objective = objective
self._path = None
self._next = []
self._last = []
self._last_serviced = 'None'
self._floortile_graph = points_to_graph(state.entities.floorlist)
self._floortile_graph = None
def tick(self, state):
if found_objective := state[self.objective].by_pos(self.pos):
if found_objective := h.get_first(state[self.objective].by_pos(self.pos)):
if found_objective.name != self._last_serviced:
self.action.do(self, state)
self._last_serviced = found_objective.name
@ -54,24 +41,27 @@ class Maintainer(Entity):
return action.do(self, state)
def get_move_action(self, state) -> Action:
if not self._floortile_graph:
state.print("Generating Floorgraph....")
self._floortile_graph = points_to_graph(state.entities.floorlist)
if self._path is None or not self._path:
if not self._next:
self._next = list(state[self.objective].values())
self._next = list(state[self.objective].values()) + [Floor(*state.random_free_position)]
shuffle(self._next)
self._last = []
self._last.append(self._next.pop())
state.print("Calculating shortest path....")
self._path = self.calculate_route(self._last[-1])
if door := self._door_is_close(state):
if door.is_closed:
# Translate the action_object to an integer to have the same output as any other model
action = do.ACTION_DOOR_USE
else:
action = self._predict_move(state)
if door := self._closed_door_in_path(state):
state.print(f"{self} found {door} that is closed. Attempt to open.")
# Translate the action_object to an integer to have the same output as any other model
action = do.ACTION_DOOR_USE
else:
action = self._predict_move(state)
# Translate the action_object to an integer to have the same output as any other model
try:
action_obj = next(x for x in self.actions if x.name == action)
action_obj = h.get_first(self.actions, lambda x: x.name == action)
except (StopIteration, UnboundLocalError):
print('Will not happen')
raise EnvironmentError
@ -81,11 +71,10 @@ class Maintainer(Entity):
route = nx.shortest_path(self._floortile_graph, self.pos, entity.pos)
return route[1:]
def _door_is_close(self, state):
state.print("Found a door that is close.")
try:
return next(y for x in state.entities.neighboring_positions(self.state.pos) for y in state.entities.pos_dict[x] if do.DOOR in y.name)
except StopIteration:
def _closed_door_in_path(self, state):
if self._path:
return h.get_first(state[do.DOORS].by_pos(self._path[0]), lambda x: x.is_closed)
else:
return None
def _predict_move(self, state):
@ -96,7 +85,7 @@ class Maintainer(Entity):
next_pos = self._path.pop(0)
diff = np.subtract(next_pos, self.pos)
# Retrieve action based on the pos dif (like in: What do I have to do to get there?)
action = next(action for action, pos_diff in MOVEMAP.items() if np.all(diff == pos_diff))
action = next(action for action, pos_diff in h.MOVEMAP.items() if np.all(diff == pos_diff))
return action
def render(self):

View File

@ -1,4 +1,4 @@
from typing import Union, List, Tuple
from typing import Union, List, Tuple, Dict
from marl_factory_grid.environment.groups.collection import Collection
from .entities import Maintainer
@ -10,25 +10,21 @@ from ...utils.states import Gamestate
class Maintainers(Collection):
_entity = Maintainer
@property
def var_can_collide(self):
return True
var_can_collide = True
var_can_move = True
var_is_blocking_light = False
var_has_position = True
@property
def var_can_move(self):
return True
@property
def var_is_blocking_light(self):
return False
@property
def var_has_position(self):
return True
def __init__(self, size, *args, coords_or_quantity: int = None,
spawnrule: Union[None, Dict[str, dict]] = None,
**kwargs):
super(Collection, self).__init__(*args, **kwargs)
self._coords_or_quantity = coords_or_quantity
self.size = size
self._spawnrule = spawnrule
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args):
state = entity_args[0]
self.add_items([self._entity(state, mc.MACHINES, MachineAction(), pos) for pos in coords_or_quantity])
self.add_items([self._entity(mc.MACHINES, MachineAction(), pos) for pos in coords_or_quantity])

View File

@ -4,29 +4,24 @@ from marl_factory_grid.utils.results import TickResult, DoneResult
from marl_factory_grid.environment import constants as c
from . import rewards as r
from . import constants as M
from marl_factory_grid.utils.states import Gamestate
class MaintenanceRule(Rule):
class MoveMaintainers(Rule):
def __init__(self, n_maintainer: int = 1, *args, **kwargs):
super(MaintenanceRule, self).__init__(*args, **kwargs)
self.n_maintainer = n_maintainer
def on_init(self, state: Gamestate, lvl_map):
state[M.MAINTAINERS].spawn(state.entities.empty_positions[:self.n_maintainer], state)
pass
def tick_pre_step(self, state) -> List[TickResult]:
pass
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def tick_step(self, state) -> List[TickResult]:
for maintainer in state[M.MAINTAINERS]:
maintainer.tick(state)
# Todo: Return a Result Object.
return []
def tick_post_step(self, state) -> List[TickResult]:
pass
class DoneAtMaintainerCollision(Rule):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def on_check_done(self, state) -> List[DoneResult]:
agents = list(state[c.AGENT].values())