Stable Baseline Running
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
import abc
|
||||
from typing import List, Union, Iterable
|
||||
|
||||
import gym
|
||||
@@ -43,10 +42,6 @@ class BaseFactory(gym.Env):
|
||||
def observation_space(self):
|
||||
return spaces.Box(low=-1, high=1, shape=self.state.shape, dtype=np.float32)
|
||||
|
||||
@property
|
||||
def monitor_as_df_list(self):
|
||||
return [x.to_pd_dataframe() for x in self._monitor_list]
|
||||
|
||||
@property
|
||||
def movement_actions(self):
|
||||
return (int(self.allow_vertical_movement) + int(self.allow_horizontal_movement)) * 4
|
||||
@@ -61,7 +56,6 @@ class BaseFactory(gym.Env):
|
||||
self.allow_vertical_movement = True
|
||||
self.allow_horizontal_movement = True
|
||||
self.allow_no_OP = True
|
||||
self._monitor_list = list()
|
||||
self._registered_actions = self.movement_actions + int(self.allow_no_OP) + self.register_additional_actions()
|
||||
self.level = h.one_hot_level(
|
||||
h.parse_level(Path(__file__).parent / h.LEVELS_DIR / f'{level}.txt')
|
||||
@@ -77,7 +71,6 @@ class BaseFactory(gym.Env):
|
||||
self.steps = 0
|
||||
self.cumulative_reward = 0
|
||||
self.monitor = FactoryMonitor(self)
|
||||
self._monitor_list.append(self.monitor)
|
||||
self.agent_states = []
|
||||
# Agent placement ...
|
||||
agents = np.zeros((self.n_agents, *self.level.shape), dtype=np.int8)
|
||||
@@ -92,7 +85,6 @@ class BaseFactory(gym.Env):
|
||||
# state.shape = level, agent 1,..., agent n,
|
||||
self.state = np.concatenate((np.expand_dims(self.level, axis=0), agents), axis=0)
|
||||
# Returns State
|
||||
|
||||
return self.state
|
||||
|
||||
def additional_actions(self, agent_i: int, action: int) -> ((int, int), bool):
|
||||
@@ -122,11 +114,11 @@ class BaseFactory(gym.Env):
|
||||
|
||||
self.agent_states = states
|
||||
reward, info = self.calculate_reward(states)
|
||||
self.cumulative_reward += reward
|
||||
|
||||
if self.steps >= self.max_steps:
|
||||
self.done = True
|
||||
return self.state, self.cumulative_reward, self.done, info
|
||||
self.monitor.add('step_reward', reward)
|
||||
return self.state, reward, self.done, info
|
||||
|
||||
def _is_moving_action(self, action):
|
||||
return action < self.movement_actions
|
||||
|
||||
Reference in New Issue
Block a user