Given a set of \(N\) observations
coming from a true (solid blue) but unknown function
Given a set of \(N\) observations
coming from a true (solid blue) but unknown function
Approximate the function from the observations
Linear models (using the basis of standard Polynomials)
Neural Networks (MLP)
Support Vector Machines (SVM)
taking a linear combination of the bases of (a standard) Polynomial function \((1,x,x^2,\cdots,x^n)\)
Approximate the function from the observations
Given a set of \(N\) observations
where
Basis : Global (fixed)
coming from a true (solid blue) but unknown function
Look for bases that are local or adaptive to the training data
Bases are global functions of input variable (a small change around the region of a data point changes the approximated function globally )
Suffers a lot from Curse of Dimensionality (COD)
For example,
Image source: PRML, Bishop
Here, we let the neural network to learn the complex feature transformation \(\phi(x)\) via adaptive basis functions
With \(N=50\) data points
Learned Bases (Activation functions) of a neuron in a hidden layer with 3 units
Neuron-1
Neuron-2
Neuron-3
Image source: PRML, Bishop
Here, we let the neural network to learn the complex feature transformation \(\phi(x)\) via adaptive basis functions
With \(N=50\) data points
Learned Bases (Activation functions) of a neuron in a hidden layer with 3 units
Neuron-2
Neuron-3
Neuron-1
Ultimately, it is all about combining the activation functions to approximates the true function at observed points
Image source: PRML, Bishop
UAT: With One (hidden) layer NN
We can approximate any continuous function to desired precision \(\epsilon\)
For a single neuron
Here the activation (basis) function \(\sigma\) is Fixed
The adaptability to the data is due to the learnable parameters \(w\)
Here, we let the neural network to learn the complex feature transformation \(\phi(x)\) via adaptive basis functions
Image source: DL Lectures by Mitesh Khapra
Interpretability is difficult with the vast number of learnable weights \(w\)
Continual learning without re-training is difficult
Catastrophic forgetting due to fixed and non-local activation functions
Sparsification and pruning are difficult due to large number of parameters
KAN overcomes (at least in toy settings) all of these
Consider a multivariate function \(f(\mathbf{x})\) with \(\mathbf{x} \in \mathbb{R}^n\), \((x_1,x_2,\cdots,x_n)\)
Bold cliam: Any continuous smooth multivariate function \(f(\mathbf{x})\) can be represented using \(n\) univariate functions \(\phi_p(x_p), p=1,2,\cdots,n\)
Is \(\phi_p(x_p)\) fixed or learnable? Learnable
How do we combine the output from the \(n\) univariate functions to get \(f(\mathbf{x})\)?
Sum the outputs from all univariate functions
In reality, we need \((n(2n+1))\) \(\phi\)s and (\(2n+1 \)) \(\Phi_q\)s, totaly \(2n^2+3n+1\) learnable univarate functions (not just \(n\) as claimed in the previous slide)
Well, How do we make a learnable 1D function?
A function in the basis \(B\)
\(c_i \in \mathbb{R}\): spline coefficeints (control points)
\(B_i(x)\): The basis function of order \(k\) (degree \(d=k-1\))
\(t=(t_{-k},\cdots,t_{G+k})\): Uniformly spaced knots (think of this as a uniformly spaced grid of points defined by \(G\))
\(c_i \)'s are learnable parameters (initialized randomly based on the problem)
Basis support:
Spline width = \((k+1)*h=1.6\)
Increasing the value of \(G\) decreases \(h\) and hence the width of the splines ( sort of controlling the resolution)
offset by starting value of x
The grid extends on both sides with the basis function
Each* learnable weight \(w_{ijk}\) in MLP is replaced by an learnable activation \(\phi_{ijk}\) as shown in the figure right
Each node sums the incoming values
*Not replacing bias (as it is not function of input)
A term in the outer summation
KAN-Layer-1
\(n \to 2n+1\)
KAN-Layer-2
\(2n+1 \to 1\)
Number of parameters:
*Not replacing bias (as it is not function of input)
15 (ignoring bias)
For each function
Therefore, 15*(5+3)=120
Number of coefficients: \(c=G+k\)
What more parameters than MLP?
Number of parameters:
*Not replacing bias (as it is not function of input)
15 (ignoring bias)
15*(8)=120
What more parameters than MLP?
However, often, it requires less number of KAN layers than MLP Layers
\(2n^2+3n+1\) learnable functions are sufficient if we are able to find the correct \(\phi(x)\) that represents \(\mathbf{f(x)}\).
But, \(\mathbf{f(x)}\) is not known. Moreover, we choose smooth B-splines (among other non-smooth alternatives like fractals) to learn \(\phi(x)\)
Therefore, to get a better approximation, one can build KAN with arbitrary width and depth ( replacing weights by learnable activations)
The author mentions "To the best of our knowledge, there is not yet a “generalized” version of the theorem that corresponds to deeper KANs"
However, they do derive the approximation theory for KAT assuming \(k-th\) order B-splines for all \(\Phi\)s.
Optimization: Backpropagation
Optimization : Backpropagation, (LBFGS for small networks)
Training tricks: Residual connection
Training tricks:
(silu),\(w\) learnable
Initialization: Typically, Xavier
Initialization: Xavier for \(w\), \(N(0,0.1)\) for \(c_is\)
Loss: \(L \propto N^{-\alpha}\)
The higher the value of \(\alpha\) for the model parameters \(N\), the lower the loss. We can get quick improvement by scaling the model parameters
How do we predict \(\alpha\)?
Different theories: relating to intricate dimensionality of data \(d\) (need to ponder over this)
, as a function of the class activations.
Function class: Order \(k\) of piecewise polynomial like ReLU ( \(k=1\)), (Maxout?)
For KAN: B-Splines with order \(k=3\) gives \(\alpha = 4\) (independent of \(d\))
Increasing the number of parameters in KAN reduces the loss quickly than MLP (albeit with more computations)
Uniform Grid intervals (knots)
Splines can adjust its resolution (like wavelets)
Coarse Resolution
Fine Resolution
One can increase the grid size (therefore number of parameters) on the fly during training (without re-training from scratch)
As per scaling law for KAN, the loss \(L \propto N^{-4}\) should reduces quickly
Unique to KAN
by varying the Grid size \(G\)
Validated in a toy setting
Uniform Grid intervals (knots)
Splines can adjust its resolution (like wavelets)
by varying the Grid size \(G\)
Target
An estimator
Use least squares to solve this
Uniform Grid intervals (knots)
Splines can adjust its resolution (like wavelets) with the Grid size \(G\)
One can increase the grid size (therefore number of parameters) on the fly during training (without re-training from scratch)
Unique to KAN
As per scaling law for KAN, the loss \(L \propto N^{-4}\) should reduces quickly
Uniform Grid intervals (knots)
Splines can adjust its resolution (like wavelets) with the Grid size \(G\)
One can increase the grid size (therefore number of parameters) on the fly during training (without re-training from scratch)
As per scaling law for KAN, this should reduces the loss quickly
Unique to KAN
KAN trains the network continually without catastrophic forgetting (Caution: not trained and tested on NLP tasks)
Continual Learning
External DoF: Composition (like MLP) of layers
Internal DoF: Grid size (Unique to KAN)
Given the samples from known function structure, the network learns the structure via sparsification and pruning.
Since we do not have linear weights. The \(L_1\) norm is computed for each KAN layer \(\Phi()\)
Here the target function contains three (activation) functions (exp,sin,parabola)
However, the KAN layers have 15 activations
Can it only learn those three out of 15 activation functions?
Given the samples from known function structure, the network learns the structure via sparsification and pruning.
Since we do not have linear weights. The \(L_1\) norm is computed on \(\Phi()\) matrices.
Norm for an activation is given by
Norm for the layer \(\Phi\)
Given the samples from known function structure, the network learns the structure via sparsification and pruning.
Since we do not have linear weights. The \(L_1\) norm is computed on \(\Phi()\) matrices.
Given the samples from known function structure, the network learns the structure via sparsification and pruning.
Since we do not have linear weights. The \(L_1\) norm is computed on \(\Phi()\) matrices.
Suppose the function structure is unkown
After training, we end up with this simple structure
Then we could infer that the target function contains only three univariate functions
Humans are good at multi-tasking without fogetting any of the tasks when learning a new task
human brains have functionally distinct modules placed locally in space
Most artificial neural networks, including MLPs, do not have this notion of locality, which is probably the reason for catastrophic forgetting.
KANs have local plasticity and can avoid catastrophic forgetting by leveraging the locality of splines (a sample will only affect a few nearby spline coefficients, leaving far-away coefficients intact)
Let's validate this via a simple 1D regression task composed of 5 Gaussian peaks
Data around each peak is presented sequentially (instead of all at once)
Data around each peak is presented sequentially (instead of all at once)
Data around each peak is presented sequentially (instead of all at once)
Slow to train.
KANs are usually 10x slower than MLPs, given the same number of parameters
Findings are validated using Toy experimental settings that are closely related to maths and physics.
They left it to engineers to optimize (efficient implementations have already started emerging)
A line segment is parameterized by \(t\) (convex combination of the control points (\(c_0,c_1\)))
and
A line segment
A quadradic curve
Control points : \(c_0,c_1,c_2\)
t : \(t_0,t_1,t_2\) , \(t_0<t_1<t_2\)
Construct two line segments
Then the curve is given by
A quadradic curve
Control points : \(c_0,c_1,c_2\)
t : \(t_0,t_1,t_2\) , \(t_0<t_1<t_2\)
Construct two line segments
Then the curve is given by
A quadradic curve
Control points : \(c_0,c_1,c_2\)
t : \(t_0,t_1,t_2\) , \(t_0<t_1<t_2\)
Construct two line segments
Then the curve is given by
unlike for line segment, choice of \(t\) (parameters) affects the shape of curve
We need to be carefult about the choice of intervals to get a smooth curves
Solution?
Control points \(n=3\): \(c_i, i=0,1,2\)
t : \(t_0,t_1,t_2\)
The interval is restricted to [0,1] (convex combination)
FInal curve is the convex combination of above
Substitute the expresssion for \(p_{0,1},p_{1,2}\)
It is smooth
We can generalize this to degree \(d\) and \(n\) control points
Where \(B_{i,d}(t)\) is a basis function for \(c_i\)
For example, it is a Bernstein Polynomials for Bezier curves
Quick Check for \(n=3, d=n-1\)
We can generalize this to degree \(d\) and \(n+1\) control points
In the case of B-Splines, it is a called Blending functions described by order \(k=d+1\) with a non-decreasing sequence of numbers called knots \(t_i: i=0,1,2,\cdots n+k\)
Let us try to understand this with illustrations (borrowed from Kenneth)
Let us start with first order \(k=1\) constant \(d=0\) piecewise polynomial functions
B-Splines \(B_{i,0}(t)\), knots \(t_i\) are uniformly spaced in the interval [0,1].
\(B_{0,0}(t)\)
\(B_{1,0}(t)\)
The first degree (\(d=1\)) spline is obtained by combining the first order splines
The function for the \(i-th\) control point (or knot) is simply a shifted version of \(B_{0,d}(t)\) by \(i\) units
Image source: Lectures by Kenneth I Joy
B-Splines \(B_{i,1}(t)\), knots \(t_i\) are uniformly spaced in the interval [0,1].
Note the increase in support for the function as the order increases
The quadratic (third order) spline is obtained by combining the second order splines
\(B_{0,1}(t)\)
Image source: Lectures by Kenneth I Joy
B-Splines \(B_{i,2}(t)\), knots \(t_i\) are uniformly spaced in the interval [0,1].
\(B_{0,2}(t)\)
\(B_{1,2}(t)\)
Image source: Lectures by Kenneth I Joy
Now treat \(c_i \in \mathbb{R}\) as a real number
Now, the curve \(P(t)\) is learnable via \(c_i\)
The number of control points \(c_i\) and the order \(k=d+1\) are independent.
Now, \(c_i\) is called B-Spline coefficient
For rigours definition and proofs: here