grouped_dataset = dataset.group_by_window( key_func=lambda seq: smart_length(tf_rank1_tensor_len(seq), bucket_bounds=bounds), # choose a bucket reduce_func=lambda key, ds: pad_batch(ds, batch_size, padding=padding, padded_shapes=pad_shape), # apply reduce funtion to pad window_size=window_size)
改为:
grouped_dataset = dataset.apply(tf.compat.v1.data.experimental.group_by_window(
        key_func=lambda seq: smart_length(tf_rank1_tensor_len(seq), bucket_bounds=bounds), # choose a bucket
        reduce_func=lambda key, ds: pad_batch(ds, batch_size, padding=padding, padded_shapes=pad_shape), # apply reduce funtion to pad
        window_size=window_size))
即可

参考:https://tensorflow.google.cn/versions/r1.15/api_docs/python/tf/data/experimental/group_by_window

相关文章:

  • 2021-11-22
  • 2022-12-23
  • 2022-12-23
  • 2022-12-23
  • 2021-08-11
  • 2022-12-23
  • 2022-12-23
  • 2021-08-16
猜你喜欢
  • 2022-12-23
  • 2021-12-20
  • 2022-12-23
  • 2022-12-23
  • 2021-06-26
  • 2021-12-15
  • 2021-06-28
相关资源
相似解决方案