【问题标题】:Unable to Implement cosine similarity for simple k-means in JAVA for WEKA无法在 JAVA for WEKA 中实现简单 k-means 的余弦相似度
【发布时间】:2014-05-08 09:22:37
【问题描述】:

我对 Java 的 ML 的 WEKA API 还是很陌生。

由于 weka 中没有余弦相似度算法,所以想通过修改 WEKA 的 simpleKmeans 算法,将这个算法加入到 WEKA 中。

weka 中的 simpleKmeans 算法使用 EuclideanDistance,我希望使用余弦相似度而不是 euclideanDistance。

我google了很多关于如何修改weka开源软件simpleKmeans算法的代码,在网上找到了这个问题(基本上是pedro的观点)

http://comments.gmane.org/gmane.comp.ai.weka/22681

这里提到的步骤是:

  1. 扩展weka.core.EuclideanDistance并覆盖距离(实例优先, 实例二,PerformanceStats stats) 方法。

  2. 使用EuclideanDistance作为类型将其实例化为扩展类, 将实例作为扩展类构造函数的参数传递。

  3. 使用SimpleKMeans 类中的setDistanceFunction 方法传递
    EuclideanDistance 实例。

这是 WEKA 流程第一部分的代码。

/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */

package weka.core;

import weka.core.Attribute;
//import weka.core.EuclideanDistance;
import java.util.Enumeration;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.converters.ConverterUtils.DataSource;
import weka.core.neighboursearch.PerformanceStats;
import weka.core.TechnicalInformation.Type;

/**
 *
 * @author Sgr
 */
public class CosineSimilarity extends EuclideanDistance{

 public Instances m_Data = null;
 public String version ="1.0";

 @Override
 public double distance(Instance arg0, Instance arg1) {
  // TODO Auto-generated method stub
  return distance(arg0, arg1, Double.POSITIVE_INFINITY, null);
 }

 @Override
 public double distance(Instance arg0, Instance arg1, PerformanceStats arg2) {
  // TODO Auto-generated method stub
  return distance(arg0, arg1, Double.POSITIVE_INFINITY, arg2);
 }

 @Override
 public double distance(Instance arg0, Instance arg1, double arg2) {
  // TODO Auto-generated method stub
  return distance(arg0, arg1, arg2, null);
 }

 @Override
 public double distance(Instance first, Instance second, double cutOffValue,PerformanceStats arg3) {

    double distance = 0;
    int firstI, secondI;
    int firstNumValues = first.numValues();
    int secondNumValues = second.numValues();
    int numAttributes = m_Data.numAttributes();
    int classIndex = m_Data.classIndex();
    double normA, normB;
    normA = 0;
    normB = 0;

    for (int p1 = 0, p2 = 0; p1 < firstNumValues || p2 < secondNumValues;) {

        if (p1 >= firstNumValues)
            firstI = numAttributes;
        else firstI = first.index(p1);


        if (p2 >= secondNumValues)
            secondI = numAttributes;
        else secondI = second.index(p2);

        if (firstI == classIndex) {
            p1++;
           continue;
        }
//   if ((firstI < numAttributes)) {
//    p1++;
//    continue;
//   }

        if (secondI == classIndex) {
            p2++;
            continue;
        }
//   if ((secondI < numAttributes)) {
//    p2++;
//    continue;
//   }

        double diff;

        if (firstI == secondI) {

            diff = difference(firstI, first.valueSparse(p1), second.valueSparse(p2));
            normA += Math.pow(first.valueSparse(p1), 2);
            normB += Math.pow(second.valueSparse(p2), 2);
            p1++;
            p2++;

        } 

        else if (firstI > secondI) {

            diff = difference(secondI, 0, second.valueSparse(p2));
            normB += Math.pow(second.valueSparse(p2), 2);
            p2++;

        }

        else {
            diff = difference(firstI, first.valueSparse(p1), 0);
            normA += Math.pow(first.valueSparse(p1), 2);
            p1++;
        }

        if (arg3 != null)
            arg3.incrCoordCount();

        distance = updateDistance(distance, diff);

        if (distance > cutOffValue)
            return Double.POSITIVE_INFINITY;
        }

  //do the post here, don't depends on other functions
  //System.out.println(distance + " " + normA + " "+ normB);
        distance = distance/Math.sqrt(normA)/Math.sqrt(normB);
        distance = 1-distance;

        if(distance < 0 || distance > 1)
            System.err.println("unknown: " + distance);

        return distance;

    }

 public double updateDistance(double currDist, double diff){

     double result;
    result = currDist;
    result += diff;

    return result;
 }

 public double difference(int index, double val1, double val2){

     switch(m_Data.attribute(index).type()){

         case Attribute.NOMINAL:
                            return Double.NaN;
                            //break;
         case Attribute.NUMERIC:
                              return val1 * val2;
                            //break;
    }

     return Double.NaN;
 }

 @Override
 public String getAttributeIndices() {
  // TODO Auto-generated method stub
  return null;
 }

 @Override
 public Instances getInstances() {
  // TODO Auto-generated method stub
  return m_Data;
 }

 @Override
 public boolean getInvertSelection() {
  // TODO Auto-generated method stub
  return false;
 }

 @Override
 public void postProcessDistances(double[] arg0) {
  // TODO Auto-generated method stub

 }

 @Override
 public void setAttributeIndices(String arg0) {
  // TODO Auto-generated method stub

 }

 @Override
 public void setInstances(Instances arg0) {
  // TODO Auto-generated method stub
  m_Data = arg0;
 }

 @Override
 public void setInvertSelection(boolean arg0) {
  // TODO Auto-generated method stub


  //do nothing
 }

 @Override
 public void update(Instance arg0) {
  // TODO Auto-generated method stub

  //do nothing
 }

 @Override
 public String[] getOptions() {
  // TODO Auto-generated method stub
  return null;
 }

 @Override
 public Enumeration listOptions() {
  // TODO Auto-generated method stub
  return null;
 }

 @Override
 public void setOptions(String[] arg0) throws Exception {
  // TODO Auto-generated method stub

 }

 @Override
 public String getRevision() {
  // TODO Auto-generated method stub
  return "Cosine Distance function writtern by Sgr, version " + version;
 }


}

但我无法处理接下来的两个步骤,因为我不太熟悉 weka。

我在 weka 中看到了 simpleKmeans 的源代码,并观察到它创建了一个 EuclideanDistance 类的实例,但我对进一步的过程一无所知。

请帮助我了解接下来要执行的两个步骤。如果余弦相似度的这种实现有错误,请找出答案。此外,如果有人可以为我的余弦实现修改 weka 中的 SimpleKmeans 代码,或者向我解释我应该在该代码中进行更改的地方,那将非常有帮助。

【问题讨论】:

    标签: cluster-analysis data-mining weka cosine-similarity


    【解决方案1】:

    Weka 在集群方面确实。它也相当慢。

    你看过ELKI。在聚类和异常值检测方面,它比 Weka 有更多的选择。您可以在 ELKI 中开箱即用地尝试 k-means 中的余弦相似度。

    但请注意,k-means 不是基于距离的。它正在最小化方差(平方和),如果您使用其他距离函数,k-means 可能会停止收敛。原因是均值是 L2 最优中心,但它确实没有优化其他距离函数。它只是优化平方和,这与平方欧几里得距离相同。

    通常,具有其他距离(例如余弦)的 k 均值可能适用于您的数据集。但是收敛证明需要平方和。事实上,当集群的均值变为 0(即使您的数据不包含零向量)时,使用具有余弦相似度的 k-means 也可能会产生除以 0 的错误。

    有许多变体,例如 k-medoid,确实支持其他距离函数。据我记得,它们也应该在 ELKI 中可用。

    【讨论】:

    • @Anony-Mousse 谢谢先生!!但我希望在 weka 中有一个使用余弦相似度的distance function。我能够在weka.core 包中添加我的cosinesimilarity 算法。但我无法正确执行2nd and 3rd step(我在问题中提到的那些)。如果您能帮助我实现这一目标,那将非常有用。我的主要目标是add cosinesimilarity in simplekmeans algo in weka。
    • 我不再使用 Weka。它太慢了。对于聚类,ELKI 要好得多;对于分类,我使用 scikit-learn。 CosineSimilarity extends EuclideanDistance 听起来像是对 OOP 继承的危险滥用...余弦不是欧几里得距离的子类型。
    • 尽管如此,k-means 根本不应该使用距离恕我直言。最快的 k-means 实现实际上计算任何距离 - 它们直接优化目标函数 sum |x_ki - c_ji|^2,它只是“看起来像”平方欧几里得距离。
    • 无论哪种方式,请尝试使用 ELKI。您需要编写 0 行代码,因为它已经包含了您需要的所有内容。您可以在 ELKI 中配置 k-means 以使用余弦相似度运行(但它会警告这可能不会收敛)。
    猜你喜欢
    • 2021-05-12
    • 2018-03-06
    • 2017-06-25
    • 1970-01-01
    • 2011-01-23
    • 2021-03-16
    • 2020-08-12
    • 2011-01-01
    相关资源
    最近更新 更多