mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-23 15:26:43 +02:00
2 agents running
This commit is contained in:
parent
05c54ad8d8
commit
c1f2ddf3cd
@ -1,4 +1,4 @@
|
|||||||
from typing import List, Union
|
from typing import List, Union, Iterable
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -90,7 +90,7 @@ class BaseFactory(gym.Env):
|
|||||||
|
|
||||||
def step(self, actions):
|
def step(self, actions):
|
||||||
actions = [actions] if isinstance(actions, int) else actions
|
actions = [actions] if isinstance(actions, int) else actions
|
||||||
assert isinstance(actions, list), f'"actions" has to be in [{int, list}]'
|
assert isinstance(actions, Iterable), f'"actions" has to be in [{int, list}]'
|
||||||
self.steps += 1
|
self.steps += 1
|
||||||
|
|
||||||
# Move this in a seperate function?
|
# Move this in a seperate function?
|
||||||
|
@ -106,7 +106,10 @@ class GettingDirty(BaseFactory):
|
|||||||
current_dirt_amount = self.state[DIRT_INDEX].sum()
|
current_dirt_amount = self.state[DIRT_INDEX].sum()
|
||||||
dirty_tiles = len(np.nonzero(self.state[DIRT_INDEX]))
|
dirty_tiles = len(np.nonzero(self.state[DIRT_INDEX]))
|
||||||
|
|
||||||
|
try:
|
||||||
this_step_reward = -(dirty_tiles / current_dirt_amount)
|
this_step_reward = -(dirty_tiles / current_dirt_amount)
|
||||||
|
except ZeroDivisionError:
|
||||||
|
this_step_reward = 0
|
||||||
|
|
||||||
for agent_state in agent_states:
|
for agent_state in agent_states:
|
||||||
collisions = agent_state.collisions
|
collisions = agent_state.collisions
|
||||||
@ -127,10 +130,10 @@ if __name__ == '__main__':
|
|||||||
render = True
|
render = True
|
||||||
|
|
||||||
dirt_props = DirtProperties()
|
dirt_props = DirtProperties()
|
||||||
factory = GettingDirty(n_agents=1, dirt_properties=dirt_props)
|
factory = GettingDirty(n_agents=2, dirt_properties=dirt_props)
|
||||||
monitor_list = list()
|
monitor_list = list()
|
||||||
for epoch in range(100):
|
for epoch in range(100):
|
||||||
random_actions = [random.randint(0, 8) for _ in range(200)]
|
random_actions = [(random.randint(0, 8), random.randint(0, 8)) for _ in range(200)]
|
||||||
env_state, reward, done_bool, _ = factory.reset()
|
env_state, reward, done_bool, _ = factory.reset()
|
||||||
for agent_i_action in random_actions:
|
for agent_i_action in random_actions:
|
||||||
env_state, reward, done_bool, info_obj = factory.step(agent_i_action)
|
env_state, reward, done_bool, info_obj = factory.step(agent_i_action)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user