Tensorflow implementation of DeepFM for CTR prediction.
This project includes a Tensorflow implementation of DeepFM [1].
This implementation requires the input data in the following format:
Please see example/DataReader.py
an example how to prepare the data in required format for DeepFM.
import tensorflow as tf
from sklearn.metrics import roc_auc_score
# params
dfm_params = {
"use_fm": True,
"use_deep": True,
"embedding_size": 8,
"dropout_fm": [1.0, 1.0],
"deep_layers": [32, 32],
"dropout_deep": [0.5, 0.5, 0.5],
"deep_layers_activation": tf.nn.relu,
"epoch": 30,
"batch_size": 1024,
"learning_rate": 0.001,
"optimizer_type": "adam",
"batch_norm": 1,
"batch_norm_decay": 0.995,
"l2_reg": 0.01,
"verbose": True,
"eval_metric": roc_auc_score,
"random_seed": 2017
}
# prepare training and validation data in the required format
Xi_train, Xv_train, y_train = prepare(...)
Xi_valid, Xv_valid, y_valid = prepare(...)
# init a DeepFM model
dfm = DeepFM(**dfm_params)
# fit a DeepFM model
dfm.fit(Xi_train, Xv_train, y_train)
# make prediction
dfm.predict(Xi_valid, Xv_valid)
# evaluate a trained model
dfm.evaluate(Xi_valid, Xv_valid, y_valid)
You can use early_stopping in the training as follow
dfm.fit(Xi_train, Xv_train, y_train, Xi_valid, Xv_valid, y_valid, early_stopping=True)
You can refit the model on the whole training and validation set as follow
dfm.fit(Xi_train, Xv_train, y_train, Xi_valid, Xv_valid, y_valid, early_stopping=True, refit=True)
You can use the FM or DNN part only by setting the parameter use_fm
or use_dnn
to False
.
This implementation also supports regression task. To use DeepFM for regression, you can set loss_type
as mse
. Accordingly, you should use eval_metric for regression, e.g., mse or mae.
Folder example
includes an example usage of DeepFM/FM/DNN models for Porto Seguro's Safe Driver Prediction competition on Kaggle.
Please download the data from the competition website and put them into the example/data
folder.
To train DeepFM model for this dataset, run
$ cd example
$ python main.py
Please see example/DataReader.py
how to parse the raw dataset into the required format for DeepFM.
[1] DeepFM: A Factorization-Machine based Neural Network for CTR Prediction, Huifeng Guo, Ruiming Tang, Yunming Yey, Zhenguo Li, Xiuqiang He.
This project gets inspirations from the following projects:
MIT