cleaned up if else mess simple_factory_getting_dirty.py:47
This commit is contained in:
@ -1,6 +1,7 @@
|
|||||||
from typing import List, Union, Iterable
|
from typing import List, Union, Iterable
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
|
from gym import spaces
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@ -34,7 +35,11 @@ class BaseFactory(gym.Env):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def action_space(self):
|
def action_space(self):
|
||||||
return self._registered_actions
|
return spaces.Discrete(self._registered_actions)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def observation_space(self):
|
||||||
|
return spaces.Box(low=-1, high=1, shape=self.state.shape, dtype=np.float32)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def movement_actions(self):
|
def movement_actions(self):
|
||||||
@ -81,15 +86,15 @@ class BaseFactory(gym.Env):
|
|||||||
self.agent_states.append(agent_state)
|
self.agent_states.append(agent_state)
|
||||||
# state.shape = level, agent 1,..., agent n,
|
# state.shape = level, agent 1,..., agent n,
|
||||||
self.state = np.concatenate((np.expand_dims(self.level, axis=0), agents), axis=0)
|
self.state = np.concatenate((np.expand_dims(self.level, axis=0), agents), axis=0)
|
||||||
# Returns State, Reward, Done, Info
|
# Returns State
|
||||||
|
|
||||||
return self.state, 0, self.done, {}
|
return self.state
|
||||||
|
|
||||||
def additional_actions(self, agent_i: int, action: int) -> ((int, int), bool):
|
def additional_actions(self, agent_i: int, action: int) -> ((int, int), bool):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def step(self, actions):
|
def step(self, actions):
|
||||||
actions = [actions] if isinstance(actions, int) else actions
|
actions = [actions] if isinstance(actions, int) or np.isscalar(actions) else actions
|
||||||
assert isinstance(actions, Iterable), f'"actions" has to be in [{int, list}]'
|
assert isinstance(actions, Iterable), f'"actions" has to be in [{int, list}]'
|
||||||
self.steps += 1
|
self.steps += 1
|
||||||
|
|
||||||
|
@ -18,8 +18,8 @@ class SimpleFactory(BaseFactory):
|
|||||||
dirt_slice = np.zeros((1, *self.state.shape[1:]))
|
dirt_slice = np.zeros((1, *self.state.shape[1:]))
|
||||||
self.state = np.concatenate((self.state, dirt_slice)) # dirt is now the last slice
|
self.state = np.concatenate((self.state, dirt_slice)) # dirt is now the last slice
|
||||||
self.spawn_dirt()
|
self.spawn_dirt()
|
||||||
# Always: This should return state, r, done, info
|
# Always: This should return state
|
||||||
return self.state, r, done, _
|
return self.state
|
||||||
|
|
||||||
def calculate_reward(self, agent_states):
|
def calculate_reward(self, agent_states):
|
||||||
for agent_state in agent_states:
|
for agent_state in agent_states:
|
||||||
|
@ -104,12 +104,12 @@ class GettingDirty(BaseFactory):
|
|||||||
raise RuntimeError('This should not happen!!!')
|
raise RuntimeError('This should not happen!!!')
|
||||||
|
|
||||||
def reset(self) -> (np.ndarray, int, bool, dict):
|
def reset(self) -> (np.ndarray, int, bool, dict):
|
||||||
state, r, done, _ = super().reset() # state, reward, done, info ... =
|
_ = super().reset() # state, reward, done, info ... =
|
||||||
dirt_slice = np.zeros((1, *self.state.shape[1:]))
|
dirt_slice = np.zeros((1, *self.state.shape[1:]))
|
||||||
self.state = np.concatenate((self.state, dirt_slice)) # dirt is now the last slice
|
self.state = np.concatenate((self.state, dirt_slice)) # dirt is now the last slice
|
||||||
self.spawn_dirt()
|
self.spawn_dirt()
|
||||||
self.next_dirt_spawn = self._dirt_properties.spawn_frequency
|
self.next_dirt_spawn = self._dirt_properties.spawn_frequency
|
||||||
return self.state, r, self.done, {}
|
return self.state
|
||||||
|
|
||||||
def calculate_reward(self, agent_states: List[AgentState]) -> (int, dict):
|
def calculate_reward(self, agent_states: List[AgentState]) -> (int, dict):
|
||||||
# TODO: What reward to use?
|
# TODO: What reward to use?
|
||||||
|
Reference in New Issue
Block a user