【问题标题】:Is there a way to write a math formula on matplotlib plot dynamically?有没有办法在 matplotlib 绘图上动态编写数学公式?
【发布时间】:2019-09-08 11:05:18
【问题描述】:

我正在为我的 Python 实验室工作制作模板。总结一下它的目的,就是绘制数据点,用 scipy curve_fit 拟合一个预定义的模型。通常我拟合多项式或指数曲线。我设法在绘图上动态打印拟合参数,但我每次都必须手动输入相关方程。我想知道,有没有一种优雅的方式来动态地做到这一点?我读过关于 sympy 的文章,但暂时我做不到。

代码如下:

import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
from datetime import datetime

#two example functions
def f(x, p0, p1):
    return p0 * x + p1

def g(x, p0, p1):
    return p0 * np.exp(x * p1)

#example data
xval = np.array([0,1,2,3,4,5,6])
yval = np.array([0, 2,3.95,5.8,8.1, 10.2, 12.4])

#curve fitting
popt, pcov = curve_fit(f, xval, yval)

plt.rcParams.update({'font.size': 12})
plt.figure(figsize=(9,7))
plt.plot(xval, yval,'ko', label = 'Data points', markersize = 7)
plt.title('TITLE', fontsize = 15)
plt.grid()
plt.plot(xval, f(xval, *popt),'r-', label = 'Fit')
#printing the params on plot
for idx in range(len(popt)):
    plt.text(0.8,0.05+0.05*(idx+1), 'p'+str(idx)+' = {0:.5f}'.format(popt[idx]), transform=plt.gca().transAxes)

#manually writing the equation, that's what I want to print dynamically
plt.text(0.8, 0.05, '$y = p0 \cdot x + p1 $' , transform=plt.gca().transAxes)

plt.text(0.86, 1.01, datetime.today().strftime('%Y.%m.%d.'), transform=plt.gca().transAxes)
plt.text(0 ,1.01, 'NAME', transform=plt.gca().transAxes)
plt.ylabel('Y axis title')
plt.xlabel('X axis title')
plt.legend()
plt.show()

预期结果是:

如果我使用拟合函数 - 假设 g(x, p0, p1) 返回 p0 * np.exp(x * p1) 然后返回的公式本身应该打印在图上,就像示例代码中的另一个一样:

plt.text(0.8, 0.05, '$y = p0 \cdot x + p1 $' , transform=plt.gca().transAxes) 

除非它是手动解决方案。

我非常感谢任何建议。

【问题讨论】:

  • 您需要在某些时候输入公式。可能你想把它存储在它所属的函数附近?
  • 是的,也许我会将公式存储在变量中,并在需要时使用它们。数量不多。

标签: python matplotlib scipy


【解决方案1】:

我认为你可以使用 sympy 包。 它允许定义自定义变量、创建、表达式然后对其进行评估。我不确定对性能有什么影响

这是您的更改代码:

import numpy as np
import sympy
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
from datetime import datetime


#two example functions
x, p0, p1 = sympy.var("x p0 p1")

f = p0 * x + p1 

g = p0 * sympy.exp(x*p1)

def partial_fun(sympy_expr):

    def res_fun(X, P0, P1):
        return np.array([sympy_expr.evalf(subs={x: x_, p0: P0, p1: P1}) for x_ in X], dtype=np.float)

    return res_fun

#example data
xval = np.array([0,1,2,3,4,5,6])
yval = np.array([0, 2,3.95,5.8,8.1, 10.2, 12.4])

#curve fitting
popt, pcov = curve_fit(partial_fun(f), xval, yval)

plt.rcParams.update({'font.size': 12})
plt.figure(figsize=(9,7))
plt.plot(xval, yval,'ko', label = 'Data points', markersize = 7)
plt.title('TITLE', fontsize = 15)
plt.grid()
plt.plot(xval, partial_fun(f)(xval, *popt),'r-', label = 'Fit')
#printing the params on plot
for idx in range(len(popt)):
    plt.text(0.8,0.05+0.05*(idx+1), 'p'+str(idx)+' = {0:.5f}'.format(popt[idx]), transform=plt.gca().transAxes)

#manually writing the equation, that's what I want to print dynamically
plt.text(0.8, 0.05, f'$y = {f} $' , transform=plt.gca().transAxes)

plt.text(0.86, 1.01, datetime.today().strftime('%Y.%m.%d.'), transform=plt.gca().transAxes)
plt.text(0 ,1.01, 'NAME', transform=plt.gca().transAxes)
plt.ylabel('Y axis title')
plt.xlabel('X axis title')
plt.legend()
plt.show()

【讨论】:

    【解决方案2】:

    我实际上设法提出了一个解决方案(虽然没有同情),我必须手动输入公式,但它们是自动选择的。我为此使用字典。

    代码如下:

    import numpy as np
    import matplotlib.pyplot as plt
    from scipy.optimize import curve_fit
    from datetime import datetime
    
    fun_dict = {}
    
    #three example functions
    def f(x, p0, p1):
        return p0 * x + p1
    
    def g(x, p0, p1):
        return p0 * np.exp(x * p1)
    
    def h(x, p0, p1, p2):
        return p0 * x ** 2 + p1 * x + p2
    
    f_string = '$y = p0 \cdot x + p1 $'
    fun_dict['f'] = f_string
    g_string = '$y = p0 \cdot e^{p1 \cdot x} $'
    fun_dict['g'] = g_string
    h_string = '$y = p0 \cdot x^2 + p1 \cdot x + p2$'
    fun_dict['h'] = h_string
    #example data
    xval = np.array([0,1,2,3,4,5,6])
    yval = np.array([0, 2,3.95,5.8,8.1, 10.2, 12.4])
    
    
    def get_fun(func):
        popt, _ = curve_fit(func, xval, yval)
        return popt, fun_dict[str(func.__name__)], func
    
    popt, str_name, func = get_fun(h)
    
    
    
    plt.rcParams.update({'font.size': 12})
    plt.figure(figsize=(9,7))
    plt.plot(xval, yval,'ko', label = 'Data points', markersize = 7)
    plt.title('TITLE', fontsize = 15)
    plt.grid()
    plt.plot(xval, func(xval, *popt),'r-', label = 'Fit')
    for idx in range(len(popt)):
        plt.text(0.8,0.05+0.05*(idx+1), 'p'+str(idx)+' = {0:.5f}'.format(popt[idx]), transform=plt.gca().transAxes)
    
    plt.text(0.7, 0.05, str_name, transform=plt.gca().transAxes)
    
    plt.text(0.86, 1.01, datetime.today().strftime('%Y.%m.%d.'), transform=plt.gca().transAxes)
    plt.text(0 ,1.01, 'NAME', transform=plt.gca().transAxes)
    plt.ylabel('Y axis title')
    plt.xlabel('X axis title')
    plt.legend()
    plt.show()
    

    【讨论】:

      猜你喜欢
      • 2013-04-06
      • 1970-01-01
      • 1970-01-01
      • 2021-03-16
      • 2016-03-24
      • 2020-12-01
      • 1970-01-01
      • 2017-12-24
      • 2017-09-16
      相关资源
      最近更新 更多