I've just come back from the NeurIPS 2024 conference. With Triceps listed as my organization. I've done a project with mentoring my family member at JHU, and well, since it's not work-related, using my day job would be wrong, and I'm not at JHU either (and I've paid for attending the conference myself, JHU turned out to be too cheap for that). So Triceps it is, especially that Triceps is used in the project. And that's, by the way, exactly the project I've mentioned before as an application of Triceps, just I couldn't tell the details before it got published. Well, now we've had a poster session at NeurIPS (which people tell me is kind of a big deal, I've just had fun), and the paper is out: "realSEUDO for real-time calcium imaging analysis" at https://neurips.cc/virtual/2024/poster/94683 or https://www.researchgate.net/publication/380895097_realSEUDO_for_real-time_calcium_imaging_analysis or https://arxiv.org/abs/2405.15701. Here is what it does, in simple words.
The calcium imaging is a brain imaging technique, in live (although not exactly unharmed) brain. People have bred a strain of mice whose neurons light up when they activate. So they take a mouse, open up its skull, and stick a camera with a microscope on it. This allows to record a movie of a living brain where you can see what is going on in it literally cell by cell, the brightness of the cells corresponding to their activation level.
Once you get the movie, you need to analyze it: figure out where the cells are, and their activation levels. The "golden standard" of the detection algorithm is CNMF. It solves an optimization problem (in the mathematical sense, i.e. finding the minimum of a function that represents the error), for two mappings at once: what pixels belong to what cell, and what cells activate at what level in each frame (activation profile). More exactly, it deals not with separate pixels but with gaussian kernels. "Kernel" in the math speak is any repeatable element (so for another example, "CUDA kernels" are the CUDA functions that are repeated across the GPU), but in this case it's a sprite that gets stenciled repeatedly, centered on each pixel. "Gaussian" is the shape of the sprite, a 2-dimensional representation of the normal distribution where the brightness corresponds to the weight, so it's an approximation of a circle that is brightest in the center and then smoothly goes to nothing at the edges.
The trouble is that CNMF works only on the whole movie after it has been collected. You can't use it to decode frame-by-frame in real time because it needs to see all the frames at once. So you can't identify cells and then immediately use this knowledge to alter the stimulus and see what effect this has on the brain. There is another algorithm, OnACID that does frame-by-frame decoding using a variation of CNMF logic, except that it requires a starting period of 100 frames or so to get the initial cells, and then the quality of identifying the cells is much worse than with CNMF, and it's still not very fast, substantially slower than the 30 frames per second even on an 80-CPU machine we used.
In the same area but in a little differentt niche, SEUDO is the JHU professor's previous project, a technique to reduce noise in the activation profiles. It can be used together with the other algorithms. Aside from the random noise, there is an inherent cross-talk: the neurons are stacked in multiple layers and can have very long dendrites that can overlap with many other neurons, adding little weak light spots that get registered as weak activations on these other neurons. So the idea there is that to recognize the activations we both use the known cell shapes and also fill the image with gaussian kernels centered on every pixel, and solve the minimization problem on all of them, then throw away the values for gaussian kernels, thus discarding the cross-talk from unknown sources. The little light spots that don't fit well with the known cells get attributed to these gaussian kernels, and so the cell activation profiles become smoother. The problem is that it's slow, and quickly becomes slower as you increase the size of the kernels. If you take a kernel of 30x30 pixels, you get extra 900 coefficients in the quadratic equation you're trying to minimize (and that's for every kernel, which gets centered on every pixel).
The goal was to get this whole thing to work in real time (i.e. at least 30 fps), and we've achieved it. It works in real time (even including SEUDO, which is really an optional extra processing on top), and produces the quality that is generally way better than OnACID, and fairly close to CNMF. Well, arguably, at least sometimes better than CNMF but the problem is that there is no existing accepted way to rate "better or worse" (the unscientific method is by eyeballing the small fragments of the movies), the accepted way of rating is "same or different from CNMF". I'll talk a bit more about it later, in a separate post.
So the work really has 5 parts:
- The optimization (in programming sense) and parallelization in C++ of the optimization (in math sense) algorithm
- Improvements to the optimization (in the math sense) algorithms
- An off-shoot of trying to apply some of these improvements to optimization algorithms, and more, to the neural network training problem (the things I did in Triceps)
- Improvements to the SEUDO algorithm
- The completely new logic for automatically finding the cells in the movies (that got the name of realSEUDO)
The first part is kind of straightforward (read my book :-) with only few things worth mentioning. First, outdoing Matlab and the TFOCS package in it is not easy, it's already well optimized and parallelized, but doable. There are two keys to it. One is to match what TFOCS does: represent the highly sparse matrix of coefficients by "drawing" the sprites and thus regenerating the matrix on the fly whenever it gets used, otherwise the matrix just won't fit into the memory. The computation of a gradient requires two passes of drawing. TFOCS does it by generating two complementary drawing functions. We use two passes with one drawing function but computing different sums: the first pass goes by columns and computes the intermediate sums (common subexpressions) that get stored, then the second pass goes by rows and computes the final sums from the intermediate ones. The two passes with stored intermediate values reduce the complexity from O(n^3) to O(n^2) compared to doing everything in one pass. The second key was to make the pixel drawing function into a template rather than a function call, because the overhead of a function call in the inner loop has a higher cost than the rest of computation. And another thing that needs mentioning is that to pass data between Matlab and C++ code you have to use Matlab's C API. Their C++ API is extremely slow, adding much more overhead than the computation itself.
But that's still not fast enough. Then it was the turn of improving the optimization (in math sense) algorithms. There are three: FISTA which is a momentum gradient descent on top of ISTA which is the logic for selecting the step size for the simple gradient descent on top of LASSO which is the idea of adding a bias to the function being optimized to drive the dimensions with small values towards 0. The amount of this bias is adjusted by the parameter lambda (wait, we'll come to its effect yet).
I've started writing up the LASSO-ISTA-FISTA summary but even in a summary form it got quite long. So I'll put it into a separate follow-up post. The summary-of-a-summary is that a weird but serendipitous experiment gave me an idea of treating the gradient dimensions separately, and that led to the idea of the momentum stopping on overshoot by dimension. The momentum gradient descent is kind of like a ball rolling down a well, observed with a stroboscope: each step of a descent is like the passage of time to the next strobe flash. So the next time we see the ball, it might already be past the minimum and the momentum carrying it up the opposite wall. Eventually it will stop, go back, and likely overshoot the minimum again, in the other direction. In the real world the ball eventually stops because of friction that drains its energy. In the momentum descent the friction is simulated by gradually reducing the momentum coefficient (which means growing friction). The trouble is, you don't know in advance, how much friction to put in. With too little friction, the ball will bounce around a lot, with too much friction it will be rolling in molasses. So people do things like resetting the momentum coefficient after some large number of steps. But if we look at each dimension, there is an easy indication of an overshoot: the gradient in that dimension changes sign. So at that time we can just immediately kill the momentum, and we don't even have to kill the whole momentum, we can kill it in just that one dimension. And this worked really well, to the point that the friction became completely unnecessary. The other place where the sudden stopping helps is when the variables try to go out out of range (for example, in our case we have the boundary condition that each dimension is non-negative).
This worked very well for a smooth (well, mostly, except for the creases added by LASSO bias, but then again the creases are at the sign change boundaries, and we always stay on the positive side) quadratic function. Would it work on a much more jagged functions used in the training of neural networks? This is what I've tried in the NN code in Triceps, and as it turns out, the answer is sort of yes, at least in certain limited ways, as I wrote at length before in the other posts, and the fraction of dimensions experiencing the stopping can also be used to auto-adjust the training step size (i.e. the training schedule - the existing algorithms like Adam rely on preset constants, here it adjusts automatically). It's currently been applied to plain gradient descent only, and not beating the stochastic gradient descent in the current form, but on the other hand, it does make the plain gradient descent with momentum not only work at all but work not that much worse than the stochastic descent. Check out the figure with NN training speed comparison in the appendix. Now at NeurIPS I've learned things that gave me more ideas for marrying this logic with the stochastic descent, so there are more things to try in the future.
For SEUDO, the trick has turned out to be the sparsity. It still works decently well when the kernels are placed less frequently, up to about 1/3 of the diameter. So with the sprites 30x30 you get a slow-down by a factor of 900, but place these sprites in a grid with the step of 10 pixels, and you get the factor of 100 back. It would probably be even better to place them on a triangular grid (basically, the same grid but with offset on alternating rows, and different horizontal and vertical step sizes to form the equilateral triangles of neighboring centers), to keep the distances between the centers even in every direction, but we haven't got around to do that (yet?).
And the way to make the detection of the cell images in the video (that's what realSEUDO is, for which incidentally the plain SEUDO is optional) fast was to do it with the old-fashioned image processing. It goes in two stages (well, and a noise reduction pre-stage, and it also expects that the image is already stabilized): the first stage finds the contiguous spots above the background brightness, then adds time as the third dimension and builds the sausage-like shapes of the cell activations. They are pretty jagged at the ends, as the cells start to light up in little spots, then these spots gradually expand and merge, and going dark again looks similar but in the opposite direction. The sausage gets eventually squashed along the time axis to build the reference cell image. The trouble is, there are multiple layers of cells in the brain, with overlapping cells sometimes activating at the same time. So the second stage is used to figure out these overlaps between the cells: when do we see the same cell but at a different brightness level, when do we see a completely different overlapping cell, and when both (or more) overlapping cells are lighting up at the same time. So it merges and splits the cell images accordingly, and generates the events telling what it's doing. At the same time it tries to fit the known cells into the current frame and generates the events for the cell activations (Triceps is used for reporting the events as a more efficient way to collect or forward them than Matlab). The splitting/merging code is very important but purely empirical, with a few important tunable coefficients chosen by trial and error at the moment. It's a tour-de-force of empiric engineering beating science. But perhaps some merging of both approaches is possible, maybe using the detected shapes as a starting point for CNMF-like further optimization. If the starting point is already very close to the optimum, the short path to the optimum would be fast (as we see in reusing the detected activation coefficients from the previous frame as the starting point for the current frame).
And there are other little things, like partitioning a large image into segments and processing them in parallel, and then splicing the cell shapes and activation profiles back together (this is not a new idea, CNMF does it too). Since all the high-level logic has been done in Matlab, the parallelism is a bit weirdly shaped by Matlab's abilities: there is parallelism on top with the partitioning by segments and on the bottom with the C++ code running FISTA but there is a good deal of parallelism opportunities in the middle of the splitting/merging logic that aren't used. The trouble is basically that Matlab doesn't have shared read-only data, it starts each parallelization by an expensive copying of the whole relevant state, and then ends by copying back the result. The other reason to use the segments comes from the overhead on the cell shapes: the way Matlab works, the shapes have to be represented as image-sized matrices, quadratically increasing the overhead of all the operations as the image grows. The sparse matrices (as the C++ part of the code already does) would fix that, there is such a recent new feature in Matlab, or obviously if moving the code from Matlab to another language. BTW, these are much simpler sparse matrices than ones produced in FISTA equations, here we know that all the non-0 values are located in a small bounding box, so they're trivial to implement.
There are also things to be said about the evaluation of detected profiles (which also comes up in splicing the segments), but it's a largish subject for another post.
No comments:
Post a Comment