diff --git a/environments/factory/__init__.py b/environments/factory/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/environments/factory/factory_cleaning.py b/environments/factory/factory_cleaning.py new file mode 100644 index 0000000..4372579 --- /dev/null +++ b/environments/factory/factory_cleaning.py @@ -0,0 +1,38 @@ +import numpy as np +from pathlib import Path +from environments import helpers as h + + +class Factory(object): + LEVELS_DIR = 'levels' + + def __init__(self, level='simple', n_agents=1, max_steps=1e3): + self.n_agents = n_agents + self.max_steps = max_steps + self.level = h.one_hot_level( + h.parse_level(Path(__file__).parent / self.LEVELS_DIR / f'{level}.txt') + )#[np.newaxis, ...] + self.reset() + + def reset(self): + self.done = False + self.agents = np.zeros((self.n_agents, *self.level.shape)) + free_cells = np.argwhere(self.level == 0) + np.random.shuffle(free_cells) + for i in range(self.n_agents): + r, c = free_cells[i] + self.agents[i, r, c] = 1 + free_cells = free_cells[self.n_agents:] + self.state = np.concatenate((self.level[np.newaxis, ...], self.agents), 0) + + def step(self, actions): + assert type(actions) in [int, list] + if type(actions) == int: + actions = [actions] + # level, agent 1,..., agent n, + for i, a in enumerate(actions): + h.check_agent_move(state=self.state, dim=i+1, action=a) + +if __name__ == '__main__': + factory = Factory(n_agents=1) + factory.step(0) \ No newline at end of file diff --git a/environments/factory/levels/__init__.py b/environments/factory/levels/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/environments/factory/levels/simple.txt b/environments/factory/levels/simple.txt new file mode 100644 index 0000000..4fe7dd8 --- /dev/null +++ b/environments/factory/levels/simple.txt @@ -0,0 +1,10 @@ +---------- +---------- +--------#- +---------- +---------- +---------- +----#----- +---------- +---------- +---------- \ No newline at end of file diff --git a/environments/factory_cleaning.py b/environments/factory_cleaning.py deleted file mode 100644 index 2247c67..0000000 --- a/environments/factory_cleaning.py +++ /dev/null @@ -1 +0,0 @@ -import numpy as np \ No newline at end of file diff --git a/environments/helpers.py b/environments/helpers.py new file mode 100644 index 0000000..57f593d --- /dev/null +++ b/environments/helpers.py @@ -0,0 +1,57 @@ +import numpy as np +from pathlib import Path + +# Constants +WALL = '#' + + +# Utility functions +def parse_level(path): + with path.open('r') as lvl: + level = list(map(lambda x: list(x.strip()), lvl.readlines())) + if len(set([len(line) for line in level])) > 1: + raise AssertionError('Every row of the level string must be of equal length.') + return level + + +def one_hot_level(level, wall_char=WALL): + grid = np.array(level) + binary_grid = np.zeros(grid.shape) + binary_grid[grid == wall_char] = 1 + return binary_grid + + +def check_agent_move(state, dim, action): + agent_slice = state[dim] # horizontal slice from state tensor + agent_pos = np.argwhere(agent_slice == 1) + if len(agent_pos) > 1: + raise AssertionError('Only one agent per slice is allowed.') + x, y = agent_pos[0] + x_new, y_new = x, y + if action == 0: # North + x_new -= 1 + elif action == 1: # East + y_new += 1 + elif action == 2: # South + x_new += 1 + elif action == 3: # West + y_new -= 1 + elif action == 4: # NE + x_new -= 1 + y_new += 1 + elif action == 5: # SE + x_new += 1 + y_new += 1 + elif action == 6: # SW + x_new += 1 + y_new -= 1 + elif action == 7: # NW + x_new -= 1 + y_new -= 1 + + + +if __name__ == '__main__': + x = parse_level(Path(__file__).parent / 'factory' / 'levels' / 'simple.txt') + y = one_hot_level(x) + print(np.argwhere(y == 0)) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index f365a79..68d27bc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ appdirs==1.4.4 +click==7.1.2 cycler==0.10.0 distlib==0.3.1 filelock==3.0.12 @@ -6,6 +7,7 @@ kiwisolver==1.3.1 matplotlib==3.4.1 numpy==1.20.2 pandas==1.2.4 +pep517==0.10.0 Pillow==8.2.0 pyparsing==2.4.7 python-dateutil==2.8.1 @@ -13,6 +15,7 @@ pytz==2021.1 scipy==1.6.3 seaborn==0.11.1 six==1.16.0 +toml==0.10.2 torch==1.8.1 torchaudio==0.8.1 torchvision==0.9.1