ifiguero 2024-12-09 17:16:47 -03:00
parent 6d13cc2a37
commit 041f602b81
1 changed files with 5 additions and 5 deletions

View File

@ -645,19 +645,19 @@ class BinaryTuner:
# expected_value = expected_value[1] # expected_value = expected_value[1]
# shap_values = explainer.shap_values(X_test)[1] # shap_values = explainer.shap_values(X_test)[1]
self.logger.info("Columns: {}".format(Xbase.columns)) self.logger.info("Columns: {}".format(Xbase.columns))
# eng_columns = ['sex', 'family hist', 'age diag', 'BMI', 'base glu', 'glu 120','HbA1c'] # label_columns = ['sex', 'family hist', 'age diag', 'BMI', 'base glu', 'glu 120','HbA1c']
esp_columns = ['sexo', 'hist fam', 'edad diag', 'IMC', 'glu ayu', 'glu 120','A1c'] label_columns = ['sexo', 'hist fam', 'edad diag', 'IMC', 'glu ayu', 'glu 120','A1c']
explainer = shap.Explainer(model.predict, X_train, seed=seed) explainer = shap.Explainer(model.predict, X_train, seed=seed)
shap_values = explainer(X_model) shap_values = explainer(X_model)
exp = shap.Explanation(shap_values, exp = shap.Explanation(shap_values,
data=X_model, data=X_model,
feature_names=esp_columns) feature_names=label_columns)
# #
# shap.plots.initjs() # shap.plots.initjs()
shap.plots.decision(exp.base_values[0], exp.values, features=eng_columns, show=False) shap.plots.decision(exp.base_values[0], exp.values, features=label_columns, show=False)
# shap.plots.force(exp.base_values, exp.values, feature_names=Xbase.columns, show=False) # shap.plots.force(exp.base_values, exp.values, feature_names=Xbase.columns, show=False)
# shap.plots.force(exp.base_values[0], exp.values[0, :], feature_names=Xbase.columns, matplotlib=True, show=False) # shap.plots.force(exp.base_values[0], exp.values[0, :], feature_names=Xbase.columns, matplotlib=True, show=False)
# shap.plots.force(expected_values[0], shap_values.values, Xbase.columns , show=False) # shap.plots.force(expected_values[0], shap_values.values, Xbase.columns , show=False)
@ -671,7 +671,7 @@ class BinaryTuner:
exp = shap.Explanation(shap_values, exp = shap.Explanation(shap_values,
data=X_explain, data=X_explain,
feature_names=eng_columns) feature_names=label_columns)
for i in range(5): for i in range(5):
shap.plots.waterfall(exp[i], show=False) shap.plots.waterfall(exp[i], show=False)