Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

scikeras.wrappers.KerasClassifier returning ValueError: Could not interpret metric identifier: loss

I was looking into KerasClassifier, as I would like to plug it in a scikit-learn pipeline, but I'm getting the aforementioned ValueError.

The following code should be able to reproduce the error I'm getting:

from sklearn.model_selection import KFold, cross_val_score
from sklearn.preprocessing import StandardScaler
from scikeras.wrappers import KerasClassifier
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from sklearn.datasets import load_iris
import numpy as np

data = load_iris()
X = data.data
y = data.target

def create_model():
    model = Sequential()
    model.add(Dense(8, input_dim=4, activation='relu'))
    model.add(Dense(3, activation='softmax'))
    model.compile(loss='sparse_categorical_crossentropy',
                  optimizer='adam',
                  metrics=['accuracy'])
    return model

clf = KerasClassifier(build_fn=create_model,
                      epochs=100,
                      batch_size=10,
                      verbose=1)

pipeline = Pipeline([
    ('scaler', StandardScaler()),
    ('clf', clf)
])

kf = KFold(n_splits=5, shuffle=True, random_state=42)
results = cross_val_score(pipeline, X, y, cv=kf)
print("Cross-Validation Accuracy:", np.mean(results))

It seems that my model is being compiled as the epochs are run. However, afterwards, I get the error:

ValueError: Could not interpret metric identifier: loss

The versions for the tensorflow and scikeras libraries are:

scikeras==0.12.0
tensorflow==2.15.0

EDIT: Eventually I experimented with different library versions and the following allowed me to run the code successfully, it seems the issue was caused by scikit-learn's version:

scikeras==0.12.0
tensorflow==2.15.0
scikit-learn==1.4.1
like image 202
Frederico Portela Avatar asked Nov 01 '25 17:11

Frederico Portela


2 Answers

Downgrading tensorflow to version 2.15 did the trick.

tensorflow==2.15
scikit-learn==1.14.post1
scikeras==0.12
like image 156
Datagniel Avatar answered Nov 04 '25 13:11

Datagniel


This is just a problem with the tensorflow version. It can be solved with tensorflow==2.15.0. It has nothing to do with scikit-learn, scikeras, and python versions.

like image 31
shadow Avatar answered Nov 04 '25 13:11

shadow