【问题标题】:Transfer Learning with MobileV2Net使用 MobileV2Net 进行迁移学习
【发布时间】:2019-10-08 01:25:14
【问题描述】:

我正在尝试使用来自 https://www.tensorflow.org/tutorials/images/transfer_learning 的 MobileV2Net 实施迁移学习。

上述教程使用 MobileV2Net 模型作为基础模型,并使用类型为 tensorflow.python.data.ops.dataset_ops._OptionsDataset 的“cats_vs_dog”数据集。 就我而言,我想使用 MobileV2Net 作为基础模型,冻结不同 C.N.N 层的所有权重,添加一个全连接层并对其进行微调。我使用的数据集是 tiny_imagenet 。以下是我的代码:

 ##After pre-processing the data : 
(x_train, y_train), (x_valid, y_valid),(x_test, y_test) = data 

#type(x_train) = numpy.ndarray
#len(x_train) = 1750
##Converting the data to use the pipleine that comes with tf.Data.Dataset
raw_train = tf.data.Dataset.from_tensor_slices((x_train,y_train))
raw_validation = tf.data.Dataset.from_tensor_slices((x_valid, y_valid))
raw_test = tf.data.Dataset.from_tensor_slices((x_test, y_test))
#print(raw_train) gives 
<DatasetV1Adapter shapes: ((64, 64, 3), ()), types: (tf.float64, tf.int64)>

## Now i follow everything from the link (given above in problem statement) : 
IMG_SIZE = 160 # All images will be resized to 160x160

def format_example(image, label):
  image = tf.cast(image, tf.float32)
  image = (image/127.5) - 1
  image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
  return image, label

train = raw_train.map(format_example)
validation = raw_validation.map(format_example)
test = raw_test.map(format_example

#print(train) gives
#<DatasetV1Adapter shapes: ((160, 160, 3), ()), types: (tf.float32, tf.int64)>

train_batches = train.shuffle(SHUFFLE_BUFFER_SIZE).batch(BATCH_SIZE)
validation_batches = validation.batch(BATCH_SIZE)
test_batches = test.batch(BATCH_SIZE)

#print(train_batches) gives : 
<DatasetV1Adapter shapes: ((?, 160, 160, 3), (?,)), types: (tf.float32, tf.int64)>
##The corresponding command in the tutorial (which works on cats vs dogs dataset gives) :
<BatchDataset shapes: ((None, 160, 160, 3), (None,)), types: (tf.float32, tf.int64)>

我也尝试使用 padded_batch() 而不是 batch() 但下面仍然进入无限循环。


##Goes to infinite loop
for image_batch, label_batch in train_batches.take(1):
  print("hello")
  pass
image_batch.shape ## Does not reach here 

##The same command in the tutorial gives :
hello
TensorShape([32, 160, 160, 3])

##Further in my case : 
#print(train_batches.take(1)) gives 
<DatasetV1Adapter shapes: ((?, 160, 160, 3), (?,)), types: (tf.float32, tf.int64)>
##In tutorial it gives : 
<TakeDataset shapes: ((None, 160, 160, 3), (None,)), types: (tf.float32, tf.int64)>

image_batch 稍后在代码中使用。

##Load the pre trained Model : 
IMG_SHAPE = (IMG_SIZE, IMG_SIZE, 3)
base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,
                                               include_top=False,
                                               weights='imagenet')
##This feature extractor converts each 160x160x3 image to a 5x5x1280 block of features. See what ##it does to the example batch of images:
feature_batch = base_model(image_batch)
print(feature_batch.shape) ## ((32, 5, 5, 1280))

##Freezing the convolution base 
base_model.trainable = False

##Adding a classification head : 
global_average_layer = tf.keras.layers.GlobalAveragePooling2D()
feature_batch_average = global_average_layer(feature_batch)
print(feature_batch_average.shape) ## (32, 1280)

prediction_layer = keras.layers.Dense(1)
prediction_batch = prediction_layer(feature_batch_average)
print(prediction_batch.shape) ##(32, 1)

model = tf.keras.Sequential([
  base_model,
  global_average_layer,
  prediction_layer
])

我从来没有使用过 tensorflow,有什么想法可以让它发挥作用吗?

【问题讨论】:

    标签: tensorflow deep-learning conv-neural-network tensorflow-datasets transfer-learning


    【解决方案1】:

    Padded batch vs batch:如果数据集中的元素具有不同的形状,则使用padded batch,而batch 要求其中的元素应具有相同的形状。

    您的代码的问题是您没有遇到您所描述的无限循环。您使用的数据集是微型 imagenet,包含 100,000 张图像,并且需要时间来遍历所有图像一次。如果您不想等那么久,您可以在 for 循环中将 pass 更改为 break,它会在第一次迭代后退出循环。

    还有另一个称为repeat 的操作。这用于按照您在其 count 参数中指定的次数重复您的数据集。但是,如果将其设置为 -1,数据集将继续循环,在这种情况下,您的数据集将进入无限循环。

    【讨论】:

    • 我没有在 100,000 张图像上进行训练。火车组的长度为 1750(现已添加到 O.P 中)。我再次运行代码并等待了一段时间(20 分钟),但循环没有结束。如果我尝试break,循环肯定会在第一次通过后结束,但这并不能解决获得TensorShape([32, 160, 160, 3]) 的目的。如果我使用repeat 和参数count=1750,它仍然会进入无限循环。
    猜你喜欢
    • 1970-01-01
    • 2020-08-06
    • 2023-01-18
    • 2018-05-31
    • 2016-05-18
    • 2019-11-08
    • 2020-10-10
    • 2018-04-12
    • 2020-03-12
    相关资源
    最近更新 更多