【问题标题】:Differential Operator usable in Matrix form, in Python module Sympy可在 Python 模块 Sympy 中以矩阵形式使用的微分运算符
【发布时间】:2013-03-05 23:56:49
【问题描述】:

我们需要两个微分算子矩阵[B][C]如:

B = sympy.Matrix([[ D(x), D(y) ],
                  [ D(y), D(x) ]])

C = sympy.Matrix([[ D(x), D(y) ]])

ans = B * sympy.Matrix([[x*y**2],
                        [x**2*y]])
print ans
[x**2 + y**2]
[      4*x*y]

ans2 = ans * C
print ans2
[2*x, 2*y]
[4*y, 4*x]

这也可以用于计算矢量场的旋度,例如:

culr  = sympy.Matrix([[ D(x), D(y), D(z) ]])
field = sympy.Matrix([[ x**2*y, x*y*z, -x**2*y**2 ]])

要使用 Sympy 解决这个问题,必须创建以下 Python 类:

import sympy

class D( sympy.Derivative ):
    def __init__( self, var ):
        super( D, self ).__init__()
        self.var = var

    def __mul__(self, other):
        return sympy.diff( other, self.var )

当微分运算符的矩阵在左边相乘时,这个类单独解决。这里diff只有在知道要微分的函数时才会执行。

为了解决微分运算符矩阵在右侧相乘时的问题,核心类Expr 中的__mul__ 方法必须按以下方式更改:

class Expr(Basic, EvalfMixin):
    # ...
    def __mul__(self, other):
        import sympy
        if other.__class__.__name__ == 'D':
            return sympy.diff( self, other.var )
        else:
            return Mul(self, other)
    #...

它工作得很好,但是 Sympy 中应该有更好的本地解决方案来处理这个问题。 有人知道它可能是什么吗?

【问题讨论】:

标签: python matrix sympy differentiation automatic-differentiation


【解决方案1】:

此解决方案应用来自其他答案和from here 的提示。 D 运算符可以定义如下:

  • 只考虑从左边相乘,所以D(t)*2*t**3 = 6*t**22*t**3*D(t) 什么都不做
  • D 一起使用的所有表达式和符号必须有is_commutative = False
  • 使用evaluateExpr() 在给定表达式的上下文中求值
    • 从右到左沿着表达式查找D 运算符并将mydiff()* 应用于相应的右侧部分

*:mydiff 用于代替diff 以允许创建更高阶的D,例如mydiff(D(t), t) = D(t,t)

D 中的 __mul__() 中的 diff 仅供参考,因为在当前解决方案中,evaluateExpr() 实际上是在做区分工作。一个python mudule被创建并保存为d.py

import sympy
from sympy.core.decorators import call_highest_priority
from sympy import Expr, Matrix, Mul, Add, diff
from sympy.core.numbers import Zero

class D(Expr):
    _op_priority = 11.
    is_commutative = False
    def __init__(self, *variables, **assumptions):
        super(D, self).__init__()
        self.evaluate = False
        self.variables = variables

    def __repr__(self):
        return 'D%s' % str(self.variables)

    def __str__(self):
        return self.__repr__()

    @call_highest_priority('__mul__')
    def __rmul__(self, other):
        return Mul(other, self)

    @call_highest_priority('__rmul__')
    def __mul__(self, other):
        if isinstance(other, D):
            variables = self.variables + other.variables
            return D(*variables)
        if isinstance(other, Matrix):
            other_copy = other.copy()
            for i, elem in enumerate(other):
                other_copy[i] = self * elem
            return other_copy

        if self.evaluate:
            return diff(other, *self.variables)
        else:
            return Mul(self, other)

    def __pow__(self, other):
        variables = self.variables
        for i in range(other-1):
            variables += self.variables
        return D(*variables)

def mydiff(expr, *variables):
    if isinstance(expr, D):
        expr.variables += variables
        return D(*expr.variables)
    if isinstance(expr, Matrix):
        expr_copy = expr.copy()
        for i, elem in enumerate(expr):
            expr_copy[i] = diff(elem, *variables)
        return expr_copy
    return diff(expr, *variables)

def evaluateMul(expr):
    end = 0
    if expr.args:
        if isinstance(expr.args[-1], D):
            if len(expr.args[:-1])==1:
                cte = expr.args[0]
                return Zero()
            end = -1
    for i in range(len(expr.args)-1+end, -1, -1):
        arg = expr.args[i]
        if isinstance(arg, Add):
            arg = evaluateAdd(arg)
        if isinstance(arg, Mul):
            arg = evaluateMul(arg)
        if isinstance(arg, D):
            left = Mul(*expr.args[:i])
            right = Mul(*expr.args[i+1:])
            right = mydiff(right, *arg.variables)
            ans = left * right
            return evaluateMul(ans)
    return expr

def evaluateAdd(expr):
    newargs = []
    for arg in expr.args:
        if isinstance(arg, Mul):
            arg = evaluateMul(arg)
        if isinstance(arg, Add):
            arg = evaluateAdd(arg)
        if isinstance(arg, D):
            arg = Zero()
        newargs.append(arg)
    return Add(*newargs)

#courtesy: https://stackoverflow.com/a/48291478/1429450
def disableNonCommutivity(expr):
    replacements = {s: sympy.Dummy(s.name) for s in expr.free_symbols}
    return expr.xreplace(replacements)

def evaluateExpr(expr):
    if isinstance(expr, Matrix):
        for i, elem in enumerate(expr):
            elem = elem.expand()
            expr[i] = evaluateExpr(elem)
        return disableNonCommutivity(expr)
    expr = expr.expand()
    if isinstance(expr, Mul):
        expr = evaluateMul(expr)
    elif isinstance(expr, Add):
        expr = evaluateAdd(expr)
    elif isinstance(expr, D):
        expr = Zero()
    return disableNonCommutivity(expr)

示例 1:向量场的卷曲。请注意,使用commutative=False 定义变量很重要,因为它们在Mul().args 中的顺序会影响结果,请参阅this other question

from d import D, evaluateExpr
from sympy import Matrix
sympy.var('x', commutative=False)
sympy.var('y', commutative=False)
sympy.var('z', commutative=False)
curl  = Matrix( [[ D(x), D(y), D(z) ]] )
field = Matrix( [[ x**2*y, x*y*z, -x**2*y**2 ]] )       
evaluateExpr( curl.cross( field ) )
# [-x*y - 2*x**2*y, 2*x*y**2, -x**2 + y*z]

示例 2:结构分析中使用的典型 Ritz 近似。

from d import D, evaluateExpr
from sympy import sin, cos, Matrix
sin.is_commutative = False
cos.is_commutative = False
g1 = []
g2 = []
g3 = []
sympy.var('x', commutative=False)
sympy.var('t', commutative=False)
sympy.var('r', commutative=False)
sympy.var('A', commutative=False)
m=5
n=5
for j in xrange(1,n+1):
    for i in xrange(1,m+1):
        g1 += [sin(i*x)*sin(j*t),                 0,                 0]
        g2 += [                0, cos(i*x)*sin(j*t),                 0]
        g3 += [                0,                 0, sin(i*x)*cos(j*t)]
g = Matrix( [g1, g2, g3] )

B = Matrix(\
    [[     D(x),        0,        0],
     [    1/r*A,        0,        0],
     [ 1/r*D(t),        0,        0],
     [        0,     D(x),        0],
     [        0,    1/r*A, 1/r*D(t)],
     [        0, 1/r*D(t), D(x)-1/x],
     [        0,        0,        1],
     [        0,        1,        0]])

ans = evaluateExpr(B*g)

创建了一个print_to_file() 函数来快速检查大表达式。

import sympy
import subprocess
def print_to_file( guy, append=False ):
    flag = 'w'
    if append: flag = 'a'
    outfile = open(r'print.txt', flag)
    outfile.write('\n')
    outfile.write( sympy.pretty(guy, wrap_line=False) )
    outfile.write('\n')
    outfile.close()
    subprocess.Popen( [r'notepad.exe', r'print.txt'] )

print_to_file( B*g )
print_to_file( ans, append=True )

【讨论】:

  • 没有expand,我只剩下分组术语,这使得treatAdd()treatMul()更难应用
  • @asmeurer 你知道如何避免args 被排序吗? I posted a question for this issue too...
  • 如果您尝试添加更多功能,您的方法将会一团糟。不要为每种类型创建单独的函数,而是为递归调用自身的每种类型创建一个具有不同分支的函数。整个事情会简单得多。
  • 谢谢,我会继续努力,当我找到更好的解决方案时更新帖子
  • @asmeurer 我已经更新了答案,看起来这些treatExpr()treatMul()treatAdd() 看起来仍然很复杂,但它现在正在工作......
【解决方案2】:

Differential operators 不存在于 SymPy 的核心中,即使它们存在“一个运算符的乘法”而不是“一个运算符的应用”,也是一种不支持的符号滥用。 SymPy。

[1] 另一个问题是 SymPy 表达式只能从 sympy.Basic 的子类构建,因此您的 class D 在输入为 sympy_expr+D(z) 时很可能会引发错误。这就是(expression*D(z)) * (another_expr) 失败的原因。 (expression*D(z)) 无法构建。

此外,如果D 的参数不是单个Symbol,则不清楚您对该运算符的期望。

最后,diff(f(x), x)(其中f 是一个符号未知函数)返回一个未计算的表达式,正如您所观察到的,因为当f 未知时,没有其他可以合理返回的东西。稍后,当您替换 expr.subs(f(x), sin(x)) 时,将计算导数(最坏的情况下您可能需要调用 expr.doit())。

[2] 没有优雅的短的解决方案来解决您的问题。我建议解决您的问题的一种方法是覆盖Expr__mul__ 方法:而不是仅仅将表达式树相乘,它会检查左侧表达式树是否包含D 的实例并将应用它们。显然,如果您想添加新对象,这不会扩展。这是 sympy 设计中长期存在的已知问题。

编辑:[1] 只是为了允许创建包含D 的表达式。 [2] 对于包含更多内容的表达式来说是必要的,而不仅仅是一个 D 才能工作。

【讨论】:

  • 使用 D 类作为 sympy.Derivative 的子类适用于这种情况(请参阅更新后的问题)
  • 只是子类Expr,没有理由(实际上很混乱)子类Derivative
  • 正如我已经提到的,您的解决方案不适用于stuff+D。您应该已经添加了答案,而不是修改了问题,因此更容易发表评论。
  • 更新后的问题不是答案,只是包含了更多信息,以提供有关所需内容的一些见解。您在stuff+D 中指出的问题也在此处得到验证。谢谢!
  • 子类化Expr 确实更有意义。这样stuff+D就不会报错,比如sympy.expand( (D(x)+1/x)*x**2 )的结果就是x**2*D(x) + x。现在我正在研究一种重新进行乘法运算的方法,以便再次调用diff
【解决方案3】:

如果你想让正确的乘法起作用,你需要从 object 继承。这将导致x*D 回退到D.__rmul__。不过,我无法想象这是高优先级,因为运算符永远不会从右侧应用。

【讨论】:

  • 有一种方法可以强制使用 __mul____rmul__,正如 Julien Rioux here 所解释的那样
【解决方案4】:

目前还不可能制作一个始终自动工作的运算符。要真正完全工作,您需要http://code.google.com/p/sympy/issues/detail?id=1941。另请参阅https://github.com/sympy/sympy/wiki/Canonicalization(请随意编辑该页面)。

但是,您可以使用该 stackoverflow 问题中的想法创建一个大部分时间都可以工作的类,对于它无法处理的情况,编写一个简单的函数,该函数通过表达式并将运算符应用于它没有的地方'还没有被应用。

顺便说一句,将微分运算符视为“乘法”需要考虑的一件事是它是非关联的。即(D*f)*g = g*Df,而D*(f*g) = g*Df + f*Dg。所以当你做一些事情时你需要小心,它不会“吃掉”一个表达的一部分,而不是整个事情。例如,D*2*x 会因此而给出0。 SymPy 到处都假设乘法是关联的,所以它很可能在某些时候做错了。

如果这成为一个问题,我建议转储自动应用程序,并只使用一个通过并应用它的函数(正如我上面提到的,无论如何你都需要它)。

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 2019-01-31
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2020-03-11
    • 1970-01-01
    相关资源
    最近更新 更多