【发布时间】:2021-04-15 05:05:29
【问题描述】:
我在 Docker 20.12 中使用 TensorRT,在 Ubuntu 18 中使用 Tensorflow-gpu 2.4
我有下面的代码
N=1024
target_size = (N, N)
warnings.filterwarnings('ignore')
LR= 1e-4
E, BS = 2,4
def get_unet(img_rows, img_cols):
inputs = Input((img_rows, img_cols, 1))
conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(inputs)
conv1 = BatchNormalization()(conv1)
conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv1)
conv1 = BatchNormalization()(conv1)
pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(pool1)
conv2 = BatchNormalization()(conv2)
conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv2)
conv2 = BatchNormalization()(conv2)
pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(pool2)
conv3 = BatchNormalization()(conv3)
conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv3)
conv3 = BatchNormalization()(conv3)
pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(pool3)
conv4 = BatchNormalization()(conv4)
conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv4)
conv4 = BatchNormalization()(conv4)
pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)
conv5 = Conv2D(512, (3, 3), activation='relu', padding='same')(pool4)
conv5 = BatchNormalization()(conv5)
conv5 = Conv2D(512, (3, 3), activation='relu', padding='same')(conv5)
conv5 = BatchNormalization()(conv5)
conv5 = Dropout(0.5)(conv5)
up6 = concatenate([UpSampling2D(size=(2, 2))(conv5), conv4], axis=3)
conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(up6)
conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv6)
up7 = concatenate([UpSampling2D(size=(2, 2))(conv6), conv3], axis=3)
conv7 = Conv2D(128, (3, 3), activation='relu', padding='same')(up7)
conv7 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv7)
up8 = concatenate([UpSampling2D(size=(2, 2))(conv7), conv2], axis=3)
conv8 = Conv2D(64, (3, 3), activation='relu', padding='same')(up8)
conv8 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv8)
up9 = concatenate([UpSampling2D(size=(2, 2))(conv8), conv1], axis=3)
conv9 = Conv2D(32, (3, 3), activation='relu', padding='same')(up9)
conv9 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv9)
conv10 = Conv2D(1, (1, 1), activation='sigmoid')(conv9)
model = Model(inputs=[inputs], outputs=[conv10])
model.compile(optimizer=Adam(lr=LR),
loss='binary_crossentropy',
metrics=['binary_crossentropy'])
return model
train_images=np.zeros([1, N,N,1], dtype=float)
annot_train=np.zeros([1, N,N,1], dtype=float)
test_images=np.zeros([1, N,N,1], dtype=float)
annot_test=np.zeros([1, N,N,1], dtype=float)
img = cv2.imread('owlResized.bmp', 0)/255.0
label = cv2.imread('owlResized.bmp', 0)/255.0
train_images[0,:,:,0], annot_train[0,:,:,0] =img, label
test_images[0,:,:,0], annot_test[0,:,:,0] =img, label
C = np.concatenate([annot_test, annot_train])
I = np.concatenate([test_images, train_images])
unet = get_unet(N, N)
history = unet.fit(I, C, verbose=2, epochs=E, batch_size=BS, validation_split=0.1)
tf.compat.v1.disable_eager_execution()
with tf.Session() as sess:
y1 = tf.image.resize_bilinear(C, target_size, align_corners=True, half_pixel_centers=False)
_1 = tf.identity(y1, name="output")
y2 = tf.image.resize_bilinear(I, target_size, align_corners=True, half_pixel_centers=False)
_2 = tf.identity(y1, name="output")
init_op = tf.global_variables_initializer()
sess.run(init_op)
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['output'])
with tf.io.gfile.FastGFile('tfmodel.pb', mode='wb') as f:
f.write(constant_graph.SerializeToString())
os.system("python3 -m tf2onnx.convert --opset 11 --input tfmodel.pb --inputs input:0 --outputs output:0 --output tfmodel.onnx")
然后我得到了错误
with tf.io.gfile.FastGFile('tfmodel.pb', mode='wb') as f:
module 'tensorflow._api.v2.io.gfile' has no attribute 'FastGFile'
然后我添加了tensorboard,如下:
import tensorboard as tb
tf.io.gfile = tb.compat.tensorflow_stub.io.gfile
我现在明白了
with tf.io.gfile.FastGFile('tfmodel.pb', mode='wb') as f:
AttributeError: module 'tensorboard.compat.tensorflow_stub.io.gfile' has no attribute 'FastGFile'
谁能帮忙?
【问题讨论】:
标签: tensorflow tensorflow2.0 onnx tensorrt