【问题标题】:What is the most efficient way to plot 3d array in Python?在 Python 中绘制 3d 数组的最有效方法是什么?
【发布时间】:2018-02-08 17:34:35
【问题描述】:

在 Python 中绘制 3d 数组最有效的方法是什么?

例如:

volume = np.random.rand(512, 512, 512)

其中数组项表示每个像素的灰度颜色。


以下代码运行速度太慢:

import matplotlib as mpl
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
import matplotlib.pyplot as plt

fig = plt.figure()
ax = fig.gca(projection='3d')
volume = np.random.rand(20, 20, 20)
for x in range(len(volume[:, 0, 0])):
    for y in range(len(volume[0, :, 0])):
        for z in range(len(volume[0, 0, :])):
            ax.scatter(x, y, z, c = tuple([volume[x, y, z], volume[x, y, z], volume[x, y, z], 1]))
plt.show()

【问题讨论】:

标签: python matplotlib plot data-visualization scatter


【解决方案1】:

为获得更好的性能,请尽可能避免多次调用ax.scatter。 相反,将所有 x,y,z 坐标和颜色打包成一维数组(或 列表),然后调用ax.scatter一次:

ax.scatter(x, y, z, c=volume.ravel())

问题(就 CPU 时间和内存而言)随着 size**3 的增长而增长,其中 size 是立方体的边长。

此外,ax.scatter 将尝试渲染所有 size**3 点而不考虑 事实上,这些点中的大部分都被外部的那些点所掩盖 外壳。

这将有助于减少 volume 中的点数——也许通过 在渲染之前以某种方式对其进行总结或重新采样/插值。

我们还可以将所需的 CPU 和内存从 O(size**3) 减少到 O(size**2) 通过仅绘制外壳:

import functools
import itertools as IT
import numpy as np
import scipy.ndimage as ndimage
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

def cartesian_product_broadcasted(*arrays):
    """
    http://stackoverflow.com/a/11146645/190597 (senderle)
    """
    broadcastable = np.ix_(*arrays)
    broadcasted = np.broadcast_arrays(*broadcastable)
    dtype = np.result_type(*arrays)
    rows, cols = functools.reduce(np.multiply, broadcasted[0].shape), len(broadcasted)
    out = np.empty(rows * cols, dtype=dtype)
    start, end = 0, rows
    for a in broadcasted:
        out[start:end] = a.reshape(-1)
        start, end = end, end + rows
    return out.reshape(cols, rows).T

# @profile  # used with `python -m memory_profiler script.py` to measure memory usage
def main():
    fig = plt.figure()
    ax = fig.add_subplot(1, 1, 1, projection='3d')

    size = 512
    volume = np.random.rand(size, size, size)
    x, y, z = cartesian_product_broadcasted(*[np.arange(size, dtype='int16')]*3).T
    mask = ((x == 0) | (x == size-1) 
            | (y == 0) | (y == size-1) 
            | (z == 0) | (z == size-1))
    x = x[mask]
    y = y[mask]
    z = z[mask]
    volume = volume.ravel()[mask]

    ax.scatter(x, y, z, c=volume, cmap=plt.get_cmap('Greys'))
    plt.show()

if __name__ == '__main__':
    main()

但请注意,即使仅绘制外壳,也要实现与 size=512 我们仍然需要大约 1.3 GiB 的内存。还要注意,即使您有足够的总内存,但由于缺少 RAM,程序使用交换空间,那么程序的整体速度将 大幅减速。如果您发现自己处于这种情况,那么唯一的解决方案是找到一种更智能的方法来渲染可接受的图像使用更少的点,或者购买更多的 RAM。

【讨论】:

  • 此代码对最大 (100, 100, 100) 的数组运行速度足够快。我是否可以期望更快地构建 x、y、z 的方法有助于使其适用于 (512、512、512) 数组?
  • 对于这样大小的数组,我认为主要瓶颈是渲染 512**3(大约 1.34 亿)点,而不是创建坐标数组。我已经编辑了上面的帖子,以展示如何仅绘制立方体的外壳。这将 CPU 和内存的复杂性从 O(size**3) 降低到 O(size**2)。尽管如此,根据您机器的速度和资源,这可能需要相当长的时间来渲染。
  • 数组中的大多数项目都是零。所以,它应该在立方体内部显示相同的形状...
【解决方案2】:

首先,一个 512x512x512 点的密集网格是太多的数据,无法绘制,不是从技术角度来看,而是从观察绘图时能够从中看到任何有用的信息。您可能需要提取一些等值面,查看切片等。如果大多数点不可见,那么可能没问题,但您应该要求ax.scatter 仅显示非零点以使其更快。

也就是说,您可以通过以下方式更快地完成此操作。诀窍是消除所有 Python 循环,包括那些隐藏在库中的循环,例如 itertools

import matplotlib as mpl
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
import matplotlib.pyplot as plt

# Make this bigger to generate a dense grid.
N = 8

# Create some random data.
volume = np.random.rand(N, N, N)

# Create the x, y, and z coordinate arrays.  We use 
# numpy's broadcasting to do all the hard work for us.
# We could shorten this even more by using np.meshgrid.
x = np.arange(volume.shape[0])[:, None, None]
y = np.arange(volume.shape[1])[None, :, None]
z = np.arange(volume.shape[2])[None, None, :]
x, y, z = np.broadcast_arrays(x, y, z)

# Turn the volumetric data into an RGB array that's
# just grayscale.  There might be better ways to make
# ax.scatter happy.
c = np.tile(volume.ravel()[:, None], [1, 3])

# Do the plotting in a single call.
fig = plt.figure()
ax = fig.gca(projection='3d')
ax.scatter(x.ravel(),
           y.ravel(),
           z.ravel(),
           c=c)

【讨论】:

  • 非常感谢您的回答!但它的工作速度与 unutbu 解决方案相同。仅适用于大小不超过 ~(100, 100, 100) 的数组。但我会尽量清理它,只留下非零点。
【解决方案3】:

使用来自itertoolsproduct 可以实现类似的解决方案:

from itertools import product
from matplotlib import pyplot as plt
N = 8
fig = plt.figure(figsize=(10,10))
ax = fig.add_subplot(projection="3d")
space = np.array([*product(range(N), range(N), range(N))]) # all possible triplets of numbers from 0 to N-1
volume = np.random.rand(N, N, N) # generate random data
ax.scatter(space[:,0], space[:,1], space[:,2], c=space/8, s=volume*300)

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2011-04-17
    • 2014-10-19
    • 1970-01-01
    相关资源
    最近更新 更多