【问题标题】:10 fold cross-validation in one-against-all SVM (using LibSVM)一对多 SVM 中的 10 倍交叉验证(使用 LibSVM)
【发布时间】:2012-12-11 01:49:42
【问题描述】:

我想在 MATLAB 中的 one-against-all support vector machine 分类中进行 10 倍交叉验证。

我试图以某种方式混合这两个相关的答案:

但由于我是 MATLAB 及其语法的新手,所以直到现在我才设法让它工作。

另一方面,我在LibSVM README 文件中只看到以下几行关于交叉验证的内容,但在那里我找不到任何相关示例:

option -v 将数据随机分成n份,计算cross 验证准确度/均方误差。

有关输出的含义,请参阅 libsvm 常见问题解答。

谁能给我一个 10 折交叉验证和一对一分类的例子?

【问题讨论】:

  • 如 carlosdc 所述,第二个链接展示了 Bioinformatics toolbox(不是 libsvm)中的 SVM 函数
  • 仅供参考,从 R2013a 开始,MATLAB 的 svm 函数已从 Bioinformatics 工具箱移至 Statistics 工具箱(我认为它们应该放在首位!)

标签: matlab machine-learning classification svm libsvm


【解决方案1】:

我们这样做主要有两个原因cross-validation

  • 作为一种测试方法,它为我们的模型的泛化能力提供了几乎无偏的估计(通过避免过度拟合)
  • 作为model selection 的一种方式(例如:在训练数据上找到最佳的Cgamma 参数,示例见this post

对于我们感兴趣的第一种情况,该过程涉及为每个折叠训练 k 模型,然后在整个训练集上训练一个最终模型。 我们报告了 k-folds 的平均准确率。

现在,由于我们使用 one-vs-all 方法来处理多类问题,因此每个模型都包含N 支持向量机(每个类一个)。


以下是实现一对多方法的包装函数:

function mdl = libsvmtrain_ova(y, X, opts)
    if nargin < 3, opts = ''; end

    %# classes
    labels = unique(y);
    numLabels = numel(labels);

    %# train one-against-all models
    models = cell(numLabels,1);
    for k=1:numLabels
        models{k} = libsvmtrain(double(y==labels(k)), X, strcat(opts,' -b 1 -q'));
    end
    mdl = struct('models',{models}, 'labels',labels);
end

function [pred,acc,prob] = libsvmpredict_ova(y, X, mdl)
    %# classes
    labels = mdl.labels;
    numLabels = numel(labels);

    %# get probability estimates of test instances using each 1-vs-all model
    prob = zeros(size(X,1), numLabels);
    for k=1:numLabels
        [~,~,p] = libsvmpredict(double(y==labels(k)), X, mdl.models{k}, '-b 1 -q');
        prob(:,k) = p(:, mdl.models{k}.Label==1);
    end

    %# predict the class with the highest probability
    [~,pred] = max(prob, [], 2);
    %# compute classification accuracy
    acc = mean(pred == y);
end

以下是支持交叉验证的函数:

function acc = libsvmcrossval_ova(y, X, opts, nfold, indices)
    if nargin < 3, opts = ''; end
    if nargin < 4, nfold = 10; end
    if nargin < 5, indices = crossvalidation(y, nfold); end

    %# N-fold cross-validation testing
    acc = zeros(nfold,1);
    for i=1:nfold
        testIdx = (indices == i); trainIdx = ~testIdx;
        mdl = libsvmtrain_ova(y(trainIdx), X(trainIdx,:), opts);
        [~,acc(i)] = libsvmpredict_ova(y(testIdx), X(testIdx,:), mdl);
    end
    acc = mean(acc);    %# average accuracy
end

function indices = crossvalidation(y, nfold)
    %# stratified n-fold cros-validation
    %#indices = crossvalind('Kfold', y, nfold);  %# Bioinformatics toolbox
    cv = cvpartition(y, 'kfold',nfold);          %# Statistics toolbox
    indices = zeros(size(y));
    for i=1:nfold
        indices(cv.test(i)) = i;
    end
end

最后,用一个简单的demo来说明用法:

%# laod dataset
S = load('fisheriris');
data = zscore(S.meas);
labels = grp2idx(S.species);

%# cross-validate using one-vs-all approach
opts = '-s 0 -t 2 -c 1 -g 0.25';    %# libsvm training options
nfold = 10;
acc = libsvmcrossval_ova(labels, data, opts, nfold);
fprintf('Cross Validation Accuracy = %.4f%%\n', 100*mean(acc));

%# compute final model over the entire dataset
mdl = libsvmtrain_ova(labels, data, opts);

将其与 libsvm 默认使用的一对一方法进行比较:

acc = libsvmtrain(labels, data, sprintf('%s -v %d -q',opts,nfold));
model = libsvmtrain(labels, data, strcat(opts,' -q'));

【讨论】:

  • 请注意,我已将 libsvm 函数重命名为 libsvmtrainlibsvmpredict 以避免与生物信息学工具箱中具有相同名称部分的函数(即 svmtrain)发生名称冲突
  • libsvmtrain_ova 函数中,我在这一行收到错误Undefined function or method 'libsvmtrain' for input arguments of type 'double'.models{k} = libsvmtrain(double(y==labels(k)), X, strcat(opts,' -b 1 -q'));
  • @Ezati:正如我在上面的评论中所说,我重命名了 libsvm MEX 函数以避免与生物信息学工具箱混淆。在您的情况下,您可以在上面的代码中简单地将libsvmtrain 替换为svmtrainlibsvmpredictsvmpredict
  • 对不起,我一开始没有注意到你的评论..现在一切都好 :) 非常感谢你,我希望我能给你一个 +100
【解决方案2】:

您可能会感到困惑,这两个问题之一与 LIBSVM 无关。你应该尝试调整this answer而忽略另一个。

您应该选择折叠,然后按照链接的问题完成其余部分。假设数据已加载到data,标签已加载到labels

n = size(data,1);
ns = floor(n/10);
for fold=1:10,
    if fold==1,
        testindices= ((fold-1)*ns+1):fold*ns;
        trainindices = fold*ns+1:n;
    else
        if fold==10,
            testindices= ((fold-1)*ns+1):n;
            trainindices = 1:(fold-1)*ns;
        else
            testindices= ((fold-1)*ns+1):fold*ns;
            trainindices = [1:(fold-1)*ns,fold*ns+1:n];
         end
    end
    % use testindices only for testing and train indices only for testing
    trainLabel = label(trainindices);
    trainData = data(trainindices,:);
    testLabel = label(testindices);
    testData = data(testindices,:)
    %# train one-against-all models
    model = cell(numLabels,1);
    for k=1:numLabels
        model{k} = svmtrain(double(trainLabel==k), trainData, '-c 1 -g 0.2 -b 1');
    end

    %# get probability estimates of test instances using each model
    prob = zeros(size(testData,1),numLabels);
    for k=1:numLabels
        [~,~,p] = svmpredict(double(testLabel==k), testData, model{k}, '-b 1');
        prob(:,k) = p(:,model{k}.Label==1);    %# probability of class==k
    end

    %# predict the class with the highest probability
    [~,pred] = max(prob,[],2);
    acc = sum(pred == testLabel) ./ numel(testLabel)    %# accuracy
    C = confusionmat(testLabel, pred)                   %# confusion matrix
end

【讨论】:

  • prob = zeros(numTest,numLabels); 行,你的意思是nsnumTest。是吗?
  • 不,我的意思是您正在测试的数据点的数量。我已经编辑了代码。
  • 那么-v 选项呢?我们不需要使用它吗?
  • 从我们的问题来看,您似乎需要一对一而不是一对一(这是 -v 在多类问题的情况下实现)
  • 但是here 它说-v 用于交叉验证,而不是一对一或一对一。我说的对吗?
猜你喜欢
  • 2012-12-14
  • 2011-07-14
  • 2015-01-11
  • 2011-07-09
  • 2018-09-12
  • 2012-11-01
  • 2014-04-02
  • 1970-01-01
  • 2013-04-12
相关资源
最近更新 更多