Rework for performance

This commit is contained in:
Steffen Illium
2022-01-10 15:54:22 +01:00
parent 78bf19f7f4
commit 435056f373
10 changed files with 525 additions and 469 deletions

View File

@ -1,6 +1,6 @@
import itertools
from collections import defaultdict
from typing import Tuple, Union, Dict, List
from typing import Tuple, Union, Dict, List, NamedTuple
import networkx as nx
import numpy as np
@ -38,37 +38,27 @@ class Constants:
OPEN_DOOR = 'open'
ACTION = 'action'
COLLISIONS = 'collision'
VALID = 'valid'
NOT_VALID = 'not_valid'
# Battery Env
CHARGE_POD = 'Charge_Pod'
BATTERIES = 'BATTERIES'
# Destination Env
DESTINATION = 'Destination'
REACHEDDESTINATION = 'ReachedDestination'
COLLISION = 'collision'
VALID = True
NOT_VALID = False
class EnvActions:
# Movements
NORTH = 'north'
EAST = 'east'
SOUTH = 'south'
WEST = 'west'
NORTHEAST = 'north_east'
SOUTHEAST = 'south_east'
SOUTHWEST = 'south_west'
NORTHWEST = 'north_west'
NORTH = 'north'
EAST = 'east'
SOUTH = 'south'
WEST = 'west'
NORTHEAST = 'north_east'
SOUTHEAST = 'south_east'
SOUTHWEST = 'south_west'
NORTHWEST = 'north_west'
# Other
NOOP = 'no_op'
# MOVE = 'move'
NOOP = 'no_op'
USE_DOOR = 'use_door'
CHARGE = 'charge'
WAIT_ON_DEST = 'wait'
@classmethod
def is_move(cls, other):
return any([other == direction for direction in cls.movement_actions()])
@ -86,8 +76,19 @@ class EnvActions:
return list(itertools.chain(cls.square_move(), cls.diagonal_move()))
class Rewards:
MOVEMENTS_VALID = -0.001
MOVEMENTS_FAIL = -0.001
NOOP = -0.1
USE_DOOR_VALID = -0.001
USE_DOOR_FAIL = -0.001
COLLISION = -1
m = EnvActions
c = Constants
r = Rewards
ACTIONMAP = defaultdict(lambda: (0, 0), {m.NORTH: (-1, 0), m.NORTHEAST: (-1, +1),
m.EAST: (0, 1), m.SOUTHEAST: (1, 1),
@ -184,15 +185,20 @@ def asset_str(agent):
# What does this abonimation do?
# if any([x is None for x in [cls._slices[j] for j in agent.collisions]]):
# print('error')
col_names = [x.name for x in agent.temp_collisions]
if any(c.AGENT in name for name in col_names):
return 'agent_collision', 'blank'
elif not agent.temp_valid or c.LEVEL in col_names or c.AGENT in col_names:
return c.AGENT, 'invalid'
elif agent.temp_valid and not EnvActions.is_move(agent.temp_action):
return c.AGENT, 'valid'
elif agent.temp_valid and EnvActions.is_move(agent.temp_action):
return c.AGENT, 'move'
if step_result := agent.step_result:
action = step_result['action_name']
valid = step_result['action_valid']
col_names = [x.name for x in step_result['collisions']]
if any(c.AGENT in name for name in col_names):
return 'agent_collision', 'blank'
elif not valid or c.LEVEL in col_names or c.AGENT in col_names:
return c.AGENT, 'invalid'
elif valid and not EnvActions.is_move(action):
return c.AGENT, 'valid'
elif valid and EnvActions.is_move(action):
return c.AGENT, 'move'
else:
return c.AGENT, 'idle'
else:
return c.AGENT, 'idle'