【问题标题】:How to set seed for jitter in seaborn stripplot?如何在seaborn stripplot中设置抖动种子?
【发布时间】:2020-09-08 16:36:04
【问题描述】:

我正在尝试准确地再现带状图,以便我可以画线并可靠地在它们上面写字。但是,当我生成带抖动的带状图时,抖动是随机的,会阻止我实现目标。

我盲目地尝试了一些在其他 Stack Overflow 帖子中发现的 rcParams,例如 mpl.rcParams['svg.hashsalt'],但没有奏效。我还尝试为random.seed() 设置种子,但没有成功。

我正在运行的代码如下所示。

import seaborn as sns
import matplotlib.pyplot as plt
import random

plt.figure(figsize=(14,9))

random.seed(123)

catagories = []
values = []

for i in range(0,200):
    n = random.randint(1,3)
    catagories.append(n)

for i in range(0,200):
    n = random.randint(1,100)
    values.append(n)

sns.stripplot(catagories, values, size=5)
plt.title('Random Jitter')
plt.xticks([0,1,2],[1,2,3])
plt.show()

此代码生成一个stripplot,就像我想要的那样。但是,如果您两次运行代码,由于抖动,您将获得不同的点位置。我正在制作的图表需要抖动才能看起来不荒谬,但我想写在图表上。但是在运行代码之前没有办法知道点的确切位置,然后每次运行代码都会改变。

有什么方法可以为 seaborn stripplots 中的抖动埋下种子,使它们完全可复制?

【问题讨论】:

  • 我没有弄清楚这一点,但通过使用sns.swarmplot() 而不是sns.stripplot() 进行了妥协。看起来并不完全符合预期,但可以达到其目的,因为 swarmplot 中的交错不像 stripplot 中的抖动那样随机。

标签: python matplotlib seaborn random-seed jitter


【解决方案1】:
  • 抖动由scipy.stats.uniform确定
  • uniformclass uniform_gen(scipy.stats._distn_infrastructure.rv_continuous)
  • 这是class rv_continuous(rv_generic)的子类
  • 其中有一个seed参数,并使用np.random
  • 因此,使用np.random.seed()
    • 需要在每个绘图之前调用它。在示例中,np.random.seed(123) 必须在循环内。

来自 Stripplot 文档字符串

jitter : float, ``True``/``1`` is special-cased, optional
    Amount of jitter (only along the categorical axis) to apply. This
    can be useful when you have many points and they overlap, so that
    it is easier to see the distribution. You can specify the amount
    of jitter (half the width of the uniform random variable support),
    or just use ``True`` for a good default.

来自class _StripPlottercategorical.py

  • 抖动用scipy.stats.uniform计算
from scipy import stats

class _StripPlotter(_CategoricalScatterPlotter):
    """1-d scatterplot with categorical organization."""
    def __init__(self, x, y, hue, data, order, hue_order,
                 jitter, dodge, orient, color, palette):
        """Initialize the plotter."""
        self.establish_variables(x, y, hue, data, orient, order, hue_order)
        self.establish_colors(color, palette, 1)

        # Set object attributes
        self.dodge = dodge
        self.width = .8

        if jitter == 1:  # Use a good default for `jitter = True`
            jlim = 0.1
        else:
            jlim = float(jitter)
        if self.hue_names is not None and dodge:
            jlim /= len(self.hue_names)
        self.jitterer = stats.uniform(-jlim, jlim * 2).rvs

来自 rv_continuous 文档字符串

    seed : {None, int, `~np.random.RandomState`, `~np.random.Generator`}, optional
        This parameter defines the object to use for drawing random variates.
        If `seed` is `None` the `~np.random.RandomState` singleton is used.
        If `seed` is an int, a new ``RandomState`` instance is used, seeded
        with seed.
        If `seed` is already a ``RandomState`` or ``Generator`` instance,
        then that object is used.
        Default is None.

将您的代码与np.random.seed 一起使用

  • 所有的情节点都是一样的
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

fig, axes = plt.subplots(2, 3, figsize=(12, 12))
for x in range(6):

    np.random.seed(123)

    catagories = []
    values = []

    for i in range(0,200):
        n = np.random.randint(1,3)
        catagories.append(n)

    for i in range(0,200):
        n = np.random.randint(1,100)
        values.append(n)

    row = x // 3
    col = x % 3
    axcurr = axes[row, col]

    sns.stripplot(catagories, values, size=5, ax=axcurr)
    axcurr.set_title(f'np.random jitter {x+1}')
plt.show()

只使用random

  • 情节点四处移动
import seaborn as sns
import matplotlib.pyplot as plt
import random

fig, axes = plt.subplots(2, 3, figsize=(12, 12))
for x in range(6):

    random.seed(123)

    catagories = []
    values = []

    for i in range(0,200):
        n = random.randint(1,3)
        catagories.append(n)

    for i in range(0,200):
        n = random.randint(1,100)
        values.append(n)

    row = x // 3
    col = x % 3
    axcurr = axes[row, col]

    sns.stripplot(catagories, values, size=5, ax=axcurr)
    axcurr.set_title(f'random jitter {x+1}')
plt.show()

对数据使用random,对绘图使用np.random.seed

fig, axes = plt.subplots(2, 3, figsize=(12, 12))
for x in range(6):

    random.seed(123)

    catagories = []
    values = []

    for i in range(0,200):
        n = random.randint(1,3)
        catagories.append(n)

    for i in range(0,200):
        n = random.randint(1,100)
        values.append(n)

    row = x // 3
    col = x % 3
    axcurr = axes[row, col]

    np.random.seed(123)
    sns.stripplot(catagories, values, size=5, ax=axcurr)
    axcurr.set_title(f'np.random jitter {x+1}')
plt.show()

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 2021-05-01
    • 2021-11-17
    • 1970-01-01
    • 1970-01-01
    • 2017-08-26
    • 1970-01-01
    • 2016-06-03
    相关资源
    最近更新 更多