mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-11-02 13:37:27 +01:00
frame stack
This commit is contained in:
@@ -6,6 +6,8 @@ import gym
|
||||
import numpy as np
|
||||
from gym import spaces
|
||||
|
||||
import yaml
|
||||
|
||||
from environments import helpers as h
|
||||
|
||||
|
||||
@@ -191,6 +193,7 @@ class BaseFactory(gym.Env):
|
||||
abs(a_pos[1]-self.pomdp_radius):abs(a_pos[1]-self.pomdp_radius)+obs.shape[2]] = obs
|
||||
obs = obs_padded
|
||||
else:
|
||||
assert not self.omit_agent_slice_in_obs
|
||||
obs = self._state
|
||||
if self.omit_agent_slice_in_obs:
|
||||
if obs.shape != (3, 5, 5):
|
||||
@@ -315,7 +318,9 @@ class BaseFactory(gym.Env):
|
||||
raise NotImplementedError
|
||||
|
||||
def save_params(self, filepath: Path):
|
||||
d = {key: val for key, val in self.__dict__.items() if not key.startswith('_') or not key.startswith('__')}
|
||||
d = {key: val for key, val in self.__dict__.items() if not key.startswith('_') and not key.startswith('__')}
|
||||
filepath.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with filepath.open('wb') as f:
|
||||
# yaml.dump(d, f)
|
||||
pickle.dump(d, f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
@@ -14,14 +14,15 @@ from environments.factory.renderer import Renderer, Entity
|
||||
DIRT_INDEX = -1
|
||||
CLEAN_UP_ACTION = 'clean_up'
|
||||
|
||||
|
||||
@dataclass
|
||||
class DirtProperties:
|
||||
clean_amount = 2 # How much does the robot clean with one action.
|
||||
max_spawn_ratio = 0.2 # On max how much tiles does the dirt spawn in percent.
|
||||
gain_amount = 0.5 # How much dirt does spawn per tile
|
||||
spawn_frequency = 5 # Spawn Frequency in Steps
|
||||
max_local_amount = 1 # Max dirt amount per tile.
|
||||
max_global_amount = 20 # Max dirt amount in the whole environment.
|
||||
clean_amount: int = 2 # How much does the robot clean with one action.
|
||||
max_spawn_ratio: float = 0.2 # On max how much tiles does the dirt spawn in percent.
|
||||
gain_amount: float = 0.5 # How much dirt does spawn per tile
|
||||
spawn_frequency: int = 5 # Spawn Frequency in Steps
|
||||
max_local_amount: int = 1 # Max dirt amount per tile.
|
||||
max_global_amount: int = 20 # Max dirt amount in the whole environment.
|
||||
|
||||
|
||||
class SimpleFactory(BaseFactory):
|
||||
@@ -93,11 +94,11 @@ class SimpleFactory(BaseFactory):
|
||||
|
||||
def step(self, actions):
|
||||
_, r, done, info = super(SimpleFactory, self).step(actions)
|
||||
if not self.next_dirt_spawn:
|
||||
if not self._next_dirt_spawn:
|
||||
self.spawn_dirt()
|
||||
self.next_dirt_spawn = self.dirt_properties.spawn_frequency
|
||||
self._next_dirt_spawn = self.dirt_properties.spawn_frequency
|
||||
else:
|
||||
self.next_dirt_spawn -= 1
|
||||
self._next_dirt_spawn -= 1
|
||||
obs = self._return_state()
|
||||
return obs, r, done, info
|
||||
|
||||
@@ -117,7 +118,7 @@ class SimpleFactory(BaseFactory):
|
||||
dirt_slice = np.zeros((1, *self._state.shape[1:]))
|
||||
self._state = np.concatenate((self._state, dirt_slice)) # dirt is now the last slice
|
||||
self.spawn_dirt()
|
||||
self.next_dirt_spawn = self.dirt_properties.spawn_frequency
|
||||
self._next_dirt_spawn = self.dirt_properties.spawn_frequency
|
||||
obs = self._return_state()
|
||||
return obs
|
||||
|
||||
|
||||
Reference in New Issue
Block a user