Redone the spawn procedute and destination objects

This commit is contained in:
Steffen Illium
2023-10-11 16:36:48 +02:00
parent e64fa84ef1
commit e326a95bf4
32 changed files with 266 additions and 146 deletions

View File

@ -1,4 +1,4 @@
from .actions import DestAction
from .entitites import Destination
from .groups import ReachedDestinations, Destinations
from .rules import DestinationDone, DestinationReach, DestinationSpawn
from .groups import Destinations, BoundDestinations
from .rules import DestinationReachAll, DestinationSpawn

View File

@ -3,8 +3,6 @@
DESTINATION = 'Destinations'
BOUNDDESTINATION = 'BoundDestinations'
DEST_SYMBOL = 1
DEST_REACHED_REWARD = 0.5
DEST_REACHED = 'ReachedDestinations'
WAIT_ON_DEST = 'WAIT'

View File

@ -16,42 +16,31 @@ class Destination(Entity):
var_is_blocking_pos = False
var_is_blocking_light = False
@property
def any_agent_has_dwelled(self):
return bool(len(self._per_agent_times))
@property
def currently_dwelling_names(self):
return list(self._per_agent_times.keys())
@property
def encoding(self):
return d.DEST_SYMBOL
def __init__(self, *args, dwell_time: int = 0, **kwargs):
def __init__(self, *args, action_counts=0, **kwargs):
super(Destination, self).__init__(*args, **kwargs)
self.dwell_time = dwell_time
self._per_agent_times = defaultdict(lambda: dwell_time)
self.action_counts = action_counts
self._per_agent_actions = defaultdict(lambda: 0)
def do_wait_action(self, agent: Agent):
self._per_agent_times[agent.name] -= 1
self._per_agent_actions[agent.name] += 1
return c.VALID
def leave(self, agent: Agent):
del self._per_agent_times[agent.name]
@property
def is_considered_reached(self):
agent_at_position = any(c.AGENT.lower() in x.name.lower() for x in self.tile.guests_that_can_collide)
return (agent_at_position and not self.dwell_time) or any(x == 0 for x in self._per_agent_times.values())
return agent_at_position or any(x >= self.action_counts for x in self._per_agent_actions.values())
def agent_is_dwelling(self, agent: Agent):
return self._per_agent_times[agent.name] < self.dwell_time
def agent_did_action(self, agent: Agent):
return self._per_agent_actions[agent.name] >= self.action_counts
def summarize_state(self) -> dict:
state_summary = super().summarize_state()
state_summary.update(per_agent_times=[
dict(belongs_to=key, time=val) for key, val in self._per_agent_times.items()], dwell_time=self.dwell_time)
dict(belongs_to=key, time=val) for key, val in self._per_agent_actions.items()], counts=self.action_counts)
return state_summary
def render(self):
@ -68,9 +57,8 @@ class BoundDestination(BoundEntityMixin, Destination):
self.bind_to(entity)
super().__init__(*args, **kwargs)
@property
def is_considered_reached(self):
agent_at_position = any(self.bound_entity == x for x in self.tile.guests_that_can_collide)
return (agent_at_position and not self.dwell_time) \
or any(x == 0 for x in self._per_agent_times[self.bound_entity.name])
return ((agent_at_position and not self.action_counts)
or self._per_agent_actions[self.bound_entity.name] >= self.action_counts >= 1)

View File

@ -23,14 +23,3 @@ class BoundDestinations(HasBoundMixin, Destinations):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
class ReachedDestinations(Destinations):
_entity = Destination
is_blocking_light = False
can_collide = False
def __init__(self, *args, **kwargs):
super(ReachedDestinations, self).__init__(*args, **kwargs)
def __repr__(self):
return super(ReachedDestinations, self).__repr__()

View File

@ -1,84 +1,61 @@
from typing import List, Union
import ast
from random import shuffle
from typing import List, Union, Dict, Tuple
from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.utils.results import TickResult, DoneResult
from marl_factory_grid.environment import constants as c
from marl_factory_grid.modules.destinations import constants as d, rewards as r
from marl_factory_grid.modules.destinations.entitites import Destination
from marl_factory_grid.modules.destinations.entitites import Destination, BoundDestination
class DestinationReach(Rule):
class DestinationReachAll(Rule):
def __init__(self, n_dests: int = 1, tiles: Union[List, None] = None):
super(DestinationReach, self).__init__()
self.n_dests = n_dests or len(tiles)
self._tiles = tiles
def __init__(self):
super(DestinationReachAll, self).__init__()
def tick_step(self, state) -> List[TickResult]:
for dest in list(state[d.DESTINATION].values()):
results = []
for dest in list(state[next(key for key in state.entities.names if d.DESTINATION in key)]):
if dest.is_considered_reached:
dest.change_parent_collection(state[d.DEST_REACHED])
agent = state[c.AGENT].by_pos(dest.pos)
results.append(TickResult(self.name, validity=c.VALID, reward=r.DEST_REACHED, entity=agent))
state.print(f'{dest.name} is reached now, removing...')
assert dest.destroy(), f'{dest.name} could not be destroyed. Critical Error.'
else:
for agent_name in dest.currently_dwelling_names:
agent = state[c.AGENT][agent_name]
if agent.pos == dest.pos:
state.print(f'{agent.name} is still waiting.')
pass
else:
dest.leave(agent)
state.print(f'{agent.name} left the destination early.')
pass
return [TickResult(self.name, validity=c.VALID, reward=0, entity=None)]
def tick_post_step(self, state) -> List[TickResult]:
results = list()
for reached_dest in state[d.DEST_REACHED]:
for guest in reached_dest.tile.guests:
if guest in state[c.AGENT]:
state.print(f'{guest.name} just reached destination at {guest.pos}')
state[d.DEST_REACHED].delete_env_object(reached_dest)
results.append(TickResult(self.name, validity=c.VALID, reward=r.DEST_REACHED, entity=guest))
return results
class DestinationDone(Rule):
def __init__(self):
super(DestinationDone, self).__init__()
def on_check_done(self, state) -> List[DoneResult]:
if not len(state[d.DESTINATION]):
return [DoneResult(self.name, validity=c.VALID, reward=r.DEST_REACHED)]
return []
class DoneOnReach(Rule):
def __init__(self):
super(DoneOnReach, self).__init__()
def on_check_done(self, state) -> List[DoneResult]:
dests = [x.pos for x in state[d.DESTINATION]]
agents = [x.pos for x in state[c.AGENT]]
if any([x in dests for x in agents]):
if not len(state[next(key for key in state.entities.names if d.DESTINATION in key)]):
return [DoneResult(self.name, validity=c.VALID, reward=r.DEST_REACHED)]
return [DoneResult(self.name, validity=c.NOT_VALID, reward=0)]
class DestinationReachAny(DestinationReachAll):
def __init__(self):
super(DestinationReachAny, self).__init__()
def on_check_done(self, state) -> List[DoneResult]:
if not len(state[next(key for key in state.entities.names if d.DESTINATION in key)]):
return [DoneResult(self.name, validity=c.VALID, reward=r.DEST_REACHED)]
return []
class DestinationSpawn(Rule):
def __init__(self, spawn_frequency: int = 5, n_dests: int = 1,
def __init__(self, n_dests: int = 1,
spawn_mode: str = d.MODE_GROUPED):
super(DestinationSpawn, self).__init__()
self.spawn_frequency = spawn_frequency
self.n_dests = n_dests
self.spawn_mode = spawn_mode
def on_init(self, state, lvl_map):
# noinspection PyAttributeOutsideInit
self._dest_spawn_timer = self.spawn_frequency
self.trigger_destination_spawn(self.n_dests, state)
pass
@ -88,16 +65,40 @@ class DestinationSpawn(Rule):
def tick_step(self, state) -> List[TickResult]:
if n_dest_spawn := max(0, self.n_dests - len(state[d.DESTINATION])):
if self.spawn_mode == d.MODE_GROUPED and n_dest_spawn == self.n_dests:
validity = state.rules['DestinationReach'].trigger_destination_spawn(n_dest_spawn, state)
validity = self.trigger_destination_spawn(n_dest_spawn, state)
return [TickResult(self.name, validity=validity, entity=None, value=n_dest_spawn)]
elif self.spawn_mode == d.MODE_SINGLE and n_dest_spawn:
validity = self.trigger_destination_spawn(n_dest_spawn, state)
return [TickResult(self.name, validity=validity, entity=None, value=n_dest_spawn)]
else:
pass
@staticmethod
def trigger_destination_spawn(n_dests, state, tiles=None):
tiles = tiles or state[c.FLOOR].empty_tiles[:n_dests]
if destinations := [Destination(tile) for tile in tiles]:
def trigger_destination_spawn(self, n_dests, state):
empty_positions = state[c.FLOORS].empty_tiles[:n_dests]
if destinations := [Destination(pos) for pos in empty_positions]:
state[d.DESTINATION].add_items(destinations)
state.print(f'{n_dests} new destinations have been spawned')
return c.VALID
else:
state.print('No Destiantions are spawning, limit is reached.')
return c.NOT_VALID
class FixedDestinationSpawn(Rule):
def __init__(self, per_agent_positions: Dict[str, List[Tuple[int, int]]]):
super(Rule, self).__init__()
self.per_agent_positions = {key: [ast.literal_eval(x) for x in val] for key, val in per_agent_positions.items()}
def on_init(self, state, lvl_map):
for (agent_name, position_list) in self.per_agent_positions.items():
agent = next(x for x in state[c.AGENT] if agent_name in x.name) # Fixme: Ugly AF
shuffle(position_list)
while True:
pos = position_list.pop()
if pos != agent.pos and not state[d.BOUNDDESTINATION].by_pos(pos):
destination = BoundDestination(agent, state[c.FLOORS].by_pos(pos))
break
else:
continue
state[d.BOUNDDESTINATION].add_item(destination)
pass