mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-23 07:16:44 +02:00
dirt spawn frequency
This commit is contained in:
parent
2acf91b395
commit
c3224cf971
@ -1,6 +1,7 @@
|
|||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List
|
from typing import List
|
||||||
|
import random
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@ -9,13 +10,15 @@ from environments import helpers as h
|
|||||||
|
|
||||||
from environments.factory.renderer import Renderer
|
from environments.factory.renderer import Renderer
|
||||||
|
|
||||||
|
|
||||||
DIRT_INDEX = -1
|
DIRT_INDEX = -1
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DirtProperties:
|
class DirtProperties:
|
||||||
clean_amount = 0.25
|
clean_amount = 0.25
|
||||||
max_spawn_ratio = 0.1
|
max_spawn_ratio = 0.1
|
||||||
gain_amount = 0.1
|
gain_amount = 0.1
|
||||||
|
spawn_frequency = 5
|
||||||
|
|
||||||
|
|
||||||
class GettingDirty(BaseFactory):
|
class GettingDirty(BaseFactory):
|
||||||
@ -62,7 +65,11 @@ class GettingDirty(BaseFactory):
|
|||||||
|
|
||||||
def step(self, actions):
|
def step(self, actions):
|
||||||
_, _, _, info = super(GettingDirty, self).step(actions)
|
_, _, _, info = super(GettingDirty, self).step(actions)
|
||||||
self.spawn_dirt()
|
if not self.next_dirt_spawn:
|
||||||
|
self.spawn_dirt()
|
||||||
|
self.next_dirt_spawn = self._dirt_properties.spawn_frequency
|
||||||
|
else:
|
||||||
|
self.next_dirt_spawn -= 1
|
||||||
return self.state, self.cumulative_reward, self.done, info
|
return self.state, self.cumulative_reward, self.done, info
|
||||||
|
|
||||||
def additional_actions(self, agent_i: int, action: int) -> ((int, int), bool):
|
def additional_actions(self, agent_i: int, action: int) -> ((int, int), bool):
|
||||||
@ -87,6 +94,7 @@ class GettingDirty(BaseFactory):
|
|||||||
dirt_slice = np.zeros((1, *self.state.shape[1:]))
|
dirt_slice = np.zeros((1, *self.state.shape[1:]))
|
||||||
self.state = np.concatenate((self.state, dirt_slice)) # dirt is now the last slice
|
self.state = np.concatenate((self.state, dirt_slice)) # dirt is now the last slice
|
||||||
self.spawn_dirt()
|
self.spawn_dirt()
|
||||||
|
self.next_dirt_spawn = self._dirt_properties.spawn_frequency
|
||||||
return self.state, r, self.done, {}
|
return self.state, r, self.done, {}
|
||||||
|
|
||||||
def calculate_reward(self, agent_states: List[AgentState]) -> (int, dict):
|
def calculate_reward(self, agent_states: List[AgentState]) -> (int, dict):
|
||||||
@ -109,8 +117,6 @@ class GettingDirty(BaseFactory):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
import random
|
|
||||||
|
|
||||||
render = True
|
render = True
|
||||||
|
|
||||||
dirt_props = DirtProperties()
|
dirt_props = DirtProperties()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user