【问题标题】:Reproduce Fisher linear discriminant figure再现Fisher线性判别图
【发布时间】:2016-02-23 23:23:38
【问题描述】:

许多书籍都使用下图说明了 Fisher 线性判别分析的思想(此图来自 Pattern Recognition and Machine Learning,第 188 页)

我想知道如何在 R(或任何其他语言)中重现此图。下面粘贴的是我在 R 中的初步工作。我模拟两组数据并使用abline() 函数绘制线性判别式。欢迎提出任何建议。

set.seed(2014)
library(MASS)
library(DiscriMiner) # For scatter matrices

# Simulate bivariate normal distribution with 2 classes
mu1 <- c(2, -4)
mu2 <- c(2, 6)
rho <- 0.8
s1 <- 1
s2 <- 3
Sigma <- matrix(c(s1^2, rho * s1 * s2, rho * s1 * s2, s2^2), byrow = TRUE, nrow = 2)
n <- 50
X1 <- mvrnorm(n, mu = mu1, Sigma = Sigma)
X2 <- mvrnorm(n, mu = mu2, Sigma = Sigma)
y <- rep(c(0, 1), each = n)
X <- rbind(x1 = X1, x2 = X2)
X <- scale(X)

# Scatter matrices
B <- betweenCov(variables = X, group = y)
W <- withinCov(variables = X, group = y)

# Eigenvectors
ev <- eigen(solve(W) %*% B)$vectors
slope <- - ev[1,1] / ev[2,1]
intercept <- ev[2,1]

par(pty = "s")
plot(X, col = y + 1, pch = 16)
abline(a = slope, b = intercept, lwd = 2, lty = 2)

我的(未完成的)作品

我在下面粘贴了我当前的解决方案。主要问题是如何根据决策边界旋转(和移动)密度图。仍然欢迎任何建议。

require(ggplot2)
library(grid)
library(MASS)

# Simulation parameters
mu1 <- c(5, -9)
mu2 <- c(4, 9)
rho <- 0.5
s1 <- 1
s2 <- 3
Sigma <- matrix(c(s1^2, rho * s1 * s2, rho * s1 * s2, s2^2), byrow = TRUE, nrow = 2)
n <- 50
# Multivariate normal sampling
X1 <- mvrnorm(n, mu = mu1, Sigma = Sigma)
X2 <- mvrnorm(n, mu = mu2, Sigma = Sigma)
# Combine into data frame
y <- rep(c(0, 1), each = n)
X <- rbind(x1 = X1, x2 = X2)
X <- scale(X)
X <- data.frame(X, class = y)

# Apply lda()
m1 <- lda(class ~ X1 + X2, data = X)
m1.pred <- predict(m1)
# Compute intercept and slope for abline
gmean <- m1$prior %*% m1$means
const <- as.numeric(gmean %*% m1$scaling)
z <- as.matrix(X[, 1:2]) %*% m1$scaling - const
slope <- - m1$scaling[1] / m1$scaling[2]
intercept <- const / m1$scaling[2]

# Projected values
LD <- data.frame(predict(m1)$x, class = y)

# Scatterplot
p1 <- ggplot(X, aes(X1, X2, color=as.factor(class))) + 
  geom_point() +
  theme_bw() +
  theme(legend.position = "none") +
  scale_x_continuous(limits=c(-5, 5)) + 
  scale_y_continuous(limits=c(-5, 5)) +
  geom_abline(intecept = intercept, slope = slope)

# Density plot 
p2 <- ggplot(LD, aes(x = LD1)) +
  geom_density(aes(fill = as.factor(class), y = ..scaled..)) +
  theme_bw() +
  theme(legend.position = "none")

grid.newpage()
print(p1)
vp <- viewport(width = .7, height = 0.6, x = 0.5, y = 0.3, just = c("centre"))
pushViewport(vp)
print(p2, vp = vp)

【问题讨论】:

    标签: r ggplot2 machine-learning statistics classification


    【解决方案1】:

    基本上你需要沿着分类器的方向投影数据,为每个类绘制一个直方图,然后旋转直方图,使其x轴平行于分类器。为了得到一个好的结果,需要对直方图进行一些试错。这是一个如何在 Matlab 中执行此操作的示例,用于天真的分类器(类的差异意味着)。对于 Fisher 分类器,它当然是相似的,您只需使用不同的分类器 w。我更改了您代码中的参数,因此情节与您提供的更相似。

    rng('default')
    n = 1000;
    mu1 = [1,3]';
    mu2 = [4,1]';
    rho = 0.3;
    s1 = .8;
    s2  = .5;
    Sigma = [s1^2,rho*s1*s1;rho*s1*s1, s2^2];
    X1 = mvnrnd(mu1,Sigma,n);
    X2 = mvnrnd(mu2,Sigma,n);
    X = [X1; X2];
    Y = [zeros(n,1);ones(n,1)];
    scatter(X1(:,1), X1(:,2), [], 'b' );
    hold on
    scatter(X2(:,1), X2(:,2), [], 'r' );
    axis equal
    m1 = mean(X(1:n,:))';
    m2 = mean(X(n+1:end,:))';
    plot(m1(1),m1(2),'bx','markersize',18)
    plot(m2(1),m2(2),'rx','markersize',18)
    plot([m1(1),m2(1)], [m1(2),m2(2)],'g')
    %% classifier taking only means into account
    w = m2 - m1; 
    w = w / norm(w);
    % project data onto w
    X1_projected = X1 * w;
    X2_projected = X2 * w;
    % plot histogram and rotate it
    angle = 180/pi * atan(w(2)/w(1));
    [hy1, hx1] = hist(X1_projected);
    [hy2, hx2] = hist(X2_projected);
    hy1 = hy1 / sum(hy1); % normalize
    hy2 = hy2 / sum(hy2); % normalize
    scale = 4; % set manually
    h1 = bar(hx1, scale*hy1,'b');
    h2 = bar(hx2, scale*hy2,'r');
    set([h1, h2],'ShowBaseLine','off')
    % rotate around the origin
    rotate(get(h1,'children'),[0,0,1], angle, [0,0,0])
    rotate(get(h2,'children'),[0,0,1], angle, [0,0,0])
    

    【讨论】:

      猜你喜欢
      • 2015-01-01
      • 2013-12-10
      • 2014-06-10
      • 2013-06-19
      • 2018-07-08
      • 2022-06-22
      • 1970-01-01
      • 2012-09-22
      • 2015-01-16
      相关资源
      最近更新 更多