I’m trying to compute the AUC score for a multiclass problem using the sklearn’s roc_auc_score() function.
I have prediction matrix of shape [n_samples,n_classes] and a ground truth vector of shape [n_samples], named np_pred
and np_label
respectively.
What I’m trying to achieve is the set of AUC scores, one for each classes that I have.
To do so I would like to use the average
parameter option None
and multi_class
parameter set to "ovr"
, but if I run
roc_auc_score(y_score=np_pred, y_true=np_label, multi_class="ovr",average=None)
I get back
ValueError: average must be one of ('macro', 'weighted') for multiclass problems
This error is expected from the sklearn function in the case of the multiclass; but if you take a look at the roc_auc_score
function source code, you can see that if the multi_class
parameter is set to "ovr"
, and the average is one of the accepted one, the multiClass case is treated as a multiLabel one and the internal multiLabel function accepts None
as average
parameter.
So, by looking at the code, it seems that I should be able to execute a multiclass with a None
average in a One vs Rest
case but the if
s in the source code do not allow such combination.
Am I wrong?
In case I’m wrong, from a theoretical point of view should I fake a multilabel case just to have the different AUCs for the different classes or should I write my own function that cycles the different classes and outputs the AUCs?
Thanks
Advertisement
Answer
As you already know, right now sklearn
multiclass ROC AUC only handles the macro
and weighted
averages. But it can be implemented as it can then individually return the scores for each class.
Theoretically speaking, you could implement OVR
and calculate per-class roc_auc_score
, as:
roc = {label: [] for label in multi_class_series.unique()} for label in multi_class_series.unique(): selected_classifier.fit(train_set_dataframe, train_class == label) predictions_proba = selected_classifier.predict_proba(test_set_dataframe) roc[label] += roc_auc_score(test_class, predictions_proba[:,1])