Remove BoundDestination Object

New Variable 'var_can_be_bound'
Observations adjusted accordingly
This commit is contained in:
Steffen Illium
2023-10-12 17:14:32 +02:00
parent e326a95bf4
commit f5c6317158
22 changed files with 98 additions and 110 deletions

View File

@@ -1,12 +1,12 @@
import ast
from random import shuffle
from typing import List, Union, Dict, Tuple
from typing import List, Dict, Tuple
from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.utils.results import TickResult, DoneResult
from marl_factory_grid.environment import constants as c
from marl_factory_grid.modules.destinations import constants as d, rewards as r
from marl_factory_grid.modules.destinations.entitites import Destination, BoundDestination
from marl_factory_grid.modules.destinations.entitites import Destination
class DestinationReachAll(Rule):
@@ -16,21 +16,28 @@ class DestinationReachAll(Rule):
def tick_step(self, state) -> List[TickResult]:
results = []
for dest in list(state[next(key for key in state.entities.names if d.DESTINATION in key)]):
if dest.is_considered_reached:
agent = state[c.AGENT].by_pos(dest.pos)
results.append(TickResult(self.name, validity=c.VALID, reward=r.DEST_REACHED, entity=agent))
state.print(f'{dest.name} is reached now, removing...')
assert dest.destroy(), f'{dest.name} could not be destroyed. Critical Error.'
for dest in state[d.DESTINATION]:
if dest.has_just_been_reached and not dest.was_reached:
# Dest has just been reached, some agent needs to stand here, grab any first.
for agent in state[c.AGENT].by_pos(dest.pos):
if dest.bound_entity:
if dest.bound_entity == agent:
results.append(TickResult(self.name, validity=c.VALID, reward=r.DEST_REACHED, entity=agent))
else:
pass
else:
results.append(TickResult(self.name, validity=c.VALID, reward=r.DEST_REACHED, entity=agent))
state.print(f'{dest.name} is reached now, mark as reached...')
dest.mark_as_reached()
else:
pass
return [TickResult(self.name, validity=c.VALID, reward=0, entity=None)]
pass
return results
def tick_post_step(self, state) -> List[TickResult]:
return []
def on_check_done(self, state) -> List[DoneResult]:
if not len(state[next(key for key in state.entities.names if d.DESTINATION in key)]):
if all(x.was_reached for x in state[d.DESTINATION]):
return [DoneResult(self.name, validity=c.VALID, reward=r.DEST_REACHED)]
return [DoneResult(self.name, validity=c.NOT_VALID, reward=0)]
@@ -41,7 +48,7 @@ class DestinationReachAny(DestinationReachAll):
super(DestinationReachAny, self).__init__()
def on_check_done(self, state) -> List[DoneResult]:
if not len(state[next(key for key in state.entities.names if d.DESTINATION in key)]):
if any(x.was_reached for x in state[d.DESTINATION]):
return [DoneResult(self.name, validity=c.VALID, reward=r.DEST_REACHED)]
return []
@@ -95,10 +102,10 @@ class FixedDestinationSpawn(Rule):
shuffle(position_list)
while True:
pos = position_list.pop()
if pos != agent.pos and not state[d.BOUNDDESTINATION].by_pos(pos):
destination = BoundDestination(agent, state[c.FLOORS].by_pos(pos))
if pos != agent.pos and not state[d.DESTINATION].by_pos(pos):
destination = Destination(state[c.FLOORS].by_pos(pos), bind_to=agent)
break
else:
continue
state[d.BOUNDDESTINATION].add_item(destination)
state[d.DESTINATION].add_item(destination)
pass