【问题标题】:Why did numpy converts my integer array to complex?为什么 numpy 将我的整数数组转换为复数?
【发布时间】:2020-10-25 03:45:57
【问题描述】:

我尝试对代码进行矢量化以计算 Mandelbrot 集,但在打包/解包操作之后,我的数据类型混淆了,无法弄清楚发生了什么。这是一个示例代码:

#!/usr/bin/python3

import numpy as np

X0=-1.5
X1=0.5
WIDTH=10
MAXCOUNT=20

xlist = np.linspace(X0, X1, WIDTH+1)
line = np.array(xlist[:]+0.8j, dtype=np.complex128)
cptr = np.zeros(shape=(WIDTH+1,), dtype=np.int32)
print("line={}".format(line))
print("cptr={}".format(cptr))
for c in range(MAXCOUNT):
    line, cptr = np.where((cptr==c) & (line*np.conj(line)<=4),
                          (line*line+xlist[:]+0.8j, cptr+1),
                          (line, cptr)
                         )
    print("c = {:2}, line={}".format(c, line))
    print("c = {:2}, cptr={}".format(c, cptr))

结果是

line=[-1.5+0.8j -1.3+0.8j -1.1+0.8j -0.9+0.8j -0.7+0.8j -0.5+0.8j -0.3+0.8j -0.1+0.8j  0.1+0.8j  0.3+0.8j  0.5+0.8j]
cptr=[0 0 0 0 0 0 0 0 0 0 0]
c =  0, line=[ 0.11-1.6j  -0.25-1.28j -0.53-0.96j -0.73-0.64j -0.85-0.32j -0.89+0.j -0.85+0.32j -0.73+0.64j -0.53+0.96j -0.25+1.28j  0.11+1.6j ]
c =  0, cptr=[1.+0.j 1.+0.j 1.+0.j 1.+0.j 1.+0.j 1.+0.j 1.+0.j 1.+0.j 1.+0.j 1.+0.j 1.+0.j]
c =  1, line=[-4.0479+0.448j  -2.8759+1.44j   -1.7407+1.8176j -0.7767+1.7344j -0.0799+1.344j   0.2921+0.8j     0.3201+0.256j   0.0233-0.1344j -0.5407-0.2176j -1.2759+0.16j   -2.0479+1.152j ]
c =  1, cptr=[2.+0.j 2.+0.j 2.+0.j 2.+0.j 2.+0.j 2.+0.j 2.+0.j 2.+0.j 2.+0.j 2.+0.j 2.+0.j]

[...]

c = 19, line=[-4.0479    +0.448j      -2.8759    +1.44j       -1.7407    +1.8176j -3.30488047-1.89421696j -2.49995199+0.5852288j  -0.99385655-1.87331238j -0.99660079+2.05062098j  0.02382819-0.09983248j -1.46600521-1.78305502j 3.76538617+2.29032378j -2.0479    +1.152j     ]
c = 19, cptr=[ 2.+0.j  2.+0.j  2.+0.j  3.+0.j  3.+0.j  4.+0.j  7.+0.j 20.+0.j  5.+0.j 4.+0.j  2.+0.j]

代码似乎产生了正确的结果,但由于某种原因,cptr 被转换为复杂事件,尽管我所做的只是根据np.where 中的条件加 1 或保持相同的值。

【问题讨论】:

  • (line, cptr) 实际上被where 用作np.array((line, cptr))where 作用于数组,而不是元组。 line, cptr = where(...) 解包一个数组,而不是一个元组。
  • 好的...你能把它作为一个答案让我接受吗?

标签: python-3.x numpy vectorization


【解决方案1】:

根据 hpaulj 的评论,问题出在 numpy.where 仅在 numpy.arrays 上运行并解压缩数组而不是元组这一事实。然后解决方案是对numpy.where 进行两次不同的调用:

#!/usr/bin/python3

import numpy as np

X0=-1.5
X1=0.5
WIDTH=10
MAXCOUNT=20

xlist = np.linspace(X0, X1, WIDTH+1)
line = np.array(xlist[:]+0.8j, dtype=np.complex128)
cptr = np.zeros(shape=(WIDTH+1,), dtype=np.int32)
print("line={}".format(line))
print("cptr={}".format(cptr))
for c in range(MAXCOUNT):
    flags = (cptr==c) & (line*np.conj(line)<=4)
    line = np.where(flags, line*line+xlist[:]+0.8j, line)
    cptr = np.where(flags, cptr+1, cptr)
    print("c = {:2}, line={}".format(c, line))
    print("c = {:2}, cptr={}".format(c, cptr))

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 2012-08-28
    • 1970-01-01
    • 1970-01-01
    • 2018-10-14
    • 1970-01-01
    • 2021-11-29
    • 2021-12-29
    相关资源
    最近更新 更多