Item and Dirt Factory Working again
This commit is contained in:
@ -1,7 +1,5 @@
|
||||
import itertools
|
||||
from collections import defaultdict
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Tuple, Union, Dict, List
|
||||
|
||||
import networkx as nx
|
||||
@ -20,7 +18,7 @@ IGNORED_DF_COLUMNS = ['Episode', 'Run', 'train_step', 'step', 'index', 'dirt_amo
|
||||
|
||||
|
||||
# Constants
|
||||
class Constants(Enum):
|
||||
class Constants:
|
||||
WALL = '#'
|
||||
WALLS = 'Walls'
|
||||
FLOOR = 'Floor'
|
||||
@ -44,14 +42,6 @@ class Constants(Enum):
|
||||
VALID = 'valid'
|
||||
NOT_VALID = 'not_valid'
|
||||
|
||||
# Dirt Env
|
||||
DIRT = 'Dirt'
|
||||
|
||||
# Item Env
|
||||
ITEM = 'Item'
|
||||
INVENTORY = 'Inventory'
|
||||
DROP_OFF = 'Drop_Off'
|
||||
|
||||
# Battery Env
|
||||
CHARGE_POD = 'Charge_Pod'
|
||||
BATTERIES = 'BATTERIES'
|
||||
@ -60,14 +50,9 @@ class Constants(Enum):
|
||||
DESTINATION = 'Destination'
|
||||
REACHEDDESTINATION = 'ReachedDestination'
|
||||
|
||||
def __bool__(self):
|
||||
if 'not_' in self.value:
|
||||
return False
|
||||
else:
|
||||
return bool(self.value)
|
||||
|
||||
|
||||
class MovingAction(Enum):
|
||||
class EnvActions:
|
||||
# Movements
|
||||
NORTH = 'north'
|
||||
EAST = 'east'
|
||||
SOUTH = 'south'
|
||||
@ -77,29 +62,31 @@ class MovingAction(Enum):
|
||||
SOUTHWEST = 'south_west'
|
||||
NORTHWEST = 'north_west'
|
||||
|
||||
@classmethod
|
||||
def is_member(cls, other):
|
||||
return any([other == direction for direction in cls])
|
||||
|
||||
@classmethod
|
||||
def square(cls):
|
||||
return [cls.NORTH, cls.EAST, cls.SOUTH, cls.WEST]
|
||||
|
||||
@classmethod
|
||||
def diagonal(cls):
|
||||
return [cls.NORTHEAST, cls.SOUTHEAST, cls.SOUTHWEST, cls.NORTHWEST]
|
||||
|
||||
|
||||
class EnvActions(Enum):
|
||||
NOOP = 'no_op'
|
||||
# Other
|
||||
NOOP = 'no_op'
|
||||
USE_DOOR = 'use_door'
|
||||
CLEAN_UP = 'clean_up'
|
||||
ITEM_ACTION = 'item_action'
|
||||
|
||||
CHARGE = 'charge'
|
||||
WAIT_ON_DEST = 'wait'
|
||||
|
||||
@classmethod
|
||||
def is_move(cls, other):
|
||||
return any([other == direction for direction in cls.movement_actions()])
|
||||
|
||||
m = MovingAction
|
||||
@classmethod
|
||||
def square_move(cls):
|
||||
return [cls.NORTH, cls.EAST, cls.SOUTH, cls.WEST]
|
||||
|
||||
@classmethod
|
||||
def diagonal_move(cls):
|
||||
return [cls.NORTHEAST, cls.SOUTHEAST, cls.SOUTHWEST, cls.NORTHWEST]
|
||||
|
||||
@classmethod
|
||||
def movement_actions(cls):
|
||||
return list(itertools.chain(cls.square_move(), cls.diagonal_move()))
|
||||
|
||||
|
||||
m = EnvActions
|
||||
c = Constants
|
||||
|
||||
ACTIONMAP = defaultdict(lambda: (0, 0), {m.NORTH: (-1, 0), m.NORTHEAST: (-1, +1),
|
||||
@ -171,13 +158,10 @@ def parse_level(path):
|
||||
return level
|
||||
|
||||
|
||||
def one_hot_level(level, wall_char: Union[c, str] = c.WALL):
|
||||
def one_hot_level(level, wall_char: str = c.WALL):
|
||||
grid = np.array(level)
|
||||
binary_grid = np.zeros(grid.shape, dtype=np.int8)
|
||||
if wall_char in c:
|
||||
binary_grid[grid == wall_char.value] = c.OCCUPIED_CELL.value
|
||||
else:
|
||||
binary_grid[grid == wall_char] = c.OCCUPIED_CELL.value
|
||||
binary_grid[grid == wall_char] = c.OCCUPIED_CELL
|
||||
return binary_grid
|
||||
|
||||
|
||||
@ -198,19 +182,19 @@ def check_position(slice_to_check_against: ArrayLike, position_to_check: Tuple[i
|
||||
|
||||
def asset_str(agent):
|
||||
# What does this abonimation do?
|
||||
# if any([x is None for x in [self._slices[j] for j in agent.collisions]]):
|
||||
# if any([x is None for x in [cls._slices[j] for j in agent.collisions]]):
|
||||
# print('error')
|
||||
col_names = [x.name for x in agent.temp_collisions]
|
||||
if any(c.AGENT.value in name for name in col_names):
|
||||
if any(c.AGENT in name for name in col_names):
|
||||
return 'agent_collision', 'blank'
|
||||
elif not agent.temp_valid or c.LEVEL.name in col_names or c.AGENT.name in col_names:
|
||||
return c.AGENT.value, 'invalid'
|
||||
elif agent.temp_valid and not MovingAction.is_member(agent.temp_action):
|
||||
return c.AGENT.value, 'valid'
|
||||
elif agent.temp_valid and MovingAction.is_member(agent.temp_action):
|
||||
return c.AGENT.value, 'move'
|
||||
elif not agent.temp_valid or c.LEVEL in col_names or c.AGENT in col_names:
|
||||
return c.AGENT, 'invalid'
|
||||
elif agent.temp_valid and not EnvActions.is_move(agent.temp_action):
|
||||
return c.AGENT, 'valid'
|
||||
elif agent.temp_valid and EnvActions.is_move(agent.temp_action):
|
||||
return c.AGENT, 'move'
|
||||
else:
|
||||
return c.AGENT.value, 'idle'
|
||||
return c.AGENT, 'idle'
|
||||
|
||||
|
||||
def points_to_graph(coordiniates_or_tiles, allow_euclidean_connections=True, allow_manhattan_connections=True):
|
||||
@ -229,9 +213,3 @@ def points_to_graph(coordiniates_or_tiles, allow_euclidean_connections=True, all
|
||||
elif allow_manhattan_connections and not allow_euclidean_connections and not all(diff) and any(diff):
|
||||
graph.add_edge(a, b)
|
||||
return graph
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parsed_level = parse_level(Path(__file__).parent / 'factory' / 'levels' / 'simple.txt')
|
||||
y = one_hot_level(parsed_level)
|
||||
print(np.argwhere(y == 0))
|
||||
|
Reference in New Issue
Block a user