updated stuff

This commit is contained in:
romue 2021-05-07 11:16:12 +02:00
parent e982630dd3
commit 6eb97e20b1
7 changed files with 108 additions and 1 deletions

View File

View File

@ -0,0 +1,38 @@
import numpy as np
from pathlib import Path
from environments import helpers as h
class Factory(object):
LEVELS_DIR = 'levels'
def __init__(self, level='simple', n_agents=1, max_steps=1e3):
self.n_agents = n_agents
self.max_steps = max_steps
self.level = h.one_hot_level(
h.parse_level(Path(__file__).parent / self.LEVELS_DIR / f'{level}.txt')
)#[np.newaxis, ...]
self.reset()
def reset(self):
self.done = False
self.agents = np.zeros((self.n_agents, *self.level.shape))
free_cells = np.argwhere(self.level == 0)
np.random.shuffle(free_cells)
for i in range(self.n_agents):
r, c = free_cells[i]
self.agents[i, r, c] = 1
free_cells = free_cells[self.n_agents:]
self.state = np.concatenate((self.level[np.newaxis, ...], self.agents), 0)
def step(self, actions):
assert type(actions) in [int, list]
if type(actions) == int:
actions = [actions]
# level, agent 1,..., agent n,
for i, a in enumerate(actions):
h.check_agent_move(state=self.state, dim=i+1, action=a)
if __name__ == '__main__':
factory = Factory(n_agents=1)
factory.step(0)

View File

View File

@ -0,0 +1,10 @@
----------
----------
--------#-
----------
----------
----------
----#-----
----------
----------
----------

View File

@ -1 +0,0 @@
import numpy as np

57
environments/helpers.py Normal file
View File

@ -0,0 +1,57 @@
import numpy as np
from pathlib import Path
# Constants
WALL = '#'
# Utility functions
def parse_level(path):
with path.open('r') as lvl:
level = list(map(lambda x: list(x.strip()), lvl.readlines()))
if len(set([len(line) for line in level])) > 1:
raise AssertionError('Every row of the level string must be of equal length.')
return level
def one_hot_level(level, wall_char=WALL):
grid = np.array(level)
binary_grid = np.zeros(grid.shape)
binary_grid[grid == wall_char] = 1
return binary_grid
def check_agent_move(state, dim, action):
agent_slice = state[dim] # horizontal slice from state tensor
agent_pos = np.argwhere(agent_slice == 1)
if len(agent_pos) > 1:
raise AssertionError('Only one agent per slice is allowed.')
x, y = agent_pos[0]
x_new, y_new = x, y
if action == 0: # North
x_new -= 1
elif action == 1: # East
y_new += 1
elif action == 2: # South
x_new += 1
elif action == 3: # West
y_new -= 1
elif action == 4: # NE
x_new -= 1
y_new += 1
elif action == 5: # SE
x_new += 1
y_new += 1
elif action == 6: # SW
x_new += 1
y_new -= 1
elif action == 7: # NW
x_new -= 1
y_new -= 1
if __name__ == '__main__':
x = parse_level(Path(__file__).parent / 'factory' / 'levels' / 'simple.txt')
y = one_hot_level(x)
print(np.argwhere(y == 0))

View File

@ -1,4 +1,5 @@
appdirs==1.4.4 appdirs==1.4.4
click==7.1.2
cycler==0.10.0 cycler==0.10.0
distlib==0.3.1 distlib==0.3.1
filelock==3.0.12 filelock==3.0.12
@ -6,6 +7,7 @@ kiwisolver==1.3.1
matplotlib==3.4.1 matplotlib==3.4.1
numpy==1.20.2 numpy==1.20.2
pandas==1.2.4 pandas==1.2.4
pep517==0.10.0
Pillow==8.2.0 Pillow==8.2.0
pyparsing==2.4.7 pyparsing==2.4.7
python-dateutil==2.8.1 python-dateutil==2.8.1
@ -13,6 +15,7 @@ pytz==2021.1
scipy==1.6.3 scipy==1.6.3
seaborn==0.11.1 seaborn==0.11.1
six==1.16.0 six==1.16.0
toml==0.10.2
torch==1.8.1 torch==1.8.1
torchaudio==0.8.1 torchaudio==0.8.1
torchvision==0.9.1 torchvision==0.9.1