Merge remote-tracking branch 'origin/AgentState_Object'

This commit is contained in:
steffen-illium 2021-05-14 15:22:18 +02:00
commit 6b1b14fa87
2 changed files with 18 additions and 9 deletions

View File

@ -93,7 +93,7 @@ class BaseFactory:
self.slice_strings = {0: 'level', **{i: f'agent#{i}' for i in range(1, self.n_agents+1)}} self.slice_strings = {0: 'level', **{i: f'agent#{i}' for i in range(1, self.n_agents+1)}}
self.reset() self.reset()
def reset(self): def reset(self) -> (np.ndarray, int, bool, dict):
self.done = False self.done = False
self.steps = 0 self.steps = 0
self.cumulative_reward = 0 self.cumulative_reward = 0

View File

@ -60,12 +60,12 @@ class GettingDirty(BaseFactory):
else: else:
raise RuntimeError('This should not happen!!!') raise RuntimeError('This should not happen!!!')
def reset(self) -> None: def reset(self) -> (np.ndarray, int, bool, dict):
# ToDo: When self.reset returns the new states and stuff, use it here! state, r, done, _ = super().reset() # state, reward, done, info ... =
super().reset() # state, agents, ... =
dirt_slice = np.zeros((1, *self.state.shape[1:])) dirt_slice = np.zeros((1, *self.state.shape[1:]))
self.state = np.concatenate((self.state, dirt_slice)) # dirt is now the last slice self.state = np.concatenate((self.state, dirt_slice)) # dirt is now the last slice
self.spawn_dirt() self.spawn_dirt()
return self.state, r, self.done, {}
def calculate_reward(self, agent_states: List[AgentState]) -> (int, dict): def calculate_reward(self, agent_states: List[AgentState]) -> (int, dict):
this_step_reward = 0 this_step_reward = 0
@ -85,8 +85,17 @@ if __name__ == '__main__':
import random import random
dirt_props = DirtProperties() dirt_props = DirtProperties()
factory = GettingDirty(n_agents=1, dirt_properties=dirt_props) factory = GettingDirty(n_agents=1, dirt_properties=dirt_props)
random_actions = [random.randint(0, 8) for _ in range(2000)] monitor_list = list()
for random_action in random_actions: for epoch in range(100):
state, r, done, _ = factory.step(random_action) random_actions = [random.randint(0, 7) for _ in range(200)]
state, r, done, _ = factory.reset()
for action in random_actions:
state, r, done, info = factory.step(action)
monitor_list.append(factory.monitor)
print(f'Factory run done, reward is:\n {r}') print(f'Factory run done, reward is:\n {r}')
print(f'The following running stats have been recorded:\n{dict(factory.monitor)}') from pathlib import Path
import pickle
out_path = Path('debug_out')
out_path.mkdir(exist_ok=True, parents=True)
with (out_path / 'monitor.pick').open('rb') as f:
pickle.dump(monitor_list, f)