dirt spawn frequency

This commit is contained in:
steffen-illium 2021-05-18 11:06:43 +02:00
parent 2acf91b395
commit c3224cf971

View File

@ -1,6 +1,7 @@
from collections import OrderedDict
from dataclasses import dataclass
from typing import List
import random
import numpy as np
@ -9,13 +10,15 @@ from environments import helpers as h
from environments.factory.renderer import Renderer
DIRT_INDEX = -1
@dataclass
class DirtProperties:
clean_amount = 0.25
max_spawn_ratio = 0.1
gain_amount = 0.1
spawn_frequency = 5
class GettingDirty(BaseFactory):
@ -62,7 +65,11 @@ class GettingDirty(BaseFactory):
def step(self, actions):
_, _, _, info = super(GettingDirty, self).step(actions)
self.spawn_dirt()
if not self.next_dirt_spawn:
self.spawn_dirt()
self.next_dirt_spawn = self._dirt_properties.spawn_frequency
else:
self.next_dirt_spawn -= 1
return self.state, self.cumulative_reward, self.done, info
def additional_actions(self, agent_i: int, action: int) -> ((int, int), bool):
@ -87,6 +94,7 @@ class GettingDirty(BaseFactory):
dirt_slice = np.zeros((1, *self.state.shape[1:]))
self.state = np.concatenate((self.state, dirt_slice)) # dirt is now the last slice
self.spawn_dirt()
self.next_dirt_spawn = self._dirt_properties.spawn_frequency
return self.state, r, self.done, {}
def calculate_reward(self, agent_states: List[AgentState]) -> (int, dict):
@ -109,8 +117,6 @@ class GettingDirty(BaseFactory):
if __name__ == '__main__':
import random
render = True
dirt_props = DirtProperties()