【发布时间】:2019-10-13 00:15:12
【问题描述】:
我正在尝试创建一个新节点并设置其属性。
例如打印一个图节点我看到它的属性是:
attr {
key: "T"
value {
type: DT_FLOAT
}
}
我可以创建一个像这样的节点:
node = tf.NodeDef(name='MyConstTensor', op='Const',
attr={'value': tf.AttrValue(tensor=tensor_proto),
'dtype': tf.AttrValue(type=dt)})
但是如何添加key: "T" 属性?即在这种情况下,tf.AttrValue 里面应该是什么?
看着attr_value.proto我试过了:
node = tf.NodeDef()
node.name = 'MySub'
node.op = 'Sub'
node.input.extend(['MyConstTensor', 'conv2'])
node.attr["key"].s = 'T' # TypeError: 'T' has type str, but expected one of: bytes
更新:
我发现在 Tensorflow 中应该这样写:
node.attr["T"].type = b'float32'
但这给出了一个错误:
TypeError: b'float32' 具有字节类型,但应为以下之一:int、long
而且我不确定哪个 int 值对应于 float32。
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/attr_value.proto#L23
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/attr_value.proto#L35
【问题讨论】:
-
您可以查看github.com/tensorflow/tensorflow/issues/616 了解如何保存/加载图表。
-
@knh190 有什么关系?
-
@knh190 我知道如何保存图形,问题是关于如何在
tf.NodeDef创建处添加属性。
标签: python tensorflow protocol-buffers