Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

cannot import name 'flash_attn_func' from 'flash_attn'

try to load llama2 model:

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map=device_map
)

with these bnb_config:

BitsAndBytesConfig {
  "bnb_4bit_compute_dtype": "bfloat16",
  "bnb_4bit_quant_type": "nf4",
  "bnb_4bit_use_double_quant": true,
  "llm_int8_enable_fp32_cpu_offload": false,
  "llm_int8_has_fp16_weight": false,
  "llm_int8_skip_modules": null,
  "llm_int8_threshold": 6.0,
  "load_in_4bit": true,
  "load_in_8bit": false,
  "quant_method": "bitsandbytes"
}

and I get this error:

RuntimeError: Failed to import transformers.models.llama.modeling_llama because of the following error (look up to see its traceback):
cannot import name 'flash_attn_func' from 'flash_attn' (/opt/conda/lib/python3.10/site-packages/flash_attn/__init__.py)

Any help would be helpful.

like image 407
Hamid K Avatar asked Oct 24 '25 14:10

Hamid K


1 Answers

I was having the same error when finetuning llama2 model, the solution will be to revert to the previous version of transformers.

pip install transformers==4.33.1 --upgrade

This should work.

like image 132
foscraft Avatar answered Oct 28 '25 00:10

foscraft



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!