middle
parent
041f602b81
commit
44d5456d0a
35
trainer.py
35
trainer.py
|
@ -655,17 +655,42 @@ class BinaryTuner:
|
|||
data=X_model,
|
||||
feature_names=label_columns)
|
||||
|
||||
#
|
||||
# shap.plots.initjs()
|
||||
shap.plots.decision(exp.base_values[0], exp.values, features=label_columns, show=False)
|
||||
# shap.plots.force(exp.base_values, exp.values, feature_names=Xbase.columns, show=False)
|
||||
# shap.plots.force(exp.base_values[0], exp.values[0, :], feature_names=Xbase.columns, matplotlib=True, show=False)
|
||||
# shap.plots.force(expected_values[0], shap_values.values, Xbase.columns , show=False)
|
||||
plt.title(r"{0}".format(modelname))
|
||||
plt.ylabel("Respuesta del Modelo: 0 Negativo, 1 Positivo")
|
||||
plt.savefig("{}/shap_{}_{}_{}.png".format(self.name, modelname, dataset, seed),dpi=150, bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
y_pred = model.predict(X_model)
|
||||
# make a numpy array from y_pred where all the values > 0.5 become 1 and all remaining values are 0
|
||||
if type_of_target(y_pred) == "continuous":
|
||||
y_pred = np.where(y_pred > 0.5, 1, 0)
|
||||
|
||||
X_pos = X_model[y_pred == 1]
|
||||
shap_values = explainer(X_pos)
|
||||
exp = shap.Explanation(shap_values,
|
||||
data=X_pos,
|
||||
feature_names=label_columns)
|
||||
|
||||
shap.plots.decision(exp.base_values[0], exp.values, features=label_columns, show=False)
|
||||
plt.title(r"{0}".format(modelname))
|
||||
plt.ylabel("Respuesta del Modelo Positivas")
|
||||
plt.savefig("{}/shap_pos_{}_{}_{}.png".format(self.name, modelname, dataset, seed),dpi=150, bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
|
||||
X_pos = X_model[y_pred == 0]
|
||||
shap_values = explainer(X_pos)
|
||||
exp = shap.Explanation(shap_values,
|
||||
data=X_pos,
|
||||
feature_names=label_columns)
|
||||
|
||||
shap.plots.decision(exp.base_values[0], exp.values, features=label_columns, show=False)
|
||||
plt.title(r"{0}".format(modelname))
|
||||
plt.ylabel("Respuesta del Modelo Negativas")
|
||||
plt.savefig("{}/shap_pos_{}_{}_{}.png".format(self.name, modelname, dataset, seed),dpi=150, bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
|
||||
shap_values = explainer(X_explain)
|
||||
|
||||
|
|
Loading…
Reference in New Issue