【问题标题】:How to decompose a 2x2 affine matrix with sympy?如何用 sympy 分解 2x2 仿射矩阵?
【发布时间】:2022-01-18 06:32:00
【问题描述】:

我试图用 sympy 展示仿射矩阵的分解 显示在以下 stackexchange 帖子中:

https://math.stackexchange.com/questions/612006/decomposing-an-affine-transformation

我设置了两个矩阵A_paramsA_matrix,前者代表 原始矩阵值,后者是从其构造的矩阵 基础参数。

import sympy
import itertools as it
import ubelt as ub
domain = {'real': True}

theta = sympy.symbols('theta', **domain)
sx, sy = sympy.symbols('sx, sy', nonzero=True, **domain)
m = sympy.symbols('m', **domain)

S = sympy.Matrix([  # scale
    [sx,  0],
    [ 0, sy]])

H = sympy.Matrix([  # shear
    [1, m],
    [0, 1]])

R = sympy.Matrix([  # rotation
    [sympy.cos(theta), -sympy.sin(theta)],
    [sympy.sin(theta),  sympy.cos(theta)]])


A_params = sympy.simplify((R @ H @ S))
a11, a12, a21, a22 = sympy.symbols(
    'a11, a12, a21, a22', real=True)
A_matrix = sympy.Matrix([[a11, a12], [a21, a22]])


print(ub.hzcat(['A_matrix = ', sympy.pretty(A_matrix)]))
print(ub.hzcat(['A_params = ', sympy.pretty(A_params)]))
A_matrix = ⎡a₁₁  a₁₂⎤
           ⎢        ⎥
           ⎣a₂₁  a₂₂⎦
A_params = ⎡sx⋅cos(θ)  sy⋅(m⋅cos(θ) - sin(θ))⎤
           ⎢                                 ⎥
           ⎣sx⋅sin(θ)  sy⋅(m⋅sin(θ) + cos(θ))⎦

据我了解,我应该能够简单地将这两个矩阵设置为 相等,然后求解感兴趣的参数。但是,我越来越 意想不到的结果。

首先,如果我只是尝试解决“sx”,我不会得到任何结果。

## Option 1: Matrix equality
mat_equation = sympy.Eq(A_matrix, A_params)
soln_sx = sympy.solve(mat_equation, sx)
print('soln_sx = {!r}'.format(soln_sx))

## Option 2: List of equations
lhs_iter = it.chain.from_iterable(A_matrix.tolist())
rhs_iter = it.chain.from_iterable(A_params.tolist())
equations = [sympy.Eq(lhs, rhs) for lhs, rhs in zip(lhs_iter, rhs_iter)]
soln_sx = sympy.solve(equations, sx)
print('soln_sx = {!r}'.format(soln_sx))
soln_sx = []
soln_sx = []

但是如果我尝试同时解决所有变量,我会得到一个结果 但它不符合我的预期

solve_for = (sx, theta, sy, m)
solutions = sympy.solve(mat_equation, *solve_for)
for sol, symbol in zip(solutions[0], solve_for):
    sol = sympy.simplify(sol)
    print('sol({!r}) = {!r}'.format(symbol, sol))
    # sympy.pretty_print(sol)
sol(sx) = -(a11**2 + a11*sqrt(a11**2 + a21**2) + a21**2)/(a11 + sqrt(a11**2 + a21**2))
sol(theta) = -2*atan((a11 + sqrt(a11**2 + a21**2))/a21)
sol(sy) = (-8*a11**6*a22 + 8*a11**5*a12*a21 - 8*a11**5*a22*sqrt(a11**2 + a21**2) + 8*a11**4*a12*a21*sqrt(a11**2 + a21**2) - 12*a11**4*a21**2*a22 + 12*a11**3*a12*a21**3 - 8*a11**3*a21**2*a22*sqrt(a11**2 + a21**2) + 8*a11**2*a12*a21**3*sqrt(a11**2 + a21**2) - 4*a11**2*a21**4*a22 + 4*a11*a12*a21**5 - a11*a21**4*a22*sqrt(a11**2 + a21**2) + a12*a21**5*sqrt(a11**2 + a21**2))/(8*a11**6 + 8*a11**5*sqrt(a11**2 + a21**2) + 16*a11**4*a21**2 + 12*a11**3*a21**2*sqrt(a11**2 + a21**2) + 9*a11**2*a21**4 + 4*a11*a21**4*sqrt(a11**2 + a21**2) + a21**6)
sol(m) = (a11*a12 + a21*a22)/(a11*a22 - a12*a21)

在努力完成上述工作后,我想看看 如果我至少可以验证 stackexchange 的解决方案。所以我把它编码了 象征性地:

# This is the guided solution by Stéphane Laurent
recon_sx = sympy.sqrt(a11 * a11 + a21 * a21)
recon_theta = sympy.atan2(a21, a11)
recon_sin_t = sympy.sin(recon_theta)
recon_cos_t = sympy.cos(recon_theta)

recon_msy = a12 * recon_sin_t + a22 * recon_cos_t

condition2 = sympy.simplify(sympy.Eq(recon_sin_t, 0))
condition1 = sympy.simplify(sympy.Not(condition2))
sy_cond1 = (recon_msy * recon_cos_t - a12) / recon_sin_t
sy_cond2 = (a22 - recon_msy * recon_sin_t) / recon_cos_t

recon_sy = sympy.Piecewise((sy_cond1, condition1), (sy_cond2, condition2))

recon_m = recon_msy / recon_sy

recon_S = sympy.Matrix([  # scale
    [recon_sx,  0],
    [ 0, recon_sy]])

recon_H = sympy.Matrix([  # shear
    [1, recon_m],
    [0, 1]])

recon_R = sympy.Matrix([  # rotation
    [sympy.cos(recon_theta), -sympy.sin(recon_theta)],
    [sympy.sin(recon_theta),  sympy.cos(recon_theta)]])

# Recombine the components
A_recon = sympy.simplify((recon_R @ recon_H @ recon_S))
print(ub.hzcat(['A_recon = ', sympy.pretty(A_recon)]))

结果与我的预期非常相似,但事实并非如此 似乎一直简化到可以以编程方式进行的程度 验证。

A_recon = ⎡     ⎧                                       a₂₁            ⎤
          ⎢     ⎪            a₁₂              for ──────────────── ≠ 0 ⎥
          ⎢     ⎪                                    _____________     ⎥
          ⎢     ⎪                                   ╱    2      2      ⎥
          ⎢a₁₁  ⎨                                 ╲╱  a₁₁  + a₂₁       ⎥
          ⎢     ⎪                                                      ⎥
          ⎢     ⎪a₁₁⋅a₂₂ + a₁₂⋅a₂₁ - a₂₁⋅a₂₂                           ⎥
          ⎢     ⎪───────────────────────────         otherwise         ⎥
          ⎢     ⎩            a₁₁                                       ⎥
          ⎢                                                            ⎥
          ⎢     ⎧-a₁₁⋅a₁₂ + a₁₁⋅a₂₂ + a₁₂⋅a₂₁            a₂₁           ⎥
          ⎢     ⎪────────────────────────────  for ──────────────── ≠ 0⎥
          ⎢     ⎪            a₂₁                      _____________    ⎥
          ⎢a₂₁  ⎨                                    ╱    2      2     ⎥
          ⎢     ⎪                                  ╲╱  a₁₁  + a₂₁      ⎥
          ⎢     ⎪                                                      ⎥
          ⎣     ⎩            a₂₂                      otherwise        ⎦

我的想法是条件搞砸了,所以我尝试了一下 使用两种情况:

recon_sy2 = sy_cond1
recon_m2 = recon_msy / recon_sy2

recon_S2 = sympy.Matrix([  # scale
    [recon_sx,  0],
    [ 0, recon_sy2]])

recon_H2 = sympy.Matrix([  # shear
    [1, recon_m2],
    [0, 1]])


recon_sy3 = sy_cond2
recon_m3 = recon_msy / recon_sy3

recon_S3 = sympy.Matrix([  # scale
    [recon_sx,  0],
    [ 0, recon_sy3]])

recon_H3 = sympy.Matrix([  # shear
    [1, recon_m3],
    [0, 1]])


# Recombine the components
A_recon2 = sympy.simplify((recon_R @ recon_H2 @ recon_S2))
A_recon3 = sympy.simplify((recon_R @ recon_H3 @ recon_S3))
print('')
print(ub.hzcat(['A_recon2 = ', sympy.pretty(A_recon2)]))
print('')
print(ub.hzcat(['A_recon3 = ', sympy.pretty(A_recon3)]))
A_recon2 = ⎡a₁₁              a₁₂             ⎤
           ⎢                                 ⎥
           ⎢     -a₁₁⋅a₁₂ + a₁₁⋅a₂₂ + a₁₂⋅a₂₁⎥
           ⎢a₂₁  ────────────────────────────⎥
           ⎣                 a₂₁             ⎦

A_recon3 = ⎡     a₁₁⋅a₂₂ + a₁₂⋅a₂₁ - a₂₁⋅a₂₂⎤
           ⎢a₁₁  ───────────────────────────⎥
           ⎢                 a₁₁            ⎥
           ⎢                                ⎥
           ⎣a₂₁              a₂₂            ⎦

但这似乎不允许进一步简化。

我不太明白 a22/a12 如何从顶部/底部方程中弹出 分别,但如果这种分解是正确的,它们应该是正确的,但是这些 结果让我担心它不是。

所以我的问题有两个:

  1. 任何同情大师可以帮助我获得分解工作的基本解决方案吗?

  2. 参考 SE 帖子中的分解是否错误?或者我不包括一个 允许简化的约束?如果是这样,我将如何在 sympy 中做到这一点?

更新

当所有变量都被联合求解时,我可以通过在 sympy.solve 的方程上使用 sympy.radsimp 获得更进一步的结果(仍然不确定为什么它不能自己求解 sx)。

solve_for = (sx, theta, sy, m)
solutions = sympy.solve(mat_equation, *solve_for, dict=True)
# minimal=True, quick=True, cubics=False, quartics=False, quintics=False, check=False)
for sol in solutions:
    for sym, symsol0 in sol.items():
        symsol = sympy.radsimp(symsol0)
        symsol = sympy.trigsimp(symsol)
        symsol = sympy.simplify(symsol)
        symsol = sympy.radsimp(symsol)
        print('\n=====')
        print('sym = {!r}'.format(sym))
        print('symsol  = {!r}'.format(symsol))
        print('--')
        sympy.pretty_print(symsol, wrap_line=False)
        print('--')
        print('=====\n')
=====
sym = sx
symsol  = -sqrt(a11**2 + a21**2)
--
    _____________
   ╱    2      2 
-╲╱  a₁₁  + a₂₁  
--
=====


=====
sym = theta
symsol  = 2*atan((a11 + sqrt(a11**2 + a21**2))/a21)
--
      ⎛         _____________⎞
      ⎜        ╱    2      2 ⎟
      ⎜a₁₁ + ╲╱  a₁₁  + a₂₁  ⎟
2⋅atan⎜──────────────────────⎟
      ⎝         a₂₁          ⎠
--
=====


=====
sym = m
symsol  = (a11*a12 + a21*a22)/(a11*a22 - a12*a21)
--
a₁₁⋅a₁₂ + a₂₁⋅a₂₂
─────────────────
a₁₁⋅a₂₂ - a₁₂⋅a₂₁
--
=====


=====
sym = sy
symsol  = (-a11*a22*sqrt(a11**2 + a21**2) + a12*a21*sqrt(a11**2 + a21**2))/(a11**2 + a21**2)
--
             _____________              _____________
            ╱    2      2              ╱    2      2 
- a₁₁⋅a₂₂⋅╲╱  a₁₁  + a₂₁   + a₁₂⋅a₂₁⋅╲╱  a₁₁  + a₂₁  
─────────────────────────────────────────────────────
                        2      2                     
                     a₁₁  + a₂₁                      
--
=====

但是 sx 的解决方案更接近我想要的(虽然它是一个负根,我认为这在技术上是正确的,但我的印象是 sympy 只处理主根)。

主要问题仍然悬而未决。 (虽然我更有信心原来的 SE 帖子是正确的)。

看起来它是在说“m”在分母中有行列式,这很有趣。 (分子是行的点积)。

更新2

我开始认为 sympy 或 Se 帖子中存在一些错误。我开始进行数字检查,它给出了我认为无法调和的错误(即轮换后相同)。

数值校验码是

params = [sx, theta, sy, m]
params_rand = {p: np.random.rand() for p in params}
A_params_rand = A_params.subs(params_rand)
matrix_rand = {lhs: rhs for lhs, rhs in zip(elements, ub.flatten(A_params_rand.tolist()))}
A_matrix_rand = A_matrix.subs(matrix_rand)
A_solved_rand = A_solved_recon.subs(matrix_rand)
A_recon_rand = A_recon.subs(matrix_rand)

mat1 = np.array(A_matrix_rand.tolist()).astype(float)
mat2 = np.array(A_params_rand.tolist()).astype(float)
mat3 = np.array(A_recon_rand.tolist()).astype(float)
assert np.all(np.isclose(mat1, mat2))

print(mat2 - mat3)

mat4 = np.array(A_solved_rand.tolist()).astype(float)

随机值似乎总是在矩阵中的a22处产生一些错误,所以我认为从手动输入的分解中对矩阵的sympy重建是错误的,或者分解本身是错误的。任何帮助都会非常有价值。

【问题讨论】:

  • atan2 函数具有参数yx,按此顺序排列。我会尝试recon_theta = sympy.atan2(a11, a21)

标签: sympy


【解决方案1】:

与同事讨论后,发现我在代码中犯了一个简单的错误。我交换了 sin 和 cos 术语。在使用@Stéphane Laurent 的分解时,修复此问题可以正确重建矩阵:

import sympy
import ubelt as ub

domain = {'real': True}

theta = sympy.symbols('theta', **domain)
sx, sy = sympy.symbols('sx, sy', **domain)
m = sympy.symbols('m', **domain)
params = [sx, theta, sy, m]

S = sympy.Matrix([  # scale
    [sx,  0],
    [ 0, sy]])

H = sympy.Matrix([  # shear
    [1, m],
    [0, 1]])

R = sympy.Matrix((  # rotation
    [sympy.cos(theta), -sympy.sin(theta)],
    [sympy.sin(theta),  sympy.cos(theta)]))

A_params = sympy.simplify((R @ H @ S))
a11, a12, a21, a22 = sympy.symbols(
    'a11, a12, a21, a22', real=True)
A_matrix = sympy.Matrix(((a11, a12), (a21, a22)))

print(ub.hzcat(['A_matrix = ', sympy.pretty(A_matrix)]))
print(ub.hzcat(['A_params = ', sympy.pretty(A_params)]))


# This is the guided solution by Stéphane Laurent
recon_sx = sympy.sqrt(a11 * a11 + a21 * a21)
recon_theta = sympy.atan2(a21, a11)
recon_sin_t = sympy.sin(recon_theta)
recon_cos_t = sympy.cos(recon_theta)

recon_msy = a12 * recon_cos_t + a22 * recon_sin_t


# condition2 = sympy.simplify(sympy.Eq(recon_sin_t, 0))
# condition1 = sympy.simplify(sympy.Not(condition2))
condition1 = sympy.Gt(recon_sin_t ** 2, recon_cos_t ** 2)
condition2 = sympy.Le(recon_sin_t ** 2, recon_cos_t ** 2)
sy_cond1 = (recon_msy * recon_cos_t - a12) / recon_sin_t
sy_cond2 = (a22 - recon_msy * recon_sin_t) / recon_cos_t
recon_sy = sympy.Piecewise((sy_cond1, condition1), (sy_cond2, condition2))
recon_m = sympy.simplify(recon_msy / recon_sy)


# Substitute the decomposition into the "A_params" to reconstruct "A_matrix"
recon_symbols = {
    sx: recon_sx,
    theta: recon_theta,
    m: recon_m,
    sy: recon_sy
}

for sym, symval in recon_symbols.items():
    # symval = sympy.radsimp(symval)
    symval = sympy.trigsimp(symval)
    symval = sympy.simplify(symval)
    if not isinstance(symval, sympy.Piecewise):
        symval = sympy.radsimp(symval)
    print('\n=====')
    print('sym = {!r}'.format(sym))
    print('symval  = {!r}'.format(symval))
    print('--')
    sympy.pretty_print(symval)
    print('=====\n')

A_recon = A_params.subs(recon_symbols)
A_recon = sympy.simplify(A_recon)
print(ub.hzcat(['A_recon = ', sympy.pretty(A_recon)]))

使用 Laurent 明确定义的分解的重建输出:

A_matrix = ⎡a₁₁  a₁₂⎤
           ⎢        ⎥
           ⎣a₂₁  a₂₂⎦
A_params = ⎡sx⋅cos(θ)  sy⋅(m⋅cos(θ) - sin(θ))⎤
           ⎢                                 ⎥
           ⎣sx⋅sin(θ)  sy⋅(m⋅sin(θ) + cos(θ))⎦

=====
sym = sx
symval  = sqrt(a11**2 + a21**2)
--
   _____________
  ╱    2      2
╲╱  a₁₁  + a₂₁
=====


=====
sym = theta
symval  = atan2(a21, a11)
--
atan2(a₂₁, a₁₁)
=====


=====
sym = m
symval  = (a11*a12 + a21*a22)/(a11*a22 - a12*a21)
--
a₁₁⋅a₁₂ + a₂₁⋅a₂₂
─────────────────
a₁₁⋅a₂₂ - a₁₂⋅a₂₁
=====


=====
sym = sy
symval  = (a11*a22*sqrt(a11**2 + a21**2) - a12*a21*sqrt(a11**2 + a21**2))/(a11**2 + a21**2)
--
           _____________              _____________
          ╱    2      2              ╱    2      2
a₁₁⋅a₂₂⋅╲╱  a₁₁  + a₂₁   - a₁₂⋅a₂₁⋅╲╱  a₁₁  + a₂₁
───────────────────────────────────────────────────
                       2      2
                    a₁₁  + a₂₁
=====

A_recon = ⎡a₁₁  a₁₂⎤
          ⎢        ⎥
          ⎣a₂₁  a₂₂⎦

我还能够让求解器生成正确重构“A_matrix”的解决方案,尽管我不得不跳过一些环节,并且分解采用不同的(有点奇怪)形式。但它确实产生了正确的答案:

mat_equation = sympy.Eq(A_matrix, A_params)
solve_for = (sx, theta, sy, m)
solutions = sympy.solve(mat_equation, *solve_for, dict=True)
solved = {}
# minimal=True, quick=True, cubics=False, quartics=False, quintics=False, check=False)
for sol in solutions:
    for sym, symsol0 in sol.items():
        symsol = sympy.radsimp(symsol0)
        symsol = sympy.trigsimp(symsol)
        symsol = sympy.simplify(symsol)
        symsol = sympy.radsimp(symsol)
        print('\n=====')
        print('sym = {!r}'.format(sym))
        print('symsol  = {!r}'.format(symsol))
        print('--')
        sympy.pretty_print(symsol, wrap_line=False)
        solved[sym] = symsol
        print('--')
        print('=====\n')

    A_matrix[0, :].dot(A_matrix[1, :]) / A_matrix.det()

A_solved_recon = sympy.simplify(A_params.subs(solved))

print(ub.hzcat(['A_solved_recon = ', sympy.pretty(A_solved_recon)]))

虽然我还没有弄清楚所有细节,但似乎这种 sympy-computed 分解是正确的:

=====
sym = sx
symsol  = -sqrt(a11**2 + a21**2)
--
    _____________
   ╱    2      2 
-╲╱  a₁₁  + a₂₁  
--
=====


=====
sym = theta
symsol  = -2*atan((a11 + sqrt(a11**2 + a21**2))/a21)
--
       ⎛         _____________⎞
       ⎜        ╱    2      2 ⎟
       ⎜a₁₁ + ╲╱  a₁₁  + a₂₁  ⎟
-2⋅atan⎜──────────────────────⎟
       ⎝         a₂₁          ⎠
--
=====


=====
sym = m
symsol  = (a11*a12 + a21*a22)/(a11*a22 - a12*a21)
--
a₁₁⋅a₁₂ + a₂₁⋅a₂₂
─────────────────
a₁₁⋅a₂₂ - a₁₂⋅a₂₁
--
=====


=====
sym = sy
symsol  = (-a11*a22*sqrt(a11**2 + a21**2) + a12*a21*sqrt(a11**2 + a21**2))/(a11**2 + a21**2)
--
             _____________              _____________
            ╱    2      2              ╱    2      2 
- a₁₁⋅a₂₂⋅╲╱  a₁₁  + a₂₁   + a₁₂⋅a₂₁⋅╲╱  a₁₁  + a₂₁  
─────────────────────────────────────────────────────
                        2      2                     
                     a₁₁  + a₂₁                      
--
=====

A_solved_recon = ⎡a₁₁  a₁₂⎤
                 ⎢        ⎥
                 ⎣a₂₁  a₂₂⎦

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2018-12-09
    • 1970-01-01
    • 2015-10-17
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多