- 数据挖掘中决策树C4.5预测算法实现(半成品,还要写规则后煎支及对非离散数据信息增益计算),下一篇博客讲原理
- package org.struct.decisiontree;
-
-
import java.util.ArrayList;
-
import java.util.Arrays;
-
import java.util.List;
-
import java.util.TreeSet;
-
-
-
-
-
public class DecisionTreeBaseC4p5 {
-
-
-
-
-
private DecisionTreeNode root;
-
-
-
-
-
private boolean[] visable;
-
-
private static final int NOT_FOUND = -1;
-
-
private static final int DATA_START_LINE = 1;
-
-
private Object[] trainingArray;
-
-
private String[] columnHeaderArray;
-
-
-
-
-
private int nodeIndex;
-
-
-
-
-
@SuppressWarnings("boxing")
-
public static void main(String[] args) {
-
Object[] array = new Object[] {
-
new String[] { "age", "income", "student", "credit_rating", "buys_computer" },
-
new String[] { "youth", "high", "no", "fair", "no" },
-
new String[] { "youth", "high", "no", "excellent", "no" },
-
new String[] { "middle_aged", "high", "no", "fair", "yes" },
-
new String[] { "senior", "medium", "no", "fair", "yes" },
-
new String[] { "senior", "low", "yes", "fair", "yes" },
-
new String[] { "senior", "low", "yes", "excellent", "no" },
-
new String[] { "middle_aged", "low", "yes", "excellent", "yes" },
-
new String[] { "youth", "medium", "no", "fair", "no" },
-
new String[] { "youth", "low", "yes", "fair", "yes" },
-
new String[] { "senior", "medium", "yes", "fair", "yes" },
-
new String[] { "youth", "medium", "yes", "excellent", "yes" },
-
new String[] { "middle_aged", "medium", "no", "excellent", "yes" },
-
new String[] { "middle_aged", "high", "yes", "fair", "yes" },
-
new String[] { "senior", "medium", "no", "excellent", "no" },
- };
-
-
DecisionTreeBaseC4p5 tree = new DecisionTreeBaseC4p5();
-
tree.create(array, 4);
-
System.out.println("===============END PRINT TREE===============");
-
System.out.println("===============DECISION RESULT===============");
-
- }
-
-
-
-
-
-
public void forecast(String[] printData, DecisionTreeNode node) {
-
int index = getColumnHeaderIndexByName(node.nodeName);
-
if (index == NOT_FOUND) {
- System.out.println(node.nodeName);
- }
- DecisionTreeNode[] childs = node.childNodesArray;
-
for (int i = 0; i < childs.length; i++) {
-
if (childs[i] != null) {
-
if (childs[i].parentArrtibute.equals(printData[index])) {
- forecast(printData, childs[i]);
- }
- }
- }
- }
-
-
-
-
-
-
public void create(Object[] array, int index) {
-
this.trainingArray = Arrays.copyOfRange(array, DATA_START_LINE,
- array.length);
- init(array, index);
-
createDecisionTree(this.trainingArray);
- printDecisionTree(root);
- }
-
-
-
-
-
-
@SuppressWarnings("boxing")
-
public Object[] getMaxGain(Object[] array) {
-
Object[] result = new Object[2];
-
double gain = 0;
-
int index = -1;
-
-
for (int i = 0; i < visable.length; i++) {
-
if (!visable[i]) {
-
-
double value = gainRatio(array, i, this.nodeIndex);
- System.out.println(value);
-
if (gain < value) {
- gain = value;
- index = i;
- }
- }
- }
-
result[0] = gain;
-
result[1] = index;
-
-
if (index != -1) {
-
visable[index] = true;
- }
-
return result;
- }
-
-
-
-
-
public void createDecisionTree(Object[] array) {
- Object[] maxgain = getMaxGain(array);
-
if (root == null) {
-
root = new DecisionTreeNode();
-
root.parentNode = null;
-
root.parentArrtibute = null;
-
root.arrtibutesArray = getArrtibutesArray(((Integer) maxgain[1])
- .intValue());
-
root.nodeName = getColumnHeaderNameByIndex(((Integer) maxgain[1])
- .intValue());
-
root.childNodesArray = new DecisionTreeNode[root.arrtibutesArray.length];
- insertDecisionTree(array, root);
- }
- }
-
-
-
-
-
-
public void insertDecisionTree(Object[] array, DecisionTreeNode parentNode) {
- String[] arrtibutes = parentNode.arrtibutesArray;
-
for (int i = 0; i < arrtibutes.length; i++) {
- Object[] pickArray = pickUpAndCreateSubArray(array, arrtibutes[i],
- getColumnHeaderIndexByName(parentNode.nodeName));
- Object[] info = getMaxGain(pickArray);
-
double gain = ((Double) info[0]).doubleValue();
-
if (gain != 0) {
-
int index = ((Integer) info[1]).intValue();
-
DecisionTreeNode currentNode = new DecisionTreeNode();
- currentNode.parentNode = parentNode;
- currentNode.parentArrtibute = arrtibutes[i];
- currentNode.arrtibutesArray = getArrtibutesArray(index);
- currentNode.nodeName = getColumnHeaderNameByIndex(index);
-
currentNode.childNodesArray = new DecisionTreeNode[currentNode.arrtibutesArray.length];
- parentNode.childNodesArray[i] = currentNode;
- insertDecisionTree(pickArray, currentNode);
-
} else {
-
DecisionTreeNode leafNode = new DecisionTreeNode();
- leafNode.parentNode = parentNode;
- leafNode.parentArrtibute = arrtibutes[i];
-
leafNode.arrtibutesArray = new String[0];
-
leafNode.nodeName = getLeafNodeName(pickArray,this.nodeIndex);
-
leafNode.childNodesArray = new DecisionTreeNode[0];
- parentNode.childNodesArray[i] = leafNode;
- }
- }
- }
-
-
-
-
-
public void printDecisionTree(DecisionTreeNode node) {
- System.out.println(node.nodeName);
- DecisionTreeNode[] childs = node.childNodesArray;
-
for (int i = 0; i < childs.length; i++) {
-
if (childs[i] != null) {
- System.out.println(childs[i].parentArrtibute);
- printDecisionTree(childs[i]);
- }
- }
- }
-
-
-
-
-
-
-
-
public void init(Object[] dataArray, int index) {
-
this.nodeIndex = index;
-
-
this.columnHeaderArray = (String[]) dataArray[0];
-
visable = new boolean[((String[]) dataArray[0]).length];
-
for (int i = 0; i < visable.length; i++) {
-
if (i == index) {
-
visable[i] = true;
-
} else {
-
visable[i] = false;
- }
- }
- }
-
-
-
-
-
-
-
-
public Object[] pickUpAndCreateSubArray(Object[] array, String arrtibute,
-
int index) {
-
List<String[]> list = new ArrayList<String[]>();
-
for (int i = 0; i < array.length; i++) {
- String[] strs = (String[]) array[i];
-
if (strs[index].equals(arrtibute)) {
- list.add(strs);
- }
- }
-
return list.toArray();
- }
-
-
-
-
-
-
-
-
-
public double gain(Object[] array, int index, int nodeIndex) {
-
int[] counts = separateToSameValueArrays(array, nodeIndex);
- String[] arrtibutes = getArrtibutesArray(index);
-
double infoD = infoD(array, counts);
-
double infoaD = infoaD(array, index, nodeIndex, arrtibutes);
-
return infoD - infoaD;
- }
-
-
-
-
-
-
-
public int[] separateToSameValueArrays(Object[] array, int nodeIndex) {
- String[] arrti = getArrtibutesArray(nodeIndex);
-
int[] counts = new int[arrti.length];
-
for (int i = 0; i < counts.length; i++) {
-
counts[i] = 0;
- }
-
for (int i = 0; i < array.length; i++) {
- String[] strs = (String[]) array[i];
-
for (int j = 0; j < arrti.length; j++) {
-
if (strs[nodeIndex].equals(arrti[j])) {
- counts[j]++;
- }
- }
- }
-
return counts;
- }
-
-
-
-
-
-
-
-
-
-
public double gainRatio(Object[] array,int index,int nodeIndex){
-
double gain = gain(array,index,nodeIndex);
-
int[] counts = separateToSameValueArrays(array, index);
-
double splitInfo = splitInfoaD(array,counts);
-
if(splitInfo != 0){
-
return gain/splitInfo;
- }
-
return 0;
- }
-
-
-
-
-
-
-
-
-
public double infoD(Object[] array, int[] counts) {
-
double infoD = 0;
-
for (int i = 0; i < counts.length; i++) {
- infoD += DecisionTreeUtil.info(counts[i], array.length);
- }
-
return infoD;
- }
-
-
-
-
-
-
-
-
-
public double splitInfoaD(Object[] array, int[] counts) {
-
return infoD(array, counts);
- }
-
-
-
-
-
-
-
-
-
-
public double infoaD(Object[] array, int index, int nodeIndex,
- String[] arrtibutes) {
-
double sv_total = 0;
-
for (int i = 0; i < arrtibutes.length; i++) {
- sv_total += infoDj(array, index, nodeIndex, arrtibutes[i],
- array.length);
- }
-
return sv_total;
- }
-
-
-
-
-
-
-
-
-
-
-
public double infoDj(Object[] array, int index, int nodeIndex,
-
String arrtibute, int allTotal) {
- String[] arrtibutes = getArrtibutesArray(nodeIndex);
-
int[] counts = new int[arrtibutes.length];
-
for (int i = 0; i < counts.length; i++) {
-
counts[i] = 0;
- }
-
-
for (int i = 0; i < array.length; i++) {
- String[] strs = (String[]) array[i];
-
if (strs[index].equals(arrtibute)) {
-
for (int k = 0; k < arrtibutes.length; k++) {
-
if (strs[nodeIndex].equals(arrtibutes[k])) {
- counts[k]++;
- }
- }
- }
- }
-
-
int total = 0;
-
double infoDj = 0;
-
for (int i = 0; i < counts.length; i++) {
- total += counts[i];
- }
-
for (int i = 0; i < counts.length; i++) {
- infoDj += DecisionTreeUtil.info(counts[i], total);
- }
-
return DecisionTreeUtil.getPi(total, allTotal) * infoDj;
- }
-
-
-
-
-
-
@SuppressWarnings("unchecked")
-
public String[] getArrtibutesArray(int index) {
-
TreeSet<String> set = new TreeSet<String>(new SequenceComparator());
-
for (int i = 0; i < trainingArray.length; i++) {
- String[] strs = (String[]) trainingArray[i];
- set.add(strs[index]);
- }
-
String[] result = new String[set.size()];
-
return set.toArray(result);
- }
-
-
-
-
-
-
public String getColumnHeaderNameByIndex(int index) {
-
for (int i = 0; i < columnHeaderArray.length; i++) {
-
if (i == index) {
-
return columnHeaderArray[i];
- }
- }
-
return null;
- }
-
-
-
-
-
-
public String getLeafNodeName(Object[] array,int nodeIndex) {
-
if (array != null && array.length > 0) {
-
String[] strs = (String[]) array[0];
-
return strs[nodeIndex];
- }
-
return null;
- }
-
-
-
-
-
-
public int getColumnHeaderIndexByName(String name) {
-
for (int i = 0; i < columnHeaderArray.length; i++) {
-
if (name.equals(columnHeaderArray[i])) {
-
return i;
- }
- }
-
return NOT_FOUND;
- }
- }
- package org.struct.decisiontree;
-
-
-
-
-
public class DecisionTreeNode {
-
- DecisionTreeNode parentNode;
-
- String parentArrtibute;
-
- String nodeName;
-
- String[] arrtibutesArray;
-
- DecisionTreeNode[] childNodesArray;
-
- }
- package org.struct.decisiontree;
-
-
-
-
-
public class DecisionTreeUtil {
-
-
-
-
-
-
-
-
-
public static double info(int x, int total) {
-
if (x == 0) {
-
return 0;
- }
-
double x_pi = getPi(x, total);
-
return -(x_pi * logYBase2(x_pi));
- }
-
-
-
-
-
-
-
-
public static double logYBase2(double y) {
-
return Math.log(y) / Math.log(2);
- }
-
-
-
-
-
-
-
-
-
public static double getPi(int x, int total) {
-
return x / (double) total;
- }
-
- }
- package org.struct.decisiontree;
-
-
import java.util.Comparator;
-
-
-
-
-
-
@SuppressWarnings("unchecked")
-
public class SequenceComparator implements Comparator {
-
-
public int compare(Object o1, Object o2) throws ClassCastException {
- String str1 = (String) o1;
- String str2 = (String) o2;
-
return str1.compareTo(str2);
- }
- }
相关文章: