Monday, March 12, 2012

How to implement Logistic Regression with stochastic gradient ascent in Java.



I will try in this post to give a step by step tutorial for implementing from scratch a simple logistic regression for classification in Java. Note that the provided code is not optimized - however it works just fine.

The problem

The goal is to build a model for binary classification: given a vector of features $\boldsymbol{x}^{(i)}$ we want to predict the corresponding label ($y^{(i)} \in \{0,1\}$). We assume that $h_{\boldsymbol\alpha}(\boldsymbol{x}^{(i)})$ is our prediction for instance $i$, which is given by the following equation:

$h_{\boldsymbol{\alpha}}(\boldsymbol{x})$= $1\over{1+e^{-\boldsymbol{\alpha} \boldsymbol{x}}}$ ,

where $\boldsymbol{\alpha}$ is a weight vector. The right hand side of the previous equation is called the logistic function or the sigmoid function of $\boldsymbol{\alpha}\boldsymbol{x}$ (see also here) .

If we assume that
$\Pr(y=1|\boldsymbol{x},\boldsymbol{\alpha}) = h_{\boldsymbol\alpha}(\boldsymbol{x})$ ,

$\Pr(y=1|\boldsymbol{x},\boldsymbol{\alpha}) = 1 - h_{\boldsymbol\alpha}(\boldsymbol{x})$ ,

we can write $\Pr(y|\boldsymbol{x},\boldsymbol{\alpha})$ as follows:

$\Pr(y|\boldsymbol{x},\boldsymbol{\alpha})=h_{\alpha}(\boldsymbol{x})^y(1-h_{\alpha}(\boldsymbol{x}))^y$ .

Now we can use MLE (see here ) to estimate the optimal weights $\boldsymbol\alpha$. Assuming that we have $m$ instances in our training set, the likelihood will be:

$L(\alpha) = \prod_{i=1}^m \Pr(y^{(i)}|\boldsymbol{x^{(i)}}\boldsymbol{\alpha})$ ,

and by taking the logs we get:

$\log L(\boldsymbol\alpha) = \sum_{i=1}^{m}y^{(i)} \log h(x^{(i)}) + (1 -y^{(i)}) \log (1-h(x^{(i)}))$ .

But how do we maximize this likelihood? There are many techniques available in the literature. In this post, I will use the simplest one, stochastic gradient ascent (ascent because we maximize - see also here). By considering only one example at a time, the gradient of the log likelihood for each dimension $j$ is:

$\partial{\log L}\over{\partial{\alpha_j}}$ $ = (y - h_{\boldsymbol{\alpha}}(\boldsymbol{x}))x_j$ .

Now that we have estimated the gradient, the updating rule for each dimension will be:

$\alpha_j := \alpha_j + \lambda (y^{(i)} - h_{\boldsymbol\alpha}(\boldsymbol{x}^{(i)}))x_j^{(i)}$ .

If you need more information about this procedure, see Andrew Ng's homepage .

The Java Code

In the code below I implemented what I just described. The Logistic class contains all the necessary methods. This is how you can call it:
/* Call Logistic constructor
initial coefficients (coeffs) to zero
*/
Logistic logistic = new Logistic(coeffs, 0.0000001, l, 10000, 0.00001,null);
logistic.classifyBatch();
Further, I am using the following holder class for keeping an instance:
public class BinaryInstance {
private boolean label;
private double[] x;

public BinaryInstance() {
// TODO Auto-generated constructor stub
}

public boolean getLabel() {
return label;
}

public void setLabel(boolean label) {
this.label = label;
}

public double[] getX() {
return x;
}

public void setX(double[] x) {
this.x = x;
}
}



Next, I am calling a method from a class "Utils" to transform raw data to "BinaryInstance" data:

/* In this specific example, "ad" and "nonad" are the labels. */

public BinaryInstance rowDataToInstance(String instance) {
BinaryInstance inst = new BinaryInstance();
Boolean labelAd = instance.endsWith(",ad.");
instance = instance.replace(",ad.", "");
instance = instance.replace(",nonad.", "");
inst.setX(getCoeffs(instance));
inst.setLabel(labelAd);
return inst;

}

public double [] getCoeffs(String coeffs){

coeffs = coeffs.replaceAll("\\[", "");
coeffs = coeffs.replaceAll("\\]", "");
String [] tmpAr = coeffs.split(",");
double [] result = new double [tmpAr.length];
for(int i=0; i< tmpAr.length; i++) result[i] = Double.parseDouble(tmpAr[i].trim()); return result;

}


Finally, the Logistic class is given below:

package Classifiers;

import java.util.ArrayList;
import java.util.Collections;
import BinaryInstance;

public class Logistic {
private double[] alpha;
private double lambda;
private Utils u;
private double logLikelihood;
private double prevLogLikelihood;
private ArrayList<Binaryinstance> allData;
private int datasetIterations;
private int maxDatasetIterations;
private double tolerance;

/**
*
* @param alpha
* : weights
* @param lambda
* : learning rate
* @param data
* : dataset (instances to vector of values)
* @param maxDatasetIterations
* @param tolerance
*/
public de(double[] alpha, double lambda,ArrayList<string> data,
int maxDatasetIterations, double tolerance) {
this.alpha = alpha;
this.lambda = lambda;
prevLogLikelihood = Double.POSITIVE_INFINITY;
logLikelihood = 0;
allData = new ArrayList<Binaryinstance>();
u = new Utils();
for (String row : data) {
BinaryInstance instance = u.rowDataToInstance(row);
allData.add(instance);
}
datasetIterations = 0;
this.maxDatasetIterations = maxDatasetIterations;
this.tolerance = tolerance;

}

public void classifyByInstance() {

while (evaluateCondition()) {
prevLogLikelihood = logLikelihood;
datasetIterations++;
Collections.shuffle(allData);
for (BinaryInstance instance : allData) {
double probPositive = estimateProbs(instance.getX());
double label = (instance.getLabel() == true) ? 1 : 0;
adjustWeights(instance.getX(), probPositive, label);

}
logLikelihood = calculateLogL(allData);

}

}

private boolean evaluateCondition() {
return (Math.abs(logLikelihood - prevLogLikelihood) > tolerance && datasetIterations < maxDatasetIterations) ? true : false; } private double estimateProbs(double[] x) { double sum = alpha[0]; for (int i = 1; i < this.alpha.length; i++) sum += this.alpha[i] * x[i - 1]; double exponent = Math.exp(-sum); double probPositive = 1 / (1 + exponent); if (probPositive == 0) probPositive = 0.00001; else if (probPositive == 1) probPositive = 0.9999; return probPositive; } private void adjustWeights(double[] x, double probPositive, double label) { //for the intercept this.alpha[0] += this.lambda * (label - probPositive); for (int i = 1; i < this.alpha.length; i++) { this.alpha[i] += this.lambda * x[i - 1] * (label - probPositive); } } private double calculateLogL(ArrayList<Binaryinstance> allData) {

double logl = 0;
for (BinaryInstance instance : allData) {

double probPositive = estimateProbs(instance.getX());
double label = (instance.getLabel() == true) ? 1 : 0;
double probNegative = 1 - probPositive;
double tmp = label * Math.log(probPositive) + (1 - label)
* Math.log(probNegative);
logl += tmp;
}
return logl;

}
}