diff --git a/algorithms/common.py b/algorithms/common.py
index 2c2f678..ddc1136 100644
--- a/algorithms/common.py
+++ b/algorithms/common.py
@@ -71,6 +71,7 @@ def mlp_maker(dims, flatten=False, activation='elu', activation_last='identity')
     return nn.Sequential(OrderedDict(layers))
 
 
+
 class BaseDQN(nn.Module):
     def __init__(self, dims=[3*5*5, 64, 64, 9]):
         super(BaseDQN, self).__init__()
@@ -100,3 +101,20 @@ class BaseDDQN(BaseDQN):
         advantages = self.advantage_head(features)
         values = self.value_head(features)
         return values + (advantages - advantages.mean())
+
+
+class QTRANtestNet(nn.Module):
+    def __init__(self, backbone_dims=[3*5*5, 64, 64], q_head=[64, 9]):
+        super(QTRANtestNet, self).__init__()
+        self.backbone = mlp_maker(backbone_dims, flatten=True, activation_last='elu')
+        self.q_head = mlp_maker(q_head)
+
+    def forward(self, x):
+        features = self.backbone(x)
+        qs = self.q_head(features)
+        return qs, features
+
+    @torch.no_grad()
+    def act(self, x) -> np.ndarray:
+        action = self.forward(x)[0].max(-1)[1].numpy()
+        return action
\ No newline at end of file
diff --git a/algorithms/q_learner.py b/algorithms/q_learner.py
index d6ee864..10d34a2 100644
--- a/algorithms/q_learner.py
+++ b/algorithms/q_learner.py
@@ -128,7 +128,7 @@ if __name__ == '__main__':
     from algorithms.common import BaseDDQN
     from algorithms.vdn_learner import VDNLearner
 
-    N_AGENTS = 1
+    N_AGENTS = 2
 
     dirt_props = DirtProperties(clean_amount=3, gain_amount=0.2, max_global_amount=30,
                                 max_local_amount=5, spawn_frequency=1, max_spawn_ratio=0.05)
@@ -138,7 +138,7 @@ if __name__ == '__main__':
     env = SimpleFactory(dirt_properties=dirt_props, movement_properties=move_props, n_agents=N_AGENTS, pomdp_radius=2,  max_steps=400, omit_agent_slice_in_obs=False, combin_agent_slices_in_obs=True)
 
     dqn, target_dqn = BaseDDQN(), BaseDDQN()
-    learner = QLearner(dqn, target_dqn, env, 40000, target_update=3500, lr=0.0007, gamma=0.99, n_agents=N_AGENTS, tau=0.95, max_grad_norm=10,
+    learner = VDNLearner(dqn, target_dqn, env, 40000, target_update=3500, lr=0.0007, gamma=0.99, n_agents=N_AGENTS, tau=0.95, max_grad_norm=10,
                        train_every_n_steps=4, eps_end=0.025, n_grad_steps=1, reg_weight=0.1, exploration_fraction=0.25, batch_size=64)
     #learner.save(Path(__file__).parent / 'test' / 'testexperiment1337')
     learner.learn(100000)
diff --git a/algorithms/qtran_learner.py b/algorithms/qtran_learner.py
new file mode 100644
index 0000000..fc6cc24
--- /dev/null
+++ b/algorithms/qtran_learner.py
@@ -0,0 +1,48 @@
+import torch
+from algorithms.q_learner import QLearner
+
+
+class QTRANLearner(QLearner):
+    def __init__(self, *args, weight_opt=1., weigt_nopt=1., **kwargs):
+        super(QTRANLearner, self).__init__(*args, **kwargs)
+        assert self.n_agents >= 2, 'QTRANLearner requires more than one agent, use QLearner instead'
+        self.weight_opt = weight_opt
+        self.weigt_nopt = weigt_nopt
+
+    def _training_routine(self, obs, next_obs, action):
+        # todo remove - is inherited - only used while implementing qtran
+        current_q_values = self.q_net(obs)
+        current_q_values = torch.gather(current_q_values, dim=-1, index=action)
+        next_q_values_raw = self.target_q_net(next_obs).max(dim=-1)[0].reshape(-1, 1).detach()
+        return current_q_values, next_q_values_raw
+
+    def local_qs(self, observations, actions):
+        Q_jt = torch.zeros_like(actions)  # placeholder to sum up individual q values
+        features = []
+        for agent_i in range(self.n_agents):
+            q_values_agent_i, features_agent_i = self.q_net(observations[:, agent_i])  # Individual action-value network
+            q_values_agent_i = torch.gather(q_values_agent_i, dim=-1, index=actions[:, agent_i].unsqueeze(-1))
+            Q_jt += q_values_agent_i
+            features.append(features_agent_i)
+        feature_sum = torch.stack(features, 0).sum(0)  # (n_agents x hdim) -> hdim
+        return Q_jt
+
+    def train(self):
+        if len(self.buffer) < self.batch_size: return
+        for _ in range(self.n_grad_steps):
+            experience = self.buffer.sample(self.batch_size, cer=self.train_every_n_steps)
+
+            Q_jt_prime = self.local_qs(experience.observation, experience.action)  # sum of individual q-vals
+            Q_jt = None
+            V_jt = None
+
+            pred_q, target_q_raw = torch.zeros((self.batch_size, 1)), torch.zeros((self.batch_size, 1))
+            for agent_i in range(self.n_agents):
+                q_values, next_q_values_raw = self._training_routine(experience.observation[:, agent_i],
+                                                                     experience.next_observation[:, agent_i],
+                                                                     experience.action[:, agent_i].unsqueeze(-1))
+                pred_q += q_values
+                target_q_raw += next_q_values_raw
+            target_q = experience.reward + (1 - experience.done) * self.gamma * target_q_raw
+            loss = torch.mean(self.reg_weight * pred_q + torch.pow(pred_q - target_q, 2))
+            self._backprop_loss(loss)
\ No newline at end of file
diff --git a/algorithms/vdn_learner.py b/algorithms/vdn_learner.py
index f50c6ab..504adb0 100644
--- a/algorithms/vdn_learner.py
+++ b/algorithms/vdn_learner.py
@@ -1,4 +1,6 @@
+from typing import Union
 import torch
+import numpy as np
 from algorithms.q_learner import QLearner
 
 
@@ -7,6 +9,21 @@ class VDNLearner(QLearner):
         super(VDNLearner, self).__init__(*args, **kwargs)
         assert self.n_agents >= 2, 'VDN requires more than one agent, use QLearner instead'
 
+    def get_action(self, obs) -> Union[int, np.ndarray]:
+        o = torch.from_numpy(obs).unsqueeze(0) if self.n_agents <= 1 else torch.from_numpy(obs)
+        eps = np.random.rand(self.n_agents)
+        greedy = eps > self.eps
+        agent_actions = None
+        actions = []
+        for i in range(self.n_agents):
+            if greedy[i]:
+                if agent_actions is None: agent_actions = self.q_net.act(o.float())
+                action = agent_actions[i]
+            else:
+                action = self.env.action_space.sample()
+            actions.append(action)
+        return np.array(actions)
+
     def train(self):
         if len(self.buffer) < self.batch_size: return
         for _ in range(self.n_grad_steps):