【问题标题】:How do I save an optimizer state of JAX trained model?如何保存 JAX 训练模型的优化器状态?
【发布时间】:2020-10-27 08:43:07
【问题描述】:

我正在使用 mnist_vae 示例,但不知道如何正确保存/加载训练模型的权重。

enc_init_rng, dec_init_rng = random.split(random.PRNGKey(2))
_, init_encoder_params = encoder_init(enc_init_rng, (batch_size, 28 * 28))
_, init_decoder_params = decoder_init(dec_init_rng, (batch_size, 10))
init_params = init_encoder_params, init_decoder_params

opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=0.9)
opt_state = opt_init(init_params)

之后,我使用 opt_update 训练模型并希望保存它。但是,我还没有找到将优化器状态保存到磁盘的任何函数。

我尝试保存参数并用它们初始化opt_state,但不是所有信息都保存下来,结果opt_state_1不是原来的opt_state。

weights=get_params(opt_state)  
jnp.save(file, weights)  
weights = jnp.load(file,allow_pickle=True)  
opt_state_1 = opt_init(init_params)

如何正确保存我训练的模型?

【问题讨论】:

    标签: python machine-learning jax


    【解决方案1】:
    import pickle
    from jax.experimental import optimizers
    
    trained_params = optimizers.unpack_optimizer_state(opt_state)
    pickle.dump(trained_params, open(os.path.join(config["ckpt_path"], "best_ckpt.pkl"), "wb"))
    
    best_params = pickle.load(open(os.path.join(config["ckpt_path"], "best_ckpt.pkl"), "rb"))
    best_opt_state = optimizers.pack_optimizer_state(best_params)
    

    【讨论】:

    • 虽然此代码可以解决问题,including an explanation 说明如何以及为什么解决问题将真正有助于提高您的帖子质量,并可能导致更多的赞成票。请记住,您正在为将来的读者回答问题,而不仅仅是现在提出问题的人。请edit您的回答添加解释并说明适用的限制和假设。
    猜你喜欢
    • 2018-09-05
    • 1970-01-01
    • 1970-01-01
    • 2017-04-11
    • 2019-11-17
    • 2016-02-18
    • 2017-01-27
    • 2020-08-06
    相关资源
    最近更新 更多