Skip to content

Fast Kolmogorov-Arnold Network in JAX, initial experiments

Notifications You must be signed in to change notification settings

stergiosba/kanx

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

6 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

KANX: Fast Implementation (Approximation) of Kolmogorov-Arnold Network in JAX

Work in progress

Introduction

Fast Kolmogorov-Arnold Network in JAX based on fast-kan using equinox.

The original implementation of KAN is pykan.

Installation

pip install .
pip install -r requirements.txt

Example

KANX comes with an example on MNIST:

python examples/train_mnist.py

Benchmark

We tested the implementation on MNIST and report the following wall-time for 3000 epochs:

Architecture Wall time (sec)
CPU (i5-1135G7) 130.51
CPU (i9-12900K) 67.85
GPU (RTX 3070 Ti) 13.55

Plots from the GPU experiment:

mlp_kan_compare

mlp_kan_compare

More experiments to come...

About

Fast Kolmogorov-Arnold Network in JAX, initial experiments

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages