Tensorflow implementation of the Restricted Boltzmann Machine
Tensorflow implementation of Restricted Boltzmann Machine for layer-wise pretraining of deep autoencoders.
This is a fork of a Michal Lukac repository with some corrections and improvements.
The Restricted Boltzmann Machine is a legacy machine learning model that should not be used for any real-world applications. This repository is of historical and educational value only. I have updated the code using the TensorFlow 2 to run on modern systems, but I will no longer maintain it.
git clone https://github.com/meownoid/tensorfow-rbm.git
cd tensorfow-rbm
python -m pip install -r requirements.txt
python setup.py
Bernoulli-Bernoulli RBM is good for Bernoulli-distributed binary input data. MNIST, for example. To train the model, simply construct the tf.data.Dataset
containing vectors of shape (n_visible,)
and pass it to the fit
method.
import tensorflow as tf
from tfrbm import BBRBM
(x_train, _), (x_test, _) = tf.keras.datasets.mnist.load_data()
x_train = x_train / 255
x_test = x_test / 255
dataset = tf.data.Dataset.from_tensor_slices(x_train.reshape(-1, 28 * 28))
dataset = dataset.shuffle(1024, reshuffle_each_iteration=True)
rbm = BBRBM(n_visible=28 * 28, n_hidden=64)
rbm.fit(dataset, epoches=100, batch_size=10)
Now you can use reconstruct
method to check the model performance.
# x_tensor: (1, n_visible)
# x_reconstructed_tensor: (1, n_visible)
x_reconstructed_tensor = rbm.reconstruct(x_tensor)
Full example code can be found in the examples/mnist.py
.
BBRBM(n_visible, n_hidden, learning_rate=0.01, momentum=0.95)
GBRBM(n_visible, n_hidden, learning_rate=0.01, momentum=0.95, sample_visible=False, sigma=1.0)
Initializes Bernoulli-Bernoulli RBM or Gaussian-Bernoulli RBM.
n_visible
— number of visible neurons (input size)n_hidden
— number of hidden neuronsOnly for GBRBM
:
sample_visible
— sample reconstructed data with Gaussian distribution (with reconstructed value as a mean and a sigma
parameter as deviation) or not (if not, every gaussoid will be projected into a single point)sigma
— standard deviation of the input dataAdvices:
0
to 1
.0
mean and sigma
standard deviation. Normalize input data if necessary.rbm.fit(dataset, epoches=10, batch_size=10)
Trains the model and returns a list of errors.
dataset
— tf.data.Dataset
composed of tensors of shape (n_visible,)
epoches
— number of epochesbatch_size
— batch size, should be as small as possiblerbm.step(x)
Performs one training step and returns reconstruction error.
x
– tensor of shape (batch_size, n_visible)
rbm.compute_hidden(x)
Computes hidden state from the input.
x
– tensor of shape (batch_size, n_visible)
rbm.compute_visible(hidden)
Computes visible state from hidden state.
x
– tensor of shape (batch_size, n_hidden)
rbm.reconstruct(x)
Computes visible state from the input. Reconstructs data.
x
– tensor of shape (batch_size, n_visible)