mirror of
				https://github.com/illiumst/marl-factory-grid.git
				synced 2025-10-31 04:37:25 +01:00 
			
		
		
		
	deleted policy daptiom, added IAC
This commit is contained in:
		| @@ -1,3 +0,0 @@ | ||||
| from environments.policy_adaption.natural_rl_environment import matting | ||||
| from environments.policy_adaption.natural_rl_environment import imgsource | ||||
| from environments.policy_adaption.natural_rl_environment import natural_env | ||||
| @@ -1,120 +0,0 @@ | ||||
| # Copyright (c) Facebook, Inc. and its affiliates. | ||||
| # All rights reserved. | ||||
| # | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
| import numpy as np | ||||
| import cv2 | ||||
| import skvideo.io | ||||
|  | ||||
|  | ||||
| class ImageSource(object): | ||||
|     """ | ||||
|     Source of natural images to be added to a simulated environment. | ||||
|     """ | ||||
|     def get_image(self): | ||||
|         """ | ||||
|         Returns: | ||||
|             an RGB image of [h, w, 3] with a fixed shape. | ||||
|         """ | ||||
|         pass | ||||
|  | ||||
|     def reset(self): | ||||
|         """ Called when an episode ends. """ | ||||
|         pass | ||||
|  | ||||
|  | ||||
| class FixedColorSource(ImageSource): | ||||
|     def __init__(self, shape, color): | ||||
|         """ | ||||
|         Args: | ||||
|             shape: [h, w] | ||||
|             color: a 3-tuple | ||||
|         """ | ||||
|         self.arr = np.zeros((shape[0], shape[1], 3)) | ||||
|         self.arr[:, :] = color | ||||
|  | ||||
|     def get_image(self): | ||||
|         return np.copy(self.arr) | ||||
|  | ||||
|  | ||||
| class RandomColorSource(ImageSource): | ||||
|     def __init__(self, shape): | ||||
|         """ | ||||
|         Args: | ||||
|             shape: [h, w] | ||||
|         """ | ||||
|         self.shape = shape | ||||
|         self.reset() | ||||
|  | ||||
|     def reset(self): | ||||
|         self._color = np.random.randint(0, 256, size=(3,)) | ||||
|  | ||||
|     def get_image(self): | ||||
|         arr = np.zeros((self.shape[0], self.shape[1], 3)) | ||||
|         arr[:, :] = self._color | ||||
|         return arr | ||||
|  | ||||
|  | ||||
| class NoiseSource(ImageSource): | ||||
|     def __init__(self, shape, strength=50): | ||||
|         """ | ||||
|         Args: | ||||
|             shape: [h, w] | ||||
|             strength (int): the strength of noise, in range [0, 255] | ||||
|         """ | ||||
|         self.shape = shape | ||||
|         self.strength = strength | ||||
|  | ||||
|     def get_image(self): | ||||
|         return np.maximum(np.random.randn( | ||||
|             self.shape[0], self.shape[1], 3) * self.strength, 0) | ||||
|  | ||||
|  | ||||
| class RandomImageSource(ImageSource): | ||||
|     def __init__(self, shape, filelist): | ||||
|         """ | ||||
|         Args: | ||||
|             shape: [h, w] | ||||
|             filelist: a list of image files | ||||
|         """ | ||||
|         self.shape_wh = shape[::-1] | ||||
|         self.filelist = filelist | ||||
|         self.reset() | ||||
|  | ||||
|     def reset(self): | ||||
|         fname = np.random.choice(self.filelist) | ||||
|         im = cv2.imread(fname, cv2.IMREAD_COLOR) | ||||
|         im = im[:, :, ::-1] | ||||
|         im = cv2.resize(im, self.shape_wh) | ||||
|         self._im = im | ||||
|  | ||||
|     def get_image(self): | ||||
|         return self._im | ||||
|  | ||||
|  | ||||
| class RandomVideoSource(ImageSource): | ||||
|     def __init__(self, shape, filelist): | ||||
|         """ | ||||
|         Args: | ||||
|             shape: [h, w] | ||||
|             filelist: a list of video files | ||||
|         """ | ||||
|         self.shape_wh = shape[::-1] | ||||
|         self.filelist = filelist | ||||
|         self.reset() | ||||
|  | ||||
|     def reset(self): | ||||
|         fname = np.random.choice(self.filelist) | ||||
|         self.frames = skvideo.io.vread(fname) | ||||
|         self.frame_idx = 0 | ||||
|  | ||||
|     def get_image(self): | ||||
|         if self.frame_idx >= self.frames.shape[0]: | ||||
|             self.reset() | ||||
|         im = self.frames[self.frame_idx][:, :, ::-1] | ||||
|         self.frame_idx += 1 | ||||
|         im = im[:, :, ::-1] | ||||
|         im = cv2.resize(im, self.shape_wh) | ||||
|         return im | ||||
| @@ -1,32 +0,0 @@ | ||||
| # Copyright (c) Facebook, Inc. and its affiliates. | ||||
| # All rights reserved. | ||||
| # | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
|  | ||||
| class BackgroundMatting(object): | ||||
|     """ | ||||
|     Produce a mask of a given image which will be replaced by natural signals. | ||||
|     """ | ||||
|     def get_mask(self, img): | ||||
|         """ | ||||
|         Take an image of [H, W, 3]. Returns a mask of [H, W] | ||||
|         """ | ||||
|         raise NotImplementedError() | ||||
|  | ||||
|  | ||||
| class BackgroundMattingWithColor(BackgroundMatting): | ||||
|     """ | ||||
|     Produce a mask by masking the given color. This is a simple strategy | ||||
|     but effective for many games. | ||||
|     """ | ||||
|     def __init__(self, color): | ||||
|         """ | ||||
|         Args: | ||||
|             color: a (r, g, b) tuple | ||||
|         """ | ||||
|         self._color = color | ||||
|  | ||||
|     def get_mask(self, img): | ||||
|         return img == self._color | ||||
| @@ -1,119 +0,0 @@ | ||||
| #!/usr/bin/env python | ||||
| # Copyright (c) Facebook, Inc. and its affiliates. | ||||
| # All rights reserved. | ||||
| # | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
| import os | ||||
| import argparse | ||||
| import glob | ||||
| import gym | ||||
| from gym.utils import play | ||||
|  | ||||
| from .matting import BackgroundMattingWithColor | ||||
| from .imgsource import ( | ||||
|     RandomImageSource, | ||||
|     RandomColorSource, | ||||
|     NoiseSource, | ||||
|     RandomVideoSource, | ||||
| ) | ||||
|  | ||||
|  | ||||
| class ReplaceBackgroundEnv(gym.ObservationWrapper): | ||||
|  | ||||
|     viewer = None | ||||
|  | ||||
|     def __init__(self, env, bg_matting, natural_source): | ||||
|         """ | ||||
|         The source must produce a image with a shape that's compatible to | ||||
|         `env.observation_space`. | ||||
|         """ | ||||
|         super(ReplaceBackgroundEnv, self).__init__(env) | ||||
|         self._bg_matting = bg_matting | ||||
|         self._natural_source = natural_source | ||||
|  | ||||
|     def observation(self, obs): | ||||
|         mask = self._bg_matting.get_mask(obs) | ||||
|         img = self._natural_source.get_image() | ||||
|         obs[mask] = img[mask] | ||||
|         self._last_ob = obs | ||||
|         return obs | ||||
|  | ||||
|     def reset(self): | ||||
|         self._natural_source.reset() | ||||
|         return super(ReplaceBackgroundEnv, self).reset() | ||||
|  | ||||
|     # modified from gym/envs/atari/atari_env.py | ||||
|     # This makes the monitor work | ||||
|     def render(self, mode="human"): | ||||
|         img = self._last_ob | ||||
|         if mode == "rgb_array": | ||||
|             return img | ||||
|         elif mode == "human": | ||||
|             from gym.envs.classic_control import rendering | ||||
|  | ||||
|             if self.viewer is None: | ||||
|                 self.viewer = rendering.SimpleImageViewer() | ||||
|             self.viewer.imshow(img) | ||||
|             return env.viewer.isopen | ||||
|  | ||||
|  | ||||
| def make(name='Pong-v0', imgsource='color', files=None): | ||||
|     env = gym.make(name)  # gravitar, breakout, MsPacman, Space Invaders | ||||
|     shape2d = env.observation_space.shape[:2] | ||||
|     color = (0, 0, 0) if 'Pong' not in name else (144, 72, 17) | ||||
|     if imgsource == 'video': | ||||
|         imgsource = RandomVideoSource(shape2d, ['/Users/romue/PycharmProjects/EDYS/environments/policy_adaption/natural_rl_environment/videos/stars.mp4']) | ||||
|     elif imgsource == "color": | ||||
|         imgsource = RandomColorSource(shape2d) | ||||
|     elif imgsource == "noise": | ||||
|         imgsource = NoiseSource(shape2d) | ||||
|     elif imgsource == "images": | ||||
|         imgsource = RandomImageSource(shape2d, files) | ||||
|     else: | ||||
|         raise NotImplementedError(f'{imgsource} is not supported, use one of {{video, color, noise}}') | ||||
|     wrapped_env = ReplaceBackgroundEnv( | ||||
|         env, BackgroundMattingWithColor(color), imgsource | ||||
|     ) | ||||
|     return wrapped_env | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser() | ||||
|     parser.add_argument("--env", help="The gym environment to base on") | ||||
|     parser.add_argument("--imgsource", choices=["color", "noise", "images", "videos"]) | ||||
|     parser.add_argument( | ||||
|         "--resource-files", help="A glob pattern to obtain images or videos" | ||||
|     ) | ||||
|     parser.add_argument("--dump-video", help="If given, a directory to dump video") | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     env = gym.make(args.env) | ||||
|     shape2d = env.observation_space.shape[:2] | ||||
|  | ||||
|     if args.imgsource: | ||||
|         if args.imgsource == "color": | ||||
|             imgsource = RandomColorSource(shape2d) | ||||
|         elif args.imgsource == "noise": | ||||
|             imgsource = NoiseSource(shape2d) | ||||
|         else: | ||||
|             files = glob.glob(os.path.expanduser(args.resource_files)) | ||||
|             assert len(files), "Pattern {} does not match any files".format( | ||||
|                 args.resource_files | ||||
|             ) | ||||
|             if args.imgsource == "images": | ||||
|                 imgsource = RandomImageSource(shape2d, files) | ||||
|             else: | ||||
|                 imgsource = RandomVideoSource(shape2d, files) | ||||
|  | ||||
|         wrapped_env = ReplaceBackgroundEnv( | ||||
|             env, BackgroundMattingWithColor((0, 0, 0)), imgsource | ||||
|         ) | ||||
|     else: | ||||
|         wrapped_env = env | ||||
|  | ||||
|     if args.dump_video: | ||||
|         assert os.path.isdir(args.dump_video) | ||||
|         wrapped_env = gym.wrappers.Monitor(wrapped_env, args.dump_video) | ||||
|     play.play(wrapped_env, zoom=4) | ||||
										
											Binary file not shown.
										
									
								
							
										
											Binary file not shown.
										
									
								
							| @@ -1,8 +0,0 @@ | ||||
| import gym | ||||
| import glob | ||||
| from environments.policy_adaption.natural_rl_environment.imgsource import * | ||||
| from environments.policy_adaption.natural_rl_environment.natural_env import * | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     env = make('SpaceInvaders-v0', 'video')  # gravitar, breakout, MsPacman, Space Invaders | ||||
|     play.play(env, zoom=4) | ||||
							
								
								
									
										30
									
								
								algorithms/utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										30
									
								
								algorithms/utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,30 @@ | ||||
| import re | ||||
| import torch | ||||
| import yaml | ||||
| from pathlib import Path | ||||
| from salina import instantiate_class | ||||
| from salina import TAgent | ||||
|  | ||||
|  | ||||
| def load_yaml_file(path: Path): | ||||
|     with path.open() as stream: | ||||
|         cfg = yaml.load(stream, Loader=yaml.FullLoader) | ||||
|     return cfg | ||||
|  | ||||
|  | ||||
| def add_env_props(cfg): | ||||
|     env = instantiate_class(cfg['env'].copy()) | ||||
|     cfg['agent'].update(dict(observation_size=env.observation_space.shape, | ||||
|                              n_actions=env.action_space.n)) | ||||
|  | ||||
|  | ||||
| class CombineActionsAgent(TAgent): | ||||
|     def __init__(self, pattern=r'^agent\d_action$'): | ||||
|         super().__init__() | ||||
|         self.pattern = pattern | ||||
|  | ||||
|     def forward(self, t, **kwargs): | ||||
|         keys = list(self.workspace.keys()) | ||||
|         action_keys = sorted([k for k in keys if bool(re.match(self.pattern, k))]) | ||||
|         actions = torch.cat([self.get((k, t)) for k in action_keys], 0) | ||||
|         self.set((f'action', t), actions.unsqueeze(0)) | ||||
| @@ -1,4 +1,4 @@ | ||||
| def make(env_str, n_agents=1, pomdp_r=2, max_steps=400, stack_n_frames=3): | ||||
| def make(env_name, n_agents=1, pomdp_r=2, max_steps=400, stack_n_frames=3): | ||||
|     import yaml | ||||
|     from pathlib import Path | ||||
|     from environments.factory.combined_factories import DirtItemFactory | ||||
| @@ -6,7 +6,7 @@ def make(env_str, n_agents=1, pomdp_r=2, max_steps=400, stack_n_frames=3): | ||||
|     from environments.factory.factory_dirt import DirtProperties, DirtFactory | ||||
|     from environments.utility_classes import MovementProperties, ObservationProperties, AgentRenderOptions | ||||
|  | ||||
|     with (Path(__file__).parent / 'levels' / 'parameters' / f'{env_str}.yaml').open('r') as stream: | ||||
|     with (Path(__file__).parent / 'levels' / 'parameters' / f'{env_name}.yaml').open('r') as stream: | ||||
|         dictionary = yaml.load(stream, Loader=yaml.FullLoader) | ||||
|  | ||||
|     obs_props = ObservationProperties(render_agents=AgentRenderOptions.COMBINED, | ||||
|   | ||||
| @@ -1,100 +1,133 @@ | ||||
| from environments.factory import make | ||||
| from salina import Workspace, TAgent | ||||
| from salina.agents.gyma import AutoResetGymAgent, GymAgent | ||||
| from salina.agents.gyma import AutoResetGymAgent | ||||
| from salina.agents import Agents, TemporalAgent | ||||
| from salina.rl.functional import _index | ||||
| from salina.rl.functional import _index, gae | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from torch.nn.utils import spectral_norm | ||||
| import torch.optim as optim | ||||
| from torch.distributions import Categorical | ||||
| from salina import TAgent, Workspace, get_arguments, get_class, instantiate_class | ||||
| from pathlib import Path | ||||
| import numpy as np | ||||
| from tqdm import tqdm | ||||
| import time | ||||
| from algorithms.utils import add_env_props, load_yaml_file, CombineActionsAgent | ||||
|  | ||||
|  | ||||
| class A2CAgent(TAgent): | ||||
|     def __init__(self, observation_size, hidden_size, n_actions): | ||||
|     def __init__(self, observation_size, hidden_size, n_actions, agent_id=-1, marl=False): | ||||
|         super().__init__() | ||||
|         observation_size = np.prod(observation_size) | ||||
|         self.agent_id = agent_id | ||||
|         self.marl = marl | ||||
|         self.model = nn.Sequential( | ||||
|             nn.Flatten(), | ||||
|             nn.Linear(observation_size, hidden_size), | ||||
|             nn.ELU(), | ||||
|             nn.Linear(hidden_size, hidden_size), | ||||
|             nn.ELU(), | ||||
|             nn.Linear(hidden_size, n_actions), | ||||
|         ) | ||||
|         self.critic_model = nn.Sequential( | ||||
|             nn.Flatten(), | ||||
|             nn.Linear(observation_size, hidden_size), | ||||
|             nn.ELU(), | ||||
|             spectral_norm(nn.Linear(hidden_size, 1)), | ||||
|             nn.Linear(hidden_size, hidden_size), | ||||
|             nn.ELU() | ||||
|         ) | ||||
|         self.action_head = nn.Linear(hidden_size, n_actions) | ||||
|         self.critic_head = nn.Linear(hidden_size, 1) | ||||
|  | ||||
|     def get_obs(self, t): | ||||
|         observation = self.get(("env/env_obs", t)) | ||||
|         if self.marl: | ||||
|             observation = observation.permute(2, 0, 1, 3, 4, 5) | ||||
|             observation = observation[self.agent_id] | ||||
|         return observation | ||||
|  | ||||
|     def forward(self, t, stochastic, **kwargs): | ||||
|         observation = self.get(("env/env_obs", t)) | ||||
|         scores = self.model(observation) | ||||
|         observation = self.get_obs(t) | ||||
|         features = self.model(observation) | ||||
|         scores = self.action_head(features) | ||||
|         probs = torch.softmax(scores, dim=-1) | ||||
|         critic = self.critic_model(observation).squeeze(-1) | ||||
|         critic = self.critic_head(features).squeeze(-1) | ||||
|         if stochastic: | ||||
|             action = torch.distributions.Categorical(probs).sample() | ||||
|         else: | ||||
|             action = probs.argmax(1) | ||||
|  | ||||
|         self.set(("action", t), action) | ||||
|         self.set(("action_probs", t), probs) | ||||
|         self.set(("critic", t), critic) | ||||
|         agent_str = f'agent{self.agent_id}_' | ||||
|         self.set((f'{agent_str}action', t), action) | ||||
|         self.set((f'{agent_str}action_probs', t), probs) | ||||
|         self.set((f'{agent_str}critic', t), critic) | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     # Setup agents and workspace | ||||
|     env_agent = AutoResetGymAgent(make, dict(env_str='DirtyFactory-v0'), n_envs=1) | ||||
|     a2c_agent = A2CAgent(3*4*5*5, 96, 10) | ||||
|     # Setup workspace | ||||
|     uid = time.time() | ||||
|     workspace = Workspace() | ||||
|     n_agents = 1 | ||||
|  | ||||
|     eval_agent = Agents(GymAgent(make, dict(env_str='DirtyFactory-v0'), n_envs=1), a2c_agent) | ||||
|     for i in range(100): | ||||
|         eval_agent(workspace, t=i, save_render=True, stochastic=True) | ||||
|     # load config | ||||
|     cfg = load_yaml_file(Path(__file__).parent / 'sat_mad.yaml') | ||||
|     add_env_props(cfg) | ||||
|     cfg['env'].update({'n_agents': n_agents}) | ||||
|  | ||||
|     # instantiate agent and env | ||||
|     env_agent = AutoResetGymAgent( | ||||
|         get_class(cfg['env']), | ||||
|         get_arguments(cfg['env']), | ||||
|         n_envs=1 | ||||
|     ) | ||||
|  | ||||
|     a2c_agents = [instantiate_class({**cfg['agent'], | ||||
|                                      'agent_id': agent_id, | ||||
|                                      'marl':     n_agents > 1}) | ||||
|                   for agent_id in range(n_agents)] | ||||
|  | ||||
|     assert False | ||||
|     # combine agents | ||||
|     acquisition_agent = TemporalAgent(Agents(env_agent, a2c_agent)) | ||||
|     acquisition_agent.seed(0) | ||||
|     acquisition_agent = TemporalAgent(Agents(env_agent, *a2c_agents, CombineActionsAgent())) | ||||
|     acquisition_agent.seed(69) | ||||
|  | ||||
|     # optimizers & other parameters | ||||
|     optimizer = optim.Adam(a2c_agent.parameters(), lr=1e-3) | ||||
|     n_timesteps = 10 | ||||
|     cfg_optim = cfg['algorithm']['optimizer'] | ||||
|     optimizers = [get_class(cfg_optim)(a2c_agent.parameters(), **get_arguments(cfg_optim)) | ||||
|                   for a2c_agent in a2c_agents] | ||||
|     n_timesteps = cfg['algorithm']['n_timesteps'] | ||||
|  | ||||
|     # Decision making loop | ||||
|     for epoch in range(200000): | ||||
|         workspace.zero_grad() | ||||
|         if epoch > 0: | ||||
|             workspace.copy_n_last_steps(1) | ||||
|             acquisition_agent(workspace, t=1, n_steps=n_timesteps-1, stochastic=True) | ||||
|         else: | ||||
|             acquisition_agent(workspace, t=0, n_steps=n_timesteps, stochastic=True) | ||||
|         #for k in workspace.keys(): | ||||
|         #    print(f'{k} ==> {workspace[k].size()}') | ||||
|         critic, done, action_probs, reward, action = workspace[ | ||||
|             "critic", "env/done", "action_probs", "env/reward", "action" | ||||
|         ] | ||||
|     best = -float('inf') | ||||
|     with tqdm(range(int(cfg['algorithm']['max_epochs'] / n_timesteps))) as pbar: | ||||
|         for epoch in pbar: | ||||
|             workspace.zero_grad() | ||||
|             if epoch > 0: | ||||
|                 workspace.copy_n_last_steps(1) | ||||
|                 acquisition_agent(workspace, t=1, n_steps=n_timesteps-1, stochastic=True) | ||||
|             else: | ||||
|                 acquisition_agent(workspace, t=0, n_steps=n_timesteps,  stochastic=True) | ||||
|  | ||||
|         target = reward[1:] + 0.99 * critic[1:].detach() * (1 - done[1:].float()) | ||||
|         td = target - critic[:-1] | ||||
|         td_error = td ** 2 | ||||
|         critic_loss = td_error.mean() | ||||
|         entropy_loss = Categorical(action_probs).entropy().mean() | ||||
|         action_logp = _index(action_probs, action).log() | ||||
|         a2c_loss = action_logp[:-1] * td.detach() | ||||
|         a2c_loss = a2c_loss.mean() | ||||
|         loss = ( | ||||
|             -0.001 * entropy_loss | ||||
|             + 1.0 * critic_loss | ||||
|             - 0.1 * a2c_loss | ||||
|         ) | ||||
|         optimizer.zero_grad() | ||||
|         loss.backward() | ||||
|         optimizer.step() | ||||
|             for agent_id in range(n_agents): | ||||
|                 critic, done, action_probs, reward, action = workspace[ | ||||
|                     f"agent{agent_id}_critic", "env/done", | ||||
|                     f'agent{agent_id}_action_probs', "env/reward", | ||||
|                     f"agent{agent_id}_action" | ||||
|                 ] | ||||
|                 td = gae(critic, reward, done, 0.99, 0.3) | ||||
|                 td_error = td ** 2 | ||||
|                 critic_loss = td_error.mean() | ||||
|                 entropy_loss = Categorical(action_probs).entropy().mean() | ||||
|                 action_logp = _index(action_probs, action).log() | ||||
|                 a2c_loss = action_logp[:-1] * td.detach() | ||||
|                 a2c_loss = a2c_loss.mean() | ||||
|                 loss = ( | ||||
|                     -0.001 * entropy_loss | ||||
|                     + 1.0 * critic_loss | ||||
|                     - 0.1 * a2c_loss | ||||
|                 ) | ||||
|                 optimizer = optimizers[agent_id] | ||||
|                 optimizer.zero_grad() | ||||
|                 loss.backward() | ||||
|                 #torch.nn.utils.clip_grad_norm_(a2c_agents[agent_id].parameters(), 2) | ||||
|                 optimizer.step() | ||||
|  | ||||
|                 # Compute the cumulated reward on final_state | ||||
|                 creward = workspace["env/cumulated_reward"] | ||||
|                 creward = creward[done] | ||||
|                 if creward.size()[0] > 0: | ||||
|                     cum_r = creward.mean().item() | ||||
|                     if cum_r > best: | ||||
|                     #    torch.save(a2c_agent.state_dict(), Path(__file__).parent / f'agent_{uid}.pt') | ||||
|                         best = cum_r | ||||
|                     pbar.set_description(f"Cum. r: {cum_r:.2f}, Best r. so far: {best:.2f}", refresh=True) | ||||
|  | ||||
|         # Compute the cumulated reward on final_state | ||||
|         creward = workspace["env/cumulated_reward"] | ||||
|         creward = creward[done] | ||||
|         if creward.size()[0] > 0: | ||||
|             print(f"Cumulative reward at A2C step #{(1+epoch)*n_timesteps}: {creward.mean().item()}") | ||||
							
								
								
									
										26
									
								
								studies/sat_mad.yaml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										26
									
								
								studies/sat_mad.yaml
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,26 @@ | ||||
| agent: | ||||
|   classname:        studies.sat_mad.A2CAgent | ||||
|   observation_size: 4*5*5 | ||||
|   hidden_size:      128 | ||||
|   n_actions:        10 | ||||
|  | ||||
| env: | ||||
|   classname:      environments.factory.make | ||||
|   env_name:       "DirtyFactory-v0" | ||||
|   n_agents:       1 | ||||
|   pomdp_r:        2 | ||||
|   max_steps:      400 | ||||
|   stack_n_frames: 3 | ||||
|  | ||||
| algorithm: | ||||
|   max_epochs:             1000000 | ||||
|   n_envs:                 1 | ||||
|   n_timesteps:            16 | ||||
|   discount_factor:        0.99 | ||||
|   entropy_coef:           0.01 | ||||
|   critic_coef:            1.0 | ||||
|   gae:                    0.3 | ||||
|   optimizer: | ||||
|     classname:            torch.optim.Adam | ||||
|     lr:                   0.0003 | ||||
|     weight_decay:         0.0 | ||||
							
								
								
									
										39
									
								
								studies/viz_salina.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										39
									
								
								studies/viz_salina.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,39 @@ | ||||
| from salina.agents import Agents, TemporalAgent | ||||
| import torch | ||||
| from salina import Workspace, get_arguments, get_class, instantiate_class | ||||
| from pathlib import Path | ||||
| from salina.agents.gyma import GymAgent | ||||
| import time | ||||
| from algorithms.utils import load_yaml_file, add_env_props | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     # Setup workspace | ||||
|     uid = time.time() | ||||
|     workspace = Workspace() | ||||
|     weights = Path('/Users/romue/PycharmProjects/EDYS/studies/agent_1636994369.145843.pt') | ||||
|  | ||||
|     cfg = load_yaml_file(Path(__file__).parent / 'sat_mad.yaml') | ||||
|     add_env_props(cfg) | ||||
|     cfg['env'].update({'n_agents': 2}) | ||||
|  | ||||
|     # instantiate agent and env | ||||
|     env_agent = GymAgent( | ||||
|         get_class(cfg['env']), | ||||
|         get_arguments(cfg['env']), | ||||
|         n_envs=1 | ||||
|     ) | ||||
|  | ||||
|     agents = [] | ||||
|     for _ in range(2): | ||||
|         a2c_agent = instantiate_class(cfg['agent']) | ||||
|         if weights: | ||||
|             a2c_agent.load_state_dict(torch.load(weights)) | ||||
|         agents.append(a2c_agent) | ||||
|  | ||||
|     # combine agents | ||||
|     acquisition_agent = TemporalAgent(Agents(env_agent, *agents)) | ||||
|     acquisition_agent.seed(42) | ||||
|  | ||||
|     acquisition_agent(workspace, t=0, n_steps=400, stochastic=False, save_render=True) | ||||
|  | ||||
|  | ||||
		Reference in New Issue
	
	Block a user
	 Robert Müller
					Robert Müller