Robustness Analysis (Classification)#

This example demonstrates how to analyze model robustness for classification problems using various methods and metrics.

Installation

# To install the required package, use the following command:
# !pip install modeva

Authentication

# To get authentication, use the following command: (To get full access please replace the token to your own token)
# from modeva.utils.authenticate import authenticate
# authenticate(auth_code='eaaa4301-b140-484c-8e93-f9f633c8bacb')

Import required modules

from modeva import DataSet
from modeva import TestSuite
from modeva.models import MoLGBMClassifier
from modeva.models import MoXGBClassifier
from modeva.testsuite.utils.slicing_utils import get_data_info

Load and prepare dataset

ds = DataSet()
ds.load(name="TaiwanCredit")
ds.set_random_split()

Train models

model1 = MoXGBClassifier()
model1.fit(ds.train_x, ds.train_y)

model2 = MoLGBMClassifier(max_depth=2, verbose=-1, random_state=0)
model2.fit(ds.train_x, ds.train_y.ravel())
MoLGBMClassifier(boosting_type='gbdt', class_weight=None, colsample_bytree=1.0,
                 importance_type='split', learning_rate=0.1, max_depth=2,
                 min_child_samples=20, min_child_weight=0.001,
                 min_split_gain=0.0, n_estimators=100, n_jobs=None,
                 num_leaves=31, objective=None, random_state=0, reg_alpha=0.0,
                 reg_lambda=0.0, subsample=1.0, subsample_for_bin=200000,
                 subsample_freq=0, verbose=-1)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


Basic robustness analysis#

ts = TestSuite(ds, model1)
results = ts.diagnose_robustness(perturb_features=("PAY_1", "EDUCATION",),
                                 noise_levels=(0.1, 0.2, 0.3, 0.4),
                                 metric="AUC")
results.table
Noise Level 0.0 0.1 0.2 0.3 0.4
Repeats
0 0.7662 0.7433 0.7414 0.7411 0.7420
1 0.7662 0.7454 0.7445 0.7448 0.7456
2 0.7662 0.7439 0.7438 0.7434 0.7412
3 0.7662 0.7427 0.7435 0.7421 0.7424
4 0.7662 0.7466 0.7450 0.7444 0.7437
5 0.7662 0.7464 0.7445 0.7448 0.7460
6 0.7662 0.7428 0.7418 0.7414 0.7417
7 0.7662 0.7455 0.7456 0.7446 0.7443
8 0.7662 0.7439 0.7443 0.7435 0.7406
9 0.7662 0.7418 0.7434 0.7393 0.7405


Box plot of robustness performance

results.plot(figsize=(6, 5))


Analyze data drift between small and large prediction changes groups

data_results = ds.data_drift_test(**results.value[0.2]["data_info"],
                                  distance_metric="PSI",
                                  psi_method="uniform",
                                  psi_bins=10)
data_results.plot("summary")


Analyze data drift for single variable

data_results.plot(("density", "PAY_1"))


Slicing robustness analysis#

Single feature slicing

results = ts.diagnose_slicing_robustness(features="PAY_1",
                                         perturb_features=("PAY_1", "EDUCATION",),
                                         noise_levels=0.1,
                                         metric="AUC",
                                         method="auto-xgb1",
                                         threshold=0.7)
results.plot()


Analyze data drift for a specific feature

data_info = get_data_info(res_value=results.value)
data_results = ds.data_drift_test(**data_info["PAY_1"],
                                  distance_metric="PSI",
                                  psi_method="uniform",
                                  psi_bins=10)
data_results.plot("summary")


Single feature density plot

data_results.plot(("density", "PAY_1"))


Bivariate feature slicing

results = ts.diagnose_slicing_robustness(features=("PAY_1", "PAY_2"),
                                         perturb_features=("PAY_1", "EDUCATION",),
                                         noise_levels=0.1,
                                         metric="AUC",
                                         threshold=0.7)
results.table
Feature1 Segment1 Feature2 Segment2 Size AUC Threshold Weak
6 PAY_1 [-1.00, -0.10) PAY_2 [4.40, 5.30) 2 0.0000 0.7 True
25 PAY_1 [0.80, 1.70) PAY_2 [3.50, 4.40) 6 0.0000 0.7 True
54 PAY_1 [3.50, 4.40) PAY_2 [2.60, 3.50) 11 0.1200 0.7 True
34 PAY_1 [1.70, 2.60) PAY_2 [2.60, 3.50) 12 0.3750 0.7 True
30 PAY_1 [1.70, 2.60) PAY_2 [-1.00, -0.10) 22 0.4615 0.7 True
... ... ... ... ... ... ... ... ...
95 PAY_1 [7.10, 8.00] PAY_2 [3.50, 4.40) 0 NaN 0.7 False
96 PAY_1 [7.10, 8.00] PAY_2 [4.40, 5.30) 0 NaN 0.7 False
97 PAY_1 [7.10, 8.00] PAY_2 [5.30, 6.20) 0 NaN 0.7 False
98 PAY_1 [7.10, 8.00] PAY_2 [6.20, 7.10) 4 NaN 0.7 False
99 PAY_1 [7.10, 8.00] PAY_2 [7.10, 8.00] 0 NaN 0.7 False

100 rows × 8 columns



Batch mode single feature slicing

results = ts.diagnose_slicing_robustness(features=(("PAY_1",), ("PAY_2",), ("PAY_3",)),
                                         perturb_features=("PAY_1", "EDUCATION",),
                                         noise_levels=0.1,
                                         perturb_method="quantile",
                                         metric="AUC",
                                         threshold=0.7)
results.table
Feature Segment Size AUC Threshold Weak
0 PAY_2 [4.40, 5.30) 4 0.1000 0.7 True
1 PAY_2 [3.50, 4.40) 14 0.3245 0.7 True
2 PAY_3 [5.30, 6.20) 5 0.3750 0.7 True
3 PAY_3 [3.50, 4.40) 13 0.4975 0.7 True
4 PAY_2 [2.60, 3.50) 56 0.5444 0.7 True
5 PAY_1 [1.70, 2.60) 513 0.5773 0.7 True
6 PAY_3 [6.20, 7.10) 4 0.6000 0.7 True
7 PAY_1 [0.80, 1.70) 752 0.6138 0.7 True
8 PAY_3 [2.60, 3.50) 49 0.6257 0.7 True
9 PAY_2 [-1.00, -0.10) 1188 0.6257 0.7 True
10 PAY_3 [-1.00, -0.10) 1191 0.6326 0.7 True
11 PAY_1 [-1.00, -0.10) 1125 0.6472 0.7 True
12 PAY_1 [-0.10, 0.80) 3521 0.6539 0.7 True
13 PAY_1 [3.50, 4.40) 14 0.6800 0.7 True
14 PAY_2 [1.70, 2.60) 773 0.6814 0.7 True
15 PAY_3 [1.70, 2.60) 740 0.7077 0.7 False
16 PAY_1 [2.60, 3.50) 69 0.7081 0.7 False
17 PAY_2 [-0.10, 0.80) 3956 0.7157 0.7 False
18 PAY_3 [-0.10, 0.80) 3995 0.7349 0.7 False
19 PAY_2 [0.80, 1.70) 3 1.0000 0.7 False
20 PAY_1 [4.40, 5.30) 1 NaN 0.7 False
21 PAY_1 [5.30, 6.20) 1 NaN 0.7 False
22 PAY_1 [6.20, 7.10) 0 NaN 0.7 False
23 PAY_1 [7.10, 8.00] 4 NaN 0.7 False
24 PAY_2 [5.30, 6.20) 1 NaN 0.7 False
25 PAY_2 [6.20, 7.10) 4 NaN 0.7 False
26 PAY_2 [7.10, 8.00] 1 NaN 0.7 False
27 PAY_3 [0.80, 1.70) 1 NaN 0.7 False
28 PAY_3 [4.40, 5.30) 1 NaN 0.7 False
29 PAY_3 [7.10, 8.00] 1 NaN 0.7 False


Batch mode 1D Slicing (all features by setting features=None)

results = ts.diagnose_slicing_robustness(features=None,
                                         perturb_features=("PAY_1", "EDUCATION",),
                                         noise_levels=0.1,
                                         perturb_method="quantile",
                                         metric="AUC",
                                         threshold=0.7)
results.table
Feature Segment Size AUC Threshold Weak
0 PAY_4 [5.40, 6.20) 3 0.0000 0.7 True
1 PAY_6 [4.40, 5.30) 2 0.0000 0.7 True
2 PAY_2 [4.40, 5.30) 4 0.1000 0.7 True
3 BILL_AMT2 [-4.42, -3.38) 9 0.2250 0.7 True
4 BILL_AMT3 [-4.79, -3.72) 6 0.2889 0.7 True
... ... ... ... ... ... ...
204 PAY_6 [6.20, 7.10) 8 NaN 0.7 False
205 PAY_6 [7.10, 8.00] 1 NaN 0.7 False
206 BILL_AMT1 [-4.03, -3.03) 14 NaN 0.7 False
207 BILL_AMT3 [0.57, 1.64) 2 NaN 0.7 False
208 BILL_AMT4 [0.57, 1.64) 1 NaN 0.7 False

209 rows × 6 columns



Robustness comparison#

tsc = TestSuite(ds, models=[model1, model2])

Compare resilience performance of multiple models

results = tsc.compare_robustness(perturb_features=("PAY_1", "EDUCATION",),
                                 noise_levels=(0.1, 0.2, 0.3, 0.4),
                                 perturb_method="quantile",
                                 metric="AUC")
results.plot(figsize=(6, 5))


Compare robustness performance of multiple models under single slicing feature

results = tsc.compare_slicing_robustness(features="PAY_1", noise_levels=0.1,
                                         method="quantile", metric="AUC")
results.plot()


Total running time of the script: (0 minutes 22.228 seconds)

Gallery generated by Sphinx-Gallery