Rework for performance
This commit is contained in:
@ -1,6 +1,6 @@
|
||||
import itertools
|
||||
from collections import defaultdict
|
||||
from typing import Tuple, Union, Dict, List
|
||||
from typing import Tuple, Union, Dict, List, NamedTuple
|
||||
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
@ -38,37 +38,27 @@ class Constants:
|
||||
OPEN_DOOR = 'open'
|
||||
|
||||
ACTION = 'action'
|
||||
COLLISIONS = 'collision'
|
||||
VALID = 'valid'
|
||||
NOT_VALID = 'not_valid'
|
||||
|
||||
# Battery Env
|
||||
CHARGE_POD = 'Charge_Pod'
|
||||
BATTERIES = 'BATTERIES'
|
||||
|
||||
# Destination Env
|
||||
DESTINATION = 'Destination'
|
||||
REACHEDDESTINATION = 'ReachedDestination'
|
||||
COLLISION = 'collision'
|
||||
VALID = True
|
||||
NOT_VALID = False
|
||||
|
||||
|
||||
class EnvActions:
|
||||
# Movements
|
||||
NORTH = 'north'
|
||||
EAST = 'east'
|
||||
SOUTH = 'south'
|
||||
WEST = 'west'
|
||||
NORTHEAST = 'north_east'
|
||||
SOUTHEAST = 'south_east'
|
||||
SOUTHWEST = 'south_west'
|
||||
NORTHWEST = 'north_west'
|
||||
NORTH = 'north'
|
||||
EAST = 'east'
|
||||
SOUTH = 'south'
|
||||
WEST = 'west'
|
||||
NORTHEAST = 'north_east'
|
||||
SOUTHEAST = 'south_east'
|
||||
SOUTHWEST = 'south_west'
|
||||
NORTHWEST = 'north_west'
|
||||
|
||||
# Other
|
||||
NOOP = 'no_op'
|
||||
# MOVE = 'move'
|
||||
NOOP = 'no_op'
|
||||
USE_DOOR = 'use_door'
|
||||
|
||||
CHARGE = 'charge'
|
||||
WAIT_ON_DEST = 'wait'
|
||||
|
||||
@classmethod
|
||||
def is_move(cls, other):
|
||||
return any([other == direction for direction in cls.movement_actions()])
|
||||
@ -86,8 +76,19 @@ class EnvActions:
|
||||
return list(itertools.chain(cls.square_move(), cls.diagonal_move()))
|
||||
|
||||
|
||||
class Rewards:
|
||||
|
||||
MOVEMENTS_VALID = -0.001
|
||||
MOVEMENTS_FAIL = -0.001
|
||||
NOOP = -0.1
|
||||
USE_DOOR_VALID = -0.001
|
||||
USE_DOOR_FAIL = -0.001
|
||||
COLLISION = -1
|
||||
|
||||
|
||||
m = EnvActions
|
||||
c = Constants
|
||||
r = Rewards
|
||||
|
||||
ACTIONMAP = defaultdict(lambda: (0, 0), {m.NORTH: (-1, 0), m.NORTHEAST: (-1, +1),
|
||||
m.EAST: (0, 1), m.SOUTHEAST: (1, 1),
|
||||
@ -184,15 +185,20 @@ def asset_str(agent):
|
||||
# What does this abonimation do?
|
||||
# if any([x is None for x in [cls._slices[j] for j in agent.collisions]]):
|
||||
# print('error')
|
||||
col_names = [x.name for x in agent.temp_collisions]
|
||||
if any(c.AGENT in name for name in col_names):
|
||||
return 'agent_collision', 'blank'
|
||||
elif not agent.temp_valid or c.LEVEL in col_names or c.AGENT in col_names:
|
||||
return c.AGENT, 'invalid'
|
||||
elif agent.temp_valid and not EnvActions.is_move(agent.temp_action):
|
||||
return c.AGENT, 'valid'
|
||||
elif agent.temp_valid and EnvActions.is_move(agent.temp_action):
|
||||
return c.AGENT, 'move'
|
||||
if step_result := agent.step_result:
|
||||
action = step_result['action_name']
|
||||
valid = step_result['action_valid']
|
||||
col_names = [x.name for x in step_result['collisions']]
|
||||
if any(c.AGENT in name for name in col_names):
|
||||
return 'agent_collision', 'blank'
|
||||
elif not valid or c.LEVEL in col_names or c.AGENT in col_names:
|
||||
return c.AGENT, 'invalid'
|
||||
elif valid and not EnvActions.is_move(action):
|
||||
return c.AGENT, 'valid'
|
||||
elif valid and EnvActions.is_move(action):
|
||||
return c.AGENT, 'move'
|
||||
else:
|
||||
return c.AGENT, 'idle'
|
||||
else:
|
||||
return c.AGENT, 'idle'
|
||||
|
||||
|
Reference in New Issue
Block a user