【发布时间】:2019-02-18 17:34:09
【问题描述】:
我是回归新手,我在 matlab 中编写了一个非常简单的代码,它使用 lasso 函数只是为了看看我是否了解 lasso 的 MSE 是如何计算的。但我得到的 mse 与套索的输出不同。我可能遗漏了一些东西,如果有人能告诉我哪里错了,我将不胜感激。为了计算 MSE,我使用了此链接中的以下公式:https://www.mathworks.com/help/stats/lasso.html
这是我写的matlab代码:
clear;
close all;
clc;
% Checking lasso MSE from this link:
% https://www.mathworks.com/help/stats/lasso.html
n = 10;
p = 3;
X = 20*rand(n,p);
min_val = -20;
max_val = 20;
y = min_val + (max_val - min_val)*rand(n,1);
lambda_vals = [0.2, 0.8, 1, 1.5];
[beta_vectors , FitInfo] = lasso(X, y, 'Lambda', lambda_vals);
eps = 10^-10;
num_of_lambda_vals = length(lambda_vals);
for i=1:num_of_lambda_vals
current_calculated_mse = sum((y - FitInfo.Intercept(i) - X*beta_vectors(:,i)).^2)/(2*n) +...
lambda_vals(i)*sum(abs(beta_vectors(:,i)));
current_mse = FitInfo.MSE(i);
fprintf('current_calculated_mse = %f\n',current_calculated_mse);
fprintf('current_mse = %f\n',current_mse);
sqr_diff_mses = (current_calculated_mse-current_mse)^2;
if (sqr_diff_mses > eps)
fprintf('The calculated MSE is wrong!\n');
end
fprintf('\n');
end
如果您运行代码,它将打印出计算 MSE 错误。 谁能告诉我的代码有什么问题?
谢谢
【问题讨论】:
-
我建议您首先整理代码并删除您所做的一堆不必要的冗长操作。例如。 Beta_0不需要repmat'ed,或者你不需要
temp,你可以直接在同一行计算平方根。你的 for 循环不应该超过 5 或 6 行,现在它只会让它更难理解。 -
我更新了代码并缩短了它。循环现在只有几行(大部分是 fprintf 行)。
-
您确定您对
x^{T}*beta的实现是正确的吗?我对 Lasso 没有经验,但我觉得那部分不太对(你似乎没有转置 X) -
我什么都不确定。这就是我在这里发帖的原因。如果我转置 X,那么我将无法将它与 beta 相乘。我不知道为什么公式中的转置符号,但我需要忽略它,否则我将无法进行计算。
-
@David 正如我在回答中提到的那样:您的代码很好,您只是使用了错误的方程式!
标签: regression matlab lasso-regression