mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-23 15:26:43 +02:00
Merge remote-tracking branch 'origin/AgentState_Object'
This commit is contained in:
commit
6b1b14fa87
@ -93,7 +93,7 @@ class BaseFactory:
|
|||||||
self.slice_strings = {0: 'level', **{i: f'agent#{i}' for i in range(1, self.n_agents+1)}}
|
self.slice_strings = {0: 'level', **{i: f'agent#{i}' for i in range(1, self.n_agents+1)}}
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def reset(self):
|
def reset(self) -> (np.ndarray, int, bool, dict):
|
||||||
self.done = False
|
self.done = False
|
||||||
self.steps = 0
|
self.steps = 0
|
||||||
self.cumulative_reward = 0
|
self.cumulative_reward = 0
|
||||||
|
@ -60,12 +60,12 @@ class GettingDirty(BaseFactory):
|
|||||||
else:
|
else:
|
||||||
raise RuntimeError('This should not happen!!!')
|
raise RuntimeError('This should not happen!!!')
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> (np.ndarray, int, bool, dict):
|
||||||
# ToDo: When self.reset returns the new states and stuff, use it here!
|
state, r, done, _ = super().reset() # state, reward, done, info ... =
|
||||||
super().reset() # state, agents, ... =
|
|
||||||
dirt_slice = np.zeros((1, *self.state.shape[1:]))
|
dirt_slice = np.zeros((1, *self.state.shape[1:]))
|
||||||
self.state = np.concatenate((self.state, dirt_slice)) # dirt is now the last slice
|
self.state = np.concatenate((self.state, dirt_slice)) # dirt is now the last slice
|
||||||
self.spawn_dirt()
|
self.spawn_dirt()
|
||||||
|
return self.state, r, self.done, {}
|
||||||
|
|
||||||
def calculate_reward(self, agent_states: List[AgentState]) -> (int, dict):
|
def calculate_reward(self, agent_states: List[AgentState]) -> (int, dict):
|
||||||
this_step_reward = 0
|
this_step_reward = 0
|
||||||
@ -85,8 +85,17 @@ if __name__ == '__main__':
|
|||||||
import random
|
import random
|
||||||
dirt_props = DirtProperties()
|
dirt_props = DirtProperties()
|
||||||
factory = GettingDirty(n_agents=1, dirt_properties=dirt_props)
|
factory = GettingDirty(n_agents=1, dirt_properties=dirt_props)
|
||||||
random_actions = [random.randint(0, 8) for _ in range(2000)]
|
monitor_list = list()
|
||||||
for random_action in random_actions:
|
for epoch in range(100):
|
||||||
state, r, done, _ = factory.step(random_action)
|
random_actions = [random.randint(0, 7) for _ in range(200)]
|
||||||
|
state, r, done, _ = factory.reset()
|
||||||
|
for action in random_actions:
|
||||||
|
state, r, done, info = factory.step(action)
|
||||||
|
monitor_list.append(factory.monitor)
|
||||||
print(f'Factory run done, reward is:\n {r}')
|
print(f'Factory run done, reward is:\n {r}')
|
||||||
print(f'The following running stats have been recorded:\n{dict(factory.monitor)}')
|
from pathlib import Path
|
||||||
|
import pickle
|
||||||
|
out_path = Path('debug_out')
|
||||||
|
out_path.mkdir(exist_ok=True, parents=True)
|
||||||
|
with (out_path / 'monitor.pick').open('rb') as f:
|
||||||
|
pickle.dump(monitor_list, f)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user