cleaned up if else mess simple_factory_getting_dirty.py:47

This commit is contained in:
steffen-illium 2021-05-18 18:29:02 +02:00
parent cc5df76ef7
commit 38ffb746e3
3 changed files with 13 additions and 8 deletions

View File

@ -1,6 +1,7 @@
from typing import List, Union, Iterable
import gym
from gym import spaces
import numpy as np
from pathlib import Path
@ -34,7 +35,11 @@ class BaseFactory(gym.Env):
@property
def action_space(self):
return self._registered_actions
return spaces.Discrete(self._registered_actions)
@property
def observation_space(self):
return spaces.Box(low=-1, high=1, shape=self.state.shape, dtype=np.float32)
@property
def movement_actions(self):
@ -81,15 +86,15 @@ class BaseFactory(gym.Env):
self.agent_states.append(agent_state)
# state.shape = level, agent 1,..., agent n,
self.state = np.concatenate((np.expand_dims(self.level, axis=0), agents), axis=0)
# Returns State, Reward, Done, Info
# Returns State
return self.state, 0, self.done, {}
return self.state
def additional_actions(self, agent_i: int, action: int) -> ((int, int), bool):
raise NotImplementedError
def step(self, actions):
actions = [actions] if isinstance(actions, int) else actions
actions = [actions] if isinstance(actions, int) or np.isscalar(actions) else actions
assert isinstance(actions, Iterable), f'"actions" has to be in [{int, list}]'
self.steps += 1

View File

@ -18,8 +18,8 @@ class SimpleFactory(BaseFactory):
dirt_slice = np.zeros((1, *self.state.shape[1:]))
self.state = np.concatenate((self.state, dirt_slice)) # dirt is now the last slice
self.spawn_dirt()
# Always: This should return state, r, done, info
return self.state, r, done, _
# Always: This should return state
return self.state
def calculate_reward(self, agent_states):
for agent_state in agent_states:

View File

@ -104,12 +104,12 @@ class GettingDirty(BaseFactory):
raise RuntimeError('This should not happen!!!')
def reset(self) -> (np.ndarray, int, bool, dict):
state, r, done, _ = super().reset() # state, reward, done, info ... =
_ = super().reset() # state, reward, done, info ... =
dirt_slice = np.zeros((1, *self.state.shape[1:]))
self.state = np.concatenate((self.state, dirt_slice)) # dirt is now the last slice
self.spawn_dirt()
self.next_dirt_spawn = self._dirt_properties.spawn_frequency
return self.state, r, self.done, {}
return self.state
def calculate_reward(self, agent_states: List[AgentState]) -> (int, dict):
# TODO: What reward to use?