【问题标题】:Use Dask to run PyBaMM battery simulations in parallel使用 Dask 并行运行 PyBaMM 电池模拟
【发布时间】:2021-10-15 02:06:28
【问题描述】:

我正在使用PyBaMM 包对电池进行建模,我想使用 Dask 并行运行多个模拟。下面的例子是我尝试使用dask.delayed。 Dask 方法比非 Dask 方法慢。在这个例子中是否有更好的方法来使用 Dask?我应该设置一个 Dask Client() 来并行运行模拟吗?我在我的本地机器上运行这个例子,但我最终想在集群上运行一个类似的例子。

在 8 核 MacBook Pro 上运行示例所经过的时间如下所示。注释掉 main() 中的相应部分,以便在有或没有 Dask 的情况下运行。

Example Elapsed time
No Dask 8.02 seconds
Dask 8.74 seconds
import matplotlib.pyplot as plt
import pybamm
import time
import dask

def generate_plots(discharge, t, capacity, current, voltage):

    def styleplot(ax):
        ax.legend(loc='best')
        ax.grid(color='0.9')
        ax.set_frame_on(False)
        ax.tick_params(color='0.9')

    _, ax = plt.subplots(tight_layout=True)
    for i in range(len(discharge)):
        ax.plot(t[i], current[i], label=f'{discharge[i]} A')
    ax.set_xlabel('Time [s]')
    ax.set_ylabel('Current [A]')
    styleplot(ax)

    _, ax = plt.subplots(tight_layout=True)
    for i in range(len(discharge)):
        ax.plot(t[i], voltage[i], label=f'{discharge[i]} A')
    ax.set_xlabel('Time [s]')
    ax.set_ylabel('Terminal voltage [V]')
    styleplot(ax)

    _, ax = plt.subplots(tight_layout=True)
    for i in range(len(discharge)):
        ax.plot(capacity[i], voltage[i], label=f'{discharge[i]} A')
    ax.set_xlabel('Discharge capacity [Ah]')
    ax.set_ylabel('Terminal voltage [V]')
    styleplot(ax)

    plt.show()

def run_simulation(dis, t_eval):

    model = pybamm.lithium_ion.SPMe()

    param = model.default_parameter_values
    param['Current function [A]'] = '[input]'

    sim = pybamm.Simulation(model, parameter_values=param)
    sim.solve(t_eval, inputs={'Current function [A]': dis})

    return sim.solution

def main():
    tic = time.perf_counter()

    discharge = [4, 3.5, 3, 2.5, 2, 1.8, 1.5, 1]  # discharge currents [A]
    t_eval = [0, 4000]                            # evaluation time [s]

    # No Dask
    # ------------------------------------------------------------------------

    label = 'no Dask'

    sols = []
    for dis in discharge:
        sol = run_simulation(dis, t_eval)
        sols.append(sol)

    # Dask
    # ------------------------------------------------------------------------

    # label = 'Dask'

    # lazy_sols = []
    # for dis in discharge:
    #     sol = dask.delayed(run_simulation)(dis, t_eval)
    #     lazy_sols.append(sol)

    # sols = dask.compute(*lazy_sols)

    # ------------------------------------------------------------------------

    t = []
    capacity = []
    current = []
    voltage = []

    for sol in sols:
        t.append(sol['Time [s]'].entries)
        capacity.append(sol['Discharge capacity [A.h]'].entries)
        current.append(sol['Current [A]'].entries)
        voltage.append(sol["Terminal voltage [V]"].entries)

    toc = time.perf_counter()
    print(f'Elapsed time ({label}) = {toc - tic:.2f} s')

    generate_plots(discharge, t, capacity, current, voltage)

if __name__ == '__main__':
    main()

【问题讨论】:

  • 您能否提供a minimal code example 来重现您的问题?这将有助于诊断您的代码可能在哪里变慢。您是否还尝试通过实例化 Client() 对象来使用分布式调度程序运行 Dask?看这里docs.dask.org/en/stable/setup/…
  • @rrpelgrim 我在我的问题中提供了一个代码示例。您应该能够使用最新版本的 Python 和相关包来运行它。
  • 这是一个代码示例,但我不会称之为最小。我有很多代码可以在这里筛选以解决您的问题。如果问题归结为普遍适用的形式,我(和与我一起的其他人)更有可能参与。另外,我会尝试的第一件事是使用上面提到的分布式调度程序。
  • 使用分布式调度程序将允许您检查 Dask 仪表板,该仪表板应该能够告诉您您的模拟是否实际并行运行。
  • 可以通过client.get_dashboard()访问仪表盘的地址

标签: python dask


【解决方案1】:

根据@rrpelgrim 的建议,我实现了一个Client() 对象,它似乎通过使用分布式调度程序改进了我的示例代码的并行执行。修改后的示例如下所示。您可以通过注释掉 main() 中的相应部分来比较使用和不使用 Dask 所经过的时间。表中给出了使用 8 核 CPU 所经过的时间。

Example Elapsed time
No Dask 8.57 seconds
Dask 3.83 seconds
import matplotlib.pyplot as plt
import pybamm
import time
from dask.distributed import Client

def create_plots(discharge, t, capacity, current, voltage):

    def styleplot(ax, xlabel, ylabel):
        ax.legend(loc='best')
        ax.grid(color='0.9')
        ax.set_frame_on(False)
        ax.tick_params(color='0.9')
        ax.set_xlabel(xlabel)
        ax.set_ylabel(ylabel)

    _, ax = plt.subplots(tight_layout=True)
    for i in range(len(discharge)):
        ax.plot(t[i], current[i], label=f'{discharge[i]} A')
    styleplot(ax, xlabel='Time [s]', ylabel='Current [A]')

    _, ax = plt.subplots(tight_layout=True)
    for i in range(len(discharge)):
        ax.plot(t[i], voltage[i], label=f'{discharge[i]} A')
    styleplot(ax, xlabel='Time [s]', ylabel='Terminal voltage [V]')

    _, ax = plt.subplots(tight_layout=True)
    for i in range(len(discharge)):
        ax.plot(capacity[i], voltage[i], label=f'{discharge[i]} A')
    styleplot(ax, xlabel='Discharge capacity [Ah]', ylabel='Terminal voltage [V]')

    plt.show()

def run_simulation(dis, t_eval):

    model = pybamm.lithium_ion.SPMe()

    param = model.default_parameter_values
    param['Current function [A]'] = '[input]'

    sim = pybamm.Simulation(model, parameter_values=param)
    sim.solve(t_eval, inputs={'Current function [A]': dis})

    return sim.solution

def main(client):
    tic = time.perf_counter()

    discharge = [4, 3.5, 3, 2.5, 2, 1.8, 1.5, 1]  # discharge currents [A]
    t_eval = [0, 4000]                            # evaluation time [s]

    # No Dask
    # ------------------------------------------------------------------------

    # label = 'no Dask'

    # sols = []
    # for dis in discharge:
    #     sol = run_simulation(dis, t_eval)
    #     sols.append(sol)

    # Dask
    # ------------------------------------------------------------------------

    label = 'Dask'

    lazy_sols = client.map(run_simulation, discharge, t_eval=t_eval)
    sols = client.gather(lazy_sols)

    # ------------------------------------------------------------------------

    t = []
    capacity = []
    current = []
    voltage = []

    for sol in sols:
        t.append(sol['Time [s]'].entries)
        capacity.append(sol['Discharge capacity [A.h]'].entries)
        current.append(sol['Current [A]'].entries)
        voltage.append(sol["Terminal voltage [V]"].entries)

    toc = time.perf_counter()
    print(f'Elapsed time ({label}) = {toc - tic:.2f} s')

    create_plots(discharge, t, capacity, current, voltage)

if __name__ == '__main__':
    client = Client()
    print(client)
    main(client)

【讨论】:

  • 很高兴看到这很有帮助!如果我没记错的话,您可以保持延迟任务不变,无需切换到 client.map/gather API。实例化您的 Client() 对象后,Dask 应该使用它(即分布式调度程序)自动运行并行化延迟任务。
  • @rrpelgrim 使用dask.delayeddask.computer 与使用client.mapclient.gather 有什么区别和/或优势?
  • 来自文档:“[The Futures] 接口(即 client.map、client.gather 等)适用于像 dask.delayed 这样的任意任务调度,但它是即时的而不是惰性的,它提供在计算可能随时间演变的情况下具有更大的灵活性。” docs.dask.org/en/stable/…
猜你喜欢
  • 1970-01-01
  • 2019-03-15
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2018-11-14
  • 1970-01-01
  • 1970-01-01
  • 2021-05-13
相关资源
最近更新 更多