Archive

Archive for the ‘Programming’ Category

An extensible Decision-Tree framework in Java

January 8th, 2009

Problem Description: Decision trees are widely used predictive models in data mining and machine learning domains. This post describes a java-based decision-tree framework that can be worked upon and extended as needed. The implementation is done similar to the C5 library. I would also assume that the user has fundamental knowledge on the classifier and won’t go into explaining what a decision-tree is. More info can be found here.

Framework Description: The framework essentially comprises 4 parts.


1.) TreeNode and DataPoint classes to hold hold data and define tree nodes: -

DataPoint contains an attribute array to hold the values for a particular node and is contained within each TreeNode instance. The structures for these classes are as follows.

class TreeNode {

public double entropy;
public Vector data;
public int decompositionAttribute;
public int decompositionValue;
public TreeNode [] children;
public TreeNode parent;

public TreeNode () {
data = new Vector ();
}
}

class DataPoint {

public int [] attributes;
public String label;


public DataPoint (final int numattributes) {
attributes = new int [numattributes];
}
}

2.) DecisionTree: -

This is the base class for decision-tree implementation and has a containment relation with the TreeNode class. It encapsulates all the implementation details for decision tree creation and entropy calculation. It uses the max-gain criteria for splitting over nodes. Right now, pruning hasn’t been implemented.

The field-definitions for decision-tree is as follows.

public class DecisionTree {

private int numAttributes;
private String [] attributeNames;
private Vector [] domains;
private TreeNode trainingRoot;
private TreeNode testingRoot;
private String namesFile;
private String trainingDataFile;
private String testingDataFile;
private int correctCount;
private int inCorrectCount;
private boolean debug;
private int maxDepth;
….

}

The decision-tree class invokes the InputProcessor (explained below) , which populates the training and testing Roots with input data. After this, it invokes createDecisionTree and induce modules to the create the actual tree.

The overall structure of induce is as follows.

public void induce (final TreeNode node) {

double bestEntropy = 0;
boolean selected = false;
int selectedAttribute = 0;
int numdata = node.data.size ();
int numinputattributes = numAttributes - 1;
node.entropy = calculateEntropy (node.data);
if (node.entropy == 0) {
return;
}


for (int i = 0; i < numinputattributes; i++) {
int numvalues = domains[i].size ();
if (alreadyUsedToDecompose (node, i)) {
continue;
}

double averageentropy = 0;

for (int j = 0; j < numvalues; j++) {
Vector subset = getSubset (node.data, i, j);
if (subset.size () == 0) {
continue;
}

double subentropy = calculateEntropy (subset);
averageentropy += subentropy * subset.size ();
}

averageentropy = averageentropy / numdata;


if (selected == false) {
selected = true;
bestEntropy = averageentropy;
selectedAttribute = i;
} else {

if (averageentropy < bestEntropy) {

selected = true;
bestEntropy = averageentropy;
selectedAttribute = i;
}
}
}

if (selected == false) {
return;
}

int numvalues = domains[selectedAttribute].size ();
node.decompositionAttribute = selectedAttribute;
node.children = new TreeNode [numvalues];
for (int j = 0; j < numvalues; j++) {
node.children[j] = new TreeNode ();
node.children[j].parent = node;
node.children[j].data = getSubset (node.data,
selectedAttribute, j);
node.children[j].decompositionValue = j;
}

for (int j = 0; j < numvalues; j++) {
induce (node.children[j]);
}
node.data = null;
}


The most important function call within induce is to calculate the entropy and the structure for that is given below.

public double calculateEntropy (final Vector data) {

int numdata = data.size ();
if (numdata == 0) {
return 0;
}

int attribute = numAttributes - 1;
int numvalues = domains[attribute].size ();
double sum = 0;

for (int i = 0; i < numvalues; i++) {
int count = 0;
for (int j = 0; j < numdata; j++) {
DataPoint point = (DataPoint) data.elementAt (j);
if (point.attributes[attribute] == i) {
count++;
}
}

double probability = ((double) count) / numdata;

if (count > 0) {
sum += -probability * Math.log (probability);
}
}
return sum;
}

3.) InputProcessor: -

This is used for all the input processing and in turn, for node data population and is utilized by DecisionTree to get the populated Node Objects from raw input data. This is a fairly complicated class with different modules to read the attributes and the actual data. Besides, it contains the functionality to do a random split between training and test in case only one data-set is provided & to discretize the linear data into bins.

The key module is readDataSets.

public int readDataSets (final String fileName, final TreeNode root, final String toSplit) throws Exception {

FileInputStream in = null;
/************************************************
* Reading the Training / Testing data set File *
************************************************/
try {
File inputFile = new File (fileName);
in = new FileInputStream (inputFile);
} catch (Exception e) {
System.err.println (”Unable to open file: ” + fileName + “n” + e);
return -1;
}

BufferedReader bin = new BufferedReader (new InputStreamReader (in));
String input;
int index = 1;
while (true) {
input = bin.readLine ();
if (input == null) {
break;
} else if (input.startsWith (”|”)) {
continue;
} else if (input.contains (”|”)) {
input = input.substring (0, input.indexOf (”|”));
}
if (input.equals (”")) {
continue;
}

StringTokenizer tokenizer = new StringTokenizer (input, “,”);
int numtokens = tokenizer.countTokens ();
if (skipCount > -1) {
if (numtokens != numAttributes + 1) {
return -1;
}
} else if (numtokens != numAttributes) {
return -1;
}

DataPoint point = new DataPoint (numAttributes);
// if there is no label to skip
if (skipCount == -1) {
point.label = “Example#” + index;
for (int i = 0; i < numAttributes; i++) {
point.attributes[i] = getSymbolValue (i, tokenizer.nextToken ());
}
} else if (skipCount > -1) {
int attributeIndex = 0;
for (int panditIndex = 0; panditIndex < numAttributes + 1; panditIndex++) {
// assign label to the data point and skip it as an attribute field
if (panditIndex == skipCount) {
point.label = tokenizer.nextToken ();
continue;
}

point.attributes[attributeIndex] = getSymbolValue (attributeIndex, tokenizer.nextToken ());
attributeIndex = attributeIndex + 1;
}
}

/************************************************
* Required 2/3-1/3 random data split follows *
************************************************/
if (toSplit.equals (”SPLIT”)) {
double randomNumber = 3 * Math.random ();
if (randomNumber > 2) {
testingRoot.data.addElement (point);
} else {
root.data.addElement (point);
}
} else {
root.data.addElement (point);
}
index = index + 1;
}
in.close ();
bin.close ();
return 1;
}

4.) DisplayProcessor:-

This class provides the key implementation for handling the display functionality. The output formatting is similar to what you see here. The key module is displayTree which displays each node-level recursively.

public void displayTree (final TreeNode node, final String tab, final int numAttributes, final String [] attributeNames, final Vector [] domains, final int depth) {

int outputattr = numAttributes - 1;
if (node.children == null) {
int [] values = getAllValues (node.data, outputattr, domains);
if (values.length == 1) {
System.out.print (” ” + domains[outputattr].elementAt (values[0]));
return;
}

System.out.print (” {”);
for (int i = 0; i < values.length; i++) {
System.out.print (”"” + domains[outputattr].elementAt (values[i]) + “”");
if (i != values.length - 1) {
System.out.print (” , “);
}
}

System.out.print (” };”);
return;
}

int numvalues = node.children.length;
for (int i = 0; i < numvalues; i++) {
System.out.print (”n” + tab + attributeNames[node.decompositionAttribute] + ” = ” + domains[node.decompositionAttribute].elementAt (i) + “:”);
if (depth <= this.MAX_DEPTH) {
displayTree (node.children[i], tab + “: “, numAttributes, attributeNames, domains, depth + 1);
} else {
System.out.print (”…”);
}

if (i != numvalues - 1) {
System.out.print (”t”);
} else {
System.out.print (”");
}
}
}

Again, this can be easily modified or extended as needed since the display functionality is cleanly decoupled from decision-tree creation logic and other parts explained above. Finally, in case you require the complete framework implementation alongwith the data-sets and instructions, please drop me a note explaining the purpose and need for it and I shall get back to you.

Share/Save/Bookmark

SociBook del.icio.us Digg Facebook Google Yahoo Buzz StumbleUpon

Artificial Intelligence, Programming , , , ,