Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to use fine-tuned model in huggingface for actual prediction after re-loading?

I'm trying to reload a DistilBertForSequenceClassification model I've fine-tuned and use that to predict some sentences into their appropriate labels (text classification).

In google Colab, after successfully training the BERT model, I downloaded it after saving:

trainer.train()
trainer.save_model("distilbert_classification")

The downloaded model has three files: config.json, pytorch_model.bin, training_args.bin.

I moved them encased in a folder named 'distilbert_classification' somewhere in my google drive.

afterwards, I reloaded the model in a different Colab notebook:


reloadtrainer = DistilBertForSequenceClassification.from_pretrained('google drive directory/distilbert_classification')

Up to this point, I have succeeded without any errors.

However, how to I use this reloaded model (the 'reloadtrainer' object) to actually make the predictions on sentences? What is the code I need to use afterwards? I tried

reloadtrainer .predict("sample sentence") but it doesn't work. Would appreciate any help!

like image 283
Robin311 Avatar asked Dec 20 '25 17:12

Robin311


1 Answers

Remember that you also need to tokenize the input to your model, just like in the training phase. Merely feeding a sentence to the model will not work (unless you use pipelines() but that's another discussion).

You may use an AutoModelForSequenceClassification() and AutoTokenizer() to make things easier.

Note that the way I am saving the model is via model.save_pretrained("path_to_model") rather than model.save().

One possible approach could be the following (say you trained with uncased distilbert):

  model = AutoModelForSequenceClassification.from_pretrained("path_to_model")
  # Replace with whatever tokenizer you used
  tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased", use_fast=True)
  input_text = "This is the text I am trying to classify."
  tokenized_text = tokenizer(input_text,
                             truncation=True,
                             is_split_into_words=False,
                             return_tensors='pt')
  outputs = model(tokenized_text["input_ids"])
  predicted_label = outputs.logits.argmax(-1)
like image 50
Timbus Calin Avatar answered Dec 23 '25 06:12

Timbus Calin



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!