mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-23 07:16:44 +02:00
cleaned up if else mess simple_factory_getting_dirty.py:47
This commit is contained in:
parent
cc5df76ef7
commit
38ffb746e3
@ -1,6 +1,7 @@
|
||||
from typing import List, Union, Iterable
|
||||
|
||||
import gym
|
||||
from gym import spaces
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
@ -34,7 +35,11 @@ class BaseFactory(gym.Env):
|
||||
|
||||
@property
|
||||
def action_space(self):
|
||||
return self._registered_actions
|
||||
return spaces.Discrete(self._registered_actions)
|
||||
|
||||
@property
|
||||
def observation_space(self):
|
||||
return spaces.Box(low=-1, high=1, shape=self.state.shape, dtype=np.float32)
|
||||
|
||||
@property
|
||||
def movement_actions(self):
|
||||
@ -81,15 +86,15 @@ class BaseFactory(gym.Env):
|
||||
self.agent_states.append(agent_state)
|
||||
# state.shape = level, agent 1,..., agent n,
|
||||
self.state = np.concatenate((np.expand_dims(self.level, axis=0), agents), axis=0)
|
||||
# Returns State, Reward, Done, Info
|
||||
# Returns State
|
||||
|
||||
return self.state, 0, self.done, {}
|
||||
return self.state
|
||||
|
||||
def additional_actions(self, agent_i: int, action: int) -> ((int, int), bool):
|
||||
raise NotImplementedError
|
||||
|
||||
def step(self, actions):
|
||||
actions = [actions] if isinstance(actions, int) else actions
|
||||
actions = [actions] if isinstance(actions, int) or np.isscalar(actions) else actions
|
||||
assert isinstance(actions, Iterable), f'"actions" has to be in [{int, list}]'
|
||||
self.steps += 1
|
||||
|
||||
|
@ -18,8 +18,8 @@ class SimpleFactory(BaseFactory):
|
||||
dirt_slice = np.zeros((1, *self.state.shape[1:]))
|
||||
self.state = np.concatenate((self.state, dirt_slice)) # dirt is now the last slice
|
||||
self.spawn_dirt()
|
||||
# Always: This should return state, r, done, info
|
||||
return self.state, r, done, _
|
||||
# Always: This should return state
|
||||
return self.state
|
||||
|
||||
def calculate_reward(self, agent_states):
|
||||
for agent_state in agent_states:
|
||||
|
@ -104,12 +104,12 @@ class GettingDirty(BaseFactory):
|
||||
raise RuntimeError('This should not happen!!!')
|
||||
|
||||
def reset(self) -> (np.ndarray, int, bool, dict):
|
||||
state, r, done, _ = super().reset() # state, reward, done, info ... =
|
||||
_ = super().reset() # state, reward, done, info ... =
|
||||
dirt_slice = np.zeros((1, *self.state.shape[1:]))
|
||||
self.state = np.concatenate((self.state, dirt_slice)) # dirt is now the last slice
|
||||
self.spawn_dirt()
|
||||
self.next_dirt_spawn = self._dirt_properties.spawn_frequency
|
||||
return self.state, r, self.done, {}
|
||||
return self.state
|
||||
|
||||
def calculate_reward(self, agent_states: List[AgentState]) -> (int, dict):
|
||||
# TODO: What reward to use?
|
||||
|
Loading…
x
Reference in New Issue
Block a user