【发布时间】:2020-10-25 03:45:57
【问题描述】:
我尝试对代码进行矢量化以计算 Mandelbrot 集,但在打包/解包操作之后,我的数据类型混淆了,无法弄清楚发生了什么。这是一个示例代码:
#!/usr/bin/python3
import numpy as np
X0=-1.5
X1=0.5
WIDTH=10
MAXCOUNT=20
xlist = np.linspace(X0, X1, WIDTH+1)
line = np.array(xlist[:]+0.8j, dtype=np.complex128)
cptr = np.zeros(shape=(WIDTH+1,), dtype=np.int32)
print("line={}".format(line))
print("cptr={}".format(cptr))
for c in range(MAXCOUNT):
line, cptr = np.where((cptr==c) & (line*np.conj(line)<=4),
(line*line+xlist[:]+0.8j, cptr+1),
(line, cptr)
)
print("c = {:2}, line={}".format(c, line))
print("c = {:2}, cptr={}".format(c, cptr))
结果是
line=[-1.5+0.8j -1.3+0.8j -1.1+0.8j -0.9+0.8j -0.7+0.8j -0.5+0.8j -0.3+0.8j -0.1+0.8j 0.1+0.8j 0.3+0.8j 0.5+0.8j]
cptr=[0 0 0 0 0 0 0 0 0 0 0]
c = 0, line=[ 0.11-1.6j -0.25-1.28j -0.53-0.96j -0.73-0.64j -0.85-0.32j -0.89+0.j -0.85+0.32j -0.73+0.64j -0.53+0.96j -0.25+1.28j 0.11+1.6j ]
c = 0, cptr=[1.+0.j 1.+0.j 1.+0.j 1.+0.j 1.+0.j 1.+0.j 1.+0.j 1.+0.j 1.+0.j 1.+0.j 1.+0.j]
c = 1, line=[-4.0479+0.448j -2.8759+1.44j -1.7407+1.8176j -0.7767+1.7344j -0.0799+1.344j 0.2921+0.8j 0.3201+0.256j 0.0233-0.1344j -0.5407-0.2176j -1.2759+0.16j -2.0479+1.152j ]
c = 1, cptr=[2.+0.j 2.+0.j 2.+0.j 2.+0.j 2.+0.j 2.+0.j 2.+0.j 2.+0.j 2.+0.j 2.+0.j 2.+0.j]
[...]
c = 19, line=[-4.0479 +0.448j -2.8759 +1.44j -1.7407 +1.8176j -3.30488047-1.89421696j -2.49995199+0.5852288j -0.99385655-1.87331238j -0.99660079+2.05062098j 0.02382819-0.09983248j -1.46600521-1.78305502j 3.76538617+2.29032378j -2.0479 +1.152j ]
c = 19, cptr=[ 2.+0.j 2.+0.j 2.+0.j 3.+0.j 3.+0.j 4.+0.j 7.+0.j 20.+0.j 5.+0.j 4.+0.j 2.+0.j]
代码似乎产生了正确的结果,但由于某种原因,cptr 被转换为复杂事件,尽管我所做的只是根据np.where 中的条件加 1 或保持相同的值。
【问题讨论】:
-
(line, cptr)实际上被where用作np.array((line, cptr))。where作用于数组,而不是元组。line, cptr = where(...)解包一个数组,而不是一个元组。 -
好的...你能把它作为一个答案让我接受吗?
标签: python-3.x numpy vectorization