PyTorch Implementation of "Distilling a Neural Network Into a Soft Decision Tree." Nicholas Frosst, Geoffrey Hinton., 2017.
This is the pytorch implementation on Soft Decision Tree (SDT), appearing in the paper "Distilling a Neural Network Into a Soft Decision Tree". 2017 (https://arxiv.org/abs/1711.09784).
To run the demo on MNIST, simply use the following commands:
git clone https://github.com/AaronX121/Soft-Decision-Tree.git
cd Soft-Decision-Tree
python main.py
Parameter | Type | Description |
---|---|---|
input_dim | int | The number of input dimensions |
output_dim | int | The number of output dimensions (e.g., the number of classes for multi-class classification) |
depth | int | Tree depth, the default is 5 |
lamda | float | The coefficient of the regularization term, the default is 1e-3 |
use_cuda | bool | Whether use GPU to train / evaluate the model, the default is False |
0
when the absolute value of input is large.After training for 40 epochs with batch_size
128, the best testing accuracy using a SDT model of depth 5, 7 are 94.15 and 94.38, respectively (which is much close to the accuracy reported in raw paper). Related hyper-parameters are available in main.py
. Better and more stable performance can be achieved by fine-tuning hyper-parameters.
Below are the testing accuracy curve and training loss curve. The testing accuracy of SDT is evaluated after each training epoch.
SDT is originally developed in Python 3.6.5
. Following are the name and version of packages used in SDT. In my practice, it works fine under different versions of Python or PyTorch.