mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-23 07:16:44 +02:00
dirt spawn frequency
This commit is contained in:
parent
2acf91b395
commit
c3224cf971
@ -1,6 +1,7 @@
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -9,13 +10,15 @@ from environments import helpers as h
|
||||
|
||||
from environments.factory.renderer import Renderer
|
||||
|
||||
|
||||
DIRT_INDEX = -1
|
||||
|
||||
|
||||
@dataclass
|
||||
class DirtProperties:
|
||||
clean_amount = 0.25
|
||||
max_spawn_ratio = 0.1
|
||||
gain_amount = 0.1
|
||||
spawn_frequency = 5
|
||||
|
||||
|
||||
class GettingDirty(BaseFactory):
|
||||
@ -62,7 +65,11 @@ class GettingDirty(BaseFactory):
|
||||
|
||||
def step(self, actions):
|
||||
_, _, _, info = super(GettingDirty, self).step(actions)
|
||||
self.spawn_dirt()
|
||||
if not self.next_dirt_spawn:
|
||||
self.spawn_dirt()
|
||||
self.next_dirt_spawn = self._dirt_properties.spawn_frequency
|
||||
else:
|
||||
self.next_dirt_spawn -= 1
|
||||
return self.state, self.cumulative_reward, self.done, info
|
||||
|
||||
def additional_actions(self, agent_i: int, action: int) -> ((int, int), bool):
|
||||
@ -87,6 +94,7 @@ class GettingDirty(BaseFactory):
|
||||
dirt_slice = np.zeros((1, *self.state.shape[1:]))
|
||||
self.state = np.concatenate((self.state, dirt_slice)) # dirt is now the last slice
|
||||
self.spawn_dirt()
|
||||
self.next_dirt_spawn = self._dirt_properties.spawn_frequency
|
||||
return self.state, r, self.done, {}
|
||||
|
||||
def calculate_reward(self, agent_states: List[AgentState]) -> (int, dict):
|
||||
@ -109,8 +117,6 @@ class GettingDirty(BaseFactory):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import random
|
||||
|
||||
render = True
|
||||
|
||||
dirt_props = DirtProperties()
|
||||
|
Loading…
x
Reference in New Issue
Block a user