【问题标题】:How to plot confidence interval in Python?如何在 Python 中绘制置信区间?
【发布时间】:2020-05-01 23:28:24
【问题描述】:

我最近开始使用 Python,但我不明白如何为给定数据(或一组数据)绘制置信区间。我已经有一个函数可以根据我传递给它的置信度来计算给定一组测量值的上限​​和下限,但我不知道如何使用这两个值来绘制置信区间。我知道这里已经有人问过这个问题,但我没有找到有用的答案。

【问题讨论】:

标签: python python-3.x confidence-interval


【解决方案1】:

有几种方法可以完成您的要求:

仅使用matplotlib

from matplotlib import pyplot as plt
import numpy as np

#some example data
x = np.linspace(0.1, 9.9, 20)
y = 3.0 * x
#some confidence interval
ci = 1.96 * np.std(y)/np.sqrt(len(x))

fig, ax = plt.subplots()
ax.plot(x,y)
ax.fill_between(x, (y-ci), (y+ci), color='b', alpha=.1)

fill_between 可以满足您的需求。有关如何使用此功能的更多信息,请参阅:https://matplotlib.org/3.1.1/api/_as_gen/matplotlib.pyplot.fill_between.html

输出

或者,选择seaborn,它支持使用lineplotregplot, 见:https://seaborn.pydata.org/generated/seaborn.lineplot.html

【讨论】:

  • 为什么要除以平均值?在ci = 1.96 * np.std(y)/np.mean(y)。不应该是样本量的平方根吗?根据维基百科:en.wikipedia.org/wiki/Confidence_interval#Basic_steps
  • @CGFoX 这只是一个玩具示例。我同意,你会使用标准错误。为了说明,我使用了不正确的平均值。使用拟合参数和未知 SD 的 t 分布计算线性回归的置信区间确实更加复杂,这里假设它是正常的,因此 1.96 表示 95 % 的置信度。
【解决方案2】:

假设我们在这三个类别中具有三个类别以及某个估计量的置信区间的上下限:

data_dict = {}
data_dict['category'] = ['category 1','category 2','category 3']
data_dict['lower'] = [0.1,0.2,0.15]
data_dict['upper'] = [0.22,0.3,0.21]
dataset = pd.DataFrame(data_dict)

您可以使用以下代码绘制每个类别的置信区间:

for lower,upper,y in zip(dataset['lower'],dataset['upper'],range(len(dataset))):
    plt.plot((lower,upper),(y,y),'ro-',color='orange')
plt.yticks(range(len(dataset)),list(dataset['category']))

结果如下图:

【讨论】:

    【解决方案3】:

    对于跨类别的 CI,建立在 @omer sagi 建议的基础上,假设我们有一个 pandas 数据框,其列包含类别(如 category 1category 2category 3)和另一个具有连续数据(例如某种rating)的函数,这是一个使用pd.groupby()scipy.stats 绘制具有置信区间的组间均值差异的函数:

    import pandas as pd
    import numpy as np
    import scipy.stats as st
    
    def plot_diff_in_means(data: pd.DataFrame, col1: str, col2: str):
        """
        given data, plots difference in means with confidence intervals across groups
        col1: categorical data with groups
        col2: continuous data for the means
        """
        n = data.groupby(col1)[col2].count()
        # n contains a pd.Series with sample size for each category
    
        cat = list(data.groupby(col1, as_index=False)[col2].count()[col1])
        # cat has names of the categories, like 'category 1', 'category 2'
    
        mean = data.groupby(col1)[col2].agg('mean')
        # the average value of col2 across the categories
    
        std = data.groupby(col1)[col2].agg(np.std)
        se = std / np.sqrt(n)
        # standard deviation and standard error
    
        lower = st.t.interval(alpha = 0.95, df=n-1, loc = mean, scale = se)[0]
        upper = st.t.interval(alpha = 0.95, df =n-1, loc = mean, scale = se)[1]
        # calculates the upper and lower bounds using scipy
    
        for upper, mean, lower, y in zip(upper, mean, lower, cat):
            plt.plot((lower, mean, upper), (y, y, y), 'b.-')
            # for 'b.-': 'b' means 'blue', '.' means dot, '-' means solid line
        plt.yticks(
            range(len(n)), 
            list(data.groupby(col1, as_index = False)[col2].count()[col1])
            )
    

    给定一个假设数据:

    cat = ['a'] * 10 + ['b'] * 10 + ['c'] * 10
    a = np.linspace(0.1, 5.0, 10)
    b = np.linspace(0.5, 7.0, 10)
    c = np.linspace(7.5, 20.0, 10)
    rating = np.concatenate([a, b, c])
    
    dat_dict = dict()
    dat_dict['cat'] = cat
    dat_dict['rating'] = rating
    test_dat = pd.DataFrame(dat_dict)
    

    看起来像这样(当然还有更多行):

    cat rating
    a 0.10000
    a 0.64444
    b 0.50000
    b 0.12222
    c 7.50000
    c 8.88889

    我们可以使用该函数来绘制与 CI 的均值差异:

    plot_diff_in_means(data = test_dat, col1 = 'cat', col2 = 'rating')
    

    这为我们提供了以下图表:

    【讨论】:

      【解决方案4】:
      import matplotlib.pyplot as plt
      import statistics
      from math import sqrt
      
      
      def plot_confidence_interval(x, values, z=1.96, color='#2187bb', horizontal_line_width=0.25):
          mean = statistics.mean(values)
          stdev = statistics.stdev(values)
          confidence_interval = z * stdev / sqrt(len(values))
      
          left = x - horizontal_line_width / 2
          top = mean - confidence_interval
          right = x + horizontal_line_width / 2
          bottom = mean + confidence_interval
          plt.plot([x, x], [top, bottom], color=color)
          plt.plot([left, right], [top, top], color=color)
          plt.plot([left, right], [bottom, bottom], color=color)
          plt.plot(x, mean, 'o', color='#f44336')
      
          return mean, confidence_interval
      
      
      plt.xticks([1, 2, 3, 4], ['FF', 'BF', 'FFD', 'BFD'])
      plt.title('Confidence Interval')
      plot_confidence_interval(1, [10, 11, 42, 45, 44])
      plot_confidence_interval(2, [10, 21, 42, 45, 44])
      plot_confidence_interval(3, [20, 2, 4, 45, 44])
      plot_confidence_interval(4, [30, 31, 42, 45, 44])
      plt.show()
      

      结果:

      【讨论】:

        猜你喜欢
        • 2020-11-16
        • 2023-04-07
        • 2016-12-24
        • 2013-02-05
        • 2018-10-02
        • 1970-01-01
        • 2018-02-05
        • 1970-01-01
        • 1970-01-01
        相关资源
        最近更新 更多