diff --git a/environments/factory/base_factory.py b/environments/factory/base_factory.py index 2f44844..85cb6e2 100644 --- a/environments/factory/base_factory.py +++ b/environments/factory/base_factory.py @@ -1,5 +1,6 @@ from typing import List, Union +import gym import numpy as np from pathlib import Path @@ -29,7 +30,11 @@ class AgentState: raise AttributeError(f'"{key}" cannot be updated, this attr is not a part of {self.__class__.__name__}') -class BaseFactory: +class BaseFactory(gym.Env): + + @property + def action_space(self): + return self._registered_actions @property def movement_actions(self): @@ -44,13 +49,20 @@ class BaseFactory: self.max_steps = max_steps self.allow_vertical_movement = True self.allow_horizontal_movement = True + self.allow_no_OP = True + self._registered_actions = self.movement_actions + int(self.allow_no_OP) self.level = h.one_hot_level( h.parse_level(Path(__file__).parent / h.LEVELS_DIR / f'{level}.txt') ) self.slice_strings = {0: 'level', **{i: f'agent#{i}' for i in range(1, self.n_agents+1)}} + self.reset() - def reset(self) -> (np.ndarray, int, bool, dict): + def register_actions(self, n_actions): + self._registered_actions += n_actions + return True + + def reset(self) -> (np.ndarray, int, bool, dict): self.done = False self.steps = 0 self.cumulative_reward = 0 diff --git a/environments/factory/simple_factory_getting_dirty.py b/environments/factory/simple_factory_getting_dirty.py index e2e190e..93421bd 100644 --- a/environments/factory/simple_factory_getting_dirty.py +++ b/environments/factory/simple_factory_getting_dirty.py @@ -11,7 +11,6 @@ from environments import helpers as h from environments.factory.renderer import Renderer from environments.factory.renderer import Entity - DIRT_INDEX = -1 @@ -40,7 +39,7 @@ class GettingDirty(BaseFactory): height, width = self.state.shape[1:] self.renderer = Renderer(width, height, view_radius=0) - dirt = [Entity('dirt', [x, y], (min(self.state[DIRT_INDEX, x, y],1)), 'scale') for x, y in np.argwhere(self.state[DIRT_INDEX] > h.IS_FREE_CELL)] + dirt = [Entity('dirt', [x, y], self.state[DIRT_INDEX, x, y]) for x, y in np.argwhere(self.state[DIRT_INDEX] > h.IS_FREE_CELL)] walls = [Entity('dirt', pos) for pos in np.argwhere(self.state[h.LEVEL_IDX] > h.IS_FREE_CELL)] agents = [Entity('agent', pos) for pos in np.argwhere(self.state[h.AGENT_START_IDX] > h.IS_FREE_CELL)] @@ -101,7 +100,10 @@ class GettingDirty(BaseFactory): def calculate_reward(self, agent_states: List[AgentState]) -> (int, dict): # TODO: What reward to use? - this_step_reward = 0 + current_dirt_amount = self.state[DIRT_INDEX].sum() + dirty_tiles = len(np.nonzero(self.state[DIRT_INDEX])) + + this_step_reward = -(dirty_tiles / current_dirt_amount) for agent_state in agent_states: collisions = agent_state.collisions @@ -113,8 +115,8 @@ class GettingDirty(BaseFactory): for entity in collisions: if entity != self.string_slices["dirt"]: self.monitor.add(f'agent_{agent_state.i}_vs_{self.slice_strings[entity]}', 1) - self.monitor.set('dirt_amount', self.state[DIRT_INDEX].sum()) - self.monitor.set('dirty_tiles', len(np.nonzero(self.state[DIRT_INDEX]))) + self.monitor.set('dirt_amount', current_dirt_amount) + self.monitor.set('dirty_tiles', dirty_tiles) return this_step_reward, {} @@ -126,13 +128,13 @@ if __name__ == '__main__': monitor_list = list() for epoch in range(100): random_actions = [random.randint(0, 8) for _ in range(200)] - state, r, done, _ = factory.reset() - for action in random_actions: - state, r, done, info = factory.step(action) + env_state, reward, done_bool, _ = factory.reset() + for agent_i_action in random_actions: + env_state, reward, done_bool, info_obj = factory.step(agent_i_action) if render: factory.render() monitor_list.append(factory.monitor.to_pd_dataframe()) - print(f'Factory run {epoch} done, reward is:\n {r}') + print(f'Factory run {epoch} done, reward is:\n {reward}') from pathlib import Path import pickle