mirror of
				https://github.com/illiumst/marl-factory-grid.git
				synced 2025-10-31 04:37:25 +01:00 
			
		
		
		
	first icm steps
This commit is contained in:
		| @@ -24,7 +24,7 @@ class BaseLearner: | ||||
|         self.n_agents = n_agents | ||||
|         self.n_grad_steps = n_grad_steps | ||||
|         self.train_every = train_every | ||||
|         self.stack_n_frames = deque(stack_n_frames) | ||||
|         self.stack_n_frames = deque(maxlen=stack_n_frames) | ||||
|         self.device = 'cpu' | ||||
|         self.n_updates = 0 | ||||
|         self.step = 0 | ||||
| @@ -51,6 +51,9 @@ class BaseLearner: | ||||
|     def on_episode_end(self, n_steps): | ||||
|         pass | ||||
|  | ||||
|     def on_all_done(self): | ||||
|         pass | ||||
|  | ||||
|     def train(self): | ||||
|         pass | ||||
|  | ||||
| @@ -93,6 +96,7 @@ class BaseLearner: | ||||
|                         f' eps: {self.eps:.4f}\tRunning loss: {sum(list(self.running_loss)) / len(self.running_loss):.4f}\tUpdates:{self.n_updates}') | ||||
|             except Exception as e: | ||||
|                 pass | ||||
|         self.on_all_done() | ||||
|  | ||||
|  | ||||
| class BaseBuffer: | ||||
| @@ -180,18 +184,17 @@ class BaseDDQN(BaseDQN): | ||||
|         return values + (advantages - advantages.mean()) | ||||
|  | ||||
|  | ||||
| class QTRANtestNet(nn.Module): | ||||
|     def __init__(self, backbone_dims=[3*5*5, 64, 64], q_head=[64, 9]): | ||||
|         super(QTRANtestNet, self).__init__() | ||||
|         self.backbone = mlp_maker(backbone_dims, flatten=True, activation_last='elu') | ||||
|         self.q_head = mlp_maker(q_head) | ||||
| class BaseICM(nn.Module): | ||||
|     def __init__(self, backbone_dims=[2*3*5*5, 64, 64], head_dims=[2*64, 64, 9]): | ||||
|         super(BaseICM, self).__init__() | ||||
|         self.backbone = mlp_maker(backbone_dims, flatten=True) | ||||
|         self.icm = mlp_maker(head_dims) | ||||
|         self.ce = nn.CrossEntropyLoss() | ||||
|  | ||||
|     def forward(self, x): | ||||
|         features = self.backbone(x) | ||||
|         qs = self.q_head(features) | ||||
|         return qs, features | ||||
|  | ||||
|     @torch.no_grad() | ||||
|     def act(self, x) -> np.ndarray: | ||||
|         action = self.forward(x)[0].max(-1)[1].numpy() | ||||
|         return action | ||||
|     def forward(self, s0, s1, a): | ||||
|         phi_s0 = self.backbone(s0) | ||||
|         phi_s1 = self.backbone(s1) | ||||
|         cat = torch.cat((phi_s0, phi_s1), dim=1) | ||||
|         a_prime = torch.softmax(self.icm(cat), dim=-1) | ||||
|         ce = self.ce(a_prime, a) | ||||
|         return dict(prediction=a_prime, loss=ce) | ||||
| @@ -50,4 +50,22 @@ class MQLearner(QLearner): | ||||
|  | ||||
|             # Compute loss | ||||
|             loss = torch.mean(self.reg_weight * pred_q + torch.pow(pred_q - m_q_target, 2)) | ||||
|             self._backprop_loss(loss) | ||||
|             self._backprop_loss(loss) | ||||
|  | ||||
| from tqdm import trange | ||||
| class MQICMLearner(MQLearner): | ||||
|     def __init__(self, *args, icm, **kwargs): | ||||
|         super(MQICMLearner, self).__init__(*args, **kwargs) | ||||
|         self.icm = icm | ||||
|         self.icm_optimizer = torch.optim.Adam(self.icm.parameters()) | ||||
|  | ||||
|     def on_all_done(self): | ||||
|         for b in trange(50000): | ||||
|             batch = self.buffer.sample(128, 0) | ||||
|             s0, s1, a = batch.observation,  batch.next_observation, batch.action | ||||
|             loss = self.icm(s0, s1, a.squeeze())['loss'] | ||||
|             self.icm_optimizer.zero_grad() | ||||
|             loss.backward() | ||||
|             self.icm_optimizer.step() | ||||
|             if b%100 == 0: | ||||
|                 print(loss.item()) | ||||
|   | ||||
| @@ -100,32 +100,28 @@ class QLearner(BaseLearner): | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     from environments.factory.simple_factory import SimpleFactory, DirtProperties, MovementProperties | ||||
|     from algorithms.common import BaseDDQN | ||||
|     from algorithms.m_q_learner import MQLearner | ||||
|     from algorithms.common import BaseDDQN, BaseICM | ||||
|     from algorithms.m_q_learner import MQLearner, MQICMLearner | ||||
|     from algorithms.vdn_learner import VDNLearner | ||||
|     from algorithms.udr_learner import UDRLearner | ||||
|  | ||||
|     N_AGENTS = 1 | ||||
|  | ||||
|     dirt_props = DirtProperties(clean_amount=1, gain_amount=0.1, max_global_amount=20, | ||||
|                                 max_local_amount=1, spawn_frequency=5, max_spawn_ratio=0.05, | ||||
|                                 dirt_smear_amount=0.0) | ||||
|     move_props = MovementProperties(allow_diagonal_movement=True, | ||||
|                                     allow_square_movement=True, | ||||
|                                     allow_no_op=False) | ||||
|  | ||||
|     env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=2, max_steps=400, parse_doors=False, | ||||
|                         movement_properties=move_props, level_name='rooms', frames_to_stack=0, | ||||
|                         omit_agent_slice_in_obs=True, combin_agent_slices_in_obs=True, record_episodes=False | ||||
|                     ) | ||||
|     with (Path(f'../environments/factory/env_default_param.yaml')).open('r') as f: | ||||
|         env_kwargs = yaml.load(f, Loader=yaml.FullLoader) | ||||
|  | ||||
|     env = SimpleFactory(**env_kwargs) | ||||
|     obs_shape = np.prod(env.observation_space.shape) | ||||
|     n_actions = env.action_space.n | ||||
|  | ||||
|     dqn, target_dqn = BaseDDQN(backbone_dims=[obs_shape, 128, 128], advantage_dims=[128, n_actions], value_dims=[128,1], activation='leaky_relu'),\ | ||||
|                       BaseDDQN(backbone_dims=[obs_shape, 128, 128], advantage_dims=[128, n_actions], value_dims=[128,1], activation='leaky_relu') | ||||
|     dqn, target_dqn = BaseDDQN(backbone_dims=[obs_shape, 128, 128], advantage_dims=[128, n_actions], value_dims=[128, 1], activation='leaky_relu'),\ | ||||
|                       BaseDDQN(backbone_dims=[obs_shape, 128, 128], advantage_dims=[128, n_actions], value_dims=[128, 1], activation='leaky_relu') | ||||
|  | ||||
|     learner = MQLearner(dqn, target_dqn, env, 50000, target_update=5000, lr=0.0007, gamma=0.99, n_agents=N_AGENTS, tau=0.95, max_grad_norm=10, | ||||
|                         train_every=('step', 4), eps_end=0.025, n_grad_steps=1, reg_weight=0.1, exploration_fraction=0.25, batch_size=64, weight_decay=1e-3) | ||||
|     icm = BaseICM(backbone_dims=[obs_shape, 64, 32], head_dims=[2*32, 64, n_actions]) | ||||
|  | ||||
|     learner = MQICMLearner(dqn, target_dqn, env, 50000, icm=icm, | ||||
|                            target_update=5000, lr=0.0007, gamma=0.99, n_agents=N_AGENTS, tau=0.95, max_grad_norm=10, | ||||
|                            train_every=('step', 4), eps_end=0.025, n_grad_steps=1, reg_weight=0.1, exploration_fraction=0.25, | ||||
|                            batch_size=64, weight_decay=1e-3 | ||||
|                            ) | ||||
|     #learner.save(Path(__file__).parent / 'test' / 'testexperiment1337') | ||||
|     learner.learn(100000) | ||||
|   | ||||
| @@ -15,7 +15,7 @@ class Entity(NamedTuple): | ||||
|     value_operation: str = 'none' | ||||
|     state: str = None | ||||
|     id: int = 0 | ||||
|     aux:Any=None | ||||
|     aux: Any = None | ||||
|  | ||||
|  | ||||
| class Renderer: | ||||
|   | ||||
							
								
								
									
										0
									
								
								studies/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								studies/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										9
									
								
								studies/sat_mad.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										9
									
								
								studies/sat_mad.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,9 @@ | ||||
| import numpy as np | ||||
|  | ||||
|  | ||||
| class SatMad(object): | ||||
|     def __init__(self): | ||||
|         pass | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     pass | ||||
		Reference in New Issue
	
	Block a user
	 romue
					romue