An implementation of the ID3 Decision Tree algorithm in Java from scratch.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

255 lines
7.5 KiB

package com.AI;
import com.DBpackage.MyDatabase;
//import java.io.BufferedReader;
//import java.io.IOException;
//import java.io.InputStreamReader;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Set;
class Main
{
public static void main(String[] args) {
MyDatabase myDatabase = new MyDatabase();
myDatabase.connectDatabase();
// MySwing mySwing = new MySwing(myDatabase);
// mySwing.start();
// String indexAttribute, targetAttribute;
// try {
// BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(System.in));
// System.out.println("Enter index attribute");
// indexAttribute = bufferedReader.readLine();
// System.out.println("Enter target attribute");
//
// }
// catch(IOException e)
// {
// e.printStackTrace();
// }
//
NaryTree tree = new NaryTree();
ID3 id3 = new ID3();
try {
ArrayList<Row> examples = myDatabase.storeValues();
// for (Row row : examples) {
// for (Cell cell : row.getCellsList()) {
// System.out.print(cell.getAttributeName() + " " + cell.getValue() + ", ");
// }
// System.out.println();
// }
Set<String> attributesSet = id3.getUniqueAttributes("Income", "Person", examples);
// Set<String> attributesSet = id3.getUniqueAttributes("Play", "Day", examples);
// Set<String> attributesSet = id3.getUniqueAttributes("Result_of_Treatment", "Patient", examples);
for(String s: attributesSet)
System.out.println(s);
id3.setAttributesNew(examples);
// System.out.println(id3.getAttributesList().size());
List<Attribute> attributesList = id3.getAttributesList();
id3.setAttributesIndex(attributesList);
for(Attribute a: id3.getAttributesList())
System.out.printf("Attribute = %s Index = %d \n",a.getAttributeName(),a.getAttributeIndex());
id3.displayAttributesList(attributesList);
// id3.setModeAndPN(examples);
// id3.createDecisionTree(tree, attributesSet, examples);
// tree.displayTree();
// myDatabase.deleteTree();
// tree.dbTest(myDatabase);
MyJOCL myJocl = new MyJOCL();
int examplesSize = examples.size();
int maxLength = 1000;
int cellSpan = 20;
char[] examplesChar = new char[examplesSize*maxLength];
id3.setMaxLength(maxLength);
id3.setExamplesSize(examplesSize);
id3.setCellSpan(cellSpan);
myJocl.setMaxLength(maxLength);
myJocl.setExamplesSize(examplesSize);
// for (Row row : examples) {
// for (Cell cell : row.getCellsList()) {
// System.out.print(cell.getAttributeName() + " " + cell.getValue() + ", ");
// }
// System.out.println();
// }
System.out.println("Examples size = " + examplesSize);
// Row row = examples.get(0);
// System.out.print("Row size = " + row.getCellsList().size());
long start = System.nanoTime();
for(int i=0;i<examplesSize;i++)
{
Row row = examples.get(i);
// StringBuilder temp = new StringBuilder();
for(int j=0;j<row.getCellsList().size();j++)
{
Cell cell = row.getCellsList().get(j);
// System.out.println("Cell = " + cell.getAttributeName());
// ArrayList<Row> temp = examples.
// temp.append(cell.getAttributeName());
// temp.append(" ").append(cell.getValue()).append(" ");
String attributeName = cell.getAttributeName();
String value = cell.getValue();
for(int k=0;k<attributeName.length();k++)
{
examplesChar[i*maxLength + 2*j*cellSpan + k] = cell.getAttributeName().toCharArray()[k];
}
for(int k=0;k<value.length();k++)
{
examplesChar[i*maxLength + (2*j+1)*cellSpan + k] = cell.getValue().toCharArray()[k];
}
// if(j%2==0)
// examplesChar[i*maxLength + j*cellSpan] = cell.getAttributeName().toCharArray()[j];
// else
// examplesChar[i*maxLength + j*cellSpan] = cell.getValue().toCharArray()[j];
}
if(i>0)
examplesChar[i*maxLength-1] = '\n';
// temp.append("\n");
// System.out.print("Temp = " + temp);
// for(int j=0;j<temp.length();j++)
// {
// examplesChar[i*maxLength+j] = temp.toString().toCharArray()[j];
//
// }
}
myJocl.setExamplesCharBuffer(examplesChar);
long end = System.nanoTime();
double setupTime = (double)(end-start)/1000000000;
// for(int i=0; i<examplesSize;i++)
// {
//
// for(int j = 0;j<maxLength;j++)
// {
//// if(examplesChar[i*maxLength+j]=='\0')
//// continue;
// System.out.print(examplesChar[i*maxLength + j]);
// }
//
// }
// int k = 2;
// for(int i=0; i<examplesSize;i++)
// {
// for(int j =0;j<cellSpan;j++)
// {
// System.out.print(examplesChar[i*maxLength + (2*k+1)*cellSpan + j]);
// }
// }
// myJocl.allOneValues(examplesChar, examplesSize);
// for(String attributeName: attributesSet)
// {
//// myJocl.splitDataSetParallel(examplesChar, examplesSize, maxLength, attributeName, id3.getAttributesList());
//
// }
// id3.allOneValuesParallel(examplesChar, examplesSize);
id3.setMyJocl(myJocl);
int[] examplesInt = new int[examplesSize];
for(int i = 0; i< examplesSize; i++)
{
examplesInt[i] = 1;
}
// start = System.nanoTime();
// myJocl.testDataSetParallel(examplesInt,examplesSize, maxLength, cellSpan,
// "Income","Yes", attributesList, id3.getTargetAttribute());
// end = System.nanoTime();
// System.out.println("Time elapsed Parallel: " + (double)(end-start)/1000000000);
//
// start = System.nanoTime();
//// for(int k=0;k<10;k++)
// id3.splitDataSet(examples,"Yes","Income");
// end = System.nanoTime();
// System.out.println("Time elapsed Serial: " + (double)(end-start)/1000000000);
//
// id3.allOneValuesParallel(examplesInt, examplesSize,maxLength,cellSpan);
//
// double classEntropy = id3.getClassEntropyParallel(examplesInt, examplesSize,maxLength,cellSpan);
// System.out.println("Class Entropy Parallel = " + classEntropy);
//
// classEntropy = id3.getClassEntropy(examples);
// System.out.println("Class Entropy Serial = " + classEntropy);
String attributeName = "Education";
String value = "Bachelors";
id3.setModeAndPNParallel(examplesInt);
start = System.nanoTime();
// myJocl.testDataSetParallel2(examplesInt,examplesSize,maxLength,cellSpan,attributeName,value,
// attributesList,id3.getTargetAttribute());
// id3.getBestAttributeParallel(attributesSet,examplesInt,examplesSize,maxLength,cellSpan);
// id3.createDecisionTreeParallel2(examples,tree,attributesSet,examplesInt);
end = System.nanoTime();
System.out.println("Time elapsed setup: " + setupTime);
System.out.println("Time elapsed Parallel: " + (double)(end-start)/1000000000);
NaryTree tree2 = new NaryTree();
start = System.nanoTime();
// myJocl.testDataSetParallel2(examplesInt,examplesSize,maxLength,cellSpan,attributeName,value,
// attributesList,id3.getTargetAttribute());
// id3.getBestAttribute(examples,attributesSet);
id3.createDecisionTree(tree2,attributesSet,examples);
end = System.nanoTime();
System.out.println("Time elapsed Serial: " + (double)(end-start)/1000000000);
// int [] examplesViInt = id3.splitDataSetParallel(examplesInt,value,attributeName,examplesSize,maxLength,cellSpan);
//
//
// System.out.println(Arrays.toString(examplesViInt));
// int size = reduceArray(examplesViInt);
// System.out.println("size = " + size);
// id3.setModeOfTargetAttribute("Yes");
// tree.displayTree();
// myJocl.sumReduction2(10, new int[1]);
myJocl.endCL();
} catch (SQLException se) {
se.printStackTrace();
}
}
static int reduceArray(int[] array)
{
int sum = 0;
for (int value : array) {
sum += value;
}
return sum;
}
}