Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Saving BERT Sentence Embedding

I'm currently working on an information retrieval task. I'm using SBERT to perform a semantic search. I already follows the documentation here

The model i use

model = SentenceTransformer('sentence-transformers/paraphrase-multilingual-mpnet-base-v2')

The outline is

  1. You have a list of corpus like this:
    data = ['A man is eating food.',
          'A man is eating a piece of bread.',
          'The girl is carrying a baby.',
          'A man is riding a horse.',
          'A woman is playing violin.',
          'Two men pushed carts through the woods.',
          'A man is riding a white horse on an enclosed ground.',
          'A monkey is playing drums.',
          'A cheetah is running behind its prey.'
          ]
  1. You have a query like this:
queries = ['A man is eating pasta.']
  1. Perform encoding with both query and corpus
query_embedding = model.encode(query)
doc_embedding = model.encode(data)

the encode function outputs a numpy.ndarray like this outputs of model.encode(data)

  1. And calculates the similarity using cosine similarity like this
similarity = util.cos_sim(query_embedding, doc_embedding)
  1. And if you print the similarity, you'll get the torch.Tensor containing score of similarity like this
tensor([[0.4389, 0.4288, 0.6079, 0.5571, 0.4063, 0.4432, 0.5467, 0.3392, 0.4293]])

And it works fine and fast. But ofcourse it is only using a small amount of corpus. When using a large amount of corpus it will take time for the encoding to work.

note: The encoding of query takes no time because it is only one sentence, but the encoding of the corpus will take some time

So, the question is can we save the doc_embedding locally, and use it again? especially when using a large corpus

is there any built-in class/function to do it from the transformers?

like image 295
Muhammad Hafidz Alfarizi Avatar asked Oct 28 '25 07:10

Muhammad Hafidz Alfarizi


1 Answers

Save them as pickle files and load them later = ]

import pickle

with open('doc_embedding.pickle', 'wb') as pkl:
    pickle.dump(doc_embedding, pkl)

with open('doc_embedding.pickle', 'rb') as pkl:
    doc_embedding = pickle.load(pkl)
like image 114
Kevin Choi Avatar answered Nov 01 '25 05:11

Kevin Choi