An implementation of the ID3 Decision Tree algorithm in Java from scratch.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

727 lines
18 KiB

package com.AI;
import java.util.*;
class ID3
{
private ArrayList<Attribute> attributesList;
private Set<String> attributesSet;
private double P, N;
private String modeOfTargetAttribute;
private MyJOCL myJocl;
private Attribute targetAttribute;
private String modeOfCurrentAttribute;
int examplesSize;
int maxLength;
int cellSpan;
ID3()
{
attributesList = new ArrayList<>();
attributesSet = new LinkedHashSet<>();
targetAttribute = new Attribute();
P = 1;
N = 1;
}
public Set<String> getAttributesSet() {
return attributesSet;
}
public void setAttributesSet(Set<String> attributesSet) {
this.attributesSet = attributesSet;
}
public double getP() {
return P;
}
public void setP(double p) {
this.P = p;
}
public double getN() {
return N;
}
public void setN(double n) {
this.N = n;
}
public ArrayList<Attribute> getAttributesList() {
return attributesList;
}
public void setAttributesList(ArrayList<Attribute> attribute) {
this.attributesList = attribute;
}
public Attribute getTargetAttribute() {
return targetAttribute;
}
public void setTargetAttribute(Attribute targetAttribute) {
this.targetAttribute = targetAttribute;
}
public String getModeOfTargetAttribute() {
return modeOfTargetAttribute;
}
public void setModeOfTargetAttribute(String modeOfTargetAttribute) {
this.modeOfTargetAttribute = modeOfTargetAttribute;
}
public int getExamplesSize() {
return examplesSize;
}
public void setExamplesSize(int examplesSize) {
this.examplesSize = examplesSize;
}
public int getMaxLength() {
return maxLength;
}
public void setMaxLength(int maxLength) {
this.maxLength = maxLength;
}
public int getCellSpan() {
return cellSpan;
}
public void setCellSpan(int cellSpan) {
this.cellSpan = cellSpan;
}
public MyJOCL getMyJocl() {
return myJocl;
}
public void setMyJocl(MyJOCL myJocl) {
this.myJocl = myJocl;
}
Set<String> getUniqueAttributes(String targetAttribute, String indexAttribute, ArrayList<Row> examples)
{
attributesSet = new LinkedHashSet<>();
Row row = examples.get(0);
for (Cell cell : row.getCellsList()) {
// System.out.print(cell.getAttributeName() + " " + cell.getValue() + " ");
if (cell.getAttributeName().equals(targetAttribute) ||
cell.getAttributeName().equalsIgnoreCase(indexAttribute))
continue;
attributesSet.add(cell.getAttributeName());
}
this.targetAttribute.setAttributeName(targetAttribute);
return attributesSet;
}
void setAttributesNew(ArrayList<Row> examples)
{
attributesList = new ArrayList<>();
Attribute attribute;
for (String s : attributesSet) {
attribute = new Attribute();
attribute.setAttributeName(s);
for (Row row :examples) {
for (Cell cell : row.getCellsList()) {
if(cell.getAttributeName().equalsIgnoreCase(s))
attribute.addToValuesSet(cell.getValue());
}
}
attributesList.add(attribute);
}
}
void setAttributesIndex(List<Attribute> attributesList)
{
int i = 1;
for(Attribute a: attributesList)
{
a.setAttributeIndex(i++);
}
targetAttribute.setAttributeIndex(i);
}
void displayAttributesList(List<Attribute> attributeList)
{
for(Attribute a: attributeList)
{
System.out.println("Attribute: " + a.getAttributeName());
for(String value: a.getValuesSet())
{
System.out.print("Value: " + value);
}
System.out.println();
}
}
static double calcEntropy(double p, double n) {
if (p == 0 || n == 0)
return 0;
else
return ((-p / (p + n)) * Logarithms.log2(p / (p + n))) - ((n / (p + n)) * Logarithms.log2(n / (p + n)));
}
Node createDecisionTree(NaryTree tree, Set<String> attributesSet, ArrayList<Row> examples) {
String targetAttributeValue = allOneValues(examples);
if(targetAttributeValue.equalsIgnoreCase("Yes"))
return new Node("Yes");
else if(targetAttributeValue.equalsIgnoreCase("No"))
return new Node("No");
if (attributesSet.isEmpty()) {
return new Node(modeOfCurrentAttribute);
}
Node node;
String attributeName;
attributeName = getBestAttribute(examples,attributesSet);
Attribute attribute = new Attribute();
for (Attribute a : getAttributesList()) {
if (a.getAttributeName().equalsIgnoreCase(attributeName))
attribute = a;
}
Set<String> uniqueValues = attribute.getValuesSet();
node = new Node(attributeName);
// node.setAttributeName(attributeName);
if (tree.getHead() == null) {
tree.setHead(node);
}
for (String s : uniqueValues) {
// System.out.println("For value: " + s);
ArrayList<Row> examplesVi = splitDataSet(examples,s,attributeName);
// System.out.println("Displaying examplesVi");
// for(Row row: examplesVi)
// {
// for(Cell cell : row.getCellsList())
// {
// System.out.print(cell.getAttributeName() + " " + cell.getValue() + ", ");
// }
// System.out.println();
// }
if(examplesVi.isEmpty())
{
tree.createNode(null,node,s,new Node(modeOfTargetAttribute));
}
else
{
for (Iterator<String> it = attributesSet.iterator(); it.hasNext(); )
{
String str = it.next();
if (str.equalsIgnoreCase(attributeName))
{
System.out.println("Removed: " + str);
it.remove();
}
}
Node temp = createDecisionTree(tree, attributesSet, examplesVi);
tree.createNode(null, node, s, temp);
}
}
return node;
}
void setModeAndPN(ArrayList<Row> examples)
{
int yesCounter = 0, noCounter = 0;
for(Row row: examples)
{
for(Cell cell: row.getCellsList())
{
if(cell.getValue().equalsIgnoreCase("Yes"))
yesCounter++;
else if(cell.getValue().equalsIgnoreCase("No"))
noCounter++;
}
}
this.P = yesCounter;
this.N = noCounter;
if(P > N)
modeOfTargetAttribute = "Yes";
else
modeOfTargetAttribute = "No";
System.out.println("Most common value: " + modeOfTargetAttribute);
}
double getClassEntropy(ArrayList<Row> examples)
{
int yesCounter = 0, noCounter = 0;
double classEntropy;
for(Row row: examples)
{
for(Cell cell: row.getCellsList())
{
if(cell.getValue().equalsIgnoreCase("Yes"))
yesCounter++;
else if(cell.getValue().equalsIgnoreCase("No"))
noCounter++;
}
}
P = yesCounter;
N = noCounter;
classEntropy = calcEntropy(yesCounter,noCounter);
return classEntropy;
}
String getBestAttribute(ArrayList<Row> examples, Set<String> attributesSet)
{
if(attributesSet.isEmpty())
return "Attributes set is empty";
double classEntropy = getClassEntropy(examples);
System.out.println("Class Entropy = " + classEntropy);
double attributeEntropy, maxGain=0, valueEntropy=0, gain = 0;
int p, n;
String bestAttribute = "";
Set<String> uniqueValues;
for(String attributeName: attributesSet)
{
attributeEntropy =0;
Attribute attribute = new Attribute();
for (Attribute a : getAttributesList()) {
if (a.getAttributeName().equalsIgnoreCase(attributeName))
attribute = a;
}
int attributeIndex = attribute.getAttributeIndex();
int targetAttributeIntex = targetAttribute.getAttributeIndex();
uniqueValues = attribute.getValuesSet();
for(String value: uniqueValues) {
p=0;
n=0;
for (Row row : examples) {
// boolean flag = false;
// for (Cell cell : row.getCellsList()) {
// if (cell.getValue().equalsIgnoreCase(value) && cell.getAttributeName().equalsIgnoreCase(attributeName))
// flag= true;
// }
// for (Cell cell : row.getCellsList()) {
// if(flag && cell.getValue().equalsIgnoreCase("Yes"))
// p++;
// else if(flag && cell.getValue().equalsIgnoreCase("No"))
// n++;
// }
Cell cell = row.getCellsList().get(attributeIndex);
Cell targetCell = row.getCellsList().get(targetAttributeIntex);
if (cell.getValue().equalsIgnoreCase(value) && targetCell.getValue().equalsIgnoreCase("Yes"))
p++;
else if(cell.getValue().equalsIgnoreCase(value) && targetCell.getValue().equalsIgnoreCase("No"))
n++;
//
}
// System.out.println("P = " + P + " " + "N = " + N);
valueEntropy = calcEntropy(p, n);
// System.out.println("Value Entropy = " + valueEntropy);
attributeEntropy = attributeEntropy + ((p + n) / (P + N)) * valueEntropy;
}
System.out.println("Attribute Entropy" + "(" + attributeName + ") = " + attributeEntropy);
gain = classEntropy - attributeEntropy;
if(gain >= maxGain)
{
// System.out.println("1");
maxGain = gain;
bestAttribute = attributeName;
}
System.out.println("Gain = " + gain + " " + "Max Gain = " + maxGain);
}
System.out.println("Best Attribute = " + bestAttribute);
return bestAttribute;
}
void test(ArrayList<Row> examples, Set<String> attributesSet)
{
double classEntropy = getClassEntropy(examples);
String bestAttribute = getBestAttribute(examples,attributesSet);
System.out.println("Class Entropy = " + classEntropy + " Best Attribute = " + bestAttribute);
}
ArrayList<Row> splitDataSet(ArrayList<Row> examples, String value, String attributeName)
{
ArrayList<Row> examplesVi = new ArrayList<>();
int attributeIndex = targetAttribute.getAttributeIndex();
for (Row row : examples)
{
boolean flag = false;
for (Cell cell : row.getCellsList())
{
if (cell.getValue().equalsIgnoreCase(value) && cell.getAttributeName().equalsIgnoreCase(attributeName))
flag= true;
}
// Cell cell = row.getCellsList().get(attributeIndex);
// if (cell.getValue().equalsIgnoreCase(value) && cell.getAttributeName().equalsIgnoreCase(attributeName))
// examplesVi.add(row);
if(flag)
examplesVi.add(row);
}
return examplesVi;
}
private String allOneValues(ArrayList<Row> examples)
{
int yesCount = 0;
int noCount = 0;
for (Row row : examples)
{
for (Cell cell : row.getCellsList())
{
if(cell.getValue().equalsIgnoreCase("Yes"))
yesCount++;
else if(cell.getValue().equalsIgnoreCase("No"))
noCount++;
}
}
if(yesCount > noCount)
modeOfCurrentAttribute = "Yes";
else
modeOfCurrentAttribute = "No";
if(yesCount == 0)
return "No";
else if(noCount == 0)
return "Yes";
else
return "";
}
String allOneValuesParallel(int [] examplesInt, int examplesSize, int maxLength, int cellSpan)
{
int yesCount = myJocl.getCount(examplesInt, examplesSize,maxLength,"Yes", cellSpan, targetAttribute);
// System.out.println("P = " + yesCount);
int noCount = myJocl.getCount(examplesInt, examplesSize,maxLength,"No", cellSpan, targetAttribute);
// System.out.println("N = " + noCount);
if(yesCount > noCount)
modeOfCurrentAttribute = "Yes";
else
modeOfCurrentAttribute = "No";
if(yesCount == 0)
return "No";
else if(noCount == 0)
return "Yes";
else
return "";
}
double getClassEntropyParallel(int [] examplesInt, int examplesSize, int maxLength, int cellSpan)
{
double classEntropy;
int yesCounter = myJocl.getCount(examplesInt, examplesSize,maxLength,"Yes", cellSpan, targetAttribute);
int noCounter = myJocl.getCount(examplesInt, examplesSize,maxLength,"No", cellSpan, targetAttribute);
P = yesCounter;
N = noCounter;
classEntropy = calcEntropy(yesCounter,noCounter);
return classEntropy;
}
int[] splitDataSetParallel(int[] examplesInt, String value, String attributeName, int examplesSize, int maxLength, int cellSpan)
{
// int[] examplesVi = new ArrayList<>();
int attributeIndex = targetAttribute.getAttributeIndex();
ReturnObj returnObj = myJocl.testDataSetParallel(examplesInt, examplesSize, maxLength, cellSpan, attributeName, value, attributesList, targetAttribute);
return returnObj.getExamplesInt();
}
String getBestAttributeParallel(Set<String> attributesSet,
int[] examplesInt,int examplesSize, int maxLength, int cellSpan)
{
if(attributesSet.isEmpty())
return "Attributes set is empty";
double classEntropy = getClassEntropyParallel(examplesInt,examplesSize,maxLength,cellSpan);
System.out.println("Class Entropy = " + classEntropy);
double attributeEntropy, maxGain=0, valueEntropy=0, gain = 0;
int p, n;
String bestAttribute = "";
Set<String> uniqueValues;
for(String attributeName: attributesSet)
{
attributeEntropy =0;
p=0;
n=0;
Attribute attribute = new Attribute();
for (Attribute a : getAttributesList()) {
if (a.getAttributeName().equalsIgnoreCase(attributeName))
attribute = a;
}
uniqueValues = attribute.getValuesSet();
for(String value: uniqueValues) {
ReturnObj returnObj = myJocl.testDataSetParallel2(examplesInt,examplesSize,maxLength,cellSpan
,attributeName,value,attributesList,targetAttribute);
p = returnObj.getYesCount();
// System.out.println("Yes count = " + p);
n = returnObj.getNoCount();
// System.out.println("No count = " + n);
// System.out.println("P = " + P + " " + "N = " + N);
valueEntropy = calcEntropy(p, n);
// System.out.println("Value Entropy = " + valueEntropy);
attributeEntropy = attributeEntropy + ((p + n) / (P + N)) * valueEntropy;
}
System.out.println("Attribute Entropy" + "(" + attributeName + ") = " + attributeEntropy);
gain = classEntropy - attributeEntropy;
if(gain >= maxGain)
{
// System.out.println("1");
maxGain = gain;
bestAttribute = attributeName;
}
System.out.println("Gain = " + gain + " " + "Max Gain = " + maxGain);
}
System.out.println("Best Attribute = " + bestAttribute);
return bestAttribute;
}
Node createDecisionTreeParallel(NaryTree tree, Set<String> attributesSet, int[] examplesInt) {
String targetAttributeValue = allOneValuesParallel(examplesInt,examplesSize,maxLength,cellSpan);
if(targetAttributeValue.equalsIgnoreCase("Yes"))
return new Node("Yes");
else if(targetAttributeValue.equalsIgnoreCase("No"))
return new Node("No");
if (attributesSet.isEmpty()) {
return new Node(modeOfCurrentAttribute);
}
Node node;
String attributeName;
attributeName = getBestAttributeParallel(attributesSet,examplesInt,examplesSize,maxLength,cellSpan);
Attribute attribute = new Attribute();
for (Attribute a : getAttributesList()) {
if (a.getAttributeName().equalsIgnoreCase(attributeName))
attribute = a;
}
Set<String> uniqueValues = attribute.getValuesSet();
node = new Node(attributeName);
// node.setAttributeName(attributeName);
if (tree.getHead() == null) {
tree.setHead(node);
}
for (String s : uniqueValues) {
System.out.println("For value: " + s);
int[] examplesViInt = splitDataSetParallel(examplesInt,s,attributeName,examplesSize,maxLength,cellSpan);
// System.out.println("Displaying examplesVi");
// for(Row row: examplesVi)
// {
// for(Cell cell : row.getCellsList())
// {
// System.out.print(cell.getAttributeName() + " " + cell.getValue() + ", ");
// }
// System.out.println();
// }
if(reduceArray(examplesViInt) == 0)
{
// System.out.println("Array size 0");
tree.createNode(null, node, s, new Node(modeOfTargetAttribute));
}
else
{
for (Iterator<String> it = attributesSet.iterator(); it.hasNext(); )
{
String str = it.next();
if (str.equalsIgnoreCase(attributeName))
{
System.out.println("Removed: " + str);
it.remove();
}
}
Node temp = createDecisionTreeParallel(tree, attributesSet, examplesViInt);
// System.out.println("Test: " + temp.getAttributeName());
tree.createNode(null, node, s, temp);
}
}
return node;
}
Node createDecisionTreeParallel2(ArrayList<Row> examples, NaryTree tree, Set<String> attributesSet, int[] examplesInt) {
String targetAttributeValue = allOneValues(examples);
if(targetAttributeValue.equalsIgnoreCase("Yes"))
return new Node("Yes");
else if(targetAttributeValue.equalsIgnoreCase("No"))
return new Node("No");
if (attributesSet.isEmpty()) {
return new Node(modeOfCurrentAttribute);
}
Node node;
String attributeName;
attributeName = getBestAttributeParallel(attributesSet,examplesInt,examplesSize,maxLength,cellSpan);
Attribute attribute = new Attribute();
for (Attribute a : getAttributesList()) {
if (a.getAttributeName().equalsIgnoreCase(attributeName))
attribute = a;
}
Set<String> uniqueValues = attribute.getValuesSet();
node = new Node(attributeName);
// node.setAttributeName(attributeName);
if (tree.getHead() == null) {
tree.setHead(node);
}
for (String s : uniqueValues) {
// System.out.println("For value: " + s);
ArrayList<Row> examplesVi = splitDataSet(examples,s,attributeName);
// System.out.println("Displaying examplesVi");
// for(Row row: examplesVi)
// {
// for(Cell cell : row.getCellsList())
// {
// System.out.print(cell.getAttributeName() + " " + cell.getValue() + ", ");
// }
// System.out.println();
// }
if(examplesVi.isEmpty())
{
tree.createNode(null,node,s,new Node(modeOfTargetAttribute));
}
else
{
for (Iterator<String> it = attributesSet.iterator(); it.hasNext(); )
{
String str = it.next();
if (str.equalsIgnoreCase(attributeName))
{
System.out.println("Removed: " + str);
it.remove();
}
}
Node temp = createDecisionTree(tree, attributesSet, examplesVi);
tree.createNode(null, node, s, temp);
}
}
return node;
}
void setModeAndPNParallel(int[] examplesInt)
{
this.P = myJocl.getCount(examplesInt,examplesSize,maxLength,"Yes",cellSpan,targetAttribute);
this.N = myJocl.getCount(examplesInt,examplesSize,maxLength,"No",cellSpan,targetAttribute);
if(P > N)
modeOfTargetAttribute = "Yes";
else
modeOfTargetAttribute = "No";
System.out.println("Most common value: " + modeOfTargetAttribute);
}
private int reduceArray(int[] array)
{
int sum = 0;
for (int value : array) {
sum += value;
}
return sum;
}
}
class ReturnObj
{
private int [] examplesInt;
private int size;
private int yesCount;
private int noCount;
ReturnObj()
{
size = 0;
}
public int[] getExamplesInt() {
return examplesInt;
}
public void setExamplesInt(int[] examplesInt) {
this.examplesInt = examplesInt;
}
public int getSize() {
return size;
}
public void setSize(int size) {
this.size = size;
}
public int getYesCount() {
return yesCount;
}
public void setYesCount(int yesCount) {
this.yesCount = yesCount;
}
public int getNoCount() {
return noCount;
}
public void setNoCount(int noCount) {
this.noCount = noCount;
}
}