Topic clustering library built on Transformer embeddings and cosine similarity metrics.Compatible with all BERT base transformers from huggingface.
This is a topic clustering library built with transformer embeddings and analysing cosine similarity between them. The topics are clustered either by kmeans or agglomeratively depending on the use case, and the embeddings are attained after propagating through any of the Transformers present in HuggingFace.The library can be found here.
Installation is carried out using the pip command as follows:
pip install ClusterTransformer==0.1
For using inside the Jupyter Notebook or Python IDE:
import ClusterTransformer.ClusterTransformer as ct
The 'ClusterTransformer_test.py' file contains an example of using the Library in this context.
The steps to operate this library is as follows:
Initialise the class: ClusterTransformer() Provide the input list of sentences: In this case, the quora similar questions dataframe has been taken for experimental purposes. Declare hyperparameters:
Call the methods:
The code steps provided in the tab below, represent all the steps required to be done for creating the clusters. The 'compute_topics' method has the following steps:
%%time
import ClusterTransformer.ClusterTransformer as cluster_transformer
def compute_topics(transformer_name):
#Instantiate the object
ct=cluster_transformer.ClusterTransformer()
#Transformer model for inference
model_name=transformer_name
#Hyperparameters
#Hyperparameters for model inference
batch_size=500
max_seq_length=64
convert_to_numpy=False
normalize_embeddings=False
#Hyperparameters for Agglomerative clustering
neighborhood_min_size=3
cutoff_threshold=0.95
#Hyperparameters for K means clustering
kmeans_max_iter=100
kmeans_random_state=42
kmeans_no_clusters=8
#Sub input data list
sub_merged_sent=merged_set[:200]
#Transformer (Longformer) embeddings
embeddings=ct.model_inference(sub_merged_sent,batch_size,model_name,max_seq_length,normalize_embeddings,convert_to_numpy)
#Hierarchical agglomerative detection
output_dict=ct.neighborhood_detection(sub_merged_sent,embeddings,cutoff_threshold,neighborhood_min_size)
#Kmeans detection
output_kmeans_dict=ct.kmeans_detection(sub_merged_sent,embeddings,kmeans_no_clusters,kmeans_max_iter,kmeans_random_state)
#Agglomerative clustering
neighborhood_detection_df=ct.convert_to_df(output_dict)
#KMeans clustering
kmeans_df=ct.convert_to_df(output_kmeans_dict)
return neighborhood_detection_df,kmeans_df
Calling the driver code:
%%time
import matplotlib.pyplot as plt
n_df,k_df=compute_topics('bert-large-uncased')
kg_df=k_df.groupby('Cluster').agg({'Text':'count'}).reset_index()
ng_df=n_df.groupby('Cluster').agg({'Text':'count'}).reset_index()
#Plotting
fig,(ax1,ax2)=plt.subplots(1,2,figsize=(15,5))
rng = np.random.RandomState(0)
s=1000*rng.rand(len(kg_df['Text']))
s1=1000*rng.rand(len(ng_df['Text']))
ax1.scatter(kg_df['Cluster'],kg_df['Text'],s=s,c=kg_df['Cluster'],alpha=0.3)
ax1.set_title('Kmeans clustering')
ax1.set_xlabel('No of clusters')
ax1.set_ylabel('No of topics')
ax2.scatter(ng_df['Cluster'],ng_df['Text'],s=s1,c=ng_df['Cluster'],alpha=0.3)
ax2.set_title('Agglomerative clustering')
ax2.set_xlabel('No of clusters')
ax2.set_ylabel('No of topics')
plt.show()
Cluster Images ( Created With Facebook BART)
Pull requests are welcome. For major changes, please open an issue first to discuss what you would like to change.
MIT