【问题标题】:Scipy Multivariate Normal: How to draw deterministic samples?Scipy Multivariate Normal:如何绘制确定性样本?
【发布时间】:2017-08-18 23:52:28
【问题描述】:

我正在使用Scipy.stats.multivariate_normal 从多元正态分布中抽取样本。像这样:

from scipy.stats import multivariate_normal
# Assume we have means and covs
mn = multivariate_normal(mean = means, cov = covs)
# Generate some samples
samples = mn.rvs()

每次运行的样本都不同。如何始终获得相同的样本? 我期待的是这样的:

mn = multivariate_normal(mean = means, cov = covs, seed = aNumber)

samples = mn.rsv(seed = aNumber)

【问题讨论】:

    标签: python random scipy statistics probability


    【解决方案1】:

    有两种方式:

    1. rvs() 方法接受 random_state 参数。它的价值可以 是整数种子,或numpy.random.RandomState 的实例。在 在这个例子中,我使用了一个整数种子:

      In [46]: mn = multivariate_normal(mean=[0,0,0], cov=[1, 5, 25])
      
      In [47]: mn.rvs(size=5, random_state=12345)
      Out[47]: 
      array([[-0.51943872,  1.07094986, -1.0235383 ],
             [ 1.39340583,  4.39561899, -2.77865152],
             [ 0.76902257,  0.63000355,  0.46453938],
             [-1.29622111,  2.25214387,  6.23217368],
             [ 1.35291684,  0.51186476,  1.37495817]])
      
      In [48]: mn.rvs(size=5, random_state=12345)
      Out[48]: 
      array([[-0.51943872,  1.07094986, -1.0235383 ],
             [ 1.39340583,  4.39561899, -2.77865152],
             [ 0.76902257,  0.63000355,  0.46453938],
             [-1.29622111,  2.25214387,  6.23217368],
             [ 1.35291684,  0.51186476,  1.37495817]])
      
    2. 您可以为 numpy 的全局随机数生成器设置种子。如果没有给出random_state,这是multivariate_normal.rvs() 使用的生成器:

      In [54]: mn = multivariate_normal(mean=[0,0,0], cov=[1, 5, 25])
      
      In [55]: np.random.seed(123)
      
      In [56]: mn.rvs(size=5)
      Out[56]: 
      array([[  0.2829785 ,   2.23013222,  -5.42815302],
             [  1.65143654,  -1.2937895 ,  -7.53147357],
             [  1.26593626,  -0.95907779, -12.13339622],
             [ -0.09470897,  -1.51803558,  -4.33370201],
             [ -0.44398196,  -1.4286283 ,   7.45694813]])
      
      In [57]: np.random.seed(123)
      
      In [58]: mn.rvs(size=5)
      Out[58]: 
      array([[  0.2829785 ,   2.23013222,  -5.42815302],
             [  1.65143654,  -1.2937895 ,  -7.53147357],
             [  1.26593626,  -0.95907779, -12.13339622],
             [ -0.09470897,  -1.51803558,  -4.33370201],
             [ -0.44398196,  -1.4286283 ,   7.45694813]])
      

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 2014-10-06
      • 2018-12-14
      • 1970-01-01
      • 2019-07-12
      • 2019-05-19
      • 2018-02-21
      • 1970-01-01
      相关资源
      最近更新 更多