## Monday, March 12, 2012

### How to implement Logistic Regression with stochastic gradient ascent in Java.

I will try in this post to give a step by step tutorial for implementing from scratch a simple logistic regression for classification in Java. Note that the provided code is not optimized - however it works just fine.

The problem

The goal is to build a model for binary classification: given a vector of features $\boldsymbol{x}^{(i)}$ we want to predict the corresponding label ($y^{(i)} \in \{0,1\}$). We assume that $h_{\boldsymbol\alpha}(\boldsymbol{x}^{(i)})$ is our prediction for instance $i$, which is given by the following equation:

$h_{\boldsymbol{\alpha}}(\boldsymbol{x})$= $1\over{1+e^{-\boldsymbol{\alpha} \boldsymbol{x}}}$ ,

where $\boldsymbol{\alpha}$ is a weight vector. The right hand side of the previous equation is called the logistic function or the sigmoid function of $\boldsymbol{\alpha}\boldsymbol{x}$ (see also here) .

If we assume that
$\Pr(y=1|\boldsymbol{x},\boldsymbol{\alpha}) = h_{\boldsymbol\alpha}(\boldsymbol{x})$ ,

$\Pr(y=1|\boldsymbol{x},\boldsymbol{\alpha}) = 1 - h_{\boldsymbol\alpha}(\boldsymbol{x})$ ,

we can write $\Pr(y|\boldsymbol{x},\boldsymbol{\alpha})$ as follows:

$\Pr(y|\boldsymbol{x},\boldsymbol{\alpha})=h_{\alpha}(\boldsymbol{x})^y(1-h_{\alpha}(\boldsymbol{x}))^y$ .

Now we can use MLE (see here ) to estimate the optimal weights $\boldsymbol\alpha$. Assuming that we have $m$ instances in our training set, the likelihood will be:

$L(\alpha) = \prod_{i=1}^m \Pr(y^{(i)}|\boldsymbol{x^{(i)}}\boldsymbol{\alpha})$ ,

and by taking the logs we get:

$\log L(\boldsymbol\alpha) = \sum_{i=1}^{m}y^{(i)} \log h(x^{(i)}) + (1 -y^{(i)}) \log (1-h(x^{(i)}))$ .

But how do we maximize this likelihood? There are many techniques available in the literature. In this post, I will use the simplest one, stochastic gradient ascent (ascent because we maximize - see also here). By considering only one example at a time, the gradient of the log likelihood for each dimension $j$ is:

$\partial{\log L}\over{\partial{\alpha_j}}$ $= (y - h_{\boldsymbol{\alpha}}(\boldsymbol{x}))x_j$ .

Now that we have estimated the gradient, the updating rule for each dimension will be:

$\alpha_j := \alpha_j + \lambda (y^{(i)} - h_{\boldsymbol\alpha}(\boldsymbol{x}^{(i)}))x_j^{(i)}$ .

The Java Code

In the code below I implemented what I just described. The Logistic class contains all the necessary methods. This is how you can call it:
/* Call Logistic constructorinitial coefficients (coeffs) to zero*/Logistic logistic = new Logistic(coeffs, 0.0000001, l, 10000, 0.00001,null);logistic.classifyBatch();
Further, I am using the following holder class for keeping an instance:
public class BinaryInstance {private boolean label;private double[] x;public BinaryInstance() {// TODO Auto-generated constructor stub}public boolean getLabel() {return label;}public void setLabel(boolean label) {this.label = label;}public double[] getX() {return x;}public void setX(double[] x) {this.x = x;}}

Next, I am calling a method from a class "Utils" to transform raw data to "BinaryInstance" data:

/* In this specific example, "ad" and "nonad" are the labels. */public BinaryInstance rowDataToInstance(String instance) {BinaryInstance inst = new BinaryInstance();Boolean labelAd = instance.endsWith(",ad.");instance = instance.replace(",ad.", "");instance = instance.replace(",nonad.", "");inst.setX(getCoeffs(instance));inst.setLabel(labelAd);return inst;}public double [] getCoeffs(String coeffs){coeffs = coeffs.replaceAll("\$", "");coeffs = coeffs.replaceAll("\$", "");String [] tmpAr = coeffs.split(",");double [] result = new double [tmpAr.length];for(int i=0; i< tmpAr.length; i++)    result[i] = Double.parseDouble(tmpAr[i].trim());     return result;}

Finally, the Logistic class is given below:

package Classifiers;import java.util.ArrayList;import java.util.Collections;import BinaryInstance;public class Logistic {private double[] alpha;private double lambda;private Utils u;private double logLikelihood;private double prevLogLikelihood;private ArrayList<Binaryinstance> allData;private int datasetIterations;private int maxDatasetIterations;private double tolerance;/**** @param alpha*            : weights* @param lambda*            : learning rate* @param data*            : dataset (instances to vector of values)* @param maxDatasetIterations* @param tolerance*/public de(double[] alpha, double lambda,ArrayList<string>  data,int maxDatasetIterations, double tolerance) {this.alpha = alpha;this.lambda = lambda;prevLogLikelihood = Double.POSITIVE_INFINITY;logLikelihood = 0;allData = new ArrayList<Binaryinstance>();u = new Utils();for (String row : data) {BinaryInstance instance = u.rowDataToInstance(row);allData.add(instance);}datasetIterations = 0;this.maxDatasetIterations = maxDatasetIterations;this.tolerance = tolerance;}public void classifyByInstance() {while (evaluateCondition()) {prevLogLikelihood = logLikelihood;datasetIterations++;Collections.shuffle(allData);for (BinaryInstance instance : allData) {double probPositive = estimateProbs(instance.getX());double label = (instance.getLabel() == true) ? 1 : 0;adjustWeights(instance.getX(), probPositive, label);}logLikelihood = calculateLogL(allData);}}private boolean evaluateCondition() {return (Math.abs(logLikelihood - prevLogLikelihood) > tolerance && datasetIterations < maxDatasetIterations) ? true     : false;  }    private double estimateProbs(double[] x) {    double sum = alpha[0];   for (int i = 1; i < this.alpha.length; i++)    sum += this.alpha[i] * x[i - 1];   double exponent = Math.exp(-sum);   double probPositive = 1 / (1 + exponent);   if (probPositive == 0)    probPositive = 0.00001;   else if (probPositive == 1)    probPositive = 0.9999;    return probPositive;  }   private void adjustWeights(double[] x, double probPositive, double label) {                 //for the intercept   this.alpha[0] += this.lambda * (label - probPositive);    for (int i = 1; i < this.alpha.length; i++) {    this.alpha[i] += this.lambda * x[i - 1] * (label - probPositive);   }  }    private double calculateLogL(ArrayList<Binaryinstance> allData) {double logl = 0;for (BinaryInstance instance : allData) {double probPositive = estimateProbs(instance.getX());double label = (instance.getLabel() == true) ? 1 : 0;double probNegative = 1 - probPositive;double tmp = label * Math.log(probPositive) + (1 - label)* Math.log(probNegative);logl += tmp;}return logl;}}