mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-06-24 04:11:36 +02:00
Remove BoundDestination Object
New Variable 'var_can_be_bound' Observations adjusted accordingly
This commit is contained in:
@ -1,4 +1,4 @@
|
||||
from .actions import DestAction
|
||||
from .entitites import Destination
|
||||
from .groups import Destinations, BoundDestinations
|
||||
from .groups import Destinations
|
||||
from .rules import DestinationReachAll, DestinationSpawn
|
||||
|
@ -13,9 +13,7 @@ class DestAction(Action):
|
||||
super().__init__(d.DESTINATION)
|
||||
|
||||
def do(self, entity, state) -> Union[None, ActionResult]:
|
||||
dest_entities = d.DESTINATION if d.DESTINATION in state else d.BOUNDDESTINATION
|
||||
assert dest_entities
|
||||
if destination := state[dest_entities].by_pos(entity.pos):
|
||||
if destination := state[d.DESTINATION].by_pos(entity.pos):
|
||||
valid = destination.do_wait_action(entity)
|
||||
state.print(f'{entity.name} just waited at {entity.pos}')
|
||||
else:
|
||||
|
@ -1,7 +1,6 @@
|
||||
|
||||
# Destination Env
|
||||
DESTINATION = 'Destinations'
|
||||
BOUNDDESTINATION = 'BoundDestinations'
|
||||
DEST_SYMBOL = 1
|
||||
|
||||
WAIT_ON_DEST = 'WAIT'
|
||||
|
@ -8,13 +8,18 @@ from marl_factory_grid.utils.render import RenderEntity
|
||||
from marl_factory_grid.modules.destinations import constants as d
|
||||
|
||||
|
||||
class Destination(Entity):
|
||||
class Destination(BoundEntityMixin, Entity):
|
||||
|
||||
var_can_move = False
|
||||
var_can_collide = False
|
||||
var_has_position = True
|
||||
var_is_blocking_pos = False
|
||||
var_is_blocking_light = False
|
||||
var_can_be_bound = True # Introduce this globally!
|
||||
|
||||
@property
|
||||
def was_reached(self):
|
||||
return self._was_reached
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
@ -22,6 +27,7 @@ class Destination(Entity):
|
||||
|
||||
def __init__(self, *args, action_counts=0, **kwargs):
|
||||
super(Destination, self).__init__(*args, **kwargs)
|
||||
self._was_reached = False
|
||||
self.action_counts = action_counts
|
||||
self._per_agent_actions = defaultdict(lambda: 0)
|
||||
|
||||
@ -30,9 +36,15 @@ class Destination(Entity):
|
||||
return c.VALID
|
||||
|
||||
@property
|
||||
def is_considered_reached(self):
|
||||
agent_at_position = any(c.AGENT.lower() in x.name.lower() for x in self.tile.guests_that_can_collide)
|
||||
return agent_at_position or any(x >= self.action_counts for x in self._per_agent_actions.values())
|
||||
def has_just_been_reached(self):
|
||||
if self.was_reached:
|
||||
return False
|
||||
agent_at_position = any(self.bound_entity == x for x in self.tile.guests_that_can_collide)
|
||||
if self.bound_entity:
|
||||
return ((agent_at_position and not self.action_counts)
|
||||
or self._per_agent_actions[self.bound_entity.name] >= self.action_counts >= 1)
|
||||
else:
|
||||
return agent_at_position or any(x >= self.action_counts for x in self._per_agent_actions.values())
|
||||
|
||||
def agent_did_action(self, agent: Agent):
|
||||
return self._per_agent_actions[agent.name] >= self.action_counts
|
||||
@ -44,21 +56,10 @@ class Destination(Entity):
|
||||
return state_summary
|
||||
|
||||
def render(self):
|
||||
return RenderEntity(d.DESTINATION, self.pos)
|
||||
if self.was_reached:
|
||||
return None
|
||||
else:
|
||||
return RenderEntity(d.DESTINATION, self.pos)
|
||||
|
||||
|
||||
class BoundDestination(BoundEntityMixin, Destination):
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
return d.DEST_SYMBOL
|
||||
|
||||
def __init__(self, entity, *args, **kwargs):
|
||||
self.bind_to(entity)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def is_considered_reached(self):
|
||||
agent_at_position = any(self.bound_entity == x for x in self.tile.guests_that_can_collide)
|
||||
return ((agent_at_position and not self.action_counts)
|
||||
or self._per_agent_actions[self.bound_entity.name] >= self.action_counts >= 1)
|
||||
def mark_as_reached(self):
|
||||
self._was_reached = True
|
||||
|
@ -1,6 +1,6 @@
|
||||
from marl_factory_grid.environment.groups.env_objects import EnvObjects
|
||||
from marl_factory_grid.environment.groups.mixins import PositionMixin, HasBoundMixin
|
||||
from marl_factory_grid.modules.destinations.entitites import Destination, BoundDestination
|
||||
from marl_factory_grid.modules.destinations.entitites import Destination
|
||||
|
||||
|
||||
class Destinations(PositionMixin, EnvObjects):
|
||||
@ -14,12 +14,3 @@ class Destinations(PositionMixin, EnvObjects):
|
||||
|
||||
def __repr__(self):
|
||||
return super(Destinations, self).__repr__()
|
||||
|
||||
|
||||
class BoundDestinations(HasBoundMixin, Destinations):
|
||||
|
||||
_entity = BoundDestination
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
@ -1,12 +1,12 @@
|
||||
import ast
|
||||
from random import shuffle
|
||||
from typing import List, Union, Dict, Tuple
|
||||
from typing import List, Dict, Tuple
|
||||
from marl_factory_grid.environment.rules import Rule
|
||||
from marl_factory_grid.utils.results import TickResult, DoneResult
|
||||
from marl_factory_grid.environment import constants as c
|
||||
|
||||
from marl_factory_grid.modules.destinations import constants as d, rewards as r
|
||||
from marl_factory_grid.modules.destinations.entitites import Destination, BoundDestination
|
||||
from marl_factory_grid.modules.destinations.entitites import Destination
|
||||
|
||||
|
||||
class DestinationReachAll(Rule):
|
||||
@ -16,21 +16,28 @@ class DestinationReachAll(Rule):
|
||||
|
||||
def tick_step(self, state) -> List[TickResult]:
|
||||
results = []
|
||||
for dest in list(state[next(key for key in state.entities.names if d.DESTINATION in key)]):
|
||||
if dest.is_considered_reached:
|
||||
agent = state[c.AGENT].by_pos(dest.pos)
|
||||
results.append(TickResult(self.name, validity=c.VALID, reward=r.DEST_REACHED, entity=agent))
|
||||
state.print(f'{dest.name} is reached now, removing...')
|
||||
assert dest.destroy(), f'{dest.name} could not be destroyed. Critical Error.'
|
||||
for dest in state[d.DESTINATION]:
|
||||
if dest.has_just_been_reached and not dest.was_reached:
|
||||
# Dest has just been reached, some agent needs to stand here, grab any first.
|
||||
for agent in state[c.AGENT].by_pos(dest.pos):
|
||||
if dest.bound_entity:
|
||||
if dest.bound_entity == agent:
|
||||
results.append(TickResult(self.name, validity=c.VALID, reward=r.DEST_REACHED, entity=agent))
|
||||
else:
|
||||
pass
|
||||
else:
|
||||
results.append(TickResult(self.name, validity=c.VALID, reward=r.DEST_REACHED, entity=agent))
|
||||
state.print(f'{dest.name} is reached now, mark as reached...')
|
||||
dest.mark_as_reached()
|
||||
else:
|
||||
pass
|
||||
return [TickResult(self.name, validity=c.VALID, reward=0, entity=None)]
|
||||
pass
|
||||
return results
|
||||
|
||||
def tick_post_step(self, state) -> List[TickResult]:
|
||||
return []
|
||||
|
||||
def on_check_done(self, state) -> List[DoneResult]:
|
||||
if not len(state[next(key for key in state.entities.names if d.DESTINATION in key)]):
|
||||
if all(x.was_reached for x in state[d.DESTINATION]):
|
||||
return [DoneResult(self.name, validity=c.VALID, reward=r.DEST_REACHED)]
|
||||
return [DoneResult(self.name, validity=c.NOT_VALID, reward=0)]
|
||||
|
||||
@ -41,7 +48,7 @@ class DestinationReachAny(DestinationReachAll):
|
||||
super(DestinationReachAny, self).__init__()
|
||||
|
||||
def on_check_done(self, state) -> List[DoneResult]:
|
||||
if not len(state[next(key for key in state.entities.names if d.DESTINATION in key)]):
|
||||
if any(x.was_reached for x in state[d.DESTINATION]):
|
||||
return [DoneResult(self.name, validity=c.VALID, reward=r.DEST_REACHED)]
|
||||
return []
|
||||
|
||||
@ -95,10 +102,10 @@ class FixedDestinationSpawn(Rule):
|
||||
shuffle(position_list)
|
||||
while True:
|
||||
pos = position_list.pop()
|
||||
if pos != agent.pos and not state[d.BOUNDDESTINATION].by_pos(pos):
|
||||
destination = BoundDestination(agent, state[c.FLOORS].by_pos(pos))
|
||||
if pos != agent.pos and not state[d.DESTINATION].by_pos(pos):
|
||||
destination = Destination(state[c.FLOORS].by_pos(pos), bind_to=agent)
|
||||
break
|
||||
else:
|
||||
continue
|
||||
state[d.BOUNDDESTINATION].add_item(destination)
|
||||
state[d.DESTINATION].add_item(destination)
|
||||
pass
|
||||
|
Reference in New Issue
Block a user