A Group Symmetric Stochastic Differential Equation Model for Molecule Multi-modal Pretraining, ICML'23
ICML 2023
Shengchao Liu+, Weitao Du+, Zhiming Ma, Hongyu Guo, Jian Tang
+ Equal contribution
[Project Page] [Paper] [ArXiv] [Checkpoints on HuggingFace]
All the pretrained checkpoints are available on this HuggingFace link.
You can find detailed mapping between checkpoints and tables in file README_checkpoints.md
.
conda create -n Geom3D python=3.7
conda activate Geom3D
conda install -y -c rdkit rdkit
conda install -y numpy networkx scikit-learn
conda install -y -c conda-forge -c pytorch pytorch=1.9.1
conda install -y -c pyg -c conda-forge pyg=2.0.2
pip install ogb==1.2.1
pip install sympy
pip install ase # for SchNet
pip intall -e .
data/PCQM4Mv2/raw
:
.
├── data
│ └── PCQM4Mv2
│ └── raw
│ ├── data.csv
│ ├── data.csv.gz
│ ├── pcqm4m-v2-train.sdf
│ └── pcqm4m-v2-train.sdf.tar.gz
examples/generate_PCQM4Mv2.py
.data/molecule_datasets/QM9
.data/MD17
. .
├── data
│ ├── molecule_datasets
│ │ ├── bace
│ │ │ ├── BACE_README
│ │ │ └── raw
│ │ │ └── bace.csv
│ │ ├── bbbp
...............
A quick demo on pretraining is:
cd examples
python pretrain_MoleculeSDE.py \
--verbose --input_data_dir=../data --dataset=PCQM4Mv2 \
--model_3d=SchNet \
--lr=1e-4 --epochs=50 --num_workers=0 --batch_size=256 --SSL_masking_ratio=0 --gnn_3d_lr_scale=0.1 --dropout_ratio=0 --graph_pooling=mean --emb_dim=300 --epochs=1 \
--SDE_coeff_contrastive=1 --CL_similarity_metric=EBM_node_dot_prod --T=0.1 --normalize --SDE_coeff_contrastive_skip_epochs=0 \
--SDE_coeff_generative_2Dto3D=1 --SDE_2Dto3D_model=SDEModel2Dto3D_02 --SDE_type_2Dto3D=VE --use_extend_graph \
--SDE_coeff_generative_3Dto2D=1 --SDE_3Dto2D_model=SDEModel3Dto2D_node_adj_dense --SDE_type_3Dto2D=VE --noise_on_one_hot \
--output_model_dir=[MODEL_DIR]
Notice that the [MODEL_DIR]
is where you are going to save your models/checkpoints.
The downstream scripts can be found under the examples
folder. Below we illustrate few simple examples.
finetune_MoleculeNet.py
:
python finetune_MoleculeNet.py \
--dataset=tox21 \
--input_model_file=[MODEL_DIR]/model_complete.pth
finetune_QM9.py
:
python finetune_QM9.py \
--dataset=QM9 --task=gap \
--model_3d=SchNet \
--input_model_file=[MODEL_DIR]/model_complete.pth
finetune_MD17.py
:
python finetune_MD17.py \
--dataset=MD17 --task=aspirin \
--model_3d=SchNet \
--input_model_file=[MODEL_DIR]/model_complete.pth
Feel free to cite this work if you find it useful to you!
@inproceedings{liu2023group,
title={A group symmetric stochastic differential equation model for molecule multi-modal pretraining},
author={Liu, Shengchao and Du, Weitao and Ma, Zhi-Ming and Guo, Hongyu and Tang, Jian},
booktitle={International Conference on Machine Learning},
pages={21497--21526},
year={2023},
organization={PMLR}
}