Adaptive wavelet distillation from neural networks through interpretations
Created on 2023-02-21T18:14:10-06:00
Wavelet transforms are implemented using the "filter bank" model where signals are recursively split in to high and low noise segments based on the chosen wavelet curve.
In this paper the wavelet function is also learned rather than hand selected.
Starts with a deep neural network which has been trained on some prediction task; this has been tested on LSTM networks, simple deep networks, and ResNets.
Uses "Transformation Importance" (TRIM) to determine what the most important coefficients of the wavelet are.
"Wavelet Loss" is used to pressure the system to learn a valid, invertible, wavelet transform. An invertible wavelet is important so data is not lost in the transform which is important because AWD is intended to help understand *why* a network did what it did.
"Interpretation Loss" is used to pressure the system to learn a sparse wavelet. The majority of important coefficients (via TRIM) should be clustered together.
K-winner take all keeps the top 6 most important coefficients of each scale; used to identify which features are responsible for the network's prediction accuracy and then train a new network based on only those features.
Resulted in a "200-fold" reduction of computing vs using LSTMs. (Though note that this is probably because the problem inherently didn't need the recurrent memory.)