diff --git a/environments/factory/base_factory.py b/environments/factory/base_factory.py index 2553245..b33c628 100644 --- a/environments/factory/base_factory.py +++ b/environments/factory/base_factory.py @@ -90,7 +90,7 @@ class BaseFactory: self.slice_strings = {0: 'level', **{i: f'agent#{i}' for i in range(1, self.n_agents+1)}} self.reset() - def reset(self): + def reset(self) -> (np.ndarray, int, bool, dict): self.done = False self.steps = 0 self.cumulative_reward = 0 diff --git a/environments/factory/simple_factory_getting_dirty.py b/environments/factory/simple_factory_getting_dirty.py index 1062091..4490e76 100644 --- a/environments/factory/simple_factory_getting_dirty.py +++ b/environments/factory/simple_factory_getting_dirty.py @@ -60,12 +60,12 @@ class GettingDirty(BaseFactory): else: raise RuntimeError('This should not happen!!!') - def reset(self) -> None: - # ToDo: When self.reset returns the new states and stuff, use it here! - super().reset() # state, agents, ... = + def reset(self) -> (np.ndarray, int, bool, dict): + state, r, done, _ = super().reset() # state, reward, done, info ... = dirt_slice = np.zeros((1, *self.state.shape[1:])) self.state = np.concatenate((self.state, dirt_slice)) # dirt is now the last slice self.spawn_dirt() + return self.state, r, self.done, {} def calculate_reward(self, agent_states: List[AgentState]) -> (int, dict): this_step_reward = 0 @@ -85,8 +85,17 @@ if __name__ == '__main__': import random dirt_props = DirtProperties() factory = GettingDirty(n_agents=1, dirt_properties=dirt_props) - random_actions = [random.randint(0, 8) for _ in range(2000)] - for random_action in random_actions: - state, r, done, _ = factory.step(random_action) - print(f'Factory run done, reward is:\n {r}') - print(f'The following running stats have been recorded:\n{dict(factory.monitor)}') + monitor_list = list() + for epoch in range(100): + random_actions = [random.randint(0, 7) for _ in range(200)] + state, r, done, _ = factory.reset() + for action in random_actions: + state, r, done, info = factory.step(action) + monitor_list.append(factory.monitor) + print(f'Factory run done, reward is:\n {r}') + from pathlib import Path + import pickle + out_path = Path('debug_out') + out_path.mkdir(exist_ok=True, parents=True) + with (out_path / 'monitor.pick').open('rb') as f: + pickle.dump(monitor_list, f)