mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-07-04 00:21:36 +02:00
first icm steps
This commit is contained in:
@ -100,32 +100,28 @@ class QLearner(BaseLearner):
|
||||
|
||||
if __name__ == '__main__':
|
||||
from environments.factory.simple_factory import SimpleFactory, DirtProperties, MovementProperties
|
||||
from algorithms.common import BaseDDQN
|
||||
from algorithms.m_q_learner import MQLearner
|
||||
from algorithms.common import BaseDDQN, BaseICM
|
||||
from algorithms.m_q_learner import MQLearner, MQICMLearner
|
||||
from algorithms.vdn_learner import VDNLearner
|
||||
from algorithms.udr_learner import UDRLearner
|
||||
|
||||
N_AGENTS = 1
|
||||
|
||||
dirt_props = DirtProperties(clean_amount=1, gain_amount=0.1, max_global_amount=20,
|
||||
max_local_amount=1, spawn_frequency=5, max_spawn_ratio=0.05,
|
||||
dirt_smear_amount=0.0)
|
||||
move_props = MovementProperties(allow_diagonal_movement=True,
|
||||
allow_square_movement=True,
|
||||
allow_no_op=False)
|
||||
|
||||
env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=2, max_steps=400, parse_doors=False,
|
||||
movement_properties=move_props, level_name='rooms', frames_to_stack=0,
|
||||
omit_agent_slice_in_obs=True, combin_agent_slices_in_obs=True, record_episodes=False
|
||||
)
|
||||
with (Path(f'../environments/factory/env_default_param.yaml')).open('r') as f:
|
||||
env_kwargs = yaml.load(f, Loader=yaml.FullLoader)
|
||||
|
||||
env = SimpleFactory(**env_kwargs)
|
||||
obs_shape = np.prod(env.observation_space.shape)
|
||||
n_actions = env.action_space.n
|
||||
|
||||
dqn, target_dqn = BaseDDQN(backbone_dims=[obs_shape, 128, 128], advantage_dims=[128, n_actions], value_dims=[128,1], activation='leaky_relu'),\
|
||||
BaseDDQN(backbone_dims=[obs_shape, 128, 128], advantage_dims=[128, n_actions], value_dims=[128,1], activation='leaky_relu')
|
||||
dqn, target_dqn = BaseDDQN(backbone_dims=[obs_shape, 128, 128], advantage_dims=[128, n_actions], value_dims=[128, 1], activation='leaky_relu'),\
|
||||
BaseDDQN(backbone_dims=[obs_shape, 128, 128], advantage_dims=[128, n_actions], value_dims=[128, 1], activation='leaky_relu')
|
||||
|
||||
learner = MQLearner(dqn, target_dqn, env, 50000, target_update=5000, lr=0.0007, gamma=0.99, n_agents=N_AGENTS, tau=0.95, max_grad_norm=10,
|
||||
train_every=('step', 4), eps_end=0.025, n_grad_steps=1, reg_weight=0.1, exploration_fraction=0.25, batch_size=64, weight_decay=1e-3)
|
||||
icm = BaseICM(backbone_dims=[obs_shape, 64, 32], head_dims=[2*32, 64, n_actions])
|
||||
|
||||
learner = MQICMLearner(dqn, target_dqn, env, 50000, icm=icm,
|
||||
target_update=5000, lr=0.0007, gamma=0.99, n_agents=N_AGENTS, tau=0.95, max_grad_norm=10,
|
||||
train_every=('step', 4), eps_end=0.025, n_grad_steps=1, reg_weight=0.1, exploration_fraction=0.25,
|
||||
batch_size=64, weight_decay=1e-3
|
||||
)
|
||||
#learner.save(Path(__file__).parent / 'test' / 'testexperiment1337')
|
||||
learner.learn(100000)
|
||||
|
Reference in New Issue
Block a user