Monday, January 2, 2023

optimization 13 - using the gradient sign changes

 When I've previously experimented with FISTA momentum descent, one thing that worked quite well was detecting when some dimension of the gradient changes sign, and then resetting the momentum in this dimension to 0.

One of the typical problems of the momentum methods is that they tend to "circle the drain" around he minimum. Think of one of those coin collection bins for charity where you get a coin rolling down the trough, and then it circles the "gravity well" of the bin for quite a while before losing momentum and descending into the center hole. This happens because the speed (momentum) of the coin is initially directed at an angle from the minimum (the center hole). And the same happens with the momentum descent in optimization, the momentum usually develops at an angle, and just as the current point gets it close to the minimum, the momentum carries it by and away to the other side of the "well", where it will eventually reverse direction and come back, hopefully this time closer to the minimum. But there is a clear indication that we're getting carried past the minimum: the sign of the gradient in the direction where we're getting carried past changes. So if we kill the momentum in this direction at this point, we won't get carried past. It helps a good deal with the quicker convergence.

An overshot is not the only reason why a gradient's dimension might be changing sign. Our "virtual coin" might be rolling down a muti-dimensional trough,  oscillating a little left and right in this trough. But there killing the momentum wouldn't hurt either, it would just dampen these oscillations, which is also a good thing.

Which gave me idea that once we have this, there is no point to the gradual reduction of momentum that is embedded into FISTA through the parameter t. If the momentum reduction on overshot gets taken care off as described above, there is no point in shrinking the momentum otherwise.

And then I've thought of applying the same logic to estimating the training rate. I've been thinking about the Adam method that I've linked to in a recent post, and it doesn't really solve the issue with the training rate. It adjusts the rate between the different dimensions but it still has in it a constant that essentially represents the global training rate. In their example they had just come up with this constant empirically, but it would be different for different problems, and how do we find it? It's the same problem as finding the simple raining rate. The sign change detection to the rescue.

After all, what happens when the too-high rate starts tearing the model apart? The rate that is too high causes an algorithm step to overshoot the minimum. And not just overshoot but overshoot it so much that the gradient grows. So on the next step it overshoots back even more, and the gradient grows again, and so on, and so on, getting farther from the minimum on each step. And the momentum methods tend to exacerbate this problem.

So if we detect the dimensions that change sign, and see if the gradient in them grows, and by how much, and how it compares with the gradient in the dimensions that didn't change sign, we'd be able to detect the starting resonance and dampen it by reducing the training rate. I've tried it, and it works very well, check out https://sourceforge.net/p/triceps/code/1792/tree/, it's the FloatNeuralNet option autoRate2_. So far the tunables for it are hardcoded, and I think that I've set them a bit too conservatively, but all together it works very well, producing a little faster training than I've seen before, and without the meltdowns. 

Another thing I've changed in the current version is the logic that pushes the rarely-seen unusual cases to be boosted for a better recognition. It previously didn't work well with the momentum methods, because it was changing the direction of the gradient drastically between the passes. I've changed it to make the boosting more persistent between the passes, and instead of shrinking the gradients of the correct cases, to gradually grow the representation of the incorrect cases, expanding their gradients. It's still a work in progress but looks promising.

Oh, and BTW, one thing that didn't work out was the attempt to boost similarly the cases that give the correct answer by having the highest output point to the right digit but do this at a very low confidence, so that even the best output is below 0 (and sometimes substantially below 0, something like -0.95). All I could do was shrink the percentage of such cases slightly, from about 17.5% to about 16.5%. And I'm  not sure what can be done about it. I guess it's just another manifestation of a great variability of handwritten characters. Maybe it could be solved by vastly growing the model size and the training set, but even if it could, it would be nice to find some smarter way. Perhaps a better topological representation of the digits instead of a plain bitmap would do the trick but I don't know how to do it. One of the theories I've had was that it's caused by a natural trend towards negative numbers, because in each training case we have one output with 1 and nine outputs with -1. So it we changed the negatives to say -0.1, that would pull the numbers higher. But that's not a solution either, it just moves the average up, diluting strength of the negatives.

No comments:

Post a Comment