mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-11-02 21:47:25 +01:00
Maintainer and pos_dicts fixed. Are sets now.
This commit is contained in:
@@ -127,30 +127,3 @@ class DoneAtBatteryDischarge(BatteryDecharge):
|
||||
return [DoneResult(self.name, validity=c.VALID, reward=self.reward_discharge_done)]
|
||||
else:
|
||||
return [DoneResult(self.name, validity=c.NOT_VALID)]
|
||||
|
||||
|
||||
class SpawnChargePods(Rule):
|
||||
|
||||
def __init__(self, n_pods: int, charge_rate: float = 0.4, multi_charge: bool = False):
|
||||
"""
|
||||
Spawn Chargepods in accordance to the given parameters.
|
||||
|
||||
:type n_pods: int
|
||||
:param n_pods: How many charge pods are there?
|
||||
:type charge_rate: float
|
||||
:param charge_rate: How much juice does each use of the charge action top up?
|
||||
:type multi_charge: bool
|
||||
:param multi_charge: Whether multiple agents are able to charge at the same time.
|
||||
"""
|
||||
super().__init__()
|
||||
self.multi_charge = multi_charge
|
||||
self.charge_rate = charge_rate
|
||||
self.n_pods = n_pods
|
||||
|
||||
def on_init(self, state, lvl_map):
|
||||
pod_collection = state[b.CHARGE_PODS]
|
||||
empty_positions = state.entities.empty_positions
|
||||
pods = pod_collection.from_coordinates(empty_positions, entity_kwargs=dict(
|
||||
multi_charge=self.multi_charge, charge_rate=self.charge_rate)
|
||||
)
|
||||
pod_collection.add_items(pods)
|
||||
|
||||
@@ -34,7 +34,12 @@ class DirtPiles(Collection):
|
||||
self.coords_or_quantity = coords_or_quantity
|
||||
self.initial_amount = initial_amount
|
||||
|
||||
def trigger_spawn(self, state, coords_or_quantity=0, amount=0) -> [Result]:
|
||||
def trigger_spawn(self, state, coords_or_quantity=0, amount=0, ignore_blocking=False) -> [Result]:
|
||||
if ignore_blocking:
|
||||
print("##########################################")
|
||||
print("Blocking should not be ignored for this Entity")
|
||||
print("Exiting....")
|
||||
exit()
|
||||
coords_or_quantity = coords_or_quantity if coords_or_quantity else self.coords_or_quantity
|
||||
n_new = int(abs(coords_or_quantity + (state.rng.uniform(-self.n_var, self.n_var))))
|
||||
n_new = state.get_n_random_free_positions(n_new)
|
||||
|
||||
@@ -106,7 +106,7 @@ class SpawnDestinationsPerAgent(Rule):
|
||||
super(Rule, self).__init__()
|
||||
self.per_agent_positions = {key: [ast.literal_eval(x) for x in val] for key, val in coords_or_quantity.items()}
|
||||
|
||||
def on_init(self, state, lvl_map):
|
||||
def on_reset(self, state, lvl_map):
|
||||
for (agent_name, position_list) in self.per_agent_positions.items():
|
||||
agent = h.get_first(state[c.AGENT], lambda x: agent_name in x.name)
|
||||
assert agent
|
||||
|
||||
@@ -15,7 +15,7 @@ class DoorUse(Action):
|
||||
# Check if agent really is standing on a door:
|
||||
e = state.entities.get_entities_near_pos(entity.pos)
|
||||
try:
|
||||
# Only one door opens TODO introcude loop
|
||||
# Only one door opens TODO introduce loop
|
||||
door = next(x for x in e if x.name.startswith(d.DOOR))
|
||||
valid = door.use()
|
||||
state.print(f'{entity.name} just used a {door.name} at {door.pos}')
|
||||
|
||||
@@ -117,3 +117,7 @@ class Door(Entity):
|
||||
def _reset_timer(self):
|
||||
self._time_to_close = self._auto_close_interval
|
||||
return True
|
||||
|
||||
def reset(self):
|
||||
self._close()
|
||||
self._reset_timer()
|
||||
|
||||
@@ -23,3 +23,7 @@ class Doors(Collection):
|
||||
results.append(tick_result)
|
||||
# TODO: Should return a Result object, not a random dict.
|
||||
return results
|
||||
|
||||
def reset(self):
|
||||
for door in self:
|
||||
door.reset()
|
||||
|
||||
@@ -40,6 +40,6 @@ class IndicateDoorAreaInObservation(Rule):
|
||||
# Could then be combined with the "Combine"-approach.
|
||||
super().__init__()
|
||||
|
||||
def on_init(self, state, lvl_map):
|
||||
def on_reset(self, state, lvl_map):
|
||||
for door in state[d.DOORS]:
|
||||
state[d.DOORS].add_items([DoorIndicator(x) for x in state.entities.neighboring_positions(door.pos)])
|
||||
|
||||
@@ -1,32 +0,0 @@
|
||||
import random
|
||||
from typing import List
|
||||
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.environment.rules import Rule
|
||||
from marl_factory_grid.utils.results import TickResult
|
||||
|
||||
|
||||
class AgentSingleZonePlacementBeta(Rule):
|
||||
|
||||
def __init__(self):
|
||||
raise NotImplementedError()
|
||||
# TODO!!!! Is this concept needed any more?
|
||||
super().__init__()
|
||||
|
||||
def on_init(self, state, lvl_map):
|
||||
agents = state[c.AGENT]
|
||||
if len(self.coordinates) == len(agents):
|
||||
coordinates = self.coordinates
|
||||
elif len(self.coordinates) > len(agents):
|
||||
coordinates = random.choices(self.coordinates, k=len(agents))
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
for agent, pos in zip(agents, coordinates):
|
||||
agent.move(pos, state)
|
||||
|
||||
def tick_step(self, state):
|
||||
return []
|
||||
|
||||
def tick_post_step(self, state) -> List[TickResult]:
|
||||
return []
|
||||
@@ -3,8 +3,6 @@ from random import shuffle
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
|
||||
|
||||
from ...algorithms.static.utils import points_to_graph
|
||||
from ...environment import constants as c
|
||||
from ...environment.actions import Action, ALL_BASEACTIONS
|
||||
from ...environment.entity.entity import Entity
|
||||
@@ -26,7 +24,6 @@ class Maintainer(Entity):
|
||||
self._next = []
|
||||
self._last = []
|
||||
self._last_serviced = 'None'
|
||||
self._floortile_graph = None
|
||||
|
||||
def tick(self, state):
|
||||
if found_objective := h.get_first(state[self.objective].by_pos(self.pos)):
|
||||
@@ -41,17 +38,18 @@ class Maintainer(Entity):
|
||||
return action.do(self, state)
|
||||
|
||||
def get_move_action(self, state) -> Action:
|
||||
if not self._floortile_graph:
|
||||
state.print("Generating Floorgraph....")
|
||||
self._floortile_graph = points_to_graph(state.entities.floorlist)
|
||||
if self._path is None or not self._path:
|
||||
if self._path is None or not len(self._path):
|
||||
if not self._next:
|
||||
self._next = list(state[self.objective].values()) + [Floor(*state.random_free_position)]
|
||||
shuffle(self._next)
|
||||
self._last = []
|
||||
self._last.append(self._next.pop())
|
||||
state.print("Calculating shortest path....")
|
||||
self._path = self.calculate_route(self._last[-1])
|
||||
self._path = self.calculate_route(self._last[-1], state.floortile_graph)
|
||||
if not self._path:
|
||||
self._last.append(self._next.pop())
|
||||
state.print("Calculating shortest path.... Again....")
|
||||
self._path = self.calculate_route(self._last[-1], state.floortile_graph)
|
||||
|
||||
if door := self._closed_door_in_path(state):
|
||||
state.print(f"{self} found {door} that is closed. Attempt to open.")
|
||||
@@ -67,8 +65,8 @@ class Maintainer(Entity):
|
||||
raise EnvironmentError
|
||||
return action_obj
|
||||
|
||||
def calculate_route(self, entity):
|
||||
route = nx.shortest_path(self._floortile_graph, self.pos, entity.pos)
|
||||
def calculate_route(self, entity, floortile_graph):
|
||||
route = nx.shortest_path(floortile_graph, self.pos, entity.pos)
|
||||
return route[1:]
|
||||
|
||||
def _closed_door_in_path(self, state):
|
||||
|
||||
@@ -14,14 +14,8 @@ class Maintainers(Collection):
|
||||
var_is_blocking_light = False
|
||||
var_has_position = True
|
||||
|
||||
def __init__(self, size, *args, coords_or_quantity: int = None,
|
||||
spawnrule: Union[None, Dict[str, dict]] = None,
|
||||
**kwargs):
|
||||
super(Collection, self).__init__(*args, **kwargs)
|
||||
self._coords_or_quantity = coords_or_quantity
|
||||
self.size = size
|
||||
self._spawnrule = spawnrule
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args):
|
||||
self.add_items([self._entity(mc.MACHINES, MachineAction(), pos) for pos in coords_or_quantity])
|
||||
|
||||
@@ -11,19 +11,21 @@ class ZoneInit(Rule):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._zones = list()
|
||||
|
||||
def on_init(self, state, lvl_map):
|
||||
zones = []
|
||||
z_idx = 1
|
||||
|
||||
while z_idx:
|
||||
zone_positions = lvl_map.get_coordinates_for_symbol(z_idx)
|
||||
if len(zone_positions):
|
||||
zones.append(Zone(zone_positions))
|
||||
self._zones.append(Zone(zone_positions))
|
||||
z_idx += 1
|
||||
else:
|
||||
z_idx = 0
|
||||
state[z.ZONES].add_items(zones)
|
||||
|
||||
def on_reset(self, state):
|
||||
state[z.ZONES].add_items(self._zones)
|
||||
return []
|
||||
|
||||
|
||||
@@ -32,7 +34,7 @@ class AgentSingleZonePlacement(Rule):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def on_init(self, state, lvl_map):
|
||||
def on_reset(self, state):
|
||||
n_agents = len(state[c.AGENT])
|
||||
assert len(state[z.ZONES]) >= n_agents
|
||||
|
||||
@@ -48,19 +50,16 @@ class AgentSingleZonePlacement(Rule):
|
||||
class IndividualDestinationZonePlacement(Rule):
|
||||
|
||||
def __init__(self):
|
||||
raise NotImplementedError("This is rpetty new, and needs to be debugged, after the zones")
|
||||
super().__init__()
|
||||
|
||||
def on_init(self, state, lvl_map):
|
||||
def on_reset(self, state):
|
||||
for agent in state[c.AGENT]:
|
||||
self.trigger_destination_spawn(agent, state)
|
||||
pass
|
||||
return []
|
||||
|
||||
def tick_step(self, state):
|
||||
self.trigger_spawn(agent, state)
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def trigger_destination_spawn(agent, state):
|
||||
def trigger_spawn(agent, state):
|
||||
agent_zones = state[z.ZONES].by_pos(agent.pos)
|
||||
other_zones = [x for x in state[z.ZONES] if x not in agent_zones]
|
||||
already_has_destination = True
|
||||
|
||||
Reference in New Issue
Block a user