【问题标题】:Mask R-CNN is not loading weights properly for inference and re-trainingMask R-CNN 未正确加载权重以进行推理和重新训练
【发布时间】:2021-12-22 19:21:09
【问题描述】:

问题:

我是计算机视觉领域的新手,这是我的第二个项目。我正在运行使用 tensorflow-gpu==2.7.0 运行的 Matterport Mask RCNN 的编辑版本。 (后来发现使用旧版本会很好)我​​正在尝试将它与我创建的笔数据集一起使用。

无论如何,我遇到的问题是,每当我将训练后的权重加载到模型中以继续训练时,指标都会猛增。我也得到了错误的预测加载它们进行推理。为什么我的重量没有正确加载或保存?我正在使用回调保存权重并使用以下内容加载它们:

model = modellib.MaskRCNN(mode="inference", 
                          config=inference_config,
                          model_dir=MODEL_DIR)

# Get path to saved weights
model_path = model.find_last()

# Load trained weights
print("Loading weights from ", model_path)
model.load_weights(model_path, by_name=True)

我的尝试:

我尝试通过将回调中的save_weights_only 更改为False 来保存整个模型。我在thread 中遇到了get_config() 问题,并遵循了其中一些解决方案,但无济于事。

我也尝试过处理图像大小和纪元数。

我尝试使用以下方法保存模型:

from tensorflow import keras
model.keras_model.save(complete filepath)
model = keras.models.load_model('path/to/location')

这导致了同样的get_config() 问题。

资源:

这是我正在运行的东西的列表:

 # ITEM ########### VERSION ##########################
 # Python         # 3.9.7                            #
 # conda          # 4.10.3                           # 
 # CUDA           # 11.4                             #
 # WindowsOS      # 11                               #
 # cuDNN          # 8.2.4                            # 
 #####################################################
 
################################### PACKAGES ##################################
# packages in environment at C:\Users\ecsan\anaconda3\envs\Prototype:
# Command: conda list
# Name #################### Version ################ Build # Channel ############
# absl-py                   1.0.0                    pypi_0    pypi           #
# alabaster                 0.7.12                   pypi_0    pypi           #
# argon2-cffi               21.1.0                   pypi_0    pypi           #
# astunparse                1.6.3                    pypi_0    pypi           #
# attrs                     21.2.0                   pypi_0    pypi           #
# babel                     2.9.1                    pypi_0    pypi           #
# backcall                  0.2.0                    pypi_0    pypi           #
# bleach                    4.1.0                    pypi_0    pypi           #
# ca-certificates           2021.10.8            h5b45459_0    conda-forge    # 
# cachetools                4.2.4                    pypi_0    pypi           #
# certifi                   2021.10.8                pypi_0    pypi           #
# cffi                      1.15.0                   pypi_0    pypi           #
# charset-normalizer        2.0.9                    pypi_0    pypi           #
# colorama                  0.4.4                    pypi_0    pypi           #
# console_shortcut          0.1.1                         4                   #
# cycler                    0.11.0                   pypi_0    pypi           #
# cython                    0.29.25                  pypi_0    pypi           #
# debugpy                   1.5.1                    pypi_0    pypi           #
# decorator                 5.1.0                    pypi_0    pypi           #
# defusedxml                0.7.1                    pypi_0    pypi           #
# dill                      0.3.4                    pypi_0    pypi           #
# docutils                  0.17.1                   pypi_0    pypi           #
# entrypoints               0.3                      pypi_0    pypi           #
# flatbuffers               2.0                      pypi_0    pypi           #
# fonttools                 4.28.3                   pypi_0    pypi           #
# gast                      0.4.0                    pypi_0    pypi           #
# google-auth               2.3.3                    pypi_0    pypi           #
# google-auth-oauthlib      0.4.6                    pypi_0    pypi           #
# google-pasta              0.2.0                    pypi_0    pypi           #
# grpcio                    1.42.0                   pypi_0    pypi           #
# h5py                      3.6.0                    pypi_0    pypi           #
# idna                      3.3                      pypi_0    pypi           #
# imageio                   2.13.2                   pypi_0    pypi           #
# imagesize                 1.3.0                    pypi_0    pypi           #
# imgaug                    0.4.0                    pypi_0    pypi           #
# importlib-metadata        4.8.2                    pypi_0    pypi           #
# ipykernel                 6.6.0                    pypi_0    pypi           #
# ipyparallel               8.0.0                    pypi_0    pypi           #
# ipython                   7.30.1                   pypi_0    pypi           #
# ipython-genutils          0.2.0                    pypi_0    pypi           #
# ipywidgets                7.6.5                    pypi_0    pypi           #
# jedi                      0.18.1                   pypi_0    pypi           #
# jinja2                    3.0.3                    pypi_0    pypi           #
# joblib                    1.1.0                    pypi_0    pypi           #
# jsonschema                4.2.1                    pypi_0    pypi           #
# jupyter-client            7.1.0                    pypi_0    pypi           #
# jupyter-core              4.9.1                    pypi_0    pypi           #
# jupyterlab-pygments       0.1.2                    pypi_0    pypi           #
# jupyterlab-widgets        1.0.2                    pypi_0    pypi           #
# keras                     2.7.0                    pypi_0    pypi           #
# keras-preprocessing       1.1.2                    pypi_0    pypi           #
# kiwisolver                1.3.2                    pypi_0    pypi           #
# libclang                  12.0.0                   pypi_0    pypi           #
# markdown                  3.3.6                    pypi_0    pypi           #
# markupsafe                2.0.1                    pypi_0    pypi           #
# matplotlib                3.5.0                    pypi_0    pypi           #
# matplotlib-inline         0.1.3                    pypi_0    pypi           #
# mistune                   0.8.4                    pypi_0    pypi           #
# nbclient                  0.5.9                    pypi_0    pypi           #
# nbconvert                 6.3.0                    pypi_0    pypi           #
# nbformat                  5.1.3                    pypi_0    pypi           #
# nest-asyncio              1.5.4                    pypi_0                   #
# networkx                  2.6.3                    pypi_0    pypi           #
# nose                      1.3.7                    pypi_0    pypi           #
# notebook                  6.4.6                    pypi_0    pypi           #
# numpy                     1.19.5                   pypi_0    pypi           #
# oauthlib                  3.1.1                    pypi_0    pypi           #
# opencv-python             4.5.4.60                 pypi_0    pypi           #
# openssl                   3.0.0                h8ffe710_2    conda-forge    #
# opt-einsum                3.3.0                    pypi_0    pypi           #
# packaging                 21.3                     pypi_0    pypi           #
# pandocfilters             1.5.0                    pypi_0    pypi           #
# parso                     0.8.3                    pypi_0    pypi           #
# pickleshare               0.7.5                    pypi_0    pypi           #
# pillow                    8.4.0                    pypi_0    pypi           #
# pip                       21.3.1             pyhd8ed1ab_0    conda-forge    #
# prometheus-client         0.12.0                   pypi_0    pypi           #
# prompt-toolkit            3.0.23                   pypi_0    pypi           #
# protobuf                  3.19.1                   pypi_0    pypi           #
# psutil                    5.8.0                    pypi_0    pypi           #
# pyasn1                    0.4.8                    pypi_0    pypi           #
# pyasn1-modules            0.2.8                    pypi_0    pypi           #
# pycparser                 2.21                     pypi_0    pypi           #
# pygments                  2.10.0                   pypi_0    pypi           #
# pyparsing                 3.0.6                    pypi_0    pypi           #
# pyrsistent                0.18.0                   pypi_0    pypi           #
# python                    3.9.7        h900ac77_3_cpython    conda-forge    #
# python-dateutil           2.8.2                    pypi_0    pypi           #
# python_abi                3.9                      2_cp39    conda-forge    #
# pytz                      2021.3                   pypi_0    pypi           #
# pywavelets                1.2.0                    pypi_0    pypi           #
# pywin32                   302                      pypi_0    pypi           #
# pywinpty                  1.1.6                    pypi_0    pypi           #
# pyzmq                     22.3.0                   pypi_0    pypi           #
# qtconsole                 5.2.1                    pypi_0    pypi           #
# qtpy                      1.11.3                   pypi_0    pypi           #
# requests                  2.26.0                   pypi_0    pypi           #
# requests-oauthlib         1.3.0                    pypi_0    pypi           #
# rsa                       4.8                      pypi_0    pypi           #
# scikit-image              0.18.3                   pypi_0    pypi           #
# scipy                     1.7.3                    pypi_0    pypi           #
# send2trash                1.8.0                    pypi_0    pypi           #
# setuptools                59.4.0           py39hcbf5309_0    conda-forge    #
# setuptools-scm            6.3.2                    pypi_0    pypi           #
# shapely                   1.8.0                    pypi_0    pypi           #
# six                       1.15.0                   pypi_0    pypi           #
# snowballstemmer           2.2.0                    pypi_0    pypi           #
# sphinx                    4.3.1                    pypi_0    pypi           #
# sphinxcontrib-applehelp   1.0.2                    pypi_0    pypi           #
# sphinxcontrib-devhelp     1.0.2                    pypi_0    pypi           #
# sphinxcontrib-htmlhelp    2.0.0                    pypi_0    pypi           #
# sphinxcontrib-jsmath      1.0.1                    pypi_0    pypi           #
# sphinxcontrib-qthelp      1.0.3                    pypi_0    pypi           #
# sphinxcontrib-serializinghtml 1.1.5                pypi_0    pypi           #
# sqlite                    3.37.0               h8ffe710_0    conda-forge    #
# tb-nightly                2.8.0a20211220           pypi_0    pypi           #
# tensorboard               2.7.0                    pypi_0    pypi           #
# tensorboard-data-server   0.6.1                    pypi_0    pypi           #
# tensorboard-plugin-wit    1.8.0                    pypi_0    pypi           #
# tensorflow-estimator      2.7.0                    pypi_0    pypi           #
# tensorflow-gpu            2.7.0                    pypi_0    pypi           #
# tensorflow-io-gcs-filesystem 0.23.1                pypi_0    pypi           #
# termcolor                 1.1.0                    pypi_0    pypi           #
# terminado                 0.12.1                   pypi_0    pypi           #
# testpath                  0.5.0                    pypi_0    pypi           #
# tf-estimator-nightly      2.8.0.dev2021122009      pypi_0    pypi           #
# tifffile                  2021.11.2                pypi_0    pypi           #
# tomli                     1.2.2                    pypi_0    pypi           #
# tornado                   6.1                      pypi_0    pypi           #
# tqdm                      4.62.3                   pypi_0    pypi           #
# traitlets                 5.1.1                    pypi_0    pypi           #
# typing-extensions         4.0.1                    pypi_0    pypi           #
# tzdata                    2021e                he74cb21_0    conda-forge    #
# ucrt                      10.0.20348.0         h57928b3_0    conda-forge    #
# urllib3                   1.26.7                   pypi_0    pypi           #
# vc                        14.2                 hb210afc_5    conda-forge    #
# vs2015_runtime            14.29.30037          h902a5da_5    conda-forge    #
# wcwidth                   0.2.5                    pypi_0    pypi           #
# webencodings              0.5.1                    pypi_0    pypi           #
# werkzeug                  2.0.2                    pypi_0    pypi           #
# wheel                     0.37.0             pyhd8ed1ab_1    conda-forge    #
# widgetsnbextension        3.5.2                    pypi_0    pypi           #
# wrapt                     1.13.3                   pypi_0    pypi           #
# zipp                      3.6.0                    pypi_0    pypi           #
###############################################################################

这是我的张量板的链接和一个错误预测的示例:

您应该看到模型学习,然后在最后出现峰值,该峰值是我再次加载权重并恢复训练时。

https://tensorboard.dev/experiment/KkgugOP7RGu12lVCA6M29Q/

Bad Prediction

这是我用于训练的自定义配置:

class CustomConfig(Config):
    """Configuration for training on the toy shapes dataset.
    Derives from the base Config class and overrides values specific
    to the toy shapes dataset.
    """
    """Configuration for training on the dataset.
    Derives from the base Config class and overrides some values.
    """    


    DETECTION_MIN_CONFIDENCE = 0.7 # Skip detections with < 90% confidence
    # Give the configuration a recognizable name
    NAME = "PEN"

    # Train on 1 GPU and 8 images per GPU. We can put multiple images on each
    # GPU because the images are small. Batch size is 8 (GPUs * images/GPU).
    GPU_COUNT = 1
    IMAGES_PER_GPU = 8

    # Number of classes (including background)
    NUM_CLASSES = 1 + 1  # background + PEN

    # Use small images for faster training. Set the limits of the small side
    # the large side, and that determines the image shape.
    IMAGE_MIN_DIM = 128
    IMAGE_MAX_DIM = 128

    # Use smaller anchors because our image and objects are small
    RPN_ANCHOR_SCALES = (8, 16, 32, 64, 128)  # anchor side in pixels

    # Reduce training ROIs per image because the images are small and have
    # few objects. Aim to allow ROI sampling to pick 33% positive ROIs.
    TRAIN_ROIS_PER_IMAGE = 32

    # Use a small epoch since the data is simple
    STEPS_PER_EPOCH = 300

    # use small validation steps since the epoch is small
    VALIDATION_STEPS = 10
    
config = CustomConfig()
config.display()

这是我的推理配置:

class InferenceConfig(CustomConfig):
    NAME = "PEN"

    NUM_CLASSES = 1 + 1  # background + PEN

    # Use small images for faster training. Set the limits of the small side
    # the large side, and that determines the image shape.
    IMAGE_MIN_DIM = 128
    IMAGE_MAX_DIM = 128

    # Use smaller anchors because our image and objects are small
    RPN_ANCHOR_SCALES = (8, 16, 32, 64, 128)  # anchor side in pixels
    GPU_COUNT = 1
    IMAGES_PER_GPU = 1
    DETECTION_MIN_CONFIDENCE = 0.9

如果您需要更多信息,请告诉我。这也是我的第一篇文章,感谢任何指导。

【问题讨论】:

    标签: python tensorflow keras


    【解决方案1】:

    请检查您正在使用的 python、tensrorflow 和 keras 的版本,因为它仅适用于 python 3.6。

    【讨论】:

      猜你喜欢
      • 2020-02-09
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2020-01-12
      • 2019-02-05
      • 2018-11-14
      • 1970-01-01
      相关资源
      最近更新 更多