我的数据挖掘算法代码:https://github.com/linyiqun/DataMiningAlgorithm

介绍

Apriori算法是一个经典的数据挖掘算法,Apriori的单词的意思是"先验的",说明这个算法是具有先验性质的,就是说要通过上一次的结果推导出下一次的结果,这个如何体现将会在下面的分析中会慢慢的体现出来。Apriori算法的用处是挖掘频繁项集的,频繁项集粗俗的理解就是找出经常出现的组合,然后根据这些组合最终推出我们的关联规则。

Apriori算法原理

Apriori算法是一种逐层搜索的迭代式算法,其中k项集用于挖掘(k+1)项集,这是依靠他的先验性质的:

频繁项集的所有非空子集一定是也是频繁的。

通过这个性质可以对候选集进行剪枝。用k项集如何生成(k+1)项集呢,这个是算法里面最难也是最核心的部分。

通过2个步骤

1、连接步,将频繁项自己与自己进行连接运算。

2、剪枝步,去除候选集项中的不符合要求的候选项,不符合要求指的是这个候选项的子集并非都是频繁项,要遵守上文提到的先验性质。

3、通过1,2步骤还不够,在后面还要根据支持度计数筛选掉不满足最小支持度数的候选集。

算法实例

首先是测试数据:

交易ID

商品ID列表

T100

I1I2I5

T200

I2I4

T300

I2I3

T400

I1I2I4

T500

I1I3

T600

I2I3

T700

I1I3

T800

I1I2I3I5

T900

I1I2I3

算法的步骤图:

Apriori算法--关联规则挖掘

最后我们可以看到频繁3项集的结果为{1, 2, 3}和{1, 2, 5},然后我们去后者{1, 2, 5}作为频繁项集来生产他的关联规则,但是在这之前得先知道一些概念,怎么样才能够成为一条关联规则,关有频繁项集还是不够的。

关联规则

confidence(置信度)

confidence的中文意思为自信的,在这里其实表示的是一种条件概率,当在A条件下,B发生的概率就可以表示为confidence(A->B)=p(B|A),意为在A的情况下,推出B的概率。那么关联规则与有什么关系呢,请继续往下看。

最小置信度阈值

按照字面上的意思就是限制置信度值的一个限制条件嘛,这个很好理解。

强规则

强规则就是指的是置信度满足最小置信度(就是>=最小置信度)的推断就是一个强规则,也就是文中所说的关联规则了。这个在下面的程序中会有所体现。

算法的代码实现

我自己写的算法实现可能会让你有点晦涩难懂,不过重在理解算法的整个思路即可,尤其是连接步和剪枝步是最难点所在,可能还存在bug。

输入数据:

  1. T1125
  2. T224
  3. T323
  4. T4124
  5. T513
  6. T623
  7. T713
  8. T81235
  9. T9123
频繁项类:

  1. /**
  2. *频繁项集
  3. *
  4. *@authorlyq
  5. *
  6. */
  7. publicclassFrequentItemimplementsComparable<FrequentItem>{
  8. //频繁项集的集合ID
  9. privateString[]idArray;
  10. //频繁项集的支持度计数
  11. privateintcount;
  12. //频繁项集的长度,1项集或是2项集,亦或是3项集
  13. privateintlength;
  14. publicFrequentItem(String[]idArray,intcount){
  15. this.idArray=idArray;
  16. this.count=count;
  17. length=idArray.length;
  18. }
  19. publicString[]getIdArray(){
  20. returnidArray;
  21. }
  22. publicvoidsetIdArray(String[]idArray){
  23. this.idArray=idArray;
  24. }
  25. publicintgetCount(){
  26. returncount;
  27. }
  28. publicvoidsetCount(intcount){
  29. this.count=count;
  30. }
  31. publicintgetLength(){
  32. returnlength;
  33. }
  34. publicvoidsetLength(intlength){
  35. this.length=length;
  36. }
  37. @Override
  38. publicintcompareTo(FrequentItemo){
  39. //TODOAuto-generatedmethodstub
  40. returnthis.getIdArray()[0].compareTo(o.getIdArray()[0]);
  41. }
  42. }
主程序类:

  1. packageDataMining_Apriori;
  2. importjava.io.BufferedReader;
  3. importjava.io.File;
  4. importjava.io.FileReader;
  5. importjava.io.IOException;
  6. importjava.text.MessageFormat;
  7. importjava.util.ArrayList;
  8. importjava.util.Collections;
  9. importjava.util.HashMap;
  10. importjava.util.Map;
  11. /**
  12. *apriori算法工具类
  13. *
  14. *@authorlyq
  15. *
  16. */
  17. publicclassAprioriTool{
  18. //最小支持度计数
  19. privateintminSupportCount;
  20. //测试数据文件地址
  21. privateStringfilePath;
  22. //每个事务中的商品ID
  23. privateArrayList<String[]>totalGoodsIDs;
  24. //过程中计算出来的所有频繁项集列表
  25. privateArrayList<FrequentItem>resultItem;
  26. //过程中计算出来频繁项集的ID集合
  27. privateArrayList<String[]>resultItemID;
  28. publicAprioriTool(StringfilePath,intminSupportCount){
  29. this.filePath=filePath;
  30. this.minSupportCount=minSupportCount;
  31. readDataFile();
  32. }
  33. /**
  34. *从文件中读取数据
  35. */
  36. privatevoidreadDataFile(){
  37. Filefile=newFile(filePath);
  38. ArrayList<String[]>dataArray=newArrayList<String[]>();
  39. try{
  40. BufferedReaderin=newBufferedReader(newFileReader(file));
  41. Stringstr;
  42. String[]tempArray;
  43. while((str=in.readLine())!=null){
  44. tempArray=str.split("");
  45. dataArray.add(tempArray);
  46. }
  47. in.close();
  48. }catch(IOExceptione){
  49. e.getStackTrace();
  50. }
  51. String[]temp=null;
  52. totalGoodsIDs=newArrayList<>();
  53. for(String[]array:dataArray){
  54. temp=newString[array.length-1];
  55. System.arraycopy(array,1,temp,0,array.length-1);
  56. //将事务ID加入列表吧中
  57. totalGoodsIDs.add(temp);
  58. }
  59. }
  60. /**
  61. *判读字符数组array2是否包含于数组array1中
  62. *
  63. *@paramarray1
  64. *@paramarray2
  65. *@return
  66. */
  67. publicbooleaniSStrContain(String[]array1,String[]array2){
  68. if(array1==null||array2==null){
  69. returnfalse;
  70. }
  71. booleaniSContain=false;
  72. for(Strings:array2){
  73. //新的字母比较时,重新初始化变量
  74. iSContain=false;
  75. //判读array2中每个字符,只要包括在array1中,就算包含
  76. for(Strings2:array1){
  77. if(s.equals(s2)){
  78. iSContain=true;
  79. break;
  80. }
  81. }
  82. //如果已经判断出不包含了,则直接中断循环
  83. if(!iSContain){
  84. break;
  85. }
  86. }
  87. returniSContain;
  88. }
  89. /**
  90. *项集进行连接运算
  91. */
  92. privatevoidcomputeLink(){
  93. //连接计算的终止数,k项集必须算到k-1子项集为止
  94. intendNum=0;
  95. //当前已经进行连接运算到几项集,开始时就是1项集
  96. intcurrentNum=1;
  97. //商品,1频繁项集映射图
  98. HashMap<String,FrequentItem>itemMap=newHashMap<>();
  99. FrequentItemtempItem;
  100. //初始列表
  101. ArrayList<FrequentItem>list=newArrayList<>();
  102. //经过连接运算后产生的结果项集
  103. resultItem=newArrayList<>();
  104. resultItemID=newArrayList<>();
  105. //商品ID的种类
  106. ArrayList<String>idType=newArrayList<>();
  107. for(String[]a:totalGoodsIDs){
  108. for(Strings:a){
  109. if(!idType.contains(s)){
  110. tempItem=newFrequentItem(newString[]{s},1);
  111. idType.add(s);
  112. resultItemID.add(newString[]{s});
  113. }else{
  114. //支持度计数加1
  115. tempItem=itemMap.get(s);
  116. tempItem.setCount(tempItem.getCount()+1);
  117. }
  118. itemMap.put(s,tempItem);
  119. }
  120. }
  121. //将初始频繁项集转入到列表中,以便继续做连接运算
  122. for(Map.Entryentry:itemMap.entrySet()){
  123. list.add((FrequentItem)entry.getValue());
  124. }
  125. //按照商品ID进行排序,否则连接计算结果将会不一致,将会减少
  126. Collections.sort(list);
  127. resultItem.addAll(list);
  128. String[]array1;
  129. String[]array2;
  130. String[]resultArray;
  131. ArrayList<String>tempIds;
  132. ArrayList<String[]>resultContainer;
  133. //总共要算到endNum项集
  134. endNum=list.size()-1;
  135. while(currentNum<endNum){
  136. resultContainer=newArrayList<>();
  137. for(inti=0;i<list.size()-1;i++){
  138. tempItem=list.get(i);
  139. array1=tempItem.getIdArray();
  140. for(intj=i+1;j<list.size();j++){
  141. tempIds=newArrayList<>();
  142. array2=list.get(j).getIdArray();
  143. for(intk=0;k<array1.length;k++){
  144. //如果对应位置上的值相等的时候,只取其中一个值,做了一个连接删除操作
  145. if(array1[k].equals(array2[k])){
  146. tempIds.add(array1[k]);
  147. }else{
  148. tempIds.add(array1[k]);
  149. tempIds.add(array2[k]);
  150. }
  151. }
  152. resultArray=newString[tempIds.size()];
  153. tempIds.toArray(resultArray);
  154. booleanisContain=false;
  155. //过滤不符合条件的的ID数组,包括重复的和长度不符合要求的
  156. if(resultArray.length==(array1.length+1)){
  157. isContain=isIDArrayContains(resultContainer,
  158. resultArray);
  159. if(!isContain){
  160. resultContainer.add(resultArray);
  161. }
  162. }
  163. }
  164. }
  165. //做频繁项集的剪枝处理,必须保证新的频繁项集的子项集也必须是频繁项集
  166. list=cutItem(resultContainer);
  167. currentNum++;
  168. }
  169. //输出频繁项集
  170. for(intk=1;k<=currentNum;k++){
  171. System.out.println("频繁"+k+"项集:");
  172. for(FrequentItemi:resultItem){
  173. if(i.getLength()==k){
  174. System.out.print("{");
  175. for(Stringt:i.getIdArray()){
  176. System.out.print(t+",");
  177. }
  178. System.out.print("},");
  179. }
  180. }
  181. System.out.println();
  182. }
  183. }
  184. /**
  185. *判断列表结果中是否已经包含此数组
  186. *
  187. *@paramcontainer
  188. *ID数组容器
  189. *@paramarray
  190. *待比较数组
  191. *@return
  192. */
  193. privatebooleanisIDArrayContains(ArrayList<String[]>container,
  194. String[]array){
  195. booleanisContain=true;
  196. if(container.size()==0){
  197. isContain=false;
  198. returnisContain;
  199. }
  200. for(String[]s:container){
  201. //比较的视乎必须保证长度一样
  202. if(s.length!=array.length){
  203. continue;
  204. }
  205. isContain=true;
  206. for(inti=0;i<s.length;i++){
  207. //只要有一个id不等,就算不相等
  208. if(s[i]!=array[i]){
  209. isContain=false;
  210. break;
  211. }
  212. }
  213. //如果已经判断是包含在容器中时,直接退出
  214. if(isContain){
  215. break;
  216. }
  217. }
  218. returnisContain;
  219. }
  220. /**
  221. *对频繁项集做剪枝步骤,必须保证新的频繁项集的子项集也必须是频繁项集
  222. */
  223. privateArrayList<FrequentItem>cutItem(ArrayList<String[]>resultIds){
  224. String[]temp;
  225. //忽略的索引位置,以此构建子集
  226. intigNoreIndex=0;
  227. FrequentItemtempItem;
  228. //剪枝生成新的频繁项集
  229. ArrayList<FrequentItem>newItem=newArrayList<>();
  230. //不符合要求的id
  231. ArrayList<String[]>deleteIdArray=newArrayList<>();
  232. //子项集是否也为频繁子项集
  233. booleanisContain=true;
  234. for(String[]array:resultIds){
  235. //列举出其中的一个个的子项集,判断存在于频繁项集列表中
  236. temp=newString[array.length-1];
  237. for(igNoreIndex=0;igNoreIndex<array.length;igNoreIndex++){
  238. isContain=true;
  239. for(intj=0,k=0;j<array.length;j++){
  240. if(j!=igNoreIndex){
  241. temp[k]=array[j];
  242. k++;
  243. }
  244. }
  245. if(!isIDArrayContains(resultItemID,temp)){
  246. isContain=false;
  247. break;
  248. }
  249. }
  250. if(!isContain){
  251. deleteIdArray.add(array);
  252. }
  253. }
  254. //移除不符合条件的ID组合
  255. resultIds.removeAll(deleteIdArray);
  256. //移除支持度计数不够的id集合
  257. inttempCount=0;
  258. for(String[]array:resultIds){
  259. tempCount=0;
  260. for(String[]array2:totalGoodsIDs){
  261. if(isStrArrayContain(array2,array)){
  262. tempCount++;
  263. }
  264. }
  265. //如果支持度计数大于等于最小最小支持度计数则生成新的频繁项集,并加入结果集中
  266. if(tempCount>=minSupportCount){
  267. tempItem=newFrequentItem(array,tempCount);
  268. newItem.add(tempItem);
  269. resultItemID.add(array);
  270. resultItem.add(tempItem);
  271. }
  272. }
  273. returnnewItem;
  274. }
  275. /**
  276. *数组array2是否包含于array1中,不需要完全一样
  277. *
  278. *@paramarray1
  279. *@paramarray2
  280. *@return
  281. */
  282. privatebooleanisStrArrayContain(String[]array1,String[]array2){
  283. booleanisContain=true;
  284. for(Strings2:array2){
  285. isContain=false;
  286. for(Strings1:array1){
  287. //只要s2字符存在于array1中,这个字符就算包含在array1中
  288. if(s2.equals(s1)){
  289. isContain=true;
  290. break;
  291. }
  292. }
  293. //一旦发现不包含的字符,则array2数组不包含于array1中
  294. if(!isContain){
  295. break;
  296. }
  297. }
  298. returnisContain;
  299. }
  300. /**
  301. *根据产生的频繁项集输出关联规则
  302. *
  303. *@paramminConf
  304. *最小置信度阈值
  305. */
  306. publicvoidprintAttachRule(doubleminConf){
  307. //进行连接和剪枝操作
  308. computeLink();
  309. intcount1=0;
  310. intcount2=0;
  311. ArrayList<String>childGroup1;
  312. ArrayList<String>childGroup2;
  313. String[]group1;
  314. String[]group2;
  315. //以最后一个频繁项集做关联规则的输出
  316. String[]array=resultItem.get(resultItem.size()-1).getIdArray();
  317. //子集总数,计算的时候除去自身和空集
  318. inttotalNum=(int)Math.pow(2,array.length);
  319. String[]temp;
  320. //二进制数组,用来代表各个子集
  321. int[]binaryArray;
  322. //除去头和尾部
  323. for(inti=1;i<totalNum-1;i++){
  324. binaryArray=newint[array.length];
  325. numToBinaryArray(binaryArray,i);
  326. childGroup1=newArrayList<>();
  327. childGroup2=newArrayList<>();
  328. count1=0;
  329. count2=0;
  330. //按照二进制位关系取出子集
  331. for(intj=0;j<binaryArray.length;j++){
  332. if(binaryArray[j]==1){
  333. childGroup1.add(array[j]);
  334. }else{
  335. childGroup2.add(array[j]);
  336. }
  337. }
  338. group1=newString[childGroup1.size()];
  339. group2=newString[childGroup2.size()];
  340. childGroup1.toArray(group1);
  341. childGroup2.toArray(group2);
  342. for(String[]a:totalGoodsIDs){
  343. if(isStrArrayContain(a,group1)){
  344. count1++;
  345. //在group1的条件下,统计group2的事件发生次数
  346. if(isStrArrayContain(a,group2)){
  347. count2++;
  348. }
  349. }
  350. }
  351. //{A}-->{B}的意思为在A的情况下发生B的概率
  352. System.out.print("{");
  353. for(Strings:group1){
  354. System.out.print(s+",");
  355. }
  356. System.out.print("}-->");
  357. System.out.print("{");
  358. for(Strings:group2){
  359. System.out.print(s+",");
  360. }
  361. System.out.print(MessageFormat.format(
  362. "},confidence(置信度):{0}/{1}={2}",count2,count1,count2
  363. *1.0/count1));
  364. if(count2*1.0/count1<minConf){
  365. //不符合要求,不是强规则
  366. System.out.println("由于此规则置信度未达到最小置信度的要求,不是强规则");
  367. }else{
  368. System.out.println("为强规则");
  369. }
  370. }
  371. }
  372. /**
  373. *数字转为二进制形式
  374. *
  375. *@parambinaryArray
  376. *转化后的二进制数组形式
  377. *@paramnum
  378. *待转化数字
  379. */
  380. privatevoidnumToBinaryArray(int[]binaryArray,intnum){
  381. intindex=0;
  382. while(num!=0){
  383. binaryArray[index]=num%2;
  384. index++;
  385. num/=2;
  386. }
  387. }
  388. }
调用类:

  1. /**
  2. *apriori关联规则挖掘算法调用类
  3. *@authorlyq
  4. *
  5. */
  6. publicclassClient{
  7. publicstaticvoidmain(String[]args){
  8. StringfilePath="C:\\Users\\lyq\\Desktop\\icon\\testInput.txt";
  9. AprioriTooltool=newAprioriTool(filePath,2);
  10. tool.printAttachRule(0.7);
  11. }
  12. }
输出的结果:

  1. 频繁1项集:
  2. {1,},{2,},{3,},{4,},{5,},
  3. 频繁2项集:
  4. {1,2,},{1,3,},{1,5,},{2,3,},{2,4,},{2,5,},
  5. 频繁3项集:
  6. {1,2,3,},{1,2,5,},
  7. 频繁4项集:
  8. {1,}-->{2,5,},confidence(置信度):2/6=0.333由于此规则置信度未达到最小置信度的要求,不是强规则
  9. {2,}-->{1,5,},confidence(置信度):2/7=0.286由于此规则置信度未达到最小置信度的要求,不是强规则
  10. {1,2,}-->{5,},confidence(置信度):2/4=0.5由于此规则置信度未达到最小置信度的要求,不是强规则
  11. {5,}-->{1,2,},confidence(置信度):2/2=1为强规则
  12. {1,5,}-->{2,},confidence(置信度):2/2=1为强规则
  13. {2,5,}-->{1,},confidence(置信度):2/2=1为强规则

程序算法的问题和技巧

在实现Apiori算法的时候,碰到的一些问题和待优化的点特别要提一下:

1、首先程序的运行效率不高,里面有大量的for嵌套循环叠加上循环,当然这有本身算法的原因(连接运算所致)还有我的各个的方法选择,很多一部分用来比较字符串数组。

2、这个是我觉得会是程序的一个漏洞,当生成的候选项集加入resultItemId时,会出现{1, 2, 3}和{3, 2, 1}会被当成不同的侯选集,未做顺序的判断。

3、程序的调试过程中由于未按照从小到大的排序,导致,生成的候选集与真实值不一致的情况,所以这里必须在频繁1项集的时候就应该是有序的。

4、在输出关联规则的时候,用到了数字转二进制数组的形式,输出他的各个非空子集,然后最出关联规则的判断。

Apriori算法的缺点

此算法的的应用非常广泛,但是他在运算的过程中会产生大量的侯选集,而且在匹配的时候要进行整个数据库的扫描,因为要做支持度计数的统计操作,在小规模的数据上操作还不会有大问题,如果是大型的数据库上呢,他的效率还是有待提高的。

相关文章: