【发布时间】:2021-04-18 16:43:22
【问题描述】:
我有 tensorflow 1.14,我想计算一些分类指标。
我正在使用tf.keras.metrics,并以以下方式使用它:
tf.keras.metrics.Accuracy(tf.argmax(tf.nn.softmax(support_pred, dim=1), axis=1),
tf.argmax(support_y, axis=1))
这给了我错误:
{TypeError}不允许使用
tf.Tensor作为 Pythonbool。使用if t is not None:而不是if t:来测试是否定义了张量,并使用TensorFlow ops(例如tf.cond)来执行以张量值为条件的子图。
我尝试改用tf.contrib.metrics,但它只有precision_at_recall 和recall_at_precision,而不是独立的精度和召回率。
编辑 1
我尝试了以下方法,但没有成功:
import tensorflow as tf
a = tf.random.uniform((32, 10), 0, 1, dtype=tf.float32)
b = tf.random.uniform((32, 10), 0, 1, dtype=tf.float32)
a_softmax = tf.nn.softmax(a)
b_softmax = tf.nn.softmax(b)
a_argmax = tf.argmax(a_softmax, axis=-1)
b_argmax = tf.argmax(b_softmax, axis=-1)
acc = tf.keras.metrics.Accuracy()(a_argmax, b_argmax)
with tf.Session() as sess:
sess.run([acc])
它给了我以下错误:
Traceback (most recent call last):
File "C:\Users\96171\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\client\session.py", line 1356, in _do_call
return fn(*args)
File "C:\Users\96171\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\client\session.py", line 1341, in _run_fn
options, feed_dict, fetch_list, target_list, run_metadata)
File "C:\Users\96171\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\client\session.py", line 1429, in _call_tf_sessionrun
run_metadata)
tensorflow.python.framework.errors_impl.NotFoundError: Container localhost does not exist. (Could not find resource: localhost/total)
[[{{node AssignAddVariableOp}}]]
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "C:/Users/96171/Desktop/dementia_cleanedup/dementia/maml_finn_copy/try_tf.py", line 15, in <module>
sess.run(acc)
File "C:\Users\96171\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\client\session.py", line 950, in run
run_metadata_ptr)
File "C:\Users\96171\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\client\session.py", line 1173, in _run
feed_dict_tensor, options, run_metadata)
File "C:\Users\96171\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\client\session.py", line 1350, in _do_run
run_metadata)
File "C:\Users\96171\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\client\session.py", line 1370, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.NotFoundError: Container localhost does not exist. (Could not find resource: localhost/total)
[[node AssignAddVariableOp (defined at /Users/96171/Desktop/dementia_cleanedup/dementia/maml_finn_copy/try_tf.py:12) ]]
Original stack trace for 'AssignAddVariableOp':
File "/Users/96171/Desktop/dementia_cleanedup/dementia/maml_finn_copy/try_tf.py", line 12, in <module>
acc = tf.keras.metrics.Accuracy()(a_argmax, b_argmax)
File "\Users\96171\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\keras\metrics.py", line 170, in __call__
update_op = self.update_state(*args, **kwargs) # pylint: disable=not-callable
File "\Users\96171\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\keras\utils\metrics_utils.py", line 73, in decorated
update_op = update_state_fn(*args, **kwargs)
File "\Users\96171\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\keras\metrics.py", line 551, in update_state
matches, sample_weight=sample_weight)
File "\Users\96171\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\keras\metrics.py", line 314, in update_state
update_total_op = self.total.assign_add(value_sum)
File "\Users\96171\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\ops\resource_variable_ops.py", line 1108, in assign_add
name=name)
File "\Users\96171\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\ops\gen_resource_variable_ops.py", line 68, in assign_add_variable_op
"AssignAddVariableOp", resource=resource, value=value, name=name)
File "\Users\96171\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\framework\op_def_library.py", line 788, in _apply_op_helper
op_def=op_def)
File "\Users\96171\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\util\deprecation.py", line 507, in new_func
return func(*args, **kwargs)
File "\Users\96171\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\framework\ops.py", line 3616, in create_op
op_def=op_def)
File "\Users\96171\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\framework\ops.py", line 2005, in __init__
self._traceback = tf_stack.extract_stack()
Process finished with exit code 1
【问题讨论】:
标签: python tensorflow keras