【发布时间】:2019-09-26 01:08:25
【问题描述】:
鉴于以下两个示例,在签名tf.data.Dataset 时是否有性能改进?
数据集不在 tf.function 中
import tensorflow as tf
class MyModel(tf.keras.Model):
def call(self, inputs):
return tf.ones([1, 1]) * inputs
model = MyModel()
model2 = MyModel()
@tf.function
def train_step(data):
output = model(data)
output = model2(output)
return output
dataset = tf.data.Dataset.from_tensors(tf.ones([1, 1]))
for data in dataset:
train_step(data)
tf.function 中的数据集
import tensorflow as tf
class MyModel(tf.keras.Model):
def call(self, inputs):
return tf.ones([1, 1]) * inputs
model = MyModel()
model2 = MyModel()
@tf.function
def train():
dataset = tf.data.Dataset.from_tensors(tf.ones([1, 1]))
def train_step(data):
output = model(data)
output = model2(output)
return output
for data in dataset:
train_step(data)
train()
【问题讨论】:
标签: python tensorflow tensorflow2.0