【发布时间】:2019-03-16 19:05:18
【问题描述】:
documentation 对此有些模糊,而我认为这将是一个非常直接的实现方式。
应用于 MNIST 数字数据集的 k_mean 算法输出 10 个区域,其中有一个特定的数字与之关联,尽管它不是该区域中包含的大多数数字所代表的数字。
我确实有自己的 ground_truth 标签表。
如何使 k_mean 算法生成的每个区域最终被标记为被覆盖概率最高的数字?
我昨天花了几个小时编写这段代码来做到这一点,但它仍然不完整:
# TODO: for centroid-average method, see https://stackoverflow.com/a/25831425/9768291
def most_probable_digit(indices, data):
"""
Avec un tableau d'indices (d'un label spécifique assigné par scikit, obtenu avec "get_indices_of_label")
où se situent les vrais labels dans 'data', cette fonction calcule combien de fois chaque vrai label
apparaît et retourne celui qui est apparu le plus souvent (et donc qui a la plus grande probabilité
d'être le ground_truth_label désigné par la région délimitée par scikit).
:param indices: tableau des indices dans 'data' qui font parti d'une région du k_mean
:param data: toutes les données réparties dans les régions du k_mean
:return: la valeur (le digit) le plus probable associé à cette région
"""
actual_labels = []
for i in indices:
actual_labels.append(data[i])
if verbose: print("The actual labels for each of those digits are:", actual_labels)
counts = count_labels("actual labels", actual_labels)
probable = counts.index(max(counts))
if verbose: print("Most probable digit:", probable)
return probable
def get_list_of_indices(data, label):
"""
Retourne une liste d'indices correspondant à tous les endroits
où on peut trouver dans 'data' le 'label' spécifié
:param data:
:param label: le numéro associé à une région générée par k_mean
:return:
"""
return (np.where(data == label))[0].tolist()
# TODO: reassign in case of doubles
def obtain_corresponding_labels(data, real_labels):
"""
Assign the most probable label to each region.
:param data: list of regions associated with x_train or x_test (the order is preserved!)
:param real_labels: actual labels to assign to the region numbers
:return: the list of corresponding actual labels to region numbers
"""
switches_to_make = []
for i in range(10):
list_of_indices = get_list_of_indices(data, i) # indices in 'data' which are associated with region "i"
probable_label = most_probable_digit(list_of_indices, real_labels)
print("The assigned region", i, "should be considered as representing the digit ", probable_label)
switches_to_make.append(probable_label)
return switches_to_make
def rearrange_labels(switches_to_make, to_change):
"""
Takes region numbers and assigns the most probable digit (label) to it.
For example, if switches_to_make[3] = 5, it means that the 4th region (index 3 of the list)
should be considered as representing the digit "5".
:param switches_to_make: list of changes to make
:param to_change: this table will be changed according to 'switches_to_make'
:return: nothing, the change is made in-situ
"""
for region in range(len(to_change)):
for label in range(len(switches_to_make)):
if to_change[region] == label: # if it corresponds to the "wrong" label given by scikit
to_change[region] = switches_to_make[label] # assign the "most probable" label
break
def count_error_rate(found, truth):
wrong = 0
for i in range(len(found)):
if found[i] != truth[i]:
wrong += 1
print("Error rate = ", wrong / len(found) * 100, "%\n\n")
def treat_data(switches_to_make, predictions, truth):
rearrange_labels(switches_to_make, predictions) # Rearranging the training labels
count_error_rate(predictions, truth) # Counting error rate
目前,我的代码的问题是它可以生成重复项(如果两个区域具有相同的最高概率数字,则该数字与两个区域相关联)。
这是我使用代码的方式:
kmeans = KMeans(n_clusters=10) # TODO: eventually use "init=ndarray" to be able to use custom centroids for init ?
kmeans.fit(x_train)
training_labels = kmeans.labels_
print("Done with calculating the k-mean.\n")
switches_to_make = utils.obtain_corresponding_labels(training_labels, y_train) # Obtaining the most probable labels
utils.treat_data(switches_to_make, training_labels, y_train)
print("Assigned labels: ", training_labels)
print("Real labels: ", y_train)
print("\n####################################################\nMoving on to predictions")
predictions = kmeans.predict(x_test)
utils.treat_data(switches_to_make, predictions, y_test)
我的代码错误率约为 50%。
【问题讨论】:
标签: python-3.x scikit-learn k-means