diff --git a/environments/factory/base_factory.py b/environments/factory/base_factory.py index 7577bd5..3a1a03a 100644 --- a/environments/factory/base_factory.py +++ b/environments/factory/base_factory.py @@ -1,6 +1,7 @@ from typing import List, Union, Iterable import gym +from gym import spaces import numpy as np from pathlib import Path @@ -34,7 +35,11 @@ class BaseFactory(gym.Env): @property def action_space(self): - return self._registered_actions + return spaces.Discrete(self._registered_actions) + + @property + def observation_space(self): + return spaces.Box(low=-1, high=1, shape=self.state.shape, dtype=np.float32) @property def movement_actions(self): @@ -81,15 +86,15 @@ class BaseFactory(gym.Env): self.agent_states.append(agent_state) # state.shape = level, agent 1,..., agent n, self.state = np.concatenate((np.expand_dims(self.level, axis=0), agents), axis=0) - # Returns State, Reward, Done, Info + # Returns State - return self.state, 0, self.done, {} + return self.state def additional_actions(self, agent_i: int, action: int) -> ((int, int), bool): raise NotImplementedError def step(self, actions): - actions = [actions] if isinstance(actions, int) else actions + actions = [actions] if isinstance(actions, int) or np.isscalar(actions) else actions assert isinstance(actions, Iterable), f'"actions" has to be in [{int, list}]' self.steps += 1 diff --git a/environments/factory/simple_factory.py b/environments/factory/simple_factory.py index b9b743e..b954b43 100644 --- a/environments/factory/simple_factory.py +++ b/environments/factory/simple_factory.py @@ -18,8 +18,8 @@ class SimpleFactory(BaseFactory): dirt_slice = np.zeros((1, *self.state.shape[1:])) self.state = np.concatenate((self.state, dirt_slice)) # dirt is now the last slice self.spawn_dirt() - # Always: This should return state, r, done, info - return self.state, r, done, _ + # Always: This should return state + return self.state def calculate_reward(self, agent_states): for agent_state in agent_states: diff --git a/environments/factory/simple_factory_getting_dirty.py b/environments/factory/simple_factory_getting_dirty.py index b2cc1a3..2906305 100644 --- a/environments/factory/simple_factory_getting_dirty.py +++ b/environments/factory/simple_factory_getting_dirty.py @@ -104,12 +104,12 @@ class GettingDirty(BaseFactory): raise RuntimeError('This should not happen!!!') def reset(self) -> (np.ndarray, int, bool, dict): - state, r, done, _ = super().reset() # state, reward, done, info ... = + _ = super().reset() # state, reward, done, info ... = dirt_slice = np.zeros((1, *self.state.shape[1:])) self.state = np.concatenate((self.state, dirt_slice)) # dirt is now the last slice self.spawn_dirt() self.next_dirt_spawn = self._dirt_properties.spawn_frequency - return self.state, r, self.done, {} + return self.state def calculate_reward(self, agent_states: List[AgentState]) -> (int, dict): # TODO: What reward to use?