dirt spawn frequency

This commit is contained in:
steffen-illium 2021-05-18 11:06:43 +02:00
parent 2acf91b395
commit c3224cf971

View File

@ -1,6 +1,7 @@
from collections import OrderedDict from collections import OrderedDict
from dataclasses import dataclass from dataclasses import dataclass
from typing import List from typing import List
import random
import numpy as np import numpy as np
@ -9,13 +10,15 @@ from environments import helpers as h
from environments.factory.renderer import Renderer from environments.factory.renderer import Renderer
DIRT_INDEX = -1 DIRT_INDEX = -1
@dataclass @dataclass
class DirtProperties: class DirtProperties:
clean_amount = 0.25 clean_amount = 0.25
max_spawn_ratio = 0.1 max_spawn_ratio = 0.1
gain_amount = 0.1 gain_amount = 0.1
spawn_frequency = 5
class GettingDirty(BaseFactory): class GettingDirty(BaseFactory):
@ -62,7 +65,11 @@ class GettingDirty(BaseFactory):
def step(self, actions): def step(self, actions):
_, _, _, info = super(GettingDirty, self).step(actions) _, _, _, info = super(GettingDirty, self).step(actions)
self.spawn_dirt() if not self.next_dirt_spawn:
self.spawn_dirt()
self.next_dirt_spawn = self._dirt_properties.spawn_frequency
else:
self.next_dirt_spawn -= 1
return self.state, self.cumulative_reward, self.done, info return self.state, self.cumulative_reward, self.done, info
def additional_actions(self, agent_i: int, action: int) -> ((int, int), bool): def additional_actions(self, agent_i: int, action: int) -> ((int, int), bool):
@ -87,6 +94,7 @@ class GettingDirty(BaseFactory):
dirt_slice = np.zeros((1, *self.state.shape[1:])) dirt_slice = np.zeros((1, *self.state.shape[1:]))
self.state = np.concatenate((self.state, dirt_slice)) # dirt is now the last slice self.state = np.concatenate((self.state, dirt_slice)) # dirt is now the last slice
self.spawn_dirt() self.spawn_dirt()
self.next_dirt_spawn = self._dirt_properties.spawn_frequency
return self.state, r, self.done, {} return self.state, r, self.done, {}
def calculate_reward(self, agent_states: List[AgentState]) -> (int, dict): def calculate_reward(self, agent_states: List[AgentState]) -> (int, dict):
@ -109,8 +117,6 @@ class GettingDirty(BaseFactory):
if __name__ == '__main__': if __name__ == '__main__':
import random
render = True render = True
dirt_props = DirtProperties() dirt_props = DirtProperties()