mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-06-21 11:21:35 +02:00
Merge remote-tracking branch 'origin/main'
This commit is contained in:
@ -1,12 +1,29 @@
|
||||
from environments.factory import make
|
||||
import random
|
||||
import salina
|
||||
import torch
|
||||
from gym.wrappers import FrameStack
|
||||
|
||||
n_agents = 4
|
||||
env = make('DirtyFactory-v0', n_agents=n_agents)
|
||||
env = FrameStack(env, num_stack=3)
|
||||
state, *_ = env.reset()
|
||||
|
||||
for i in range(1000):
|
||||
state, *_ = env.step([env.unwrapped.action_space.sample() for _ in range(n_agents)])
|
||||
env.render()
|
||||
class MyAgent(salina.TAgent):
|
||||
def __init__(self):
|
||||
super(MyAgent, self).__init__()
|
||||
|
||||
def forward(self, t, **kwargs):
|
||||
self.set(('timer', t), torch.tensor([t]))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
n_agents = 1
|
||||
env = make('DirtyFactory-v0', n_agents=n_agents)
|
||||
env = FrameStack(env, num_stack=3)
|
||||
env.reset()
|
||||
agent = MyAgent()
|
||||
workspace = salina.Workspace()
|
||||
agent(workspace, t=0, n_steps=10)
|
||||
|
||||
print(workspace)
|
||||
|
||||
|
||||
for i in range(1000):
|
||||
state, *_ = env.step([env.unwrapped.action_space.sample() for _ in range(n_agents)])
|
||||
#env.render()
|
Reference in New Issue
Block a user