Remove BoundDestination Object

New Variable 'var_can_be_bound'
Observations adjusted accordingly
This commit is contained in:
Steffen Illium
2023-10-12 17:14:32 +02:00
parent e326a95bf4
commit f5c6317158
22 changed files with 98 additions and 110 deletions

View File

@ -5,7 +5,8 @@ Agents:
- Move8 - Move8
Observations: Observations:
- Walls - Walls
- BoundDestination - Other
- Destination
Positions: Positions:
- (2, 1) - (2, 1)
- (2, 5) - (2, 5)
@ -15,12 +16,13 @@ Agents:
- Move8 - Move8
Observations: Observations:
- Walls - Walls
- BoundDestination - Other
- Destination
Positions: Positions:
- (2, 1) - (2, 1)
- (2, 5) - (2, 5)
Entities: Entities:
BoundDestinations: {} Destinations: {}
General: General:
env_seed: 69 env_seed: 69
@ -32,7 +34,7 @@ General:
Rules: Rules:
SpawnAgents: {} SpawnAgents: {}
Collision: Collision:
done_at_collisions: true done_at_collisions: false
FixedDestinationSpawn: FixedDestinationSpawn:
per_agent_positions: per_agent_positions:
Wolfgang: Wolfgang:

View File

@ -6,7 +6,7 @@ General:
verbose: false verbose: false
Entities: Entities:
BoundDestinations: {} Destinations: {}
Doors: {} Doors: {}
GlobalPositions: {} GlobalPositions: {}
Zones: {} Zones: {}
@ -36,7 +36,7 @@ Agents:
- Walls - Walls
- Other - Other
- Doors - Doors
- BoundDestination - Destination
Sigmund: Sigmund:
Actions: Actions:
- Move8 - Move8
@ -47,5 +47,5 @@ Agents:
- Combined: - Combined:
- Other - Other
- Walls - Walls
- BoundDestination - Destination
- Doors - Doors

View File

@ -56,6 +56,7 @@ class Entity(EnvObject, abc.ABC):
return last_x - curr_x, last_y - curr_y return last_x - curr_x, last_y - curr_y
def destroy(self): def destroy(self):
if
valid = self._collection.remove_item(self) valid = self._collection.remove_item(self)
for observer in self.observers: for observer in self.observers:
observer.notify_del_entity(self) observer.notify_del_entity(self)
@ -73,10 +74,17 @@ class Entity(EnvObject, abc.ABC):
return valid return valid
return not_same_tile return not_same_tile
def __init__(self, tile, **kwargs): def __init__(self, tile, bind_to=None, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self._status = None self._status = None
self._tile = tile self._tile = tile
if bind_to:
try:
self.bind_to(bind_to)
except AttributeError:
print(f'Objects of {self.__class__.__name__} can not be bound to other entities.')
exit()
assert tile.enter(self, spawn=True), "Positions was not valid!" assert tile.enter(self, spawn=True), "Positions was not valid!"
def summarize_state(self) -> dict: def summarize_state(self) -> dict:

View File

@ -9,10 +9,16 @@ class BoundEntityMixin:
@property @property
def name(self): def name(self):
if self.bound_entity:
return f'{self.__class__.__name__}({self.bound_entity.name})' return f'{self.__class__.__name__}({self.bound_entity.name})'
else:
print()
def belongs_to_entity(self, entity): def belongs_to_entity(self, entity):
return entity == self.bound_entity return entity == self.bound_entity
def bind_to(self, entity): def bind_to(self, entity):
self._bound_entity = entity self._bound_entity = entity
def unbind(self):
self._bound_entity = None

View File

@ -91,6 +91,13 @@ class EnvObject(Object):
except AttributeError: except AttributeError:
return False return False
@property
def var_can_be_bound(self):
try:
return self._collection.var_can_be_bound or False
except AttributeError:
return False
@property @property
def var_can_move(self): def var_can_move(self):
try: try:

View File

@ -90,7 +90,7 @@ class Factory(gym.Env):
# Parse the agent conf # Parse the agent conf
parsed_agents_conf = self.conf.parse_agents_conf() parsed_agents_conf = self.conf.parse_agents_conf()
self.state = Gamestate(entities, parsed_agents_conf, rules, self.conf.env_seed) self.state = Gamestate(entities, parsed_agents_conf, rules, self.conf.env_seed, self.conf.verbose)
# All is set up, trigger entity init with variable pos # All is set up, trigger entity init with variable pos
self.state.rules.do_all_init(self.state, self.map) self.state.rules.do_all_init(self.state, self.map)
@ -235,10 +235,6 @@ class Factory(gym.Env):
del summary[key] del summary[key]
return summary return summary
def print(self, string):
if self.conf.verbose:
print(string)
def save_params(self, filepath: Path): def save_params(self, filepath: Path):
# noinspection PyProtectedMember # noinspection PyProtectedMember
filepath = Path(filepath) filepath = Path(filepath)

View File

@ -11,10 +11,6 @@ class Agents(PositionMixin, EnvObjects):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@property
def obs_pairs(self):
return [(a.name, a) for a in self]
@property @property
def action_space(self): def action_space(self):
from gymnasium import spaces from gymnasium import spaces

View File

@ -9,6 +9,7 @@ class EnvObjects(Objects):
var_can_collide: bool = False var_can_collide: bool = False
var_has_position: bool = False var_has_position: bool = False
var_can_move: bool = False var_can_move: bool = False
var_can_be_bound: bool = False
@property @property
def encodings(self): def encodings(self):

View File

@ -92,11 +92,11 @@ class HasBoundMixin:
def by_entity(self, entity): def by_entity(self, entity):
try: try:
return next((x for x in self if x.belongs_to_entity(entity))) return next((x for x in self if x.belongs_to_entity(entity)))
except StopIteration: except (StopIteration, AttributeError):
return None return None
def idx_by_entity(self, entity): def idx_by_entity(self, entity):
try: try:
return next((idx for idx, x in enumerate(self) if x.belongs_to_entity(entity))) return next((idx for idx, x in enumerate(self) if x.belongs_to_entity(entity)))
except StopIteration: except (StopIteration, AttributeError):
return None return None

View File

@ -24,7 +24,9 @@ class Objects:
@property @property
def obs_pairs(self): def obs_pairs(self):
return [(self.name, self)] pair_list = [(self.name, self)]
pair_list.extend([(a.name, a) for a in self])
return pair_list
@property @property
def names(self): def names(self):

View File

@ -124,11 +124,12 @@ class Collision(Rule):
pass pass
results.append(TickResult(entity=guest, identifier=c.COLLISION, results.append(TickResult(entity=guest, identifier=c.COLLISION,
reward=r.COLLISION, validity=c.VALID)) reward=r.COLLISION, validity=c.VALID))
self.curr_done = True self.curr_done = True if self.done_at_collisions else False
return results return results
def on_check_done(self, state) -> List[DoneResult]: def on_check_done(self, state) -> List[DoneResult]:
inter_entity_collision_detected = self.curr_done and self.done_at_collisions if self.done_at_collisions:
inter_entity_collision_detected = self.curr_done
move_failed = any(h.is_move(x.state.identifier) and not x.state.validity for x in state[c.AGENT]) move_failed = any(h.is_move(x.state.identifier) and not x.state.validity for x in state[c.AGENT])
if inter_entity_collision_detected or move_failed: if inter_entity_collision_detected or move_failed:
return [DoneResult(validity=c.VALID, identifier=c.COLLISION, reward=r.COLLISION)] return [DoneResult(validity=c.VALID, identifier=c.COLLISION, reward=r.COLLISION)]

View File

@ -1,28 +0,0 @@
from random import shuffle
from typing import List, Tuple
from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.environment import constants as c
from marl_factory_grid.modules.destinations import constants as d
from marl_factory_grid.modules.destinations.entitites import BoundDestination
class NarrowCorridorSpawn(Rule):
def __init__(self, positions: List[Tuple[int, int]], fixed: bool = False):
super().__init__()
self.fixed = fixed
self.positions = positions
def on_init(self, state, lvl_map):
if not self.fixed:
shuffle(self.positions)
for agent in state[c.AGENT]:
pass
def trigger_destination_spawn(self, state):
for (agent_name, position_list) in self.per_agent_positions.items():
agent = state[c.AGENT][agent_name]
destinations = [BoundDestination(agent, pos) for pos in position_list]
state[d.DESTINATION].add_items(destinations)
return c.VALID

View File

@ -1,4 +1,4 @@
from .actions import DestAction from .actions import DestAction
from .entitites import Destination from .entitites import Destination
from .groups import Destinations, BoundDestinations from .groups import Destinations
from .rules import DestinationReachAll, DestinationSpawn from .rules import DestinationReachAll, DestinationSpawn

View File

@ -13,9 +13,7 @@ class DestAction(Action):
super().__init__(d.DESTINATION) super().__init__(d.DESTINATION)
def do(self, entity, state) -> Union[None, ActionResult]: def do(self, entity, state) -> Union[None, ActionResult]:
dest_entities = d.DESTINATION if d.DESTINATION in state else d.BOUNDDESTINATION if destination := state[d.DESTINATION].by_pos(entity.pos):
assert dest_entities
if destination := state[dest_entities].by_pos(entity.pos):
valid = destination.do_wait_action(entity) valid = destination.do_wait_action(entity)
state.print(f'{entity.name} just waited at {entity.pos}') state.print(f'{entity.name} just waited at {entity.pos}')
else: else:

View File

@ -1,7 +1,6 @@
# Destination Env # Destination Env
DESTINATION = 'Destinations' DESTINATION = 'Destinations'
BOUNDDESTINATION = 'BoundDestinations'
DEST_SYMBOL = 1 DEST_SYMBOL = 1
WAIT_ON_DEST = 'WAIT' WAIT_ON_DEST = 'WAIT'

View File

@ -8,13 +8,18 @@ from marl_factory_grid.utils.render import RenderEntity
from marl_factory_grid.modules.destinations import constants as d from marl_factory_grid.modules.destinations import constants as d
class Destination(Entity): class Destination(BoundEntityMixin, Entity):
var_can_move = False var_can_move = False
var_can_collide = False var_can_collide = False
var_has_position = True var_has_position = True
var_is_blocking_pos = False var_is_blocking_pos = False
var_is_blocking_light = False var_is_blocking_light = False
var_can_be_bound = True # Introduce this globally!
@property
def was_reached(self):
return self._was_reached
@property @property
def encoding(self): def encoding(self):
@ -22,6 +27,7 @@ class Destination(Entity):
def __init__(self, *args, action_counts=0, **kwargs): def __init__(self, *args, action_counts=0, **kwargs):
super(Destination, self).__init__(*args, **kwargs) super(Destination, self).__init__(*args, **kwargs)
self._was_reached = False
self.action_counts = action_counts self.action_counts = action_counts
self._per_agent_actions = defaultdict(lambda: 0) self._per_agent_actions = defaultdict(lambda: 0)
@ -30,8 +36,14 @@ class Destination(Entity):
return c.VALID return c.VALID
@property @property
def is_considered_reached(self): def has_just_been_reached(self):
agent_at_position = any(c.AGENT.lower() in x.name.lower() for x in self.tile.guests_that_can_collide) if self.was_reached:
return False
agent_at_position = any(self.bound_entity == x for x in self.tile.guests_that_can_collide)
if self.bound_entity:
return ((agent_at_position and not self.action_counts)
or self._per_agent_actions[self.bound_entity.name] >= self.action_counts >= 1)
else:
return agent_at_position or any(x >= self.action_counts for x in self._per_agent_actions.values()) return agent_at_position or any(x >= self.action_counts for x in self._per_agent_actions.values())
def agent_did_action(self, agent: Agent): def agent_did_action(self, agent: Agent):
@ -44,21 +56,10 @@ class Destination(Entity):
return state_summary return state_summary
def render(self): def render(self):
if self.was_reached:
return None
else:
return RenderEntity(d.DESTINATION, self.pos) return RenderEntity(d.DESTINATION, self.pos)
def mark_as_reached(self):
class BoundDestination(BoundEntityMixin, Destination): self._was_reached = True
@property
def encoding(self):
return d.DEST_SYMBOL
def __init__(self, entity, *args, **kwargs):
self.bind_to(entity)
super().__init__(*args, **kwargs)
@property
def is_considered_reached(self):
agent_at_position = any(self.bound_entity == x for x in self.tile.guests_that_can_collide)
return ((agent_at_position and not self.action_counts)
or self._per_agent_actions[self.bound_entity.name] >= self.action_counts >= 1)

View File

@ -1,6 +1,6 @@
from marl_factory_grid.environment.groups.env_objects import EnvObjects from marl_factory_grid.environment.groups.env_objects import EnvObjects
from marl_factory_grid.environment.groups.mixins import PositionMixin, HasBoundMixin from marl_factory_grid.environment.groups.mixins import PositionMixin, HasBoundMixin
from marl_factory_grid.modules.destinations.entitites import Destination, BoundDestination from marl_factory_grid.modules.destinations.entitites import Destination
class Destinations(PositionMixin, EnvObjects): class Destinations(PositionMixin, EnvObjects):
@ -14,12 +14,3 @@ class Destinations(PositionMixin, EnvObjects):
def __repr__(self): def __repr__(self):
return super(Destinations, self).__repr__() return super(Destinations, self).__repr__()
class BoundDestinations(HasBoundMixin, Destinations):
_entity = BoundDestination
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

View File

@ -1,12 +1,12 @@
import ast import ast
from random import shuffle from random import shuffle
from typing import List, Union, Dict, Tuple from typing import List, Dict, Tuple
from marl_factory_grid.environment.rules import Rule from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.utils.results import TickResult, DoneResult from marl_factory_grid.utils.results import TickResult, DoneResult
from marl_factory_grid.environment import constants as c from marl_factory_grid.environment import constants as c
from marl_factory_grid.modules.destinations import constants as d, rewards as r from marl_factory_grid.modules.destinations import constants as d, rewards as r
from marl_factory_grid.modules.destinations.entitites import Destination, BoundDestination from marl_factory_grid.modules.destinations.entitites import Destination
class DestinationReachAll(Rule): class DestinationReachAll(Rule):
@ -16,21 +16,28 @@ class DestinationReachAll(Rule):
def tick_step(self, state) -> List[TickResult]: def tick_step(self, state) -> List[TickResult]:
results = [] results = []
for dest in list(state[next(key for key in state.entities.names if d.DESTINATION in key)]): for dest in state[d.DESTINATION]:
if dest.is_considered_reached: if dest.has_just_been_reached and not dest.was_reached:
agent = state[c.AGENT].by_pos(dest.pos) # Dest has just been reached, some agent needs to stand here, grab any first.
for agent in state[c.AGENT].by_pos(dest.pos):
if dest.bound_entity:
if dest.bound_entity == agent:
results.append(TickResult(self.name, validity=c.VALID, reward=r.DEST_REACHED, entity=agent)) results.append(TickResult(self.name, validity=c.VALID, reward=r.DEST_REACHED, entity=agent))
state.print(f'{dest.name} is reached now, removing...')
assert dest.destroy(), f'{dest.name} could not be destroyed. Critical Error.'
else: else:
pass pass
return [TickResult(self.name, validity=c.VALID, reward=0, entity=None)] else:
results.append(TickResult(self.name, validity=c.VALID, reward=r.DEST_REACHED, entity=agent))
state.print(f'{dest.name} is reached now, mark as reached...')
dest.mark_as_reached()
else:
pass
return results
def tick_post_step(self, state) -> List[TickResult]: def tick_post_step(self, state) -> List[TickResult]:
return [] return []
def on_check_done(self, state) -> List[DoneResult]: def on_check_done(self, state) -> List[DoneResult]:
if not len(state[next(key for key in state.entities.names if d.DESTINATION in key)]): if all(x.was_reached for x in state[d.DESTINATION]):
return [DoneResult(self.name, validity=c.VALID, reward=r.DEST_REACHED)] return [DoneResult(self.name, validity=c.VALID, reward=r.DEST_REACHED)]
return [DoneResult(self.name, validity=c.NOT_VALID, reward=0)] return [DoneResult(self.name, validity=c.NOT_VALID, reward=0)]
@ -41,7 +48,7 @@ class DestinationReachAny(DestinationReachAll):
super(DestinationReachAny, self).__init__() super(DestinationReachAny, self).__init__()
def on_check_done(self, state) -> List[DoneResult]: def on_check_done(self, state) -> List[DoneResult]:
if not len(state[next(key for key in state.entities.names if d.DESTINATION in key)]): if any(x.was_reached for x in state[d.DESTINATION]):
return [DoneResult(self.name, validity=c.VALID, reward=r.DEST_REACHED)] return [DoneResult(self.name, validity=c.VALID, reward=r.DEST_REACHED)]
return [] return []
@ -95,10 +102,10 @@ class FixedDestinationSpawn(Rule):
shuffle(position_list) shuffle(position_list)
while True: while True:
pos = position_list.pop() pos = position_list.pop()
if pos != agent.pos and not state[d.BOUNDDESTINATION].by_pos(pos): if pos != agent.pos and not state[d.DESTINATION].by_pos(pos):
destination = BoundDestination(agent, state[c.FLOORS].by_pos(pos)) destination = Destination(state[c.FLOORS].by_pos(pos), bind_to=agent)
break break
else: else:
continue continue
state[d.BOUNDDESTINATION].add_item(destination) state[d.DESTINATION].add_item(destination)
pass pass

View File

@ -1,5 +1,5 @@
####### #######
###-### ###-###
#1---2# #-----#
###-### ###-###
####### #######

View File

@ -2,7 +2,7 @@ from random import choices, choice
from . import constants as z, Zone from . import constants as z, Zone
from ..destinations import constants as d from ..destinations import constants as d
from ..destinations.entitites import BoundDestination from ... import Destination
from ...environment.rules import Rule from ...environment.rules import Rule
from ...environment import constants as c from ...environment import constants as c
@ -66,9 +66,10 @@ class IndividualDestinationZonePlacement(Rule):
already_has_destination = True already_has_destination = True
while already_has_destination: while already_has_destination:
tile = choice(other_zones).random_tile tile = choice(other_zones).random_tile
if state[d.BOUNDDESTINATION].by_pos(tile.pos) is None: if state[d.DESTINATION].by_pos(tile.pos) is None:
already_has_destination = False already_has_destination = False
destination = BoundDestination(agent, tile) destination = Destination(tile, bind_to=agent)
state[d.BOUNDDESTINATION].add_item(destination)
state[d.DESTINATION].add_item(destination)
continue continue
return c.VALID return c.VALID