mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2026-01-15 23:41:39 +01:00
added allowed direction check for predict move
This commit is contained in:
@@ -135,18 +135,22 @@ class TSPBaseAgent(ABC):
|
|||||||
pass
|
pass
|
||||||
next_pos = self._static_route.pop(0)
|
next_pos = self._static_route.pop(0)
|
||||||
while next_pos == self.state.pos:
|
while next_pos == self.state.pos:
|
||||||
next_pos = self._static_route.pop(0)
|
if self._static_route:
|
||||||
|
next_pos = self._static_route.pop(0)
|
||||||
else:
|
else:
|
||||||
if not self._static_route:
|
if not self._static_route:
|
||||||
self._static_route = self.calculate_tsp_route(target_identifier)[:7]
|
self._static_route = self.calculate_tsp_route(target_identifier)[:7]
|
||||||
next_pos = self._static_route.pop(0)
|
next_pos = self._static_route.pop(0)
|
||||||
while next_pos == self.state.pos:
|
while next_pos == self.state.pos:
|
||||||
next_pos = self._static_route.pop(0)
|
if self._static_route:
|
||||||
|
next_pos = self._static_route.pop(0)
|
||||||
diff = np.subtract(next_pos, self.state.pos)
|
diff = np.subtract(next_pos, self.state.pos)
|
||||||
# Retrieve action based on the pos dif (like in: What do I have to do to get there?)
|
# Retrieve action based on the pos dif (like in: What do I have to do to get there?)
|
||||||
try:
|
try:
|
||||||
action = next(action for action, pos_diff in MOVEMAP.items() if np.all(diff == pos_diff))
|
allowed_directions = [action.name for action in self.state.actions if
|
||||||
|
action.name in ['north', 'east', 'south', 'west', 'north_east', 'south_east',
|
||||||
|
'south_west', 'north_west']]
|
||||||
|
action = next(action for action, pos_diff in MOVEMAP.items() if np.all(diff == pos_diff) and action in allowed_directions)
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
print(f"No valid action found for pos diff: {diff}. Using fallback action.")
|
print(f"No valid action found for pos diff: {diff}. Using fallback action.")
|
||||||
action = choice(self.state.actions).name
|
action = choice(self.state.actions).name
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ Agents:
|
|||||||
Agent_horizontal:
|
Agent_horizontal:
|
||||||
Actions:
|
Actions:
|
||||||
- Noop
|
- Noop
|
||||||
- Move8
|
- Move4
|
||||||
Observations:
|
Observations:
|
||||||
- Walls
|
- Walls
|
||||||
- Other
|
- Other
|
||||||
@@ -27,7 +27,7 @@ Agents:
|
|||||||
Agent_vertical:
|
Agent_vertical:
|
||||||
Actions:
|
Actions:
|
||||||
- Noop
|
- Noop
|
||||||
- Move8
|
- Move4
|
||||||
Observations:
|
Observations:
|
||||||
- Walls
|
- Walls
|
||||||
- Other
|
- Other
|
||||||
|
|||||||
Reference in New Issue
Block a user