【发布时间】:2020-08-06 05:27:57
【问题描述】:
我没有得到Dataset 张量,而是得到了这个,我不知道如何处理:
Tensor("StatefulPartitionedCall:0", shape=(), dtype=float64)
一切都很顺利(我认为),但这是我尝试打印损失时得到的结果。这是我正在玩的代码:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
from sklearn.datasets import load_breast_cancer
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense
tf.keras.backend.set_floatx('float64')
x, y = load_breast_cancer(return_X_y=True)
data = tf.data.Dataset.from_tensors((x, y)).shuffle(len(x))
train_data = data.take(int(8e-1*len(x))).batch(32)
test_data = data.skip(int(8e-1*len(x)))
class DenseNet(Model):
def __init__(self):
super(DenseNet, self).__init__()
self.D1 = Dense(8, activation=tf.keras.activations.selu)
self.D2 = Dense(16, activation=tf.keras.activations.elu)
self.D3 = Dense(32, activation=tf.keras.activations.relu)
self.D4 = Dense(1)
def __call__(self, x):
x = self.D1(x)
x = self.D2(x)
x = self.D3(x)
out = self.D4(x)
return out
network = DenseNet()
optimizer = tf.keras.optimizers.Adam()
@tf.function
def compute_loss(labels, logits):
labels = tf.cast(tf.one_hot(labels, depth=1), tf.float64)
return tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits))
@tf.function
def compute_accuracy(labels, logits):
labels = tf.cast(tf.one_hot(labels, depth=2), tf.float64)
return tf.reduce_mean(tf.cast(tf.equal(logits, labels), tf.float32))
@tf.function
def train_step(inputs, targets):
with tf.GradientTape() as tape:
logits = network(inputs)
loss = compute_loss(labels=targets, logits=logits)
gradients = tape.gradient(loss, network.trainable_variables)
optimizer.apply_gradients(zip(gradients, network.trainable_variables))
accuracy = compute_accuracy(labels=targets, logits=logits)
return loss, accuracy
@tf.function
def train():
for inputs, labels in train_data:
loss, acc = train_step(inputs, labels)
print(loss, acc)
def main(epochs=5):
for i in range(1, epochs + 1):
train()
if __name__ == '__main__':
main(epochs=10)
【问题讨论】:
标签: python tensorflow keras tensorflow2.0 tf.keras