diff --git a/algorithms/common.py b/algorithms/common.py index 97d166c..876f689 100644 --- a/algorithms/common.py +++ b/algorithms/common.py @@ -24,7 +24,7 @@ class BaseLearner: self.n_agents = n_agents self.n_grad_steps = n_grad_steps self.train_every = train_every - self.stack_n_frames = deque(stack_n_frames) + self.stack_n_frames = deque(maxlen=stack_n_frames) self.device = 'cpu' self.n_updates = 0 self.step = 0 @@ -51,6 +51,9 @@ class BaseLearner: def on_episode_end(self, n_steps): pass + def on_all_done(self): + pass + def train(self): pass @@ -93,6 +96,7 @@ class BaseLearner: f' eps: {self.eps:.4f}\tRunning loss: {sum(list(self.running_loss)) / len(self.running_loss):.4f}\tUpdates:{self.n_updates}') except Exception as e: pass + self.on_all_done() class BaseBuffer: @@ -180,18 +184,17 @@ class BaseDDQN(BaseDQN): return values + (advantages - advantages.mean()) -class QTRANtestNet(nn.Module): - def __init__(self, backbone_dims=[3*5*5, 64, 64], q_head=[64, 9]): - super(QTRANtestNet, self).__init__() - self.backbone = mlp_maker(backbone_dims, flatten=True, activation_last='elu') - self.q_head = mlp_maker(q_head) +class BaseICM(nn.Module): + def __init__(self, backbone_dims=[2*3*5*5, 64, 64], head_dims=[2*64, 64, 9]): + super(BaseICM, self).__init__() + self.backbone = mlp_maker(backbone_dims, flatten=True) + self.icm = mlp_maker(head_dims) + self.ce = nn.CrossEntropyLoss() - def forward(self, x): - features = self.backbone(x) - qs = self.q_head(features) - return qs, features - - @torch.no_grad() - def act(self, x) -> np.ndarray: - action = self.forward(x)[0].max(-1)[1].numpy() - return action \ No newline at end of file + def forward(self, s0, s1, a): + phi_s0 = self.backbone(s0) + phi_s1 = self.backbone(s1) + cat = torch.cat((phi_s0, phi_s1), dim=1) + a_prime = torch.softmax(self.icm(cat), dim=-1) + ce = self.ce(a_prime, a) + return dict(prediction=a_prime, loss=ce) \ No newline at end of file diff --git a/algorithms/m_q_learner.py b/algorithms/m_q_learner.py index ded972e..e4f85eb 100644 --- a/algorithms/m_q_learner.py +++ b/algorithms/m_q_learner.py @@ -50,4 +50,22 @@ class MQLearner(QLearner): # Compute loss loss = torch.mean(self.reg_weight * pred_q + torch.pow(pred_q - m_q_target, 2)) - self._backprop_loss(loss) \ No newline at end of file + self._backprop_loss(loss) + +from tqdm import trange +class MQICMLearner(MQLearner): + def __init__(self, *args, icm, **kwargs): + super(MQICMLearner, self).__init__(*args, **kwargs) + self.icm = icm + self.icm_optimizer = torch.optim.Adam(self.icm.parameters()) + + def on_all_done(self): + for b in trange(50000): + batch = self.buffer.sample(128, 0) + s0, s1, a = batch.observation, batch.next_observation, batch.action + loss = self.icm(s0, s1, a.squeeze())['loss'] + self.icm_optimizer.zero_grad() + loss.backward() + self.icm_optimizer.step() + if b%100 == 0: + print(loss.item()) diff --git a/algorithms/q_learner.py b/algorithms/q_learner.py index 0cd04f0..93ea949 100644 --- a/algorithms/q_learner.py +++ b/algorithms/q_learner.py @@ -100,32 +100,28 @@ class QLearner(BaseLearner): if __name__ == '__main__': from environments.factory.simple_factory import SimpleFactory, DirtProperties, MovementProperties - from algorithms.common import BaseDDQN - from algorithms.m_q_learner import MQLearner + from algorithms.common import BaseDDQN, BaseICM + from algorithms.m_q_learner import MQLearner, MQICMLearner from algorithms.vdn_learner import VDNLearner - from algorithms.udr_learner import UDRLearner N_AGENTS = 1 - dirt_props = DirtProperties(clean_amount=1, gain_amount=0.1, max_global_amount=20, - max_local_amount=1, spawn_frequency=5, max_spawn_ratio=0.05, - dirt_smear_amount=0.0) - move_props = MovementProperties(allow_diagonal_movement=True, - allow_square_movement=True, - allow_no_op=False) - - env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=2, max_steps=400, parse_doors=False, - movement_properties=move_props, level_name='rooms', frames_to_stack=0, - omit_agent_slice_in_obs=True, combin_agent_slices_in_obs=True, record_episodes=False - ) + with (Path(f'../environments/factory/env_default_param.yaml')).open('r') as f: + env_kwargs = yaml.load(f, Loader=yaml.FullLoader) + env = SimpleFactory(**env_kwargs) obs_shape = np.prod(env.observation_space.shape) n_actions = env.action_space.n - dqn, target_dqn = BaseDDQN(backbone_dims=[obs_shape, 128, 128], advantage_dims=[128, n_actions], value_dims=[128,1], activation='leaky_relu'),\ - BaseDDQN(backbone_dims=[obs_shape, 128, 128], advantage_dims=[128, n_actions], value_dims=[128,1], activation='leaky_relu') + dqn, target_dqn = BaseDDQN(backbone_dims=[obs_shape, 128, 128], advantage_dims=[128, n_actions], value_dims=[128, 1], activation='leaky_relu'),\ + BaseDDQN(backbone_dims=[obs_shape, 128, 128], advantage_dims=[128, n_actions], value_dims=[128, 1], activation='leaky_relu') - learner = MQLearner(dqn, target_dqn, env, 50000, target_update=5000, lr=0.0007, gamma=0.99, n_agents=N_AGENTS, tau=0.95, max_grad_norm=10, - train_every=('step', 4), eps_end=0.025, n_grad_steps=1, reg_weight=0.1, exploration_fraction=0.25, batch_size=64, weight_decay=1e-3) + icm = BaseICM(backbone_dims=[obs_shape, 64, 32], head_dims=[2*32, 64, n_actions]) + + learner = MQICMLearner(dqn, target_dqn, env, 50000, icm=icm, + target_update=5000, lr=0.0007, gamma=0.99, n_agents=N_AGENTS, tau=0.95, max_grad_norm=10, + train_every=('step', 4), eps_end=0.025, n_grad_steps=1, reg_weight=0.1, exploration_fraction=0.25, + batch_size=64, weight_decay=1e-3 + ) #learner.save(Path(__file__).parent / 'test' / 'testexperiment1337') learner.learn(100000) diff --git a/environments/factory/renderer.py b/environments/factory/renderer.py index 48868b4..bf25f77 100644 --- a/environments/factory/renderer.py +++ b/environments/factory/renderer.py @@ -15,7 +15,7 @@ class Entity(NamedTuple): value_operation: str = 'none' state: str = None id: int = 0 - aux:Any=None + aux: Any = None class Renderer: diff --git a/studies/__init__.py b/studies/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/studies/sat_mad.py b/studies/sat_mad.py new file mode 100644 index 0000000..46754a9 --- /dev/null +++ b/studies/sat_mad.py @@ -0,0 +1,9 @@ +import numpy as np + + +class SatMad(object): + def __init__(self): + pass + +if __name__ == '__main__': + pass \ No newline at end of file