CUDA implementation of Hierarchical Navigable Small World Graph algorithm
Efficient CUDA implementation of Hierarchical Navigable Small World (HNSW) graph algorithm for Approximate Nearest Neighbor (ANN)
This project is to speed up HNSW algorithm by CUDA. I expect that anyone who will be interested in this project might be already familiar with the following paper and the open source project. If not, I strongly recommend that you check them first.
I also adapted some ideas from the following project.
By brief survey, I found there are several papers and projects to suggest to speed up ANN algorithms by GPU.
I started this project because I was originally interested in both CUDA programming and ANN algorithms. I release this project because it achieved meaningful performance and hope to develop further by community participation.
Literally, this package is implemented to build HNSW graphs using GPU, and to approximate nearest neighbor search through the built graphs, and the format of the model file is compatible with hnswlib. In other words, you can build a HNSW graph from this package, then save it and load it from hnswlib for search, and vice versa.
pip install cuhnsw
# clone repo and submodules
git clone [email protected]:js1010/cuhnsw.git && cd cuhnsw && git submodule update --init
# install requirements
pip install -r requirements.txt
# generate proto
python -m grpc_tools.protoc --python_out cuhnsw/ --proto_path cuhnsw/proto/ config.proto
# install
python setup.py install
examples/example1.py
and examples/README.md
will be very helpful to understand the usage.import h5py
from cuhnsw import CuHNSW
h5f = h5py.File("glove-50-angular.hdf5", "r")
data = h5f["train"][:, :].astype(np.float32)
h5f.close()
ch0 = CuHNSW(opt={})
ch0.set_data(data)
ch0.build()
ch0.save_index("cuhnsw.index")
import h5py
from cuhnsw import CuHNSW
h5f = h5py.File("glove-50-angular.hdf5", "r")
data = h5f["test"][:, :].astype(np.float32)
h5f.close()
ch0 = CuHNSW(opt={})
ch0.load_index("cuhnsw.index")
nns, distances, found_cnt = ch0.search_knn(data, topk=10, ef_search=300)
cuhnsw/proto/config.proto
)
seed
: numpy random seed (used in random levels)c_log_level
: log level in cpp logging (spdlog)py_log_level
: log level in python loggingmax_m
: maximum number of links in layers higher than ground layermax_m0
: maximum number of links in the ground layerlevel_mult
: multiplier to draw levels of each element (defualt: 0 => setted as 1 / log(max_m0)
in initialization as recommended in hnsw paper)save_remains
: link to remained candidates in SearchHeuristic (adapted from n2)heuristic_coff
: select some closest candidates by default (also adapted from n2)hyper_threads
: set the number of gpu blocks as the total number of concurrent cores exceeds the physical number of coresblock_dim
: block dimension (should be smaller than 32^2=1024 and should be the multiple of 32)nrz
: normalize data vector if Truevisited_table_size
: size of table to store the visited nodes in each searchvisited_list_size
: size of list to store the visited nodes in each search (useful to reset table after each search)reverse_cand
: select the candidate with the furthest distance if True (it makes the build slower but achieves better quality)dist_type
: euclidean distance if "l2" and inner product distaance if "dot"Batch
search (i.e. processing large number of queries at once.)example/example1.py
.ef_construction
=150 for hnswlib and ef_construction=110
for cuhnsw to achieve the same build qualityef_search
=300)attr | 1 vcpu | 2 vcpu | 4 vcpu | 8 vcpu | gpu |
---|---|---|---|---|---|
build time | 343.909 | 179.836 | 89.7936 | 70.5476 | 8.2847 |
build quality | 0.863193 | 0.863301 | 0.863238 | 0.863165 | 0.865471 |
quality
is guaranteed to the same (exact match)attr | 1 vcpu | 2 vcpu | 4 vcpu | 8 vcpu | gpu |
---|---|---|---|---|---|
search time | 556.605 | 287.967 | 146.331 | 115.431 | 29.7008 |
update: measured 1M queries search time for cpu-only instance (c5.24xlarge, 96 vcpu, 4.08 USD / hr): 22.4642 sec
the reason why the parallel efficiency significantly drops from 4 vcpu to 8 vcpu might be hyper threading (there might be only 4 "physical" cores in this instance).