mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-23 07:16:44 +02:00
2 agents running
This commit is contained in:
parent
05c54ad8d8
commit
c1f2ddf3cd
@ -1,4 +1,4 @@
|
||||
from typing import List, Union
|
||||
from typing import List, Union, Iterable
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
@ -90,7 +90,7 @@ class BaseFactory(gym.Env):
|
||||
|
||||
def step(self, actions):
|
||||
actions = [actions] if isinstance(actions, int) else actions
|
||||
assert isinstance(actions, list), f'"actions" has to be in [{int, list}]'
|
||||
assert isinstance(actions, Iterable), f'"actions" has to be in [{int, list}]'
|
||||
self.steps += 1
|
||||
|
||||
# Move this in a seperate function?
|
||||
|
@ -106,7 +106,10 @@ class GettingDirty(BaseFactory):
|
||||
current_dirt_amount = self.state[DIRT_INDEX].sum()
|
||||
dirty_tiles = len(np.nonzero(self.state[DIRT_INDEX]))
|
||||
|
||||
try:
|
||||
this_step_reward = -(dirty_tiles / current_dirt_amount)
|
||||
except ZeroDivisionError:
|
||||
this_step_reward = 0
|
||||
|
||||
for agent_state in agent_states:
|
||||
collisions = agent_state.collisions
|
||||
@ -127,10 +130,10 @@ if __name__ == '__main__':
|
||||
render = True
|
||||
|
||||
dirt_props = DirtProperties()
|
||||
factory = GettingDirty(n_agents=1, dirt_properties=dirt_props)
|
||||
factory = GettingDirty(n_agents=2, dirt_properties=dirt_props)
|
||||
monitor_list = list()
|
||||
for epoch in range(100):
|
||||
random_actions = [random.randint(0, 8) for _ in range(200)]
|
||||
random_actions = [(random.randint(0, 8), random.randint(0, 8)) for _ in range(200)]
|
||||
env_state, reward, done_bool, _ = factory.reset()
|
||||
for agent_i_action in random_actions:
|
||||
env_state, reward, done_bool, info_obj = factory.step(agent_i_action)
|
||||
|
Loading…
x
Reference in New Issue
Block a user