mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-23 07:16:44 +02:00
deleted policy daptiom, added IAC
This commit is contained in:
parent
0fe90f3ac0
commit
65056b2c61
@ -1,3 +0,0 @@
|
|||||||
from environments.policy_adaption.natural_rl_environment import matting
|
|
||||||
from environments.policy_adaption.natural_rl_environment import imgsource
|
|
||||||
from environments.policy_adaption.natural_rl_environment import natural_env
|
|
@ -1,120 +0,0 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the license found in the
|
|
||||||
# LICENSE file in the root directory of this source tree.
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import cv2
|
|
||||||
import skvideo.io
|
|
||||||
|
|
||||||
|
|
||||||
class ImageSource(object):
|
|
||||||
"""
|
|
||||||
Source of natural images to be added to a simulated environment.
|
|
||||||
"""
|
|
||||||
def get_image(self):
|
|
||||||
"""
|
|
||||||
Returns:
|
|
||||||
an RGB image of [h, w, 3] with a fixed shape.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
""" Called when an episode ends. """
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class FixedColorSource(ImageSource):
|
|
||||||
def __init__(self, shape, color):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
shape: [h, w]
|
|
||||||
color: a 3-tuple
|
|
||||||
"""
|
|
||||||
self.arr = np.zeros((shape[0], shape[1], 3))
|
|
||||||
self.arr[:, :] = color
|
|
||||||
|
|
||||||
def get_image(self):
|
|
||||||
return np.copy(self.arr)
|
|
||||||
|
|
||||||
|
|
||||||
class RandomColorSource(ImageSource):
|
|
||||||
def __init__(self, shape):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
shape: [h, w]
|
|
||||||
"""
|
|
||||||
self.shape = shape
|
|
||||||
self.reset()
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
self._color = np.random.randint(0, 256, size=(3,))
|
|
||||||
|
|
||||||
def get_image(self):
|
|
||||||
arr = np.zeros((self.shape[0], self.shape[1], 3))
|
|
||||||
arr[:, :] = self._color
|
|
||||||
return arr
|
|
||||||
|
|
||||||
|
|
||||||
class NoiseSource(ImageSource):
|
|
||||||
def __init__(self, shape, strength=50):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
shape: [h, w]
|
|
||||||
strength (int): the strength of noise, in range [0, 255]
|
|
||||||
"""
|
|
||||||
self.shape = shape
|
|
||||||
self.strength = strength
|
|
||||||
|
|
||||||
def get_image(self):
|
|
||||||
return np.maximum(np.random.randn(
|
|
||||||
self.shape[0], self.shape[1], 3) * self.strength, 0)
|
|
||||||
|
|
||||||
|
|
||||||
class RandomImageSource(ImageSource):
|
|
||||||
def __init__(self, shape, filelist):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
shape: [h, w]
|
|
||||||
filelist: a list of image files
|
|
||||||
"""
|
|
||||||
self.shape_wh = shape[::-1]
|
|
||||||
self.filelist = filelist
|
|
||||||
self.reset()
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
fname = np.random.choice(self.filelist)
|
|
||||||
im = cv2.imread(fname, cv2.IMREAD_COLOR)
|
|
||||||
im = im[:, :, ::-1]
|
|
||||||
im = cv2.resize(im, self.shape_wh)
|
|
||||||
self._im = im
|
|
||||||
|
|
||||||
def get_image(self):
|
|
||||||
return self._im
|
|
||||||
|
|
||||||
|
|
||||||
class RandomVideoSource(ImageSource):
|
|
||||||
def __init__(self, shape, filelist):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
shape: [h, w]
|
|
||||||
filelist: a list of video files
|
|
||||||
"""
|
|
||||||
self.shape_wh = shape[::-1]
|
|
||||||
self.filelist = filelist
|
|
||||||
self.reset()
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
fname = np.random.choice(self.filelist)
|
|
||||||
self.frames = skvideo.io.vread(fname)
|
|
||||||
self.frame_idx = 0
|
|
||||||
|
|
||||||
def get_image(self):
|
|
||||||
if self.frame_idx >= self.frames.shape[0]:
|
|
||||||
self.reset()
|
|
||||||
im = self.frames[self.frame_idx][:, :, ::-1]
|
|
||||||
self.frame_idx += 1
|
|
||||||
im = im[:, :, ::-1]
|
|
||||||
im = cv2.resize(im, self.shape_wh)
|
|
||||||
return im
|
|
@ -1,32 +0,0 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the license found in the
|
|
||||||
# LICENSE file in the root directory of this source tree.
|
|
||||||
|
|
||||||
|
|
||||||
class BackgroundMatting(object):
|
|
||||||
"""
|
|
||||||
Produce a mask of a given image which will be replaced by natural signals.
|
|
||||||
"""
|
|
||||||
def get_mask(self, img):
|
|
||||||
"""
|
|
||||||
Take an image of [H, W, 3]. Returns a mask of [H, W]
|
|
||||||
"""
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
|
|
||||||
class BackgroundMattingWithColor(BackgroundMatting):
|
|
||||||
"""
|
|
||||||
Produce a mask by masking the given color. This is a simple strategy
|
|
||||||
but effective for many games.
|
|
||||||
"""
|
|
||||||
def __init__(self, color):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
color: a (r, g, b) tuple
|
|
||||||
"""
|
|
||||||
self._color = color
|
|
||||||
|
|
||||||
def get_mask(self, img):
|
|
||||||
return img == self._color
|
|
@ -1,119 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the license found in the
|
|
||||||
# LICENSE file in the root directory of this source tree.
|
|
||||||
|
|
||||||
import os
|
|
||||||
import argparse
|
|
||||||
import glob
|
|
||||||
import gym
|
|
||||||
from gym.utils import play
|
|
||||||
|
|
||||||
from .matting import BackgroundMattingWithColor
|
|
||||||
from .imgsource import (
|
|
||||||
RandomImageSource,
|
|
||||||
RandomColorSource,
|
|
||||||
NoiseSource,
|
|
||||||
RandomVideoSource,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ReplaceBackgroundEnv(gym.ObservationWrapper):
|
|
||||||
|
|
||||||
viewer = None
|
|
||||||
|
|
||||||
def __init__(self, env, bg_matting, natural_source):
|
|
||||||
"""
|
|
||||||
The source must produce a image with a shape that's compatible to
|
|
||||||
`env.observation_space`.
|
|
||||||
"""
|
|
||||||
super(ReplaceBackgroundEnv, self).__init__(env)
|
|
||||||
self._bg_matting = bg_matting
|
|
||||||
self._natural_source = natural_source
|
|
||||||
|
|
||||||
def observation(self, obs):
|
|
||||||
mask = self._bg_matting.get_mask(obs)
|
|
||||||
img = self._natural_source.get_image()
|
|
||||||
obs[mask] = img[mask]
|
|
||||||
self._last_ob = obs
|
|
||||||
return obs
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
self._natural_source.reset()
|
|
||||||
return super(ReplaceBackgroundEnv, self).reset()
|
|
||||||
|
|
||||||
# modified from gym/envs/atari/atari_env.py
|
|
||||||
# This makes the monitor work
|
|
||||||
def render(self, mode="human"):
|
|
||||||
img = self._last_ob
|
|
||||||
if mode == "rgb_array":
|
|
||||||
return img
|
|
||||||
elif mode == "human":
|
|
||||||
from gym.envs.classic_control import rendering
|
|
||||||
|
|
||||||
if self.viewer is None:
|
|
||||||
self.viewer = rendering.SimpleImageViewer()
|
|
||||||
self.viewer.imshow(img)
|
|
||||||
return env.viewer.isopen
|
|
||||||
|
|
||||||
|
|
||||||
def make(name='Pong-v0', imgsource='color', files=None):
|
|
||||||
env = gym.make(name) # gravitar, breakout, MsPacman, Space Invaders
|
|
||||||
shape2d = env.observation_space.shape[:2]
|
|
||||||
color = (0, 0, 0) if 'Pong' not in name else (144, 72, 17)
|
|
||||||
if imgsource == 'video':
|
|
||||||
imgsource = RandomVideoSource(shape2d, ['/Users/romue/PycharmProjects/EDYS/environments/policy_adaption/natural_rl_environment/videos/stars.mp4'])
|
|
||||||
elif imgsource == "color":
|
|
||||||
imgsource = RandomColorSource(shape2d)
|
|
||||||
elif imgsource == "noise":
|
|
||||||
imgsource = NoiseSource(shape2d)
|
|
||||||
elif imgsource == "images":
|
|
||||||
imgsource = RandomImageSource(shape2d, files)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f'{imgsource} is not supported, use one of {{video, color, noise}}')
|
|
||||||
wrapped_env = ReplaceBackgroundEnv(
|
|
||||||
env, BackgroundMattingWithColor(color), imgsource
|
|
||||||
)
|
|
||||||
return wrapped_env
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument("--env", help="The gym environment to base on")
|
|
||||||
parser.add_argument("--imgsource", choices=["color", "noise", "images", "videos"])
|
|
||||||
parser.add_argument(
|
|
||||||
"--resource-files", help="A glob pattern to obtain images or videos"
|
|
||||||
)
|
|
||||||
parser.add_argument("--dump-video", help="If given, a directory to dump video")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
env = gym.make(args.env)
|
|
||||||
shape2d = env.observation_space.shape[:2]
|
|
||||||
|
|
||||||
if args.imgsource:
|
|
||||||
if args.imgsource == "color":
|
|
||||||
imgsource = RandomColorSource(shape2d)
|
|
||||||
elif args.imgsource == "noise":
|
|
||||||
imgsource = NoiseSource(shape2d)
|
|
||||||
else:
|
|
||||||
files = glob.glob(os.path.expanduser(args.resource_files))
|
|
||||||
assert len(files), "Pattern {} does not match any files".format(
|
|
||||||
args.resource_files
|
|
||||||
)
|
|
||||||
if args.imgsource == "images":
|
|
||||||
imgsource = RandomImageSource(shape2d, files)
|
|
||||||
else:
|
|
||||||
imgsource = RandomVideoSource(shape2d, files)
|
|
||||||
|
|
||||||
wrapped_env = ReplaceBackgroundEnv(
|
|
||||||
env, BackgroundMattingWithColor((0, 0, 0)), imgsource
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
wrapped_env = env
|
|
||||||
|
|
||||||
if args.dump_video:
|
|
||||||
assert os.path.isdir(args.dump_video)
|
|
||||||
wrapped_env = gym.wrappers.Monitor(wrapped_env, args.dump_video)
|
|
||||||
play.play(wrapped_env, zoom=4)
|
|
Binary file not shown.
Binary file not shown.
@ -1,8 +0,0 @@
|
|||||||
import gym
|
|
||||||
import glob
|
|
||||||
from environments.policy_adaption.natural_rl_environment.imgsource import *
|
|
||||||
from environments.policy_adaption.natural_rl_environment.natural_env import *
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
env = make('SpaceInvaders-v0', 'video') # gravitar, breakout, MsPacman, Space Invaders
|
|
||||||
play.play(env, zoom=4)
|
|
30
algorithms/utils.py
Normal file
30
algorithms/utils.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
import re
|
||||||
|
import torch
|
||||||
|
import yaml
|
||||||
|
from pathlib import Path
|
||||||
|
from salina import instantiate_class
|
||||||
|
from salina import TAgent
|
||||||
|
|
||||||
|
|
||||||
|
def load_yaml_file(path: Path):
|
||||||
|
with path.open() as stream:
|
||||||
|
cfg = yaml.load(stream, Loader=yaml.FullLoader)
|
||||||
|
return cfg
|
||||||
|
|
||||||
|
|
||||||
|
def add_env_props(cfg):
|
||||||
|
env = instantiate_class(cfg['env'].copy())
|
||||||
|
cfg['agent'].update(dict(observation_size=env.observation_space.shape,
|
||||||
|
n_actions=env.action_space.n))
|
||||||
|
|
||||||
|
|
||||||
|
class CombineActionsAgent(TAgent):
|
||||||
|
def __init__(self, pattern=r'^agent\d_action$'):
|
||||||
|
super().__init__()
|
||||||
|
self.pattern = pattern
|
||||||
|
|
||||||
|
def forward(self, t, **kwargs):
|
||||||
|
keys = list(self.workspace.keys())
|
||||||
|
action_keys = sorted([k for k in keys if bool(re.match(self.pattern, k))])
|
||||||
|
actions = torch.cat([self.get((k, t)) for k in action_keys], 0)
|
||||||
|
self.set((f'action', t), actions.unsqueeze(0))
|
@ -1,4 +1,4 @@
|
|||||||
def make(env_str, n_agents=1, pomdp_r=2, max_steps=400, stack_n_frames=3):
|
def make(env_name, n_agents=1, pomdp_r=2, max_steps=400, stack_n_frames=3):
|
||||||
import yaml
|
import yaml
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from environments.factory.combined_factories import DirtItemFactory
|
from environments.factory.combined_factories import DirtItemFactory
|
||||||
@ -6,7 +6,7 @@ def make(env_str, n_agents=1, pomdp_r=2, max_steps=400, stack_n_frames=3):
|
|||||||
from environments.factory.factory_dirt import DirtProperties, DirtFactory
|
from environments.factory.factory_dirt import DirtProperties, DirtFactory
|
||||||
from environments.utility_classes import MovementProperties, ObservationProperties, AgentRenderOptions
|
from environments.utility_classes import MovementProperties, ObservationProperties, AgentRenderOptions
|
||||||
|
|
||||||
with (Path(__file__).parent / 'levels' / 'parameters' / f'{env_str}.yaml').open('r') as stream:
|
with (Path(__file__).parent / 'levels' / 'parameters' / f'{env_name}.yaml').open('r') as stream:
|
||||||
dictionary = yaml.load(stream, Loader=yaml.FullLoader)
|
dictionary = yaml.load(stream, Loader=yaml.FullLoader)
|
||||||
|
|
||||||
obs_props = ObservationProperties(render_agents=AgentRenderOptions.COMBINED,
|
obs_props = ObservationProperties(render_agents=AgentRenderOptions.COMBINED,
|
||||||
|
@ -1,100 +1,133 @@
|
|||||||
from environments.factory import make
|
from salina.agents.gyma import AutoResetGymAgent
|
||||||
from salina import Workspace, TAgent
|
|
||||||
from salina.agents.gyma import AutoResetGymAgent, GymAgent
|
|
||||||
from salina.agents import Agents, TemporalAgent
|
from salina.agents import Agents, TemporalAgent
|
||||||
from salina.rl.functional import _index
|
from salina.rl.functional import _index, gae
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.nn.utils import spectral_norm
|
|
||||||
import torch.optim as optim
|
|
||||||
from torch.distributions import Categorical
|
from torch.distributions import Categorical
|
||||||
|
from salina import TAgent, Workspace, get_arguments, get_class, instantiate_class
|
||||||
|
from pathlib import Path
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
import time
|
||||||
|
from algorithms.utils import add_env_props, load_yaml_file, CombineActionsAgent
|
||||||
|
|
||||||
|
|
||||||
class A2CAgent(TAgent):
|
class A2CAgent(TAgent):
|
||||||
def __init__(self, observation_size, hidden_size, n_actions):
|
def __init__(self, observation_size, hidden_size, n_actions, agent_id=-1, marl=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
observation_size = np.prod(observation_size)
|
||||||
|
self.agent_id = agent_id
|
||||||
|
self.marl = marl
|
||||||
self.model = nn.Sequential(
|
self.model = nn.Sequential(
|
||||||
nn.Flatten(),
|
nn.Flatten(),
|
||||||
nn.Linear(observation_size, hidden_size),
|
nn.Linear(observation_size, hidden_size),
|
||||||
nn.ELU(),
|
nn.ELU(),
|
||||||
nn.Linear(hidden_size, hidden_size),
|
nn.Linear(hidden_size, hidden_size),
|
||||||
nn.ELU(),
|
nn.ELU(),
|
||||||
nn.Linear(hidden_size, n_actions),
|
nn.Linear(hidden_size, hidden_size),
|
||||||
)
|
nn.ELU()
|
||||||
self.critic_model = nn.Sequential(
|
|
||||||
nn.Flatten(),
|
|
||||||
nn.Linear(observation_size, hidden_size),
|
|
||||||
nn.ELU(),
|
|
||||||
spectral_norm(nn.Linear(hidden_size, 1)),
|
|
||||||
)
|
)
|
||||||
|
self.action_head = nn.Linear(hidden_size, n_actions)
|
||||||
|
self.critic_head = nn.Linear(hidden_size, 1)
|
||||||
|
|
||||||
|
def get_obs(self, t):
|
||||||
|
observation = self.get(("env/env_obs", t))
|
||||||
|
if self.marl:
|
||||||
|
observation = observation.permute(2, 0, 1, 3, 4, 5)
|
||||||
|
observation = observation[self.agent_id]
|
||||||
|
return observation
|
||||||
|
|
||||||
def forward(self, t, stochastic, **kwargs):
|
def forward(self, t, stochastic, **kwargs):
|
||||||
observation = self.get(("env/env_obs", t))
|
observation = self.get_obs(t)
|
||||||
scores = self.model(observation)
|
features = self.model(observation)
|
||||||
|
scores = self.action_head(features)
|
||||||
probs = torch.softmax(scores, dim=-1)
|
probs = torch.softmax(scores, dim=-1)
|
||||||
critic = self.critic_model(observation).squeeze(-1)
|
critic = self.critic_head(features).squeeze(-1)
|
||||||
if stochastic:
|
if stochastic:
|
||||||
action = torch.distributions.Categorical(probs).sample()
|
action = torch.distributions.Categorical(probs).sample()
|
||||||
else:
|
else:
|
||||||
action = probs.argmax(1)
|
action = probs.argmax(1)
|
||||||
|
agent_str = f'agent{self.agent_id}_'
|
||||||
self.set(("action", t), action)
|
self.set((f'{agent_str}action', t), action)
|
||||||
self.set(("action_probs", t), probs)
|
self.set((f'{agent_str}action_probs', t), probs)
|
||||||
self.set(("critic", t), critic)
|
self.set((f'{agent_str}critic', t), critic)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# Setup agents and workspace
|
# Setup workspace
|
||||||
env_agent = AutoResetGymAgent(make, dict(env_str='DirtyFactory-v0'), n_envs=1)
|
uid = time.time()
|
||||||
a2c_agent = A2CAgent(3*4*5*5, 96, 10)
|
|
||||||
workspace = Workspace()
|
workspace = Workspace()
|
||||||
|
n_agents = 1
|
||||||
|
|
||||||
eval_agent = Agents(GymAgent(make, dict(env_str='DirtyFactory-v0'), n_envs=1), a2c_agent)
|
# load config
|
||||||
for i in range(100):
|
cfg = load_yaml_file(Path(__file__).parent / 'sat_mad.yaml')
|
||||||
eval_agent(workspace, t=i, save_render=True, stochastic=True)
|
add_env_props(cfg)
|
||||||
|
cfg['env'].update({'n_agents': n_agents})
|
||||||
|
|
||||||
|
# instantiate agent and env
|
||||||
|
env_agent = AutoResetGymAgent(
|
||||||
|
get_class(cfg['env']),
|
||||||
|
get_arguments(cfg['env']),
|
||||||
|
n_envs=1
|
||||||
|
)
|
||||||
|
|
||||||
|
a2c_agents = [instantiate_class({**cfg['agent'],
|
||||||
|
'agent_id': agent_id,
|
||||||
|
'marl': n_agents > 1})
|
||||||
|
for agent_id in range(n_agents)]
|
||||||
|
|
||||||
assert False
|
|
||||||
# combine agents
|
# combine agents
|
||||||
acquisition_agent = TemporalAgent(Agents(env_agent, a2c_agent))
|
acquisition_agent = TemporalAgent(Agents(env_agent, *a2c_agents, CombineActionsAgent()))
|
||||||
acquisition_agent.seed(0)
|
acquisition_agent.seed(69)
|
||||||
|
|
||||||
# optimizers & other parameters
|
# optimizers & other parameters
|
||||||
optimizer = optim.Adam(a2c_agent.parameters(), lr=1e-3)
|
cfg_optim = cfg['algorithm']['optimizer']
|
||||||
n_timesteps = 10
|
optimizers = [get_class(cfg_optim)(a2c_agent.parameters(), **get_arguments(cfg_optim))
|
||||||
|
for a2c_agent in a2c_agents]
|
||||||
|
n_timesteps = cfg['algorithm']['n_timesteps']
|
||||||
|
|
||||||
# Decision making loop
|
# Decision making loop
|
||||||
for epoch in range(200000):
|
best = -float('inf')
|
||||||
workspace.zero_grad()
|
with tqdm(range(int(cfg['algorithm']['max_epochs'] / n_timesteps))) as pbar:
|
||||||
if epoch > 0:
|
for epoch in pbar:
|
||||||
workspace.copy_n_last_steps(1)
|
workspace.zero_grad()
|
||||||
acquisition_agent(workspace, t=1, n_steps=n_timesteps-1, stochastic=True)
|
if epoch > 0:
|
||||||
else:
|
workspace.copy_n_last_steps(1)
|
||||||
acquisition_agent(workspace, t=0, n_steps=n_timesteps, stochastic=True)
|
acquisition_agent(workspace, t=1, n_steps=n_timesteps-1, stochastic=True)
|
||||||
#for k in workspace.keys():
|
else:
|
||||||
# print(f'{k} ==> {workspace[k].size()}')
|
acquisition_agent(workspace, t=0, n_steps=n_timesteps, stochastic=True)
|
||||||
critic, done, action_probs, reward, action = workspace[
|
|
||||||
"critic", "env/done", "action_probs", "env/reward", "action"
|
|
||||||
]
|
|
||||||
|
|
||||||
target = reward[1:] + 0.99 * critic[1:].detach() * (1 - done[1:].float())
|
for agent_id in range(n_agents):
|
||||||
td = target - critic[:-1]
|
critic, done, action_probs, reward, action = workspace[
|
||||||
td_error = td ** 2
|
f"agent{agent_id}_critic", "env/done",
|
||||||
critic_loss = td_error.mean()
|
f'agent{agent_id}_action_probs', "env/reward",
|
||||||
entropy_loss = Categorical(action_probs).entropy().mean()
|
f"agent{agent_id}_action"
|
||||||
action_logp = _index(action_probs, action).log()
|
]
|
||||||
a2c_loss = action_logp[:-1] * td.detach()
|
td = gae(critic, reward, done, 0.99, 0.3)
|
||||||
a2c_loss = a2c_loss.mean()
|
td_error = td ** 2
|
||||||
loss = (
|
critic_loss = td_error.mean()
|
||||||
-0.001 * entropy_loss
|
entropy_loss = Categorical(action_probs).entropy().mean()
|
||||||
+ 1.0 * critic_loss
|
action_logp = _index(action_probs, action).log()
|
||||||
- 0.1 * a2c_loss
|
a2c_loss = action_logp[:-1] * td.detach()
|
||||||
)
|
a2c_loss = a2c_loss.mean()
|
||||||
optimizer.zero_grad()
|
loss = (
|
||||||
loss.backward()
|
-0.001 * entropy_loss
|
||||||
optimizer.step()
|
+ 1.0 * critic_loss
|
||||||
|
- 0.1 * a2c_loss
|
||||||
|
)
|
||||||
|
optimizer = optimizers[agent_id]
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
#torch.nn.utils.clip_grad_norm_(a2c_agents[agent_id].parameters(), 2)
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
# Compute the cumulated reward on final_state
|
||||||
|
creward = workspace["env/cumulated_reward"]
|
||||||
|
creward = creward[done]
|
||||||
|
if creward.size()[0] > 0:
|
||||||
|
cum_r = creward.mean().item()
|
||||||
|
if cum_r > best:
|
||||||
|
# torch.save(a2c_agent.state_dict(), Path(__file__).parent / f'agent_{uid}.pt')
|
||||||
|
best = cum_r
|
||||||
|
pbar.set_description(f"Cum. r: {cum_r:.2f}, Best r. so far: {best:.2f}", refresh=True)
|
||||||
|
|
||||||
# Compute the cumulated reward on final_state
|
|
||||||
creward = workspace["env/cumulated_reward"]
|
|
||||||
creward = creward[done]
|
|
||||||
if creward.size()[0] > 0:
|
|
||||||
print(f"Cumulative reward at A2C step #{(1+epoch)*n_timesteps}: {creward.mean().item()}")
|
|
26
studies/sat_mad.yaml
Normal file
26
studies/sat_mad.yaml
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
agent:
|
||||||
|
classname: studies.sat_mad.A2CAgent
|
||||||
|
observation_size: 4*5*5
|
||||||
|
hidden_size: 128
|
||||||
|
n_actions: 10
|
||||||
|
|
||||||
|
env:
|
||||||
|
classname: environments.factory.make
|
||||||
|
env_name: "DirtyFactory-v0"
|
||||||
|
n_agents: 1
|
||||||
|
pomdp_r: 2
|
||||||
|
max_steps: 400
|
||||||
|
stack_n_frames: 3
|
||||||
|
|
||||||
|
algorithm:
|
||||||
|
max_epochs: 1000000
|
||||||
|
n_envs: 1
|
||||||
|
n_timesteps: 16
|
||||||
|
discount_factor: 0.99
|
||||||
|
entropy_coef: 0.01
|
||||||
|
critic_coef: 1.0
|
||||||
|
gae: 0.3
|
||||||
|
optimizer:
|
||||||
|
classname: torch.optim.Adam
|
||||||
|
lr: 0.0003
|
||||||
|
weight_decay: 0.0
|
39
studies/viz_salina.py
Normal file
39
studies/viz_salina.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
from salina.agents import Agents, TemporalAgent
|
||||||
|
import torch
|
||||||
|
from salina import Workspace, get_arguments, get_class, instantiate_class
|
||||||
|
from pathlib import Path
|
||||||
|
from salina.agents.gyma import GymAgent
|
||||||
|
import time
|
||||||
|
from algorithms.utils import load_yaml_file, add_env_props
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# Setup workspace
|
||||||
|
uid = time.time()
|
||||||
|
workspace = Workspace()
|
||||||
|
weights = Path('/Users/romue/PycharmProjects/EDYS/studies/agent_1636994369.145843.pt')
|
||||||
|
|
||||||
|
cfg = load_yaml_file(Path(__file__).parent / 'sat_mad.yaml')
|
||||||
|
add_env_props(cfg)
|
||||||
|
cfg['env'].update({'n_agents': 2})
|
||||||
|
|
||||||
|
# instantiate agent and env
|
||||||
|
env_agent = GymAgent(
|
||||||
|
get_class(cfg['env']),
|
||||||
|
get_arguments(cfg['env']),
|
||||||
|
n_envs=1
|
||||||
|
)
|
||||||
|
|
||||||
|
agents = []
|
||||||
|
for _ in range(2):
|
||||||
|
a2c_agent = instantiate_class(cfg['agent'])
|
||||||
|
if weights:
|
||||||
|
a2c_agent.load_state_dict(torch.load(weights))
|
||||||
|
agents.append(a2c_agent)
|
||||||
|
|
||||||
|
# combine agents
|
||||||
|
acquisition_agent = TemporalAgent(Agents(env_agent, *agents))
|
||||||
|
acquisition_agent.seed(42)
|
||||||
|
|
||||||
|
acquisition_agent(workspace, t=0, n_steps=400, stochastic=False, save_render=True)
|
||||||
|
|
||||||
|
|
Loading…
x
Reference in New Issue
Block a user