【发布时间】:2020-09-24 07:20:30
【问题描述】:
我正在尝试为 ocr 做一些图像分割,我的蒙版图像是 3 类图像,像这样
我的原始图像是这样的灰色图像
但是当我尝试拟合模型时出现此错误
无法将输入数组从形状 (128,128,3) 广播到形状 (128,128)
这是我用来创建数据集的代码
img_size = (128, 128)
batch_size = 32
input_img_paths = sorted(
[ os.path.join(input_dir, fname)
for fname in os.listdir(input_dir)
if fname.endswith(".jpg") ] )
target_img_paths = sorted(
[ os.path.join(target_dir, fname)
for fname in os.listdir(target_dir)
if fname.endswith(".jpg") and not fname.startswith(".") ])
class OxfordPets(keras.utils.Sequence):
"""Helper to iterate over the data (as Numpy arrays)."""
def __init__(self, batch_size, img_size, input_img_paths, target_img_paths):
self.batch_size = batch_size
self.img_size = img_size
self.input_img_paths = input_img_paths
self.target_img_paths = target_img_paths
def __len__(self):
return len(self.target_img_paths) // self.batch_size
def __getitem__(self, idx):
"""Returns tuple (input, target) correspond to batch #idx."""
i = idx * self.batch_size
batch_input_img_paths = self.input_img_paths[i : i + self.batch_size]
batch_target_img_paths = self.target_img_paths[i : i + self.batch_size]
x = np.zeros((batch_size,) + self.img_size, dtype="float32")
for j, path in enumerate(batch_input_img_paths):
img = load_img(path, target_size=self.img_size)
x[j] = img
y = np.zeros((batch_size,) + self.img_size, dtype="float32")
for j, path in enumerate(batch_target_img_paths):
img = load_img(path, target_size=self.img_size, color_mode="rgb")
y[j] = img
return x, y
val_samples = 150
random.Random(1337).shuffle(input_img_paths)
random.Random(1337).shuffle(target_img_paths)
train_input_img_paths = input_img_paths[:-val_samples]
train_target_img_paths = target_img_paths[:-val_samples]
val_input_img_paths = input_img_paths[-val_samples:]
val_target_img_paths = target_img_paths[-val_samples:]
# Instantiate data Sequences for each split
train_gen = OxfordPets(
batch_size, img_size, train_input_img_paths, train_target_img_paths
)
val_gen = OxfordPets(batch_size, img_size, val_input_img_paths, val_target_img_paths)
但是当我试图适应这个
model_history = model.fit(train_gen, epochs=30,
steps_per_epoch=50,
validation_steps=25,
validation_data=val_gen)
我收到错误消息,我正在尝试调整此解决方案 https://keras.io/examples/vision/oxford_pets_image_segmentation/?fbclid=IwAR2wFYju-N0X7FUaWkhvOVaAAaVqLdOryBwg7xDC0Rji9LQ5F2jYOkeNnns 来自 keras
进入tensorflow页面的例子 https://www.tensorflow.org/tutorials/images/segmentation
我的印象是问题与原始图像是灰度的事实有关,我该如何解决这个错误?任何建议都会很棒!
【问题讨论】:
标签: tensorflow deep-learning ocr