【发布时间】:2021-11-06 14:43:58
【问题描述】:
我在 sklearn Pipeline 中有一组预处理阶段和一个估计器 KerasClassifier (from tensorflow.keras.wrappers.scikit_learn import KerasClassifier)。
我的总体目标是在 mlflow(在 databricks evn 中)中调整和记录整个 sklearn 管道。我得到一个令人困惑的类型错误,我不知道如何解决:
TypeError: can't pickle _thread.RLock objects
我有以下代码(没有调整阶段)返回上述错误:
conda_env = _mlflow_conda_env(
additional_conda_deps=None,
additional_pip_deps=[
"cloudpickle=={}".format(cloudpickle.__version__),
"scikit-learn=={}".format(sklearn.__version__),
"numpy=={}".format(np.__version__),
"tensorflow=={}".format(tf.__version__),
],
additional_conda_channels=None,
)
search_space = {
"estimator__dense_l1": 20,
"estimator__dense_l2": 20,
"estimator__learning_rate": 0.1,
"estimator__optimizer": "Adam",
}
def create_model(n):
model = Sequential()
model.add(Dense(int(n["estimator__dense_l1"]), activation="relu"))
model.add(Dense(int(n["estimator__dense_l2"]), activation="relu"))
model.add(Dense(1, activation="sigmoid"))
model.compile(
loss="binary_crossentropy",
optimizer=n["estimator__optimizer"],
metrics=["accuracy"],
)
return model
mlflow.sklearn.autolog()
with mlflow.start_run(nested=True) as run:
classfier = KerasClassifier(build_fn=create_model, n=search_space)
# fit the pipeline
clf = Pipeline(steps=[("preprocessor", preprocessor),
("estimator", classfier)])
h = clf.fit(
X_train,
y_train.values,
estimator__validation_split=0.2,
estimator__epochs=10,
estimator__verbose=2,
)
# log scores
acc_score = clf.score(X=X_test, y=y_test)
mlflow.log_metric("accuracy", acc_score)
signature = infer_signature(X_test, clf.predict(X_test))
# Log the model with a signature that defines the schema of the model's inputs and outputs.
mlflow.sklearn.log_model(
sk_model=clf, artifact_path="model",
signature=signature,
conda_env=conda_env
)
我在错误之前也收到了这个警告:
WARNING mlflow.sklearn.utils: Truncated the value of the key `steps`. Truncated value: `[('preprocessor', ColumnTransformer(n_jobs=None, remainder='drop', sparse_threshold=0.3,
transformer_weights=None,
transformers=[('num',
Pipeline(memory=None,
注意整个管道在 mlflow 之外运行。 有人可以帮忙吗?
【问题讨论】:
标签: python tensorflow scikit-learn databricks mlflow