mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-23 07:16:44 +02:00
first icm steps
This commit is contained in:
parent
aebe3b2f60
commit
ebf49cadea
@ -24,7 +24,7 @@ class BaseLearner:
|
|||||||
self.n_agents = n_agents
|
self.n_agents = n_agents
|
||||||
self.n_grad_steps = n_grad_steps
|
self.n_grad_steps = n_grad_steps
|
||||||
self.train_every = train_every
|
self.train_every = train_every
|
||||||
self.stack_n_frames = deque(stack_n_frames)
|
self.stack_n_frames = deque(maxlen=stack_n_frames)
|
||||||
self.device = 'cpu'
|
self.device = 'cpu'
|
||||||
self.n_updates = 0
|
self.n_updates = 0
|
||||||
self.step = 0
|
self.step = 0
|
||||||
@ -51,6 +51,9 @@ class BaseLearner:
|
|||||||
def on_episode_end(self, n_steps):
|
def on_episode_end(self, n_steps):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def on_all_done(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def train(self):
|
def train(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -93,6 +96,7 @@ class BaseLearner:
|
|||||||
f' eps: {self.eps:.4f}\tRunning loss: {sum(list(self.running_loss)) / len(self.running_loss):.4f}\tUpdates:{self.n_updates}')
|
f' eps: {self.eps:.4f}\tRunning loss: {sum(list(self.running_loss)) / len(self.running_loss):.4f}\tUpdates:{self.n_updates}')
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
self.on_all_done()
|
||||||
|
|
||||||
|
|
||||||
class BaseBuffer:
|
class BaseBuffer:
|
||||||
@ -180,18 +184,17 @@ class BaseDDQN(BaseDQN):
|
|||||||
return values + (advantages - advantages.mean())
|
return values + (advantages - advantages.mean())
|
||||||
|
|
||||||
|
|
||||||
class QTRANtestNet(nn.Module):
|
class BaseICM(nn.Module):
|
||||||
def __init__(self, backbone_dims=[3*5*5, 64, 64], q_head=[64, 9]):
|
def __init__(self, backbone_dims=[2*3*5*5, 64, 64], head_dims=[2*64, 64, 9]):
|
||||||
super(QTRANtestNet, self).__init__()
|
super(BaseICM, self).__init__()
|
||||||
self.backbone = mlp_maker(backbone_dims, flatten=True, activation_last='elu')
|
self.backbone = mlp_maker(backbone_dims, flatten=True)
|
||||||
self.q_head = mlp_maker(q_head)
|
self.icm = mlp_maker(head_dims)
|
||||||
|
self.ce = nn.CrossEntropyLoss()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, s0, s1, a):
|
||||||
features = self.backbone(x)
|
phi_s0 = self.backbone(s0)
|
||||||
qs = self.q_head(features)
|
phi_s1 = self.backbone(s1)
|
||||||
return qs, features
|
cat = torch.cat((phi_s0, phi_s1), dim=1)
|
||||||
|
a_prime = torch.softmax(self.icm(cat), dim=-1)
|
||||||
@torch.no_grad()
|
ce = self.ce(a_prime, a)
|
||||||
def act(self, x) -> np.ndarray:
|
return dict(prediction=a_prime, loss=ce)
|
||||||
action = self.forward(x)[0].max(-1)[1].numpy()
|
|
||||||
return action
|
|
@ -50,4 +50,22 @@ class MQLearner(QLearner):
|
|||||||
|
|
||||||
# Compute loss
|
# Compute loss
|
||||||
loss = torch.mean(self.reg_weight * pred_q + torch.pow(pred_q - m_q_target, 2))
|
loss = torch.mean(self.reg_weight * pred_q + torch.pow(pred_q - m_q_target, 2))
|
||||||
self._backprop_loss(loss)
|
self._backprop_loss(loss)
|
||||||
|
|
||||||
|
from tqdm import trange
|
||||||
|
class MQICMLearner(MQLearner):
|
||||||
|
def __init__(self, *args, icm, **kwargs):
|
||||||
|
super(MQICMLearner, self).__init__(*args, **kwargs)
|
||||||
|
self.icm = icm
|
||||||
|
self.icm_optimizer = torch.optim.Adam(self.icm.parameters())
|
||||||
|
|
||||||
|
def on_all_done(self):
|
||||||
|
for b in trange(50000):
|
||||||
|
batch = self.buffer.sample(128, 0)
|
||||||
|
s0, s1, a = batch.observation, batch.next_observation, batch.action
|
||||||
|
loss = self.icm(s0, s1, a.squeeze())['loss']
|
||||||
|
self.icm_optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
self.icm_optimizer.step()
|
||||||
|
if b%100 == 0:
|
||||||
|
print(loss.item())
|
||||||
|
@ -100,32 +100,28 @@ class QLearner(BaseLearner):
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
from environments.factory.simple_factory import SimpleFactory, DirtProperties, MovementProperties
|
from environments.factory.simple_factory import SimpleFactory, DirtProperties, MovementProperties
|
||||||
from algorithms.common import BaseDDQN
|
from algorithms.common import BaseDDQN, BaseICM
|
||||||
from algorithms.m_q_learner import MQLearner
|
from algorithms.m_q_learner import MQLearner, MQICMLearner
|
||||||
from algorithms.vdn_learner import VDNLearner
|
from algorithms.vdn_learner import VDNLearner
|
||||||
from algorithms.udr_learner import UDRLearner
|
|
||||||
|
|
||||||
N_AGENTS = 1
|
N_AGENTS = 1
|
||||||
|
|
||||||
dirt_props = DirtProperties(clean_amount=1, gain_amount=0.1, max_global_amount=20,
|
with (Path(f'../environments/factory/env_default_param.yaml')).open('r') as f:
|
||||||
max_local_amount=1, spawn_frequency=5, max_spawn_ratio=0.05,
|
env_kwargs = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
dirt_smear_amount=0.0)
|
|
||||||
move_props = MovementProperties(allow_diagonal_movement=True,
|
|
||||||
allow_square_movement=True,
|
|
||||||
allow_no_op=False)
|
|
||||||
|
|
||||||
env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=2, max_steps=400, parse_doors=False,
|
|
||||||
movement_properties=move_props, level_name='rooms', frames_to_stack=0,
|
|
||||||
omit_agent_slice_in_obs=True, combin_agent_slices_in_obs=True, record_episodes=False
|
|
||||||
)
|
|
||||||
|
|
||||||
|
env = SimpleFactory(**env_kwargs)
|
||||||
obs_shape = np.prod(env.observation_space.shape)
|
obs_shape = np.prod(env.observation_space.shape)
|
||||||
n_actions = env.action_space.n
|
n_actions = env.action_space.n
|
||||||
|
|
||||||
dqn, target_dqn = BaseDDQN(backbone_dims=[obs_shape, 128, 128], advantage_dims=[128, n_actions], value_dims=[128,1], activation='leaky_relu'),\
|
dqn, target_dqn = BaseDDQN(backbone_dims=[obs_shape, 128, 128], advantage_dims=[128, n_actions], value_dims=[128, 1], activation='leaky_relu'),\
|
||||||
BaseDDQN(backbone_dims=[obs_shape, 128, 128], advantage_dims=[128, n_actions], value_dims=[128,1], activation='leaky_relu')
|
BaseDDQN(backbone_dims=[obs_shape, 128, 128], advantage_dims=[128, n_actions], value_dims=[128, 1], activation='leaky_relu')
|
||||||
|
|
||||||
learner = MQLearner(dqn, target_dqn, env, 50000, target_update=5000, lr=0.0007, gamma=0.99, n_agents=N_AGENTS, tau=0.95, max_grad_norm=10,
|
icm = BaseICM(backbone_dims=[obs_shape, 64, 32], head_dims=[2*32, 64, n_actions])
|
||||||
train_every=('step', 4), eps_end=0.025, n_grad_steps=1, reg_weight=0.1, exploration_fraction=0.25, batch_size=64, weight_decay=1e-3)
|
|
||||||
|
learner = MQICMLearner(dqn, target_dqn, env, 50000, icm=icm,
|
||||||
|
target_update=5000, lr=0.0007, gamma=0.99, n_agents=N_AGENTS, tau=0.95, max_grad_norm=10,
|
||||||
|
train_every=('step', 4), eps_end=0.025, n_grad_steps=1, reg_weight=0.1, exploration_fraction=0.25,
|
||||||
|
batch_size=64, weight_decay=1e-3
|
||||||
|
)
|
||||||
#learner.save(Path(__file__).parent / 'test' / 'testexperiment1337')
|
#learner.save(Path(__file__).parent / 'test' / 'testexperiment1337')
|
||||||
learner.learn(100000)
|
learner.learn(100000)
|
||||||
|
@ -15,7 +15,7 @@ class Entity(NamedTuple):
|
|||||||
value_operation: str = 'none'
|
value_operation: str = 'none'
|
||||||
state: str = None
|
state: str = None
|
||||||
id: int = 0
|
id: int = 0
|
||||||
aux:Any=None
|
aux: Any = None
|
||||||
|
|
||||||
|
|
||||||
class Renderer:
|
class Renderer:
|
||||||
|
0
studies/__init__.py
Normal file
0
studies/__init__.py
Normal file
9
studies/sat_mad.py
Normal file
9
studies/sat_mad.py
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class SatMad(object):
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
pass
|
Loading…
x
Reference in New Issue
Block a user