mirror of
				https://github.com/illiumst/marl-factory-grid.git
				synced 2025-10-31 12:37:27 +01:00 
			
		
		
		
	added RegDQN
This commit is contained in:
		
							
								
								
									
										0
									
								
								algorithms/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								algorithms/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										52
									
								
								algorithms/dqn_reg.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										52
									
								
								algorithms/dqn_reg.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,52 @@ | ||||
| import numpy as np | ||||
| import torch | ||||
| import stable_baselines3 as sb3 | ||||
| from stable_baselines3.common import logger | ||||
|  | ||||
|  | ||||
| class RegDQN(sb3.dqn.DQN): | ||||
|     def __init__(self, *args, reg_weight=0.1, **kwargs): | ||||
|         super().__init__(*args, **kwargs) | ||||
|         self.reg_weight = reg_weight | ||||
|  | ||||
|     def train(self, gradient_steps: int, batch_size: int = 100) -> None: | ||||
|         # Update learning rate according to schedule | ||||
|         self._update_learning_rate(self.policy.optimizer) | ||||
|  | ||||
|         losses = [] | ||||
|         for _ in range(gradient_steps): | ||||
|             # Sample replay buffer | ||||
|             replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) | ||||
|  | ||||
|             with torch.no_grad(): | ||||
|                 # Compute the next Q-values using the target network | ||||
|                 next_q_values = self.q_net_target(replay_data.next_observations) | ||||
|                 # Follow greedy policy: use the one with the highest value | ||||
|                 next_q_values, _ = next_q_values.max(dim=1) | ||||
|                 # Avoid potential broadcast issue | ||||
|                 next_q_values = next_q_values.reshape(-1, 1) | ||||
|                 # 1-step TD target | ||||
|                 target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values | ||||
|  | ||||
|             # Get current Q-values estimates | ||||
|             current_q_values = self.q_net(replay_data.observations) | ||||
|  | ||||
|             # Retrieve the q-values for the actions from the replay buffer | ||||
|             current_q_values = torch.gather(current_q_values, dim=1, index=replay_data.actions.long()) | ||||
|  | ||||
|             delta = current_q_values - target_q_values | ||||
|             loss = torch.mean(self.reg_weight * current_q_values + torch.pow(delta, 2)) | ||||
|             losses.append(loss.item()) | ||||
|  | ||||
|             # Optimize the policy | ||||
|             self.policy.optimizer.zero_grad() | ||||
|             loss.backward() | ||||
|             # Clip gradient norm | ||||
|             torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) | ||||
|             self.policy.optimizer.step() | ||||
|  | ||||
|         # Increase update counter | ||||
|         self._n_updates += gradient_steps | ||||
|  | ||||
|         logger.record("train/n_updates", self._n_updates, exclude="tensorboard") | ||||
|         logger.record("train/loss", np.mean(losses)) | ||||
		Reference in New Issue
	
	Block a user
	 romue
					romue