mirror of
				https://github.com/illiumst/marl-factory-grid.git
				synced 2025-10-31 04:37:25 +01:00 
			
		
		
		
	moved renderer.py to base, added initial salina experiments
This commit is contained in:
		| @@ -1,4 +1,4 @@ | ||||
| def make(env_str, n_agents=1, pomdp_r=2, max_steps=400): | ||||
| def make(env_str, n_agents=1, pomdp_r=2, max_steps=400, stack_n_frames=3): | ||||
|     import yaml | ||||
|     from pathlib import Path | ||||
|     from environments.factory.combined_factories import DirtItemFactory | ||||
| @@ -9,7 +9,8 @@ def make(env_str, n_agents=1, pomdp_r=2, max_steps=400): | ||||
|     with (Path(__file__).parent / 'levels' / 'parameters' / f'{env_str}.yaml').open('r') as stream: | ||||
|         dictionary = yaml.load(stream, Loader=yaml.FullLoader) | ||||
|  | ||||
|     obs_props = ObservationProperties(render_agents=AgentRenderOptions.COMBINED, frames_to_stack=0, pomdp_r=pomdp_r) | ||||
|     obs_props = ObservationProperties(render_agents=AgentRenderOptions.COMBINED, | ||||
|                                       frames_to_stack=stack_n_frames, pomdp_r=pomdp_r) | ||||
|  | ||||
|     factory_kwargs = dict(n_agents=n_agents, max_steps=max_steps, obs_prop=obs_props, | ||||
|                           mv_prop=MovementProperties(**dictionary['movement_props']), | ||||
| @@ -17,4 +18,4 @@ def make(env_str, n_agents=1, pomdp_r=2, max_steps=400): | ||||
|                           record_episodes=False, verbose=False, **dictionary['factory_props'] | ||||
|                           ) | ||||
|  | ||||
|     return DirtFactory(**factory_kwargs) | ||||
|     return DirtFactory(**factory_kwargs).__enter__() | ||||
|   | ||||
| @@ -544,7 +544,7 @@ class BaseFactory(gym.Env): | ||||
|  | ||||
|     def render(self, mode='human'): | ||||
|         if not self._renderer:  # lazy init | ||||
|             from environments.factory.renderer import Renderer, RenderEntity | ||||
|             from environments.factory.base.renderer import Renderer, RenderEntity | ||||
|             global Renderer, RenderEntity | ||||
|             height, width = self._obs_cube.shape[1:] | ||||
|             self._renderer = Renderer(width, height, view_radius=self._pomdp_r, fps=5) | ||||
| @@ -562,7 +562,7 @@ class BaseFactory(gym.Env): | ||||
|                 doors.append(RenderEntity(name, door.pos, 1, 'none', state, i + 1)) | ||||
|         additional_assets = self.render_additional_assets() | ||||
|  | ||||
|         self._renderer.render(walls + doors + additional_assets + agents) | ||||
|         return self._renderer.render(walls + doors + additional_assets + agents) | ||||
|  | ||||
|     def save_params(self, filepath: Path): | ||||
|         # noinspection PyProtectedMember | ||||
|   | ||||
| @@ -7,6 +7,8 @@ import pygame | ||||
| from typing import NamedTuple, Any | ||||
| import time | ||||
| 
 | ||||
| import torch | ||||
| 
 | ||||
| 
 | ||||
| class RenderEntity(NamedTuple): | ||||
|     name: str | ||||
| @@ -22,7 +24,7 @@ class Renderer: | ||||
|     BG_COLOR = (178, 190, 195)         # (99, 110, 114) | ||||
|     WHITE = (223, 230, 233)            # (200, 200, 200) | ||||
|     AGENT_VIEW_COLOR = (9, 132, 227) | ||||
|     ASSETS = Path(__file__).parent / 'assets' | ||||
|     ASSETS = Path(__file__).parent.parent / 'assets' | ||||
| 
 | ||||
|     def __init__(self, grid_w=16, grid_h=16, cell_size=40, fps=7,  grid_lines=True, view_radius=2): | ||||
|         self.grid_h = grid_h | ||||
| @@ -121,6 +123,8 @@ class Renderer: | ||||
| 
 | ||||
|         pygame.display.flip() | ||||
|         self.clock.tick(self.fps) | ||||
|         rgb_obs = pygame.surfarray.array3d(self.screen) | ||||
|         return torch.from_numpy(rgb_obs).permute(2, 0, 1) | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
| @@ -1,11 +1,11 @@ | ||||
| from typing import Union, NamedTuple, Dict | ||||
| from typing import Union, NamedTuple | ||||
|  | ||||
| import numpy as np | ||||
|  | ||||
| from environments.factory.base.base_factory import BaseFactory | ||||
| from environments.factory.base.objects import Agent, Action, Entity | ||||
| from environments.factory.base.registers import EntityObjectRegister, ObjectRegister | ||||
| from environments.factory.renderer import RenderEntity | ||||
| from environments.factory.base.renderer import RenderEntity | ||||
| from environments.helpers import Constants as c | ||||
|  | ||||
| from environments import helpers as h | ||||
|   | ||||
| @@ -1,6 +1,5 @@ | ||||
| import time | ||||
| from enum import Enum | ||||
| from pathlib import Path | ||||
| from typing import List, Union, NamedTuple, Dict | ||||
| import random | ||||
|  | ||||
| @@ -12,8 +11,7 @@ from environments.factory.base.base_factory import BaseFactory | ||||
| from environments.factory.base.objects import Agent, Action, Entity, Tile | ||||
| from environments.factory.base.registers import Entities, MovingEntityObjectRegister | ||||
|  | ||||
| from environments.factory.renderer import RenderEntity | ||||
| from environments.logging.recorder import RecorderCallback | ||||
| from environments.factory.base.renderer import RenderEntity | ||||
| from environments.utility_classes import ObservationProperties | ||||
|  | ||||
| CLEAN_UP_ACTION = h.EnvActions.CLEAN_UP | ||||
|   | ||||
| @@ -10,9 +10,9 @@ from environments.helpers import Constants as c | ||||
| from environments import helpers as h | ||||
| from environments.factory.base.objects import Agent, Entity, Action, Tile, MoveableEntity | ||||
| from environments.factory.base.registers import Entities, EntityObjectRegister, ObjectRegister, \ | ||||
|     MovingEntityObjectRegister, Register | ||||
|     MovingEntityObjectRegister | ||||
|  | ||||
| from environments.factory.renderer import RenderEntity | ||||
| from environments.factory.base.renderer import RenderEntity | ||||
|  | ||||
|  | ||||
| NO_ITEM = 0 | ||||
|   | ||||
| @@ -1,29 +1,100 @@ | ||||
| from environments.factory import make | ||||
| import salina | ||||
| from salina import Workspace, TAgent | ||||
| from salina.agents.gyma import AutoResetGymAgent, GymAgent | ||||
| from salina.agents import Agents, TemporalAgent | ||||
| from salina.rl.functional import _index | ||||
| import torch | ||||
| from gym.wrappers import FrameStack | ||||
| import torch.nn as nn | ||||
| from torch.nn.utils import spectral_norm | ||||
| import torch.optim as optim | ||||
| from torch.distributions import Categorical | ||||
|  | ||||
|  | ||||
| class MyAgent(salina.TAgent): | ||||
|     def __init__(self): | ||||
|         super(MyAgent, self).__init__() | ||||
| class A2CAgent(TAgent): | ||||
|     def __init__(self, observation_size, hidden_size, n_actions): | ||||
|         super().__init__() | ||||
|         self.model = nn.Sequential( | ||||
|             nn.Flatten(), | ||||
|             nn.Linear(observation_size, hidden_size), | ||||
|             nn.ELU(), | ||||
|             nn.Linear(hidden_size, hidden_size), | ||||
|             nn.ELU(), | ||||
|             nn.Linear(hidden_size, n_actions), | ||||
|         ) | ||||
|         self.critic_model = nn.Sequential( | ||||
|             nn.Flatten(), | ||||
|             nn.Linear(observation_size, hidden_size), | ||||
|             nn.ELU(), | ||||
|             spectral_norm(nn.Linear(hidden_size, 1)), | ||||
|         ) | ||||
|  | ||||
|     def forward(self, t, **kwargs): | ||||
|         self.set(('timer', t), torch.tensor([t])) | ||||
|     def forward(self, t, stochastic, **kwargs): | ||||
|         observation = self.get(("env/env_obs", t)) | ||||
|         scores = self.model(observation) | ||||
|         probs = torch.softmax(scores, dim=-1) | ||||
|         critic = self.critic_model(observation).squeeze(-1) | ||||
|         if stochastic: | ||||
|             action = torch.distributions.Categorical(probs).sample() | ||||
|         else: | ||||
|             action = probs.argmax(1) | ||||
|  | ||||
|         self.set(("action", t), action) | ||||
|         self.set(("action_probs", t), probs) | ||||
|         self.set(("critic", t), critic) | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     n_agents = 1 | ||||
|     env = make('DirtyFactory-v0', n_agents=n_agents) | ||||
|     env = FrameStack(env, num_stack=3) | ||||
|     env.reset() | ||||
|     agent = MyAgent() | ||||
|     workspace = salina.Workspace() | ||||
|     agent(workspace, t=0, n_steps=10) | ||||
|     # Setup agents and workspace | ||||
|     env_agent = AutoResetGymAgent(make, dict(env_str='DirtyFactory-v0'), n_envs=1) | ||||
|     a2c_agent = A2CAgent(3*4*5*5, 96, 10) | ||||
|     workspace = Workspace() | ||||
|  | ||||
|     print(workspace) | ||||
|     eval_agent = Agents(GymAgent(make, dict(env_str='DirtyFactory-v0'), n_envs=1), a2c_agent) | ||||
|     for i in range(100): | ||||
|         eval_agent(workspace, t=i, save_render=True, stochastic=True) | ||||
|  | ||||
|     assert False | ||||
|     # combine agents | ||||
|     acquisition_agent = TemporalAgent(Agents(env_agent, a2c_agent)) | ||||
|     acquisition_agent.seed(0) | ||||
|  | ||||
|     for i in range(1000): | ||||
|         state, *_ = env.step([env.unwrapped.action_space.sample() for _ in range(n_agents)]) | ||||
|         #env.render() | ||||
|     # optimizers & other parameters | ||||
|     optimizer = optim.Adam(a2c_agent.parameters(), lr=1e-3) | ||||
|     n_timesteps = 10 | ||||
|  | ||||
|     # Decision making loop | ||||
|     for epoch in range(200000): | ||||
|         workspace.zero_grad() | ||||
|         if epoch > 0: | ||||
|             workspace.copy_n_last_steps(1) | ||||
|             acquisition_agent(workspace, t=1, n_steps=n_timesteps-1, stochastic=True) | ||||
|         else: | ||||
|             acquisition_agent(workspace, t=0, n_steps=n_timesteps, stochastic=True) | ||||
|         #for k in workspace.keys(): | ||||
|         #    print(f'{k} ==> {workspace[k].size()}') | ||||
|         critic, done, action_probs, reward, action = workspace[ | ||||
|             "critic", "env/done", "action_probs", "env/reward", "action" | ||||
|         ] | ||||
|  | ||||
|         target = reward[1:] + 0.99 * critic[1:].detach() * (1 - done[1:].float()) | ||||
|         td = target - critic[:-1] | ||||
|         td_error = td ** 2 | ||||
|         critic_loss = td_error.mean() | ||||
|         entropy_loss = Categorical(action_probs).entropy().mean() | ||||
|         action_logp = _index(action_probs, action).log() | ||||
|         a2c_loss = action_logp[:-1] * td.detach() | ||||
|         a2c_loss = a2c_loss.mean() | ||||
|         loss = ( | ||||
|             -0.001 * entropy_loss | ||||
|             + 1.0 * critic_loss | ||||
|             - 0.1 * a2c_loss | ||||
|         ) | ||||
|         optimizer.zero_grad() | ||||
|         loss.backward() | ||||
|         optimizer.step() | ||||
|  | ||||
|         # Compute the cumulated reward on final_state | ||||
|         creward = workspace["env/cumulated_reward"] | ||||
|         creward = creward[done] | ||||
|         if creward.size()[0] > 0: | ||||
|             print(f"Cumulative reward at A2C step #{(1+epoch)*n_timesteps}: {creward.mean().item()}") | ||||
		Reference in New Issue
	
	Block a user
	 romue
					romue