【发布时间】:2021-12-22 19:21:09
【问题描述】:
问题:
我是计算机视觉领域的新手,这是我的第二个项目。我正在运行使用 tensorflow-gpu==2.7.0 运行的 Matterport Mask RCNN 的编辑版本。 (后来发现使用旧版本会很好)我正在尝试将它与我创建的笔数据集一起使用。
无论如何,我遇到的问题是,每当我将训练后的权重加载到模型中以继续训练时,指标都会猛增。我也得到了错误的预测加载它们进行推理。为什么我的重量没有正确加载或保存?我正在使用回调保存权重并使用以下内容加载它们:
model = modellib.MaskRCNN(mode="inference",
config=inference_config,
model_dir=MODEL_DIR)
# Get path to saved weights
model_path = model.find_last()
# Load trained weights
print("Loading weights from ", model_path)
model.load_weights(model_path, by_name=True)
我的尝试:
我尝试通过将回调中的save_weights_only 更改为False 来保存整个模型。我在thread 中遇到了get_config() 问题,并遵循了其中一些解决方案,但无济于事。
我也尝试过处理图像大小和纪元数。
我尝试使用以下方法保存模型:
from tensorflow import keras
model.keras_model.save(complete filepath)
model = keras.models.load_model('path/to/location')
这导致了同样的get_config() 问题。
资源:
这是我正在运行的东西的列表:
# ITEM ########### VERSION ##########################
# Python # 3.9.7 #
# conda # 4.10.3 #
# CUDA # 11.4 #
# WindowsOS # 11 #
# cuDNN # 8.2.4 #
#####################################################
################################### PACKAGES ##################################
# packages in environment at C:\Users\ecsan\anaconda3\envs\Prototype:
# Command: conda list
# Name #################### Version ################ Build # Channel ############
# absl-py 1.0.0 pypi_0 pypi #
# alabaster 0.7.12 pypi_0 pypi #
# argon2-cffi 21.1.0 pypi_0 pypi #
# astunparse 1.6.3 pypi_0 pypi #
# attrs 21.2.0 pypi_0 pypi #
# babel 2.9.1 pypi_0 pypi #
# backcall 0.2.0 pypi_0 pypi #
# bleach 4.1.0 pypi_0 pypi #
# ca-certificates 2021.10.8 h5b45459_0 conda-forge #
# cachetools 4.2.4 pypi_0 pypi #
# certifi 2021.10.8 pypi_0 pypi #
# cffi 1.15.0 pypi_0 pypi #
# charset-normalizer 2.0.9 pypi_0 pypi #
# colorama 0.4.4 pypi_0 pypi #
# console_shortcut 0.1.1 4 #
# cycler 0.11.0 pypi_0 pypi #
# cython 0.29.25 pypi_0 pypi #
# debugpy 1.5.1 pypi_0 pypi #
# decorator 5.1.0 pypi_0 pypi #
# defusedxml 0.7.1 pypi_0 pypi #
# dill 0.3.4 pypi_0 pypi #
# docutils 0.17.1 pypi_0 pypi #
# entrypoints 0.3 pypi_0 pypi #
# flatbuffers 2.0 pypi_0 pypi #
# fonttools 4.28.3 pypi_0 pypi #
# gast 0.4.0 pypi_0 pypi #
# google-auth 2.3.3 pypi_0 pypi #
# google-auth-oauthlib 0.4.6 pypi_0 pypi #
# google-pasta 0.2.0 pypi_0 pypi #
# grpcio 1.42.0 pypi_0 pypi #
# h5py 3.6.0 pypi_0 pypi #
# idna 3.3 pypi_0 pypi #
# imageio 2.13.2 pypi_0 pypi #
# imagesize 1.3.0 pypi_0 pypi #
# imgaug 0.4.0 pypi_0 pypi #
# importlib-metadata 4.8.2 pypi_0 pypi #
# ipykernel 6.6.0 pypi_0 pypi #
# ipyparallel 8.0.0 pypi_0 pypi #
# ipython 7.30.1 pypi_0 pypi #
# ipython-genutils 0.2.0 pypi_0 pypi #
# ipywidgets 7.6.5 pypi_0 pypi #
# jedi 0.18.1 pypi_0 pypi #
# jinja2 3.0.3 pypi_0 pypi #
# joblib 1.1.0 pypi_0 pypi #
# jsonschema 4.2.1 pypi_0 pypi #
# jupyter-client 7.1.0 pypi_0 pypi #
# jupyter-core 4.9.1 pypi_0 pypi #
# jupyterlab-pygments 0.1.2 pypi_0 pypi #
# jupyterlab-widgets 1.0.2 pypi_0 pypi #
# keras 2.7.0 pypi_0 pypi #
# keras-preprocessing 1.1.2 pypi_0 pypi #
# kiwisolver 1.3.2 pypi_0 pypi #
# libclang 12.0.0 pypi_0 pypi #
# markdown 3.3.6 pypi_0 pypi #
# markupsafe 2.0.1 pypi_0 pypi #
# matplotlib 3.5.0 pypi_0 pypi #
# matplotlib-inline 0.1.3 pypi_0 pypi #
# mistune 0.8.4 pypi_0 pypi #
# nbclient 0.5.9 pypi_0 pypi #
# nbconvert 6.3.0 pypi_0 pypi #
# nbformat 5.1.3 pypi_0 pypi #
# nest-asyncio 1.5.4 pypi_0 #
# networkx 2.6.3 pypi_0 pypi #
# nose 1.3.7 pypi_0 pypi #
# notebook 6.4.6 pypi_0 pypi #
# numpy 1.19.5 pypi_0 pypi #
# oauthlib 3.1.1 pypi_0 pypi #
# opencv-python 4.5.4.60 pypi_0 pypi #
# openssl 3.0.0 h8ffe710_2 conda-forge #
# opt-einsum 3.3.0 pypi_0 pypi #
# packaging 21.3 pypi_0 pypi #
# pandocfilters 1.5.0 pypi_0 pypi #
# parso 0.8.3 pypi_0 pypi #
# pickleshare 0.7.5 pypi_0 pypi #
# pillow 8.4.0 pypi_0 pypi #
# pip 21.3.1 pyhd8ed1ab_0 conda-forge #
# prometheus-client 0.12.0 pypi_0 pypi #
# prompt-toolkit 3.0.23 pypi_0 pypi #
# protobuf 3.19.1 pypi_0 pypi #
# psutil 5.8.0 pypi_0 pypi #
# pyasn1 0.4.8 pypi_0 pypi #
# pyasn1-modules 0.2.8 pypi_0 pypi #
# pycparser 2.21 pypi_0 pypi #
# pygments 2.10.0 pypi_0 pypi #
# pyparsing 3.0.6 pypi_0 pypi #
# pyrsistent 0.18.0 pypi_0 pypi #
# python 3.9.7 h900ac77_3_cpython conda-forge #
# python-dateutil 2.8.2 pypi_0 pypi #
# python_abi 3.9 2_cp39 conda-forge #
# pytz 2021.3 pypi_0 pypi #
# pywavelets 1.2.0 pypi_0 pypi #
# pywin32 302 pypi_0 pypi #
# pywinpty 1.1.6 pypi_0 pypi #
# pyzmq 22.3.0 pypi_0 pypi #
# qtconsole 5.2.1 pypi_0 pypi #
# qtpy 1.11.3 pypi_0 pypi #
# requests 2.26.0 pypi_0 pypi #
# requests-oauthlib 1.3.0 pypi_0 pypi #
# rsa 4.8 pypi_0 pypi #
# scikit-image 0.18.3 pypi_0 pypi #
# scipy 1.7.3 pypi_0 pypi #
# send2trash 1.8.0 pypi_0 pypi #
# setuptools 59.4.0 py39hcbf5309_0 conda-forge #
# setuptools-scm 6.3.2 pypi_0 pypi #
# shapely 1.8.0 pypi_0 pypi #
# six 1.15.0 pypi_0 pypi #
# snowballstemmer 2.2.0 pypi_0 pypi #
# sphinx 4.3.1 pypi_0 pypi #
# sphinxcontrib-applehelp 1.0.2 pypi_0 pypi #
# sphinxcontrib-devhelp 1.0.2 pypi_0 pypi #
# sphinxcontrib-htmlhelp 2.0.0 pypi_0 pypi #
# sphinxcontrib-jsmath 1.0.1 pypi_0 pypi #
# sphinxcontrib-qthelp 1.0.3 pypi_0 pypi #
# sphinxcontrib-serializinghtml 1.1.5 pypi_0 pypi #
# sqlite 3.37.0 h8ffe710_0 conda-forge #
# tb-nightly 2.8.0a20211220 pypi_0 pypi #
# tensorboard 2.7.0 pypi_0 pypi #
# tensorboard-data-server 0.6.1 pypi_0 pypi #
# tensorboard-plugin-wit 1.8.0 pypi_0 pypi #
# tensorflow-estimator 2.7.0 pypi_0 pypi #
# tensorflow-gpu 2.7.0 pypi_0 pypi #
# tensorflow-io-gcs-filesystem 0.23.1 pypi_0 pypi #
# termcolor 1.1.0 pypi_0 pypi #
# terminado 0.12.1 pypi_0 pypi #
# testpath 0.5.0 pypi_0 pypi #
# tf-estimator-nightly 2.8.0.dev2021122009 pypi_0 pypi #
# tifffile 2021.11.2 pypi_0 pypi #
# tomli 1.2.2 pypi_0 pypi #
# tornado 6.1 pypi_0 pypi #
# tqdm 4.62.3 pypi_0 pypi #
# traitlets 5.1.1 pypi_0 pypi #
# typing-extensions 4.0.1 pypi_0 pypi #
# tzdata 2021e he74cb21_0 conda-forge #
# ucrt 10.0.20348.0 h57928b3_0 conda-forge #
# urllib3 1.26.7 pypi_0 pypi #
# vc 14.2 hb210afc_5 conda-forge #
# vs2015_runtime 14.29.30037 h902a5da_5 conda-forge #
# wcwidth 0.2.5 pypi_0 pypi #
# webencodings 0.5.1 pypi_0 pypi #
# werkzeug 2.0.2 pypi_0 pypi #
# wheel 0.37.0 pyhd8ed1ab_1 conda-forge #
# widgetsnbextension 3.5.2 pypi_0 pypi #
# wrapt 1.13.3 pypi_0 pypi #
# zipp 3.6.0 pypi_0 pypi #
###############################################################################
这是我的张量板的链接和一个错误预测的示例:
您应该看到模型学习,然后在最后出现峰值,该峰值是我再次加载权重并恢复训练时。
https://tensorboard.dev/experiment/KkgugOP7RGu12lVCA6M29Q/
这是我用于训练的自定义配置:
class CustomConfig(Config):
"""Configuration for training on the toy shapes dataset.
Derives from the base Config class and overrides values specific
to the toy shapes dataset.
"""
"""Configuration for training on the dataset.
Derives from the base Config class and overrides some values.
"""
DETECTION_MIN_CONFIDENCE = 0.7 # Skip detections with < 90% confidence
# Give the configuration a recognizable name
NAME = "PEN"
# Train on 1 GPU and 8 images per GPU. We can put multiple images on each
# GPU because the images are small. Batch size is 8 (GPUs * images/GPU).
GPU_COUNT = 1
IMAGES_PER_GPU = 8
# Number of classes (including background)
NUM_CLASSES = 1 + 1 # background + PEN
# Use small images for faster training. Set the limits of the small side
# the large side, and that determines the image shape.
IMAGE_MIN_DIM = 128
IMAGE_MAX_DIM = 128
# Use smaller anchors because our image and objects are small
RPN_ANCHOR_SCALES = (8, 16, 32, 64, 128) # anchor side in pixels
# Reduce training ROIs per image because the images are small and have
# few objects. Aim to allow ROI sampling to pick 33% positive ROIs.
TRAIN_ROIS_PER_IMAGE = 32
# Use a small epoch since the data is simple
STEPS_PER_EPOCH = 300
# use small validation steps since the epoch is small
VALIDATION_STEPS = 10
config = CustomConfig()
config.display()
这是我的推理配置:
class InferenceConfig(CustomConfig):
NAME = "PEN"
NUM_CLASSES = 1 + 1 # background + PEN
# Use small images for faster training. Set the limits of the small side
# the large side, and that determines the image shape.
IMAGE_MIN_DIM = 128
IMAGE_MAX_DIM = 128
# Use smaller anchors because our image and objects are small
RPN_ANCHOR_SCALES = (8, 16, 32, 64, 128) # anchor side in pixels
GPU_COUNT = 1
IMAGES_PER_GPU = 1
DETECTION_MIN_CONFIDENCE = 0.9
如果您需要更多信息,请告诉我。这也是我的第一篇文章,感谢任何指导。
【问题讨论】:
标签: python tensorflow keras