【发布时间】:2020-10-27 08:43:07
【问题描述】:
我正在使用 mnist_vae 示例,但不知道如何正确保存/加载训练模型的权重。
enc_init_rng, dec_init_rng = random.split(random.PRNGKey(2))
_, init_encoder_params = encoder_init(enc_init_rng, (batch_size, 28 * 28))
_, init_decoder_params = decoder_init(dec_init_rng, (batch_size, 10))
init_params = init_encoder_params, init_decoder_params
opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=0.9)
opt_state = opt_init(init_params)
之后,我使用 opt_update 训练模型并希望保存它。但是,我还没有找到将优化器状态保存到磁盘的任何函数。
我尝试保存参数并用它们初始化opt_state,但不是所有信息都保存下来,结果opt_state_1不是原来的opt_state。
weights=get_params(opt_state)
jnp.save(file, weights)
weights = jnp.load(file,allow_pickle=True)
opt_state_1 = opt_init(init_params)
如何正确保存我训练的模型?
【问题讨论】:
标签: python machine-learning jax