下面是PointNet论文中分类模型的结构:
但是对于模型的细节,PointNet论文中并没有详细的解释,尤其是T-Net,可以参考PointNet的supplemental部分。如果找不到,可以留言找我要。
话不多说,下面是代码,基本上完全还原了论文中的PointNet分类模型。
第一部分:数据处理模块
更新一下代码,修复了原先batchsize=1的时候会出错的毛病。
import h5py
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
main_path="E:/DataSets/shapenet_part_seg_hdf5_data/hdf5_data/"
train_txt_path=main_path+"train_hdf5_file_list.txt"
valid_txt_path=main_path+"val_hdf5_file_list.txt"
def get_data(train=True):
data_txt_path =train_txt_path if train else valid_txt_path
with open(data_txt_path, "r") as f:
txt = f.read()
clouds_li = []
labels_li = []
for file_name in txt.split():
h5 = h5py.File(main_path + file_name)
pts = h5["data"].value
lbl = h5["label"].value
clouds_li.append(torch.Tensor(pts))
labels_li.append(torch.Tensor(lbl))
clouds = torch.cat(clouds_li)
labels = torch.cat(labels_li)
return clouds,labels.long().squeeze()
class PointDataSet(Dataset):
def __init__(self,train=True):
clouds, labels = get_data(train=train)
self.x_data=clouds
self.y_data=labels
self.lenth=clouds.size(0)
def __getitem__(self, index):
return self.x_data[index],self.y_data[index]
def __len__(self):
return self.lenth
def get_dataLoader(train=True):
point_data_set=PointDataSet(train=train)
data_loader=DataLoader(dataset=point_data_set,batch_size=16,shuffle=train)
return data_loader
第二部分:模型及其训练
import torch
import torch.nn as nn
import getData
import datetime
class PointNet(nn.Module):
def __init__(self,point_num):
super(PointNet, self).__init__()
self.inputTransform=nn.Sequential(
nn.Conv2d(1,64,(1,3)),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 128,1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 1024,1),
nn.BatchNorm2d(1024),
nn.ReLU(inplace=True),
nn.MaxPool2d((point_num,1)),
)
self.inputFC = nn.Sequential(
nn.Linear(1024,512),
nn.ReLU(inplace=True),
nn.Linear(512, 256),
nn.ReLU(inplace=True),
nn.Linear(256,9),
)
self.mlp1=nn.Sequential(
nn.Conv2d(1,64,(1,3)),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64,64,1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
)
self.featureTransform = nn.Sequential(
nn.Conv2d(64, 64,1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 128,1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 1024,1),
nn.BatchNorm2d(1024),
nn.ReLU(inplace=True),
nn.MaxPool2d((point_num, 1)),
)
self.featureFC=nn.Sequential(
nn.Linear(1024, 512),
nn.ReLU(inplace=True),
nn.Linear(512, 256),
nn.ReLU(inplace=True),
nn.Linear(256, 64*64),
)
self.mlp2=nn.Sequential(
nn.Conv2d(64,64,1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64,128,1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 1024, 1),
nn.BatchNorm2d(1024),
nn.ReLU(inplace=True),
)
self.fc=nn.Sequential(
nn.Linear(1024,512),
nn.ReLU(inplace=True),
nn.Linear(512,256),
nn.ReLU(inplace=True),
#nn.Dropout(p=0.7,inplace=True),对于ShapeNet数据集来说,用dropout反而准确率会降低
nn.Linear(256,16),
nn.Softmax(dim=1),
)
self.inputFC[4].weight.data=torch.zeros(3*3,256)
self.inputFC[4].bias.data=torch.eye(3).view(-1)
def forward(self, x): #[B, N, XYZ]
'''
B:batch_size
N:point_num
K:k_classes
XYZ:input_features
'''
batch_size=x.size(0)#batchsize大小
x=x.unsqueeze(1) #[B, 1, N, XYZ]
t_net=self.inputTransform(x) #[B, 1024, 1,1]
t_net=t_net.squeeze() #[B, 1024]
t_net=self.inputFC(t_net) #[B, 3*3]
t_net=t_net.view(batch_size,3,3)#[B, 3, 3]
x=x.squeeze(1) #[B, N, XYZ]
x=torch.stack([x_item.mm(t_item) for x_item,t_item in zip(x,t_net)])#[B, N, XYZ]# 因为mm只能二维矩阵之间,故逐个乘再拼起来
x=x.unsqueeze(1) #[B, 1, N, XYZ]
x=self.mlp1(x) #[B, 64, N, 1]
t_net=self.featureTransform(x) #[B, 1024, 1, 1]
t_net=t_net.squeeze() #[B, 1024]
t_net=self.featureFC(t_net) #[B, 64*64]
t_net=t_net.view(batch_size,64,64)#[B, 64, 64]
x=x.squeeze(3).permute(0,2,1) #[B, N, 64]
x=torch.stack([x_item.mm(t_item)for x_item,t_item in zip(x,t_net)])#[B, N, 64]
x=x.permute(0,2,1).unsqueeze(-1)#[B, 64, N, 1]
x=self.mlp2(x) #[B, N, 64]
x,_=torch.max(x,2) #[B, 1024, 1]
x=self.fc(x.squeeze(2)) #[B, K]
return x
EPOCHES=100
POINT_NUM=2048
train_loader=getData.get_dataLoader(train=True)
test_loader=getData.get_dataLoader(train=False)
net=PointNet(POINT_NUM).cuda()
optimizer=torch.optim.Adam(net.parameters(),weight_decay=0.001)
loss_function=nn.CrossEntropyLoss()
for epoch in range(EPOCHES):
time_start=datetime.datetime.now()
net.train()
for cloud,label in train_loader:
cloud,label=cloud.cuda(),label.cuda()
out = net(cloud)
loss=loss_function(out,label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total=0
net.eval()
for cloud,label in test_loader:
cloud,label=cloud.cuda(),label.cuda()
out=net(cloud)
_,pre=torch.max(out,1)
correct=(pre==label).sum()
total+=correct.item()
time_end=datetime.datetime.now()
time_span_str=str((time_end-time_start).seconds)
print(str(epoch+1)+"迭代期准确率:"+ str(total/len(test_loader.dataset))+"耗时"+time_span_str+"S")
#python的强大之处
#acc=sum([(torch.max(net(cloud.cuda()),1)[1]==label.cuda()).sum() for cloud,label in test_loader]).item()/len(test_loader.dataset)
就是上面的配置,对于所使用的ShapeNet数据集,准确度可以达到百分之93以上。如发现什么问题bug,请留言。