I’m trying to plot the below summary metric plot using plotly.
data
JavaScript
x
17
17
1
Model F1_Score Precision Recall Accuracy ROC_AUC CV_Score
2
0 LogisticRegression 0.815068 0.777778 0.856115 0.739130 0.678058 0.752876
3
1 K-NearestNeighbors 0.828767 0.790850 0.870504 0.758454 0.699958 0.714476
4
2 SVM 0.852459 0.783133 0.935252 0.782609 0.702920 0.665067
5
3 GaussianProcess 0.825503 0.773585 0.884892 0.748792 0.677740 0.665067
6
4 MLP 0.774436 0.811024 0.741007 0.710145 0.694033 0.735327
7
5 DecisionTree 0.747170 0.785714 0.712230 0.676329 0.657586 0.692216
8
6 ExtraTrees 0.859060 0.805031 0.920863 0.797101 0.732490 0.792698
9
7 RandomForest 0.826667 0.770186 0.892086 0.748792 0.673984 0.778324
10
8 XGBoost 0.838488 0.802632 0.877698 0.772947 0.718261 0.764025
11
9 AdaBoostClassifier 0.800000 0.780822 0.820144 0.724638 0.674778 0.728927
12
10 GBClassifier 0.835017 0.784810 0.892086 0.763285 0.696043 0.754451
13
11 CatBoost 0.843854 0.783951 0.913669 0.772947 0.699482 0.768787
14
12 Stacking 0.833333 0.776398 0.899281 0.758454 0.684934 0.787949
15
13 Voting 0.836120 0.781250 0.899281 0.763285 0.692287 0.778337
16
14 Bagging 0.855263 0.787879 0.935252 0.787440 0.710273 0.792673
17
JavaScript
1
33
33
1
import plotly.graph_objects as go
2
3
mark_color = ['rgba(246, 78, 139, 0.6)', 'rgba(58, 71, 80, 0.6)', 'rgba(50, 171, 96, 0.6)', 'rgba(38, 24, 74, 0.6)', 'rgba(155, 83, 109, 0.6)', 'rgba(297, 55, 74, 0.6)']
4
line_color = ['rgba(246, 78, 139, 1.0)', 'rgba(58, 71, 80, 1.0)', 'rgba(50, 171, 96, 1.0)', 'rgba(38, 24, 74, 1.0)', 'rgba(155, 83, 109, 1.0)', 'rgba(297, 55, 74, 1.0)']
5
6
7
y_labels = ["F1_Score", "Precision", "Recall", "Accuracy", "ROC_AUC", "CV_Score"]
8
9
fig = go.Figure()
10
11
for i, j in enumerate(y_labels):
12
fig.add_trace(go.Bar(
13
y=y_labels,
14
x=list(scores[j].values),
15
name=j,
16
orientation='h',
17
marker=dict(
18
color=mark_color[i]
19
20
)
21
))
22
23
fig.update_layout(
24
barmode='stack',
25
title="Summary Metrics",
26
xaxis_title="Metric Value",
27
yaxis_title="Metric Name",
28
legend_title="Model",
29
30
)
31
32
fig.show()
33
So far, I’m able to plot this
I’m unable to add Model Names to the plot. How add Model column as Legend and add all model values into the plot?
Advertisement
Answer
- shape the data frame first
df2 = df.set_index("Model").unstack().to_frame().reset_index()
- then it’s a simple case of using Plotly Express
JavaScript
1
59
59
1
import pandas as pd
2
import io
3
import plotly.express as px
4
5
df = pd.read_csv(
6
io.StringIO(
7
"""Model F1_Score Precision Recall Accuracy ROC_AUC CV_Score
8
0 LogisticRegression 0.815068 0.777778 0.856115 0.739130 0.678058 0.752876
9
1 K-NearestNeighbors 0.828767 0.790850 0.870504 0.758454 0.699958 0.714476
10
2 SVM 0.852459 0.783133 0.935252 0.782609 0.702920 0.665067
11
3 GaussianProcess 0.825503 0.773585 0.884892 0.748792 0.677740 0.665067
12
4 MLP 0.774436 0.811024 0.741007 0.710145 0.694033 0.735327
13
5 DecisionTree 0.747170 0.785714 0.712230 0.676329 0.657586 0.692216
14
6 ExtraTrees 0.859060 0.805031 0.920863 0.797101 0.732490 0.792698
15
7 RandomForest 0.826667 0.770186 0.892086 0.748792 0.673984 0.778324
16
8 XGBoost 0.838488 0.802632 0.877698 0.772947 0.718261 0.764025
17
9 AdaBoostClassifier 0.800000 0.780822 0.820144 0.724638 0.674778 0.728927
18
10 GBClassifier 0.835017 0.784810 0.892086 0.763285 0.696043 0.754451
19
11 CatBoost 0.843854 0.783951 0.913669 0.772947 0.699482 0.768787
20
12 Stacking 0.833333 0.776398 0.899281 0.758454 0.684934 0.787949
21
13 Voting 0.836120 0.781250 0.899281 0.763285 0.692287 0.778337
22
14 Bagging 0.855263 0.787879 0.935252 0.787440 0.710273 0.792673"""
23
),
24
sep="s+",
25
)
26
27
df2 = df.set_index("Model").unstack().to_frame().reset_index()
28
29
fig = px.bar(
30
df2,
31
y="level_0",
32
x=0,
33
color="Model",
34
color_discrete_map={
35
"LogisticRegression": "#2E91E5",
36
"K-NearestNeighbors": "#E15F99",
37
"SVM": "#1CA71C",
38
"GaussianProcess": "#FB0D0D",
39
"MLP": "#DA16FF",
40
"DecisionTree": "#222A2A",
41
"ExtraTrees": "#B68100",
42
"RandomForest": "#750D86",
43
"XGBoost": "#EB663B",
44
"AdaBoostClassifier": "#511CFB",
45
"GBClassifier": "#00A08B",
46
"CatBoost": "#FB00D1",
47
"Stacking": "#FC0080",
48
"Voting": "#B2828D",
49
"Bagging": "#6C7C32",
50
},
51
)
52
53
fig.update_layout(
54
title="Summary Metrics",
55
xaxis_title="Metric Value",
56
yaxis_title="Metric Name",
57
legend_title="Model",
58
)
59