jkmiao

ocr 文字区域检测及识别

# coding=utf-8

from PIL import Image, ImageFilter, ImageEnhance
from skimage.filters import threshold_otsu
import skimage.morphology as sm
from skimage.measure import regionprops
import matplotlib.pyplot as plt
import numpy as np
import pytesseract
import re
import os
import time
import logging
logging.basicConfig(level=logging.INFO, format="%(message)s", filename=\'train_output.log\')

class ShopCert(object): def cut_region(self, img): """ 先按规则缩小搜索范围 """ w, h = img.size if h<1200: factor = max(1, 1600.0/h) newsize = int(w*factor), int(h*factor) img = img.resize(newsize, Image.ANTIALIAS) if w<h: box = (w*0.4, h*0.18, w*0.96, h*0.6) else: box = (w*0.1, h*0.18, w*0.96, h*0.9) return img.crop(box) def detect_text(self, img): """ 检测字符区域 """ imgM = np.array(img.convert(\'L\')) imgM = 1 * (imgM < threshold_otsu(imgM)) imgM = sm.binary_closing(imgM, np.ones((5, 10))) imgM = sm.remove_small_objects(imgM, 600) label_img = sm.label(imgM) imgList = [] for region in regionprops(label_img): minr, minc, maxr, maxc = region.bbox w, h = (maxc-minc), (maxr-minr) if h > w * 0.2: continue box = minc-5, minr-3, maxc+5, maxr+3 imgList.append(img.crop(box)) return imgList def clear_noise(self, box): """ 降噪处理 """ box = box.convert(\'L\') # box = box.point(lambda x: 0 if x<50 else x) box = box.point(lambda x: 200 if x>200 else x) box = ImageEnhance.Contrast(box).enhance(2.5) return box def predict(self, fname, lang=\'eng\'): """ ocr 识别 """ img = Image.open(fname) # 先大致缩小范围 region = self.cut_region(img) # 候选字符区域 # region = self.clear_noise(region) boxList = self.detect_text(region) # 遍历识别 for box in boxList: box = self.clear_noise(box) w, h = box.size if float(w)/h > 12.5: res = pytesseract.image_to_string(box, lang=\'chi_sim\', config=\'-psm 7\') else: res = pytesseract.image_to_string(box, lang=\'eng\', config=\'-psm 7\') res = re.sub(\'\s\', \'\', res) # 去除中间空白 res = re.findall(r\'[0-9][A-Z0-9]{13,20}\', res) # 13-20位 for line in res: line = line.strip() if line.find(u\'\')>1: continue print \'line\', line if len(line)> 14: box.save(\'img/clearNoise/%s_%s.jpg\' % (fname.split(\'/\')[-1].split(\'.\')[0], line)) return line else: print \'error line\', line return \'error\' def show_pic(path=\'img/origin2/\'): fnames = [os.path.join(path, fname) for fname in os.listdir(path)] for i, fname in enumerate(fnames, 0): print fname img = Image.open(fname) # img.save(\'./tesseract-train/cert.normal.exp%d.ttf\' % i) img = ImageEnhance.Contrast(img).enhance(2.0) img = img.filter(ImageFilter.MedianFilter).convert(\'L\') plt.figure(figsize=(10, 12), dpi=300) plt.imshow(img, plt.cm.gray) plt.title(fname.split(\'/\')[-1]+\'_%d\' % i) plt.show()if __name__ == \'__main__\': test = ShopCert() path = \'img/origin2/\' fnames = [os.path.join(path, fname) for fname in os.listdir(path) if fname.endswith(\'jpg\')] fnames.sort() arguments = \'mode: L; enhance:2.0; h:0.5; dh:0.15\' logging.info(\'%s\' % arguments) logging.info("%s: %s" % (\'imgname\', \'result\')) start_time = time.time() cnt = 0 for idx, fname in enumerate(fnames, 1): print idx, fname y_true = fname.split(\'/\')[-1].split(\'_\')[0] y_pred = test.predict(fname) if y_true == y_pred: cnt +=1 print fname else: print \'***\'*20 print \'error\' logging.info("%s: %s" % (fname, y_pred)) print \'y_true\', y_true print \'y_pred\', y_pred acc = float(cnt)/idx print acc, cnt print \'==\'*20, idx logging.info(\'%.3f %d/%d\' % (acc, cnt, idx)) print \'cost time: \', time.time()-start_time logging.info(\'accuracy: %.2f\' % acc)

 

分类:

技术点:

相关文章: