首先打开互联网支持并从 github 克隆 google-research repo:
!git clone https://github.com/google-research/google-research.git
那么我们需要 g++ 的编译和链接选项,所以运行以下代码 sn-ps:
import tensorflow as tf;
print(" ".join(tf.sysconfig.get_compile_flags()))
和
import tensorflow as tf;
print(" ".join(tf.sysconfig.get_link_flags()))
对于我的笔记本,我得到了以下标志:
-I/opt/conda/lib/python3.7/site-packages/tensorflow/include -D_GLIBCXX_USE_CXX11_ABI=0
-L/opt/conda/lib/python3.7/site-packages/tensorflow -l:libtensorflow_framework.so.2
之后,只需将变量 ${TF_CFLAGS[@]} 和 ${TF_LFLAGS[@]} 替换为上述输出
!g++ -std=c++11 -shared google-research/tf_trees/neural_trees_ops.cc google-research/tf_trees/neural_trees_kernels.cc google-research/tf_trees/neural_trees_helpers.cc -o google-research/tf_trees/neural_trees_ops.so -fPIC -I/opt/conda/lib/python3.7/site-packages/tensorflow/include -D_GLIBCXX_USE_CXX11_ABI=0 -L/opt/conda/lib/python3.7/site-packages/tensorflow -l:libtensorflow_framework.so.2 -O2
最后我们需要添加系统路径
import sys
sys.path.insert(1, '/kaggle/working/google-research')
然后运行你的 sn-p
from tensorflow import keras
from tf_trees import TEL
tree_layer = TEL(output_logits_dim=2, trees_num=10, depth=3)
model = keras.Sequential()
model.add(keras.layers.BatchNormalization())
model.add(tree_layer)