实际上,在这两种情况下,您都在尝试对浮点值进行矩阵乘法。在第一种情况下,您使用的是 float16,在第二种情况下,您使用的是 float32。
import tensorflow as tf
import time
a = tf.random.uniform(shape=(9180, 3049), seed = 10)
b = tf.random.uniform(shape=(3049, 1913), seed = 10)
第一次运行
x2 = a
y2 = b
s = time.time()
r2 = tf.matmul(x2,y2)
e = time.time()
print((e-s)*1000)
x1 = tf.cast(a ,tf.float16)
y1 = tf.cast(b ,tf.float16)
s = time.time()
r1 = tf.matmul(x1,y1)
e = time.time()
print((e-s)*1000)
输出:
184.76319313049316
0.0
重启内核后第二次运行。
x1 = tf.cast(a ,tf.float16)
y1 = tf.cast(b ,tf.float16)
s = time.time()
r1 = tf.matmul(x1,y1)
e = time.time()
print((e-s)*1000)
x2 = a
y2 = b
s = time.time()
r2 = tf.matmul(x2,y2)
e = time.time()
print((e-s)*1000)
输出:
183.03942680358887
1.0335445404052734
现在如果我再次运行相同的代码而不重新启动内核,即使更改了 a 和 b 的值。
x1 = tf.cast(a ,tf.float16)
y1 = tf.cast(b ,tf.float16)
s = time.time()
r1 = tf.matmul(x1,y1)
e = time.time()
print((e-s)*1000)
x2 = a
y2 = b
s = time.time()
r2 = tf.matmul(x2,y2)
e = time.time()
print((e-s)*1000)
输出:
0.0
0.0
所以本质上这不是 TensorFlow 的问题。 TensorFlow 以图的形式执行。当您第一次运行它时,它会使用上述数据结构初始化图形并对其进行优化以供进一步计算。看看this中的最终评论。
因此,您的第二次执行操作会更快