【问题标题】:Convert a tensorflow dataset to a python list with strings将 tensorflow 数据集转换为带有字符串的 python 列表
【发布时间】:2022-01-18 13:32:26
【问题描述】:

考虑下面的代码:

import numpy as np
import tensorflow as tf

simple_data_samples = np.array([
         [1, 1, 1, -1, -1],
         [2, 2, 2, -2, -2],
         [3, 3, 3, -3, -3],
         [4, 4, 4, -4, -4],
         [5, 5, 5, -5, -5],
         [6, 6, 6, -6, -6],
         [7, 7, 7, -7, -7],
         [8, 8, 8, -8, -8],
         [9, 9, 9, -9, -9],
         [10, 10, 10, -10, -10],
         [11, 11, 11, -11, -11],
         [12, 12, 12, -12, -12],
])

def timeseries_dataset_multistep_combined(features, label_slice, input_sequence_length, output_sequence_length, batch_size):
    feature_ds = tf.keras.preprocessing.timeseries_dataset_from_array(features, None, input_sequence_length + output_sequence_length, batch_size=batch_size)

    def split_feature_label(x):
        x=tf.strings.as_string(x)

        return x[:, :input_sequence_length, :], x[:, input_sequence_length:, label_slice]

    feature_ds = feature_ds.map(split_feature_label)

    return feature_ds

ds = timeseries_dataset_multistep_combined(simple_data_samples, slice(None, None, None), input_sequence_length=4, output_sequence_length=2,
batch_size=1)
def print_dataset(ds):
    for inputs, targets in ds:
        print("---Batch---")
        print("Feature:", inputs.numpy())
        print("Label:", targets.numpy())
        print("")



print_dataset(ds)

张量流数据集“ds”由输入和目标组成。 现在我想将 tensorflow 数据集转换为具有以下属性的 python 列表:

Index Type Size  Value 
0     str    13   1  2  3  4      5  6 
1     str    13   1  2  3  4      5  6
2     str    13   1  2  3  4      5  6
3     str    13   -1 -2 -3 -4    -5 -6   
4     str    13   -1 -2 -3 -4    -5 -6
5     str    13    2  3  4  5     6  7
.... and so on

在上面的示例中,我们假设创建了一个包含字符串的 python 列表。在“值”字段中,您可以在左侧看到 tensorflow 数据集的输入(例如,字符串之间有空格的 1 2 3 4),在右侧,您可以看到相应的目标(例如,5 6 和字符串之间的空格)。需要注意的是,输入和目标之间有一个水平制表符“\t”(例如 1 2 3 4.\t5 6.)

我将如何编码?

【问题讨论】:

    标签: python list tensorflow tensorflow2.0 tensorflow-datasets


    【解决方案1】:

    如果你想要一个pandas 数据框,你可以试试这样的:

    features = np.concatenate(list(ds.map(lambda x, y: tf.transpose(tf.squeeze(x, axis=0)))))
    targets = np.concatenate(list(ds.map(lambda x, y: tf.transpose(tf.squeeze(y, axis=0)))))
    
    values = list(map(lambda x: x[0]+ "\t" + x[1], zip([" ".join(item) for item in features.astype(str)], 
                                                       [" ".join(item) for item in targets.astype(str)])))
    types = [type(v).__name__ for v in values]
    sizes = [len(v) for v in values]
    df = pd.DataFrame({'Size':sizes, 'Type':types, 'Value':values})
    df.index.name = 'Index'
    print(df.head())
    

    【讨论】:

    • 哇,很棒的编码。谢谢
    【解决方案2】:

    我使用了你的 print_dataset 函数。

    def print_dataset(ds):
    
        list_sets = []
    
        for input, targets in ds:
    
            input = np.transpose(np.array(inputs)[0])
            label = np.transpose(np.array(targets)[0])
    
            for input_set, label_set in zip(input, label):
    
                set = ""
                set = "".join(str(value).replace("b'", "").replace("'", "") + " " for value in input_set)
    
                set += "\t" # add the tab
    
                set += "".join(str(value).replace("b'", "").replace("'", "") + " " for value in label_set)
                set = set[:-1] # remove the trailing white space
    
                # print(set) #prints each line individually 
                list_sets.append(set)
    
        print(list_sets) # prints the whole list
    

    如果您打印每行都可以正常工作,请忽略您可以看到“\t”而不是带有空格的制表符。 Python 只打印“\t”,通过用快捷方式替换无用的空格来缩短长度。

    【讨论】:

    • 听说过简单的return list_sets 吗?
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2021-10-11
    • 1970-01-01
    • 2022-08-17
    • 2014-08-06
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多