lomooo

在看官方教程时,无意中发现别人写的一个脚本,非常简洁。

官方教程地址:http://pytorch.org/tutorials/beginner/data_loading_tutorial.html#sphx-glr-beginner-data-loading-tutorial-py

使用的是dlib自带的特征点检测库,初期用来测试还是不错的

 

 1 """Create a sample face landmarks dataset.
 2 
 3 Adapted from dlib/python_examples/face_landmark_detection.py
 4 See this file for more explanation.
 5 
 6 Download a trained facial shape predictor from:
 7     http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2
 8 """
 9 import dlib
10 import glob
11 import csv
12 from skimage import io
13 
14 detector = dlib.get_frontal_face_detector()
15 predictor = dlib.shape_predictor(\'shape_predictor_68_face_landmarks.dat\')
16 num_landmarks = 68
17 
18 with open(\'face_landmarks.csv\', \'w\', newline=\'\') as csvfile:
19     csv_writer = csv.writer(csvfile)
20 
21     header = [\'image_name\']
22     for i in range(num_landmarks):
23         header += [\'part_{}_x\'.format(i), \'part_{}_y\'.format(i)]
24 
25     csv_writer.writerow(header)
26 
27     for f in glob.glob(\'*.jpg\'):
28         img = io.imread(f)
29         dets = detector(img, 1)  # face detection
30 
31         # ignore all the files with no or more than one faces detected.
32         if len(dets) == 1:
33             row = [f]
34 
35             d = dets[0]
36             # Get the landmarks/parts for the face in box d.
37             shape = predictor(img, d)
38             for i in range(num_landmarks):
39                 part_i_x = shape.part(i).x
40                 part_i_y = shape.part(i).y
41                 row += [part_i_x, part_i_y]
42 
43             csv_writer.writerow(row)
View Code

附上使用matplotlib显示特征点的脚本:

 1 from __future__ import print_function, division
 2 import os
 3 import torch
 4 import pandas as pd
 5 from skimage import io, transform
 6 import numpy as np
 7 import matplotlib.pyplot as plt
 8 from torch.utils.data import Dataset, DataLoader
 9 from torchvision import transforms, utils
10 
11 # Ignore warnings
12 import warnings
13 warnings.filterwarnings("ignore")
14 
15 plt.ion()   # interactive mode
16 
17 landmarks_frame = pd.read_csv(\'faces/face_landmarks.csv\')
18 
19 n = 5
20 img_name = landmarks_frame.iloc[n, 0]
21 landmarks = landmarks_frame.iloc[n, 1:].as_matrix()
22 landmarks = landmarks.astype(\'float\').reshape(-1, 2)
23 
24 print(\'Image name: {}\'.format(img_name))
25 print(\'Landmarks shape: {}\'.format(landmarks.shape))
26 print(\'First 4 Landmarks: {}\'.format(landmarks[:4]))
27 
28 def show_landmarks(image, landmarks):
29     """Show image with landmarks"""
30     plt.imshow(image)
31     plt.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker=\'.\', c=\'r\')
32     plt.pause(0.001)  # pause a bit so that plots are updated
33 
34 plt.figure()
35 show_landmarks(io.imread(os.path.join(\'faces/\', img_name)),
36                landmarks)
37 plt.show()
View Code

 效果图:

分类:

技术点:

相关文章: