【发布时间】:2015-01-29 00:03:50
【问题描述】:
有没有人尝试在 Java 中实现 matlab 的 filtfilt() 函数(或至少在 C++ 中)?如果你们有算法,那将有很大帮助。
【问题讨论】:
-
做一个卷积从左到右,然后在结果上重做一次从右到左,这就是你的
filtfilt
标签: java c++ matlab signal-processing
有没有人尝试在 Java 中实现 matlab 的 filtfilt() 函数(或至少在 C++ 中)?如果你们有算法,那将有很大帮助。
【问题讨论】:
filtfilt
标签: java c++ matlab signal-processing
Here 是我在 C++ 中实现的 filtfilt 算法,如在 MATLAB 中实现的。希望对您有所帮助。
【讨论】:
好的,我知道这个问题很古老,但也许我可以帮助到这里的其他人想知道 filtfilt 究竟做了什么。
虽然从文档中可以明显看出 filtfilt 进行了前向后向(又名零相位)过滤,但对我来说它如何处理诸如 padding 和 初始条件。
由于我在此处(或其他地方)找不到任何其他答案以及有关filtfilt 的这些实现细节的足够信息,因此我实现了Python 的scipy.signal.filtfilt 的简化版本,基于其源代码和文档(因此,不是Java,也不是C++,而是Python)。我相信scipy 版本works the same 方式为Matlab 的。
为简单起见,下面的代码是专门为二阶 IIR filter 编写的,它假设系数向量 a 和 b 是已知的(例如从 scipy.signal.butter 或 calculated by hand 获得) .
它匹配filtfilt 默认行为,使用长度为3 * max(len(a), len(b)) 的odd 填充,在前向传递之前应用。使用scipy.signal.lfilter_zi (docs) 的方法找到初始状态。
免责声明:此代码仅旨在提供对filtfilt 某些实现细节的一些见解,因此目标是清晰而不是计算效率/性能。 scipy.signal.filtfilt 的实现要快得多(例如,根据我系统上的快速而肮脏的 timeit 测试,速度快了 100 倍)。
import numpy
def custom_filter(b, a, x):
"""
Filter implemented using state-space representation.
Assume a filter with second order difference equation (assuming a[0]=1):
y[n] = b[0]*x[n] + b[1]*x[n-1] + b[2]*x[n-2] + ...
- a[1]*y[n-1] - a[2]*y[n-2]
"""
# State space representation (transposed direct form II)
A = numpy.array([[-a[1], 1], [-a[2], 0]])
B = numpy.array([b[1] - b[0] * a[1], b[2] - b[0] * a[2]])
C = numpy.array([1.0, 0.0])
D = b[0]
# Determine initial state (solve zi = A*zi + B, see scipy.signal.lfilter_zi)
zi = numpy.linalg.solve(numpy.eye(2) - A, B)
# Scale the initial state vector zi by the first input value
z = zi * x[0]
# Apply filter
y = numpy.zeros(numpy.shape(x))
for n in range(len(x)):
# Determine n-th output value (note this simplifies to y[n] = z[0] + b[0]*x[n])
y[n] = numpy.dot(C, z) + D * x[n]
# Determine next state (i.e. z[n+1])
z = numpy.dot(A, z) + B * x[n]
return y
def custom_filtfilt(b, a, x):
# Apply 'odd' padding to input signal
padding_length = 3 * max(len(a), len(b)) # the scipy.signal.filtfilt default
x_forward = numpy.concatenate((
[2 * x[0] - xi for xi in x[padding_length:0:-1]],
x,
[2 * x[-1] - xi for xi in x[-2:-padding_length-2:-1]]))
# Filter forward
y_forward = custom_filter(b, a, x_forward)
# Filter backward
x_backward = y_forward[::-1] # reverse
y_backward = custom_filter(b, a, x_backward)
# Remove padding and reverse
return y_backward[-padding_length-1:padding_length-1:-1]
请注意,此实现不需要scipy。此外,通过写出zi 的解决方案并使用列表而不是numpy 数组,它可以很容易地适应纯python,甚至不需要numpy。这甚至带来了巨大的性能优势,因为在 python 循环中访问单个 numpy 数组元素比访问列表元素要慢得多。
过滤器本身是在一个简单的Python 循环中实现的。它使用状态空间表示,因为无论如何都会使用它来确定初始条件(请参阅scipy.signal.lfilter_zi)。我相信线性滤波器的实际scipy 实现(即scipy.signal.sigtools._linear_filter)在C 中做了类似的事情,可以看到here(感谢this answer)。
这里有一些代码提供(非常基本的)检查scipy 输出和custom 输出的相等性:
import numpy
import numpy.testing
import scipy.signal
from matplotlib import pyplot
from . import custom_filtfilt
def sinusoid(sampling_frequency_Hz=50.0, signal_frequency_Hz=1.0, periods=1.0,
amplitude=1.0, offset=0.0, phase_deg=0.0, noise_std=0.1):
"""
Create a noisy test signal sampled from a sinusoid (time series)
"""
signal_frequency_rad_per_s = signal_frequency_Hz * 2 * numpy.pi
phase_rad = numpy.radians(phase_deg)
duration_s = periods / signal_frequency_Hz
number_of_samples = int(duration_s * sampling_frequency_Hz)
time_s = (numpy.array(range(number_of_samples), float) /
sampling_frequency_Hz)
angle_rad = signal_frequency_rad_per_s * time_s
signal = offset + amplitude * numpy.sin(angle_rad - phase_rad)
noise = numpy.random.normal(loc=0.0, scale=noise_std, size=signal.shape)
return signal + noise
if __name__ == '__main__':
# Design filter
sampling_freq_hz = 50.0
cutoff_freq_hz = 2.5
order = 2
normalized_frequency = cutoff_freq_hz * 2 / sampling_freq_hz
b, a = scipy.signal.butter(order, normalized_frequency, btype='lowpass')
# Create test signal
signal = sinusoid(sampling_frequency_Hz=sampling_freq_hz,
signal_frequency_Hz=1.5, periods=3, amplitude=2.0,
offset=2.0, phase_deg=25)
# Apply zero-phase filters
filtered_custom = custom_filtfilt(b, a, signal)
filtered_scipy = scipy.signal.filtfilt(b, a, signal)
# Verify near-equality
numpy.testing.assert_array_almost_equal(filtered_custom, filtered_scipy,
decimal=12)
# Plot result
pyplot.subplot(1, 2, 1)
pyplot.plot(signal)
pyplot.plot(filtered_scipy)
pyplot.plot(filtered_custom, '.')
pyplot.title('raw vs filtered signals')
pyplot.legend(['raw', 'scipy filtfilt', 'custom filtfilt'])
pyplot.subplot(1, 2, 2)
pyplot.plot(filtered_scipy-filtered_custom)
pyplot.title('difference (scipy vs custom)')
pyplot.show()
这个基本比较产生了一个如下图,表明至少 14 位小数,对于这种特定情况(机器精度,我猜?):
【讨论】: