【问题标题】:Using tensorflow.keras model in pyspark UDF generates a pickle error在 pyspark UDF 中使用 tensorflow.keras 模型会产生 pickle 错误
【发布时间】:2020-07-20 14:55:30
【问题描述】:

我想在 pysark pandas_udf 中使用 tensorflow.keras 模型。但是,在将模型发送给工作人员之前对其进行序列化时,我会收到一个 pickle 错误。我不确定我是否使用了最好的方法来执行我想要的,因此我将展示一个最小但完整的示例。

包:

  • tensorflow-2.2.0(但所有以前的版本也会触发错误)
  • pyspark-2.4.5

导入语句是:

import pandas as pd
import numpy as np

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense

from pyspark.sql import SparkSession, functions as F, types as T

Pyspark UDF 是 pandas_udf:

def compute_output_pandas_udf(model):
    '''Spark pandas udf for model prediction.'''

    @F.pandas_udf(T.DoubleType(), F.PandasUDFType.SCALAR)
    def compute_output(inputs1, inputs2, inputs3):
        pdf = pd.DataFrame({
            'input1': inputs1,
            'input2': inputs2,
            'input3': inputs3
        })
        pdf['predicted_output'] = model.predict(pdf.values)
        return pdf['predicted_output']

    return compute_output

主要代码:

# Model parameters
weights = np.array([[0.5], [0.4], [0.3]])
bias = np.array([1.25])
activation = 'linear'
input_dim, output_dim = weights.shape

# Initialize model
model = Sequential()
layer = Dense(output_dim, input_dim=input_dim, activation=activation)
model.add(layer)
layer.set_weights([weights, bias])

# Initialize Spark session
spark = SparkSession.builder.appName('test').getOrCreate()

# Create pandas df with inputs and run model
pdf = pd.DataFrame({
    'input1': np.random.randn(200),
    'input2': np.random.randn(200),
    'input3': np.random.randn(200)
})
pdf['predicted_output'] = model.predict(pdf[['input1', 'input2', 'input3']].values)

# Create spark df with inputs and run model using udf
sdf = spark.createDataFrame(pdf)
sdf = sdf.withColumn('predicted_output', compute_output_pandas_udf(model)('input1', 'input2', 'input3'))
sdf.limit(5).show()

调用compute_output_pandas_udf(model)时会触发此错误:

PicklingError: Could not serialize object: TypeError: can't pickle _thread.RLock objects

我发现这个page 关于酸洗 keras 模型并在 tensorflow.keras 上进行了尝试,但是当在 UDF 中调用模型的 predict 函数时出现以下错误(因此序列化有效但反序列化不是?):

AttributeError: 'Sequential' object has no attribute '_distribution_strategy'

有人知道如何进行吗?提前谢谢!

PS:请注意,我没有直接使用 keras 库中的模型,因为我会定期出现另一个错误,而且似乎更难以解决。但是,模型的序列化不会像 tensorflow.keras 模型那样产生错误。

【问题讨论】:

    标签: apache-spark tensorflow keras pyspark user-defined-functions


    【解决方案1】:

    最简单的解决方案是broadcast 模型的权重并在pandas_udf 中加载权重。这是一个演示示例:

    import pandas as pd
    import numpy as np
    from tensorflow.keras.layers import Input, Dense
    from tensorflow.keras.models import Model
    
    # spark = SparkSession.builder.xxx.getOrCreate()
    # sc = spark.sparkContext
    
    def build_model():
        inputs = Input(shape=(3,), name='inputs')
        d1 = Dense(20, name='dense_01')(inputs)
        d2 = Dense(50, name='dense_02')(d1)
        o = Dense(1, activation='sigmoid', name='output')(d2)
        net = Model(inputs=inputs, outputs=o)
    
        return net
    
    net = build_model()
    ws = net.get_weights()
    bc_model_state = sc.broadcast(ws)
    
    @pandas_udf(FloatType())
    def batch_predict(data):  # input: pd.Series; output: pd.Series
        mdl = build_model()
        mdl.set_weights(bc_model_state.value)
    
        prediction = mdl.predict(data.values)
        return pd.Series(prediction[:, 0])
    

    该方案不仅适用于tensorflow keras模型,也适用于pytorch模型。检查this

    【讨论】:

      【解决方案2】:

      如果我们使用解决方案直接在 tensorflow.keras.models.Model 类中扩展 getstatesetstate 方法,就像http://zachmoshe.com/2017/04/03/pickling-keras-models.html ,然后工作人员无法反序列化模型,因为他们没有该类的扩展。

      然后,解决方案是使用 Erp12 在此 post 中建议的包装类。

      class ModelWrapperPickable:
      
          def __init__(self, model):
              self.model = model
      
          def __getstate__(self):
              model_str = ''
              with tempfile.NamedTemporaryFile(suffix='.hdf5', delete=True) as fd:
                  tensorflow.keras.models.save_model(self.model, fd.name, overwrite=True)
                  model_str = fd.read()
              d = { 'model_str': model_str }
              return d
      
          def __setstate__(self, state):
              with tempfile.NamedTemporaryFile(suffix='.hdf5', delete=True) as fd:
                  fd.write(state['model_str'])
                  fd.flush()
                  self.model = tensorflow.keras.models.load_model(fd.name)
      

      UDF 变为:

      def compute_output_pandas_udf(model_wrapper):
          '''Spark pandas udf for model prediction.'''
      
          @F.pandas_udf(T.DoubleType(), F.PandasUDFType.SCALAR)
          def compute_output(inputs1, inputs2, inputs3):
              pdf = pd.DataFrame({
                  'input1': inputs1,
                  'input2': inputs2,
                  'input3': inputs3
              })
              pdf['predicted_output'] = model_wrapper.model.predict(pdf.values)
              return pdf['predicted_output']
      
          return compute_output
      

      以及主要代码:

      # Model parameters
      weights = np.array([[0.5], [0.4], [0.3]])
      bias = np.array([1.25])
      activation = 'linear'
      input_dim, output_dim = weights.shape
      
      # Initialize keras model
      model = Sequential()
      layer = Dense(output_dim, input_dim=input_dim, activation=activation)
      model.add(layer)
      layer.set_weights([weights, bias])
      # Initialize model wrapper
      model_wrapper= ModelWrapperPickable(model)
      
      # Initialize Spark session
      spark = SparkSession.builder.appName('test').getOrCreate()
      
      # Create pandas df with inputs and run model
      pdf = pd.DataFrame({
          'input1': np.random.randn(200),
          'input2': np.random.randn(200),
          'input3': np.random.randn(200)
      })
      pdf['predicted_output'] = model_wrapper.model.predict(pdf[['input1', 'input2', 'input3']].values)
      
      # Create spark df with inputs and run model using udf
      sdf = spark.createDataFrame(pdf)
      sdf = sdf.withColumn('predicted_output', compute_output_pandas_udf(model_wrapper)('input1', 'input2', 'input3'))
      sdf.limit(5).show()
      

      【讨论】:

        猜你喜欢
        • 2015-07-19
        • 2014-04-11
        • 1970-01-01
        • 1970-01-01
        • 2021-12-24
        • 1970-01-01
        • 2021-12-22
        • 1970-01-01
        • 1970-01-01
        相关资源
        最近更新 更多