【发布时间】:2021-11-11 12:48:59
【问题描述】:
编辑:帖子末尾的可能答案
您好,我正在尝试将 LSTM 转换为 tflite 模型,但我遇到了
TypeError: 'generator' 对象不可调用
错误。我的代码以前使用 python3.6,但是由于使用 TensorFlow-nightly 2.7 版本(LSTM 转换所需),我必须使用 python 3.7
现在我遇到了一个错误,我想知道我的代码是不是从一开始就有错误,或者我应该打开一个 git 票证。
在我的代码中,我设置了一个生成器函数
def my_batch_generator(X, batch_size = 500):
indices = np.arange(len(X))
batch=[]
while True:
for i in indices:
batch.append(i)
if len(batch)==batch_size:
yield X[batch]
batch=[]
数据输入 X 是从 csv 文件中读取的。
data=pd.read_csv('./test_x_data_OOP3.csv', index_col=[0])
data=np.array(data)
data=reshape_for_Lstm(data) #a function that just transforms the array
后来我调用了代表性数据集的生成器:
converter.representative_dataset = my_batch_generator(data)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
converter._experimental_lower_tensor_list_ops = False
quantized_tflite_model = converter.convert()
并为这一行抛出错误(长回溯,我将为我的 git 票保留 :-))
converter.representative_dataset = my_batch_generator(data)
问题:您是否发现我的生成器函数有任何错误以及我如何称呼它?或者可能是因为使用了 python 3.7?
谢谢
编辑:
Thierry Lathuille,你是对的。我将添加回溯。我还将上传一个工作代码示例。但是,我确保不会像您提示的那样覆盖我的函数。
您可以在此处找到一个简单版本,其中包含下载模型和 csv 文件的信息。 https://github.com/JanderHungrige/forstackoverflow
回溯
File "/home/base/Documents/Git/KundenProjekte2021/Ginko/pump_sensor/Quantizing_LSTM_3.py", line 59, in <module>
quantized_tflite_model = converter.convert()
File "/home/base/anaconda3/envs/AInight/lib/python3.7/site-packages/tensorflow/lite/python/lite.py", line 775, in wrapper
return self._convert_and_export_metrics(convert_func, *args, **kwargs)
File "/home/base/anaconda3/envs/AInight/lib/python3.7/site-packages/tensorflow/lite/python/lite.py", line 761, in _convert_and_export_metrics
result = convert_func(self, *args, **kwargs)
File "/home/base/anaconda3/envs/AInight/lib/python3.7/site-packages/tensorflow/lite/python/lite.py", line 1044, in convert
result, quant_mode, quant_io=self.experimental_new_quantizer)
File "/home/base/anaconda3/envs/AInight/lib/python3.7/site-packages/tensorflow/lite/python/convert_phase.py", line 226, in wrapper
raise error from None # Re-throws the exception.
File "/home/base/anaconda3/envs/AInight/lib/python3.7/site-packages/tensorflow/lite/python/convert_phase.py", line 216, in wrapper
return func(*args, **kwargs)
File "/home/base/anaconda3/envs/AInight/lib/python3.7/site-packages/tensorflow/lite/python/lite.py", line 722, in _optimize_tflite_model
model, q_in_type, q_out_type, q_activations_type, q_allow_float)
File "/home/base/anaconda3/envs/AInight/lib/python3.7/site-packages/tensorflow/lite/python/lite.py", line 530, in _quantize
self.representative_dataset.input_gen)
File "/home/base/anaconda3/envs/AInight/lib/python3.7/site-packages/tensorflow/lite/python/convert_phase.py", line 226, in wrapper
raise error from None # Re-throws the exception.
File "/home/base/anaconda3/envs/AInight/lib/python3.7/site-packages/tensorflow/lite/python/convert_phase.py", line 216, in wrapper
return func(*args, **kwargs)
File "/home/base/anaconda3/envs/AInight/lib/python3.7/site-packages/tensorflow/lite/python/optimize/calibrator.py", line 228, in calibrate
self._feed_tensors(dataset_gen, resize_input=True)
File "/home/base/anaconda3/envs/AInight/lib/python3.7/site-packages/tensorflow/lite/python/optimize/calibrator.py", line 97, in _feed_tensors
for sample in dataset_gen():
TypeError: 'generator' object is not callable
我的回答
似乎有两种方法可以解决这个问题。
首先很简单,只需调用函数batch_generator 而不传递任何值。 X 和批量大小是在函数中获取而不是传递(batch_generator() 函数如何知道data 和batch_size 对我来说不是很清楚)。所以就像这样:
def batch_generator():
for X in data:
batch_size = 2
indices = np.arange(len(X))
batch=[]
while True:
for i in indices:
batch.append(i)
if len(batch)==batch_size:
yield X[batch]
batch=[]
data=pd.read_csv('./test_x_data_OOP3.csv', index_col=[0])
data=np.array(data)
data=reshape_for_Lstm(data)
converter.representative_dataset = batch_generator
joanis 提出的第二种更优雅的方法是使用 init 创建一个类对象并调用,然后初始化生成器。如下:
class BatchGenerator():
def __init__(self, X, batch_size):
self.X=X
self.batch_size=batch_size
def __call__(self):
indices = np.arange(len(self.X))
batch=[]
while True:
for i in indices:
batch.append(i)
if len(batch)==self.batch_size:
yield self.X[batch]
batch=[]
data=pd.read_csv('./test_x_data_OOP3.csv', index_col=[0])
data=np.array(data)
data=reshape_for_Lstm(data)
batch_generator=BatchGenerator(data, 2)
converter.representative_dataset = batch_generator
感谢您的所有意见
【问题讨论】:
-
“长回溯,我会为我的 git 票保留它” - 好吧,你真的应该把它包含在你的问题中,并提供一个minimal reproducible example。
my_batch_generator显然不是您目前所期望的。您的代码中可能有类似my_batch_generator = <something that is a generator, maybe my_batch_generator()>的内容。 -
我怀疑生成器的基本使用在 Python 3.6 和 Python 3.7 之间发生了变化。我怀疑您做了一些无意的更改,破坏了 3.6 中的代码。
-
您能否将您的错误回溯格式化为代码,以保留其结构?
-
问题可能不在于尝试调用
my_batch_generator,而在于下游尝试调用my_batch_generator(data)。如果回溯最后一行中的dataset_gen是my_batch_generator(data),那么dataset_gen()将抛出该错误。 -
只传递没有
()的生成器是不够的,因为它不会被任何参数调用,所以你必须在创建生成器时保存这些参数。跨度>
标签: python python-3.x generator tensorflow2.0 tensorflow-lite