【发布时间】:2017-04-17 22:56:17
【问题描述】:
我正在尝试学习使用 sympy 来优化 C 中数学表达式的数值评估。一方面我知道 sympy 可以生成 C 代码来评估一个表达式,如下所示:
from mpmath import *
from sympy.utilities.codegen import codegen
from sympy import *
x,y,z = symbols('x y z')
[(c_name, c_code), (h_name, c_header)] = codegen([('x', sin(x))], 'C')
然后您可以将 c_code 打印到目标文件。另一方面,我知道 cse 可以用来简化表达式如下:
from mpmath import *
from sympy.utilities.codegen import codegen
from sympy import *
x,y,z, B1, B2, B3, B4 = symbols('x y z B1 B2 B3 B4 ')
cse([3.0*B2 + 8.0*B3*x**2 + 3.0*B3*x*y + 4.0*B3*x*z + B3*y**2 + B3*z**2 + B4*x**4 + B4*x**3*y + B4*x**3*z + B4*x**2*y**2 + B4*x**2*y*z + B4*x**2*z**2, 7.0*B3*x*y + 2*B3*x*z + B3*(x**2 + y**2) + B4*x**3*y + B4*x**2*y**2 + B4*x**2*y*z + B4*x*y**3 + B4*x*y**2*z + B4*x*y*z**2, B3*x*y + 8.0*B3*x*z + B3*(x**2 + z**2) + B4*x**3*z + B4*x**2*y*z + B4*x**2*z**2 + B4*x*y**2*z + B4*x*y*z**2 + B4*x*z**3, 3.0*B2 + B3*x**2 + 3.0*B3*x*y + B3*x*z + 8.0*B3*y**2 + 3.0*B3*y*z + B3*z**2 + B4*x**2*y**2 + B4*x*y**3 + B4*x*y**2*z + B4*y**4 + B4*y**3*z + B4*y**2*z**2, B3*x*y + 2*B3*x*z + 6.0*B3*y*z + B3*(y**2 + z**2) + B4*x**2*y*z + B4*x*y**2*z + B4*x*y*z**2 + B4*y**3*z + B4*y**2*z**2 + B4*y*z**3, 3.0*B2 + B3*x**2 + B3*x*y + 3.0*B3*x*z + B3*y**2 + 3.0*B3*y*z + 8.0*B3*z**2 + B4*x**2*z**2 + B4*x*y*z**2 + B4*x*z**3 + B4*y**2*z**2 + B4*y*z**3 + B4*z**4])
得到输出:
([(x0, z**2),
(x1, B3*x0),
(x2, B3*x),
(x3, x2*y),
(x4, 3.0*x3),
(x5, 3.0*B2),
(x6, y**2),
(x7, B3*x6),
(x8, x2*z),
(x9, x**2),
(x10, B3*x9),
(x11, B4*x**3),
(x12, x11*y),
(x13, x11*z),
(x14, B4*y),
(x15, x14*x9*z),
(x16, B4*x9),
(x17, x16*x6),
(x18, x0*x16),
(x19, 2*x8),
(x20, y**3),
(x21, B4*x),
(x22, x20*x21),
(x23, x0*x21*y),
(x24, x21*x6*z),
(x25, z**3),
(x26, x21*x25),
(x27, B3*y*z),
(x28, x10 + 3.0*x27),
(x29, B4*x20*z),
(x30, B4*x0*x6),
(x31, x14*x25)],
[B4*x**4 + x1 + 8.0*x10 + x12 + x13 + x15 + x17 + x18 + x4 + x5 + x7 + 4.0*x8,
B3*(x6 + x9) + x12 + x15 + x17 + x19 + x22 + x23 + x24 + 7.0*x3,
B3*(x0 + x9) + x13 + x15 + x18 + x23 + x24 + x26 + x3 + 8.0*x8,
B4*y**4 + x1 + x17 + x22 + x24 + x28 + x29 + x30 + x4 + x5 + 8.0*x7 + x8,
B3*(x0 + x6) + x15 + x19 + x23 + x24 + 6.0*x27 + x29 + x3 + x30 + x31,
B4*z**4 + 8.0*x1 + x18 + x23 + x26 + x28 + x3 + x30 + x31 + x5 + x7 + 3.0*x8])
我的问题是如何正确地将以前的结果转换为 C 代码?有时可以用于转换字符串中的简化表达式并对此类字符串进行操作,如何做到这一点?目的是自动化 CSE 之后的代码生成过程,以便将其应用于许多表达式。
编辑:
根据下面的答案,感谢 Wrzlprmft,生成相应 C 代码 sn-p 的代码是:
from sympy.printing import ccode
from sympy import symbols, cse, numbered_symbols
x,y,z, B1, B2, B3, B4 = symbols('x y z B1 B2 B3 B4 ')
results = [3.0*B2 + 8.0*B3*x**2 + 3.0*B3*x*y + 4.0*B3*x*z + B3*y**2 + B3*z**2 + B4*x**4 + B4*x**3*y + B4*x**3*z + B4*x**2*y**2 + B4*x**2*y*z + B4*x**2*z**2, 7.0*B3*x*y + 2*B3*x*z + B3*(x**2 + y**2) + B4*x**3*y + B4*x**2*y**2 + B4*x**2*y*z + B4*x*y**3 + B4*x*y**2*z + B4*x*y*z**2, B3*x*y + 8.0*B3*x*z + B3*(x**2 + z**2) + B4*x**3*z + B4*x**2*y*z + B4*x**2*z**2 + B4*x*y**2*z + B4*x*y*z**2 + B4*x*z**3, 3.0*B2 + B3*x**2 + 3.0*B3*x*y + B3*x*z + 8.0*B3*y**2 + 3.0*B3*y*z + B3*z**2 + B4*x**2*y**2 + B4*x*y**3 + B4*x*y**2*z + B4*y**4 + B4*y**3*z + B4*y**2*z**2, B3*x*y + 2*B3*x*z + 6.0*B3*y*z + B3*(y**2 + z**2) + B4*x**2*y*z + B4*x*y**2*z + B4*x*y*z**2 + B4*y**3*z + B4*y**2*z**2 + B4*y*z**3, 3.0*B2 + B3*x**2 + B3*x*y + 3.0*B3*x*z + B3*y**2 + 3.0*B3*y*z + 8.0*B3*z**2 + B4*x**2*z**2 + B4*x*y*z**2 + B4*x*z**3 + B4*y**2*z**2 + B4*y*z**3 + B4*z**4]
CSE_results = cse(results,numbered_symbols("helper_"))
with open("snippet.c", "w") as output:
for helper in CSE_results[0]:
output.write("double ")
output.write(ccode(helper[1],helper[0]))
output.write("\n")
for i,result in enumerate(CSE_results[1]):
output.write(ccode(result,"result_%d"%i))
output.write("\n")
【问题讨论】:
-
由于SymPy issue 8997,如果不编写自己的代码打印机,您的约束使这成为不可能。请注意,您的 C 编译器可以自行处理此问题以获得足够高的优化标志(请参阅上一个链接上的最后一个 cmets)。您的 C 编译器也可以处理 CSE。
-
感谢您的回答。好的,约束已被删除,是的,编译器可以做到这一点,但如果人类编写脚本来自动发出代码,为什么不这样做呢?
-
如果一个人编写了一个脚本来自动发出代码,为什么不这样做呢? – 与您不手动编写 C 代码的原因相同。无论如何,我可能稍后再写一个答案。