mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-23 07:16:44 +02:00
Monitor and Agent State Merge
This commit is contained in:
parent
35ef708b20
commit
83d77df216
@ -90,7 +90,7 @@ class BaseFactory:
|
||||
self.slice_strings = {0: 'level', **{i: f'agent#{i}' for i in range(1, self.n_agents+1)}}
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> (np.ndarray, int, bool, dict):
|
||||
self.done = False
|
||||
self.steps = 0
|
||||
self.cumulative_reward = 0
|
||||
|
@ -60,12 +60,12 @@ class GettingDirty(BaseFactory):
|
||||
else:
|
||||
raise RuntimeError('This should not happen!!!')
|
||||
|
||||
def reset(self) -> None:
|
||||
# ToDo: When self.reset returns the new states and stuff, use it here!
|
||||
super().reset() # state, agents, ... =
|
||||
def reset(self) -> (np.ndarray, int, bool, dict):
|
||||
state, r, done, _ = super().reset() # state, reward, done, info ... =
|
||||
dirt_slice = np.zeros((1, *self.state.shape[1:]))
|
||||
self.state = np.concatenate((self.state, dirt_slice)) # dirt is now the last slice
|
||||
self.spawn_dirt()
|
||||
return self.state, r, self.done, {}
|
||||
|
||||
def calculate_reward(self, agent_states: List[AgentState]) -> (int, dict):
|
||||
this_step_reward = 0
|
||||
@ -85,8 +85,17 @@ if __name__ == '__main__':
|
||||
import random
|
||||
dirt_props = DirtProperties()
|
||||
factory = GettingDirty(n_agents=1, dirt_properties=dirt_props)
|
||||
random_actions = [random.randint(0, 8) for _ in range(2000)]
|
||||
for random_action in random_actions:
|
||||
state, r, done, _ = factory.step(random_action)
|
||||
print(f'Factory run done, reward is:\n {r}')
|
||||
print(f'The following running stats have been recorded:\n{dict(factory.monitor)}')
|
||||
monitor_list = list()
|
||||
for epoch in range(100):
|
||||
random_actions = [random.randint(0, 7) for _ in range(200)]
|
||||
state, r, done, _ = factory.reset()
|
||||
for action in random_actions:
|
||||
state, r, done, info = factory.step(action)
|
||||
monitor_list.append(factory.monitor)
|
||||
print(f'Factory run done, reward is:\n {r}')
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
out_path = Path('debug_out')
|
||||
out_path.mkdir(exist_ok=True, parents=True)
|
||||
with (out_path / 'monitor.pick').open('rb') as f:
|
||||
pickle.dump(monitor_list, f)
|
||||
|
Loading…
x
Reference in New Issue
Block a user