【问题标题】:How to convert all layers of a pretrained Keras model to a different dtype (from float32 to float16)?如何将预训练的 Keras 模型的所有层转换为不同的 dtype(从 float32 到 float16)?
【发布时间】:2018-06-02 09:23:47
【问题描述】:

我正在尝试将我的 (float32) 模型的精度更改为 float16,以查看它对性能的影响有多大。 加载模型(base_model)后,我尝试了这个:

from keras import backend as K
K.set_floatx('float16')
weights_list = base_model.layers[1].get_weights()
print('Original:')
print(weights_list[0].dtype)
new_weights = [K.cast_to_floatx(weights_list[0])]
print('New Weights:')
print(new_weights[0].dtype)
print('Setting New Weights')
base_model.layers[1].set_weights(new_weights)
new_weights_list = base_model.layers[1].get_weights()
print(new_weights_list[0].dtype)

输出:

Original:
float32
New Weights:
float16
Setting New Weights
float32

使用此代码,将一层内的权重转换为float16,并将模型中的权重设置为新的权重,但使用get_weights后,数据类型又变回了float32。有没有办法设置图层的dtype?据我所知,K.cast_to_floatx 用于 numpy 数组,而 K.cast 用于张量。我是否需要使用新的 dtype 构建全新的空模型并将重铸的权重放入新模型中?

或者是否有一些更直接的方法来加载所有图层都具有 dtype 'float32' 的模型,并将所有图层转换为具有 dtype 'float16'?这是一个融入 mlmodel 的功能,所以我认为在 Keras 中它不会特别困难。

【问题讨论】:

  • 你做到了吗?
  • """After"" loading a model",这很可能是原因。 K.set_floatx() ""before"" 加载模型怎么样?
  • 你能重新创建模型吗?仔细检查base_model.summary()应该是可能的。
  • 我的印象是K.set_floatx() 在您从头开始构建的新模型上设置了浮点数,但是加载的模型使用与保存时相同的 dtype 加载。我没有深入了解制作一个新的空模型并用重铸的重量填充它是否可行。

标签: python-3.x numpy keras


【解决方案1】:

有同样的问题,并得到这个工作。什么对我有用:

  • 保存到文件并重新加载
  • 铸造所有权重并重新分配给原始模型

以下是对我有用的

  • 创建具有相同架构的新模型并手动设置其权重

MWE:

>>> from keras import backend as K
>>> from keras.models import Sequential
>>> from keras.layers import Dense, Dropout, Activation
>>> import numpy as np
>>> 
>>> def make_model():
...     model = Sequential()
...     model.add(Dense(64, activation='relu', input_dim=20))
...     model.add(Dropout(0.5))
...     model.add(Dense(64, activation='relu'))
...     model.add(Dropout(0.5))
...     model.add(Dense(10, activation='softmax'))
...     return model
... 
>>> K.set_floatx('float64')
>>> model = make_model()
>>> 
>>> K.set_floatx('float32')
>>> ws = model.get_weights()
>>> wsp = [w.astype(K.floatx()) for w in ws]
>>> model_quant = make_model()
>>> model_quant.set_weights(wsp)
>>> xp = x.astype(K.floatx())
>>> 
>>> print(np.unique([w.dtype for w in model.get_weights()]))
[dtype('float64')]
>>> print(np.unique([w.dtype for w in model_quant.get_weights()]))
[dtype('float32')]

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 2017-11-11
    • 2020-09-12
    • 1970-01-01
    • 2018-12-21
    • 2021-05-20
    • 2019-03-08
    • 2019-11-28
    • 1970-01-01
    相关资源
    最近更新 更多