Merge remote-tracking branch 'origin/main' into main

This commit is contained in:
romue
2021-05-18 11:30:24 +02:00
2 changed files with 35 additions and 15 deletions

View File

@ -1,5 +1,6 @@
from typing import List, Union from typing import List, Union
import gym
import numpy as np import numpy as np
from pathlib import Path from pathlib import Path
@ -29,7 +30,11 @@ class AgentState:
raise AttributeError(f'"{key}" cannot be updated, this attr is not a part of {self.__class__.__name__}') raise AttributeError(f'"{key}" cannot be updated, this attr is not a part of {self.__class__.__name__}')
class BaseFactory: class BaseFactory(gym.Env):
@property
def action_space(self):
return self._registered_actions
@property @property
def movement_actions(self): def movement_actions(self):
@ -44,13 +49,20 @@ class BaseFactory:
self.max_steps = max_steps self.max_steps = max_steps
self.allow_vertical_movement = True self.allow_vertical_movement = True
self.allow_horizontal_movement = True self.allow_horizontal_movement = True
self.allow_no_OP = True
self._registered_actions = self.movement_actions + int(self.allow_no_OP)
self.level = h.one_hot_level( self.level = h.one_hot_level(
h.parse_level(Path(__file__).parent / h.LEVELS_DIR / f'{level}.txt') h.parse_level(Path(__file__).parent / h.LEVELS_DIR / f'{level}.txt')
) )
self.slice_strings = {0: 'level', **{i: f'agent#{i}' for i in range(1, self.n_agents+1)}} self.slice_strings = {0: 'level', **{i: f'agent#{i}' for i in range(1, self.n_agents+1)}}
self.reset() self.reset()
def reset(self) -> (np.ndarray, int, bool, dict): def register_actions(self, n_actions):
self._registered_actions += n_actions
return True
def reset(self) -> (np.ndarray, int, bool, dict):
self.done = False self.done = False
self.steps = 0 self.steps = 0
self.cumulative_reward = 0 self.cumulative_reward = 0

View File

@ -1,6 +1,7 @@
from collections import OrderedDict from collections import OrderedDict
from dataclasses import dataclass from dataclasses import dataclass
from typing import List from typing import List
import random
import numpy as np import numpy as np
@ -10,14 +11,15 @@ from environments import helpers as h
from environments.factory.renderer import Renderer from environments.factory.renderer import Renderer
from environments.factory.renderer import Entity from environments.factory.renderer import Entity
DIRT_INDEX = -1 DIRT_INDEX = -1
@dataclass @dataclass
class DirtProperties: class DirtProperties:
clean_amount = 0.25 clean_amount = 0.25
max_spawn_ratio = 0.1 max_spawn_ratio = 0.1
gain_amount = 0.1 gain_amount = 0.1
spawn_frequency = 5
class GettingDirty(BaseFactory): class GettingDirty(BaseFactory):
@ -37,7 +39,7 @@ class GettingDirty(BaseFactory):
height, width = self.state.shape[1:] height, width = self.state.shape[1:]
self.renderer = Renderer(width, height, view_radius=0) self.renderer = Renderer(width, height, view_radius=0)
dirt = [Entity('dirt', [x, y], (min(self.state[DIRT_INDEX, x, y],1)), 'scale') for x, y in np.argwhere(self.state[DIRT_INDEX] > h.IS_FREE_CELL)] dirt = [Entity('dirt', [x, y], self.state[DIRT_INDEX, x, y]) for x, y in np.argwhere(self.state[DIRT_INDEX] > h.IS_FREE_CELL)]
walls = [Entity('dirt', pos) for pos in np.argwhere(self.state[h.LEVEL_IDX] > h.IS_FREE_CELL)] walls = [Entity('dirt', pos) for pos in np.argwhere(self.state[h.LEVEL_IDX] > h.IS_FREE_CELL)]
agents = [Entity('agent', pos) for pos in np.argwhere(self.state[h.AGENT_START_IDX] > h.IS_FREE_CELL)] agents = [Entity('agent', pos) for pos in np.argwhere(self.state[h.AGENT_START_IDX] > h.IS_FREE_CELL)]
@ -64,7 +66,11 @@ class GettingDirty(BaseFactory):
def step(self, actions): def step(self, actions):
_, _, _, info = super(GettingDirty, self).step(actions) _, _, _, info = super(GettingDirty, self).step(actions)
self.spawn_dirt() if not self.next_dirt_spawn:
self.spawn_dirt()
self.next_dirt_spawn = self._dirt_properties.spawn_frequency
else:
self.next_dirt_spawn -= 1
return self.state, self.cumulative_reward, self.done, info return self.state, self.cumulative_reward, self.done, info
def additional_actions(self, agent_i: int, action: int) -> ((int, int), bool): def additional_actions(self, agent_i: int, action: int) -> ((int, int), bool):
@ -89,11 +95,15 @@ class GettingDirty(BaseFactory):
dirt_slice = np.zeros((1, *self.state.shape[1:])) dirt_slice = np.zeros((1, *self.state.shape[1:]))
self.state = np.concatenate((self.state, dirt_slice)) # dirt is now the last slice self.state = np.concatenate((self.state, dirt_slice)) # dirt is now the last slice
self.spawn_dirt() self.spawn_dirt()
self.next_dirt_spawn = self._dirt_properties.spawn_frequency
return self.state, r, self.done, {} return self.state, r, self.done, {}
def calculate_reward(self, agent_states: List[AgentState]) -> (int, dict): def calculate_reward(self, agent_states: List[AgentState]) -> (int, dict):
# TODO: What reward to use? # TODO: What reward to use?
this_step_reward = 0 current_dirt_amount = self.state[DIRT_INDEX].sum()
dirty_tiles = len(np.nonzero(self.state[DIRT_INDEX]))
this_step_reward = -(dirty_tiles / current_dirt_amount)
for agent_state in agent_states: for agent_state in agent_states:
collisions = agent_state.collisions collisions = agent_state.collisions
@ -105,14 +115,12 @@ class GettingDirty(BaseFactory):
for entity in collisions: for entity in collisions:
if entity != self.string_slices["dirt"]: if entity != self.string_slices["dirt"]:
self.monitor.add(f'agent_{agent_state.i}_vs_{self.slice_strings[entity]}', 1) self.monitor.add(f'agent_{agent_state.i}_vs_{self.slice_strings[entity]}', 1)
self.monitor.set('dirt_amount', self.state[DIRT_INDEX].sum()) self.monitor.set('dirt_amount', current_dirt_amount)
self.monitor.set('dirty_tiles', len(np.nonzero(self.state[DIRT_INDEX]))) self.monitor.set('dirty_tiles', dirty_tiles)
return this_step_reward, {} return this_step_reward, {}
if __name__ == '__main__': if __name__ == '__main__':
import random
render = True render = True
dirt_props = DirtProperties() dirt_props = DirtProperties()
@ -120,13 +128,13 @@ if __name__ == '__main__':
monitor_list = list() monitor_list = list()
for epoch in range(100): for epoch in range(100):
random_actions = [random.randint(0, 8) for _ in range(200)] random_actions = [random.randint(0, 8) for _ in range(200)]
state, r, done, _ = factory.reset() env_state, reward, done_bool, _ = factory.reset()
for action in random_actions: for agent_i_action in random_actions:
state, r, done, info = factory.step(action) env_state, reward, done_bool, info_obj = factory.step(agent_i_action)
if render: if render:
factory.render() factory.render()
monitor_list.append(factory.monitor.to_pd_dataframe()) monitor_list.append(factory.monitor.to_pd_dataframe())
print(f'Factory run {epoch} done, reward is:\n {r}') print(f'Factory run {epoch} done, reward is:\n {reward}')
from pathlib import Path from pathlib import Path
import pickle import pickle