From 44d5456d0a36027dcee4237cd695d5ceb248414f Mon Sep 17 00:00:00 2001 From: Israel Figueroa Date: Tue, 10 Dec 2024 01:05:03 -0300 Subject: [PATCH] middle --- trainer.py | 35 ++++++++++++++++++++++++++++++----- 1 file changed, 30 insertions(+), 5 deletions(-) diff --git a/trainer.py b/trainer.py index c2989f3..4d80380 100644 --- a/trainer.py +++ b/trainer.py @@ -655,17 +655,42 @@ class BinaryTuner: data=X_model, feature_names=label_columns) - # -# shap.plots.initjs() shap.plots.decision(exp.base_values[0], exp.values, features=label_columns, show=False) -# shap.plots.force(exp.base_values, exp.values, feature_names=Xbase.columns, show=False) -# shap.plots.force(exp.base_values[0], exp.values[0, :], feature_names=Xbase.columns, matplotlib=True, show=False) -# shap.plots.force(expected_values[0], shap_values.values, Xbase.columns , show=False) plt.title(r"{0}".format(modelname)) plt.ylabel("Respuesta del Modelo: 0 Negativo, 1 Positivo") plt.savefig("{}/shap_{}_{}_{}.png".format(self.name, modelname, dataset, seed),dpi=150, bbox_inches='tight') plt.close() + y_pred = model.predict(X_model) + # make a numpy array from y_pred where all the values > 0.5 become 1 and all remaining values are 0 + if type_of_target(y_pred) == "continuous": + y_pred = np.where(y_pred > 0.5, 1, 0) + + X_pos = X_model[y_pred == 1] + shap_values = explainer(X_pos) + exp = shap.Explanation(shap_values, + data=X_pos, + feature_names=label_columns) + + shap.plots.decision(exp.base_values[0], exp.values, features=label_columns, show=False) + plt.title(r"{0}".format(modelname)) + plt.ylabel("Respuesta del Modelo Positivas") + plt.savefig("{}/shap_pos_{}_{}_{}.png".format(self.name, modelname, dataset, seed),dpi=150, bbox_inches='tight') + plt.close() + + + X_pos = X_model[y_pred == 0] + shap_values = explainer(X_pos) + exp = shap.Explanation(shap_values, + data=X_pos, + feature_names=label_columns) + + shap.plots.decision(exp.base_values[0], exp.values, features=label_columns, show=False) + plt.title(r"{0}".format(modelname)) + plt.ylabel("Respuesta del Modelo Negativas") + plt.savefig("{}/shap_pos_{}_{}_{}.png".format(self.name, modelname, dataset, seed),dpi=150, bbox_inches='tight') + plt.close() + shap_values = explainer(X_explain)