Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Neural networks and large data sets

I have a basic framework for a neural network to recognize numeric digits, but I'm having some problems with training it. My back-propogation works for small data sets, but when I have more than 50 data points, the return value starts converging to 0. And when I have data sets in the thousands, I get NaN's for costs and returns.

Basic structure: 3 layers: 784 : 15 : 1

784 is the number of pixels per data set, 15 neurons in my hidden layer, and one output neuron which returns a value from 0 to 1 (when you multiply by 10 you get a digit).

public class NetworkManager {
    int inputSize;
    int hiddenSize;
    int outputSize;
    public Matrix W1;
    public Matrix W2;

    public NetworkManager(int input, int hidden, int output) {
        inputSize = input;
        hiddenSize = hidden;
        outputSize = output;
        W1 = new Matrix(inputSize, hiddenSize);
        W2 = new Matrix(hiddenSize, output);
    }

    Matrix z2, z3;
    Matrix a2;
    public Matrix forward(Matrix X) {
        z2 = X.dot(W1);
        a2 = sigmoid(z2);

        z3 = a2.dot(W2);
        Matrix yHat = sigmoid(z3);

        return yHat;
    }

    public double costFunction(Matrix X, Matrix y) {
        Matrix yHat = forward(X);

        Matrix cost = yHat.sub(y);
        cost = cost.mult(cost);

        double returnValue = 0;
        int i = 0;
        while (i < cost.m.length) {
            returnValue += cost.m[i][0];
            i++;
        }
        return returnValue;
    }

    Matrix yHat;
    public Matrix[] costFunctionPrime(Matrix X, Matrix y) {

        yHat = forward(X);

        Matrix delta3 = (yHat.sub(y)).mult(sigmoidPrime(z3));
        Matrix dJdW2 = a2.t().dot(delta3);

        Matrix delta2 = (delta3.dot(W2.t())).mult(sigmoidPrime(z2));
        Matrix dJdW1 = X.t().dot(delta2);

        return new Matrix[]{dJdW1, dJdW2};
    }
}   

There's the code for network framework. I pass double arrays of length 784 into the forward method.

    int t = 0;
    while (t < 10000) {
        dJdW = Nn.costFunctionPrime(X, y);

        Nn.W1 = Nn.W1.sub(dJdW[0].scalar(3));
        Nn.W2 = Nn.W2.sub(dJdW[1].scalar(3));

        t++;
    }

I call this to adjust the weights. With small sets, the cost converges to 0 pretty well, but larger sets don't (the cost associated with 100 characters converges to 13, always). And if the set is too large, the first adjustment works (and costs go down) but after the second all I can get is NaN.

Why does this implementation fail with larger data sets (specifically training) and how can I fix this? I tried a similar structure with 10 outputs instead of 1 where each would return a value near 0 or 1 acting like boolean values, but the same thing was happening.

I'm also doing this in java by the way, and I'm wondering if that has something to do with the problem. I was wondering if it was a problem with running out of space but I haven't been getting any heap space messages. Is there a problem with how I'm back-propogating or is something else happening?

EDIT: I think I know what's happening. I think my backpropogation function is getting caught in local minimums. Sometimes the training succeeds and sometimes it fails for large data sets. Because I'm starting with random weights, I get random initial costs. What I've noticed is that when the cost initially exceeds a certain amount (it depends on the number of datasets involved), the costs converge to a clean number (sometimes 27, others 17.4) and the outputs converge to 0 (which makes sense).

I was warned about relative minimums in the cost function when I began, and I'm beginning to realize why. So now the question becomes, how do I go about my gradient descent so that I'll actually find the global minimum? I'm working in Java by the way.

like image 525
Jean Valjean Avatar asked Nov 17 '25 20:11

Jean Valjean


1 Answers

This seems like a problem with weight initialization.

As far as i can see you never initialize the weights to any specific value. Therefore the network diverges. You should at least use random initialization.

like image 90
Thomas Pinetz Avatar answered Nov 19 '25 09:11

Thomas Pinetz