cleaned up if else mess simple_factory_getting_dirty.py:47

This commit is contained in:
steffen-illium
2021-05-18 18:29:02 +02:00
parent cc5df76ef7
commit 38ffb746e3
3 changed files with 13 additions and 8 deletions

View File

@ -1,6 +1,7 @@
from typing import List, Union, Iterable from typing import List, Union, Iterable
import gym import gym
from gym import spaces
import numpy as np import numpy as np
from pathlib import Path from pathlib import Path
@ -34,7 +35,11 @@ class BaseFactory(gym.Env):
@property @property
def action_space(self): def action_space(self):
return self._registered_actions return spaces.Discrete(self._registered_actions)
@property
def observation_space(self):
return spaces.Box(low=-1, high=1, shape=self.state.shape, dtype=np.float32)
@property @property
def movement_actions(self): def movement_actions(self):
@ -81,15 +86,15 @@ class BaseFactory(gym.Env):
self.agent_states.append(agent_state) self.agent_states.append(agent_state)
# state.shape = level, agent 1,..., agent n, # state.shape = level, agent 1,..., agent n,
self.state = np.concatenate((np.expand_dims(self.level, axis=0), agents), axis=0) self.state = np.concatenate((np.expand_dims(self.level, axis=0), agents), axis=0)
# Returns State, Reward, Done, Info # Returns State
return self.state, 0, self.done, {} return self.state
def additional_actions(self, agent_i: int, action: int) -> ((int, int), bool): def additional_actions(self, agent_i: int, action: int) -> ((int, int), bool):
raise NotImplementedError raise NotImplementedError
def step(self, actions): def step(self, actions):
actions = [actions] if isinstance(actions, int) else actions actions = [actions] if isinstance(actions, int) or np.isscalar(actions) else actions
assert isinstance(actions, Iterable), f'"actions" has to be in [{int, list}]' assert isinstance(actions, Iterable), f'"actions" has to be in [{int, list}]'
self.steps += 1 self.steps += 1

View File

@ -18,8 +18,8 @@ class SimpleFactory(BaseFactory):
dirt_slice = np.zeros((1, *self.state.shape[1:])) dirt_slice = np.zeros((1, *self.state.shape[1:]))
self.state = np.concatenate((self.state, dirt_slice)) # dirt is now the last slice self.state = np.concatenate((self.state, dirt_slice)) # dirt is now the last slice
self.spawn_dirt() self.spawn_dirt()
# Always: This should return state, r, done, info # Always: This should return state
return self.state, r, done, _ return self.state
def calculate_reward(self, agent_states): def calculate_reward(self, agent_states):
for agent_state in agent_states: for agent_state in agent_states:

View File

@ -104,12 +104,12 @@ class GettingDirty(BaseFactory):
raise RuntimeError('This should not happen!!!') raise RuntimeError('This should not happen!!!')
def reset(self) -> (np.ndarray, int, bool, dict): def reset(self) -> (np.ndarray, int, bool, dict):
state, r, done, _ = super().reset() # state, reward, done, info ... = _ = super().reset() # state, reward, done, info ... =
dirt_slice = np.zeros((1, *self.state.shape[1:])) dirt_slice = np.zeros((1, *self.state.shape[1:]))
self.state = np.concatenate((self.state, dirt_slice)) # dirt is now the last slice self.state = np.concatenate((self.state, dirt_slice)) # dirt is now the last slice
self.spawn_dirt() self.spawn_dirt()
self.next_dirt_spawn = self._dirt_properties.spawn_frequency self.next_dirt_spawn = self._dirt_properties.spawn_frequency
return self.state, r, self.done, {} return self.state
def calculate_reward(self, agent_states: List[AgentState]) -> (int, dict): def calculate_reward(self, agent_states: List[AgentState]) -> (int, dict):
# TODO: What reward to use? # TODO: What reward to use?