Top-KAST: Top-K Always Sparse Training

Created on 2024-02-17T18:49:32-06:00

Return to the Index

This card pertains to a resource available on the internet.

This card can also be read via Gemini.

In tasks using large networks--only a small subset of the neurons are genuinely needed to exhibit a high degree of accuracy in predictions.

References RigL as a successful implementation of fully sparse learning--while remarking that it is difficult to integrate in to existing frameworks.

Talks about using "masks" to control which parameters are involved in a sparse operation.

Dense to Sparse training: training a fully dense neural network, followed by using a distillation process to remove parts of the network which are not necessary.

Sparsity is controlled by S, some real number between 0 and 1 for the percentage of the network that should be active in the sparse version of the run.

Exploratory stage: start with a dense neural network but a subset of each layer. The "mask" will be changed from time to time, in case a better layer exists somewhere.

Refinement stage: the mask is no longer changed, the neurons selected for that layer become permanent and are trained.

Computing the mask can be delayed--the paper only updates every 100 steps--which loosens contention on hardware.

Disabling the exploration phase appears to greatly harm performance. It is important to select random subsets of layers for some time, accumulating gradients and training them, and then after many steps (paper mentins 5,000) finally settling on the winners and locking the layer configuration from that point on.

Forward/backward sparsity: these are controlled separately. Forward sparsity controls how many nodes compute a value, and backward sparsity is supposed to spread the gradient back to uninvolved nodes. I think they do this so when the next set of random nodes is selected the nodes have been pushed a little and allowed a better chance to turn out to be important.

Quinn note: up to 80% of a large neural network is doing, largely, nothing. This is derived from Top-KAST being able to recreate transformer models with significantly less neurons active just by having them wired in a more correct way.