diff --git a/trainer.py b/trainer.py index c13efb4..b7eb8c0 100644 --- a/trainer.py +++ b/trainer.py @@ -700,13 +700,13 @@ class BinaryTuner: for i in range(5): shap.plots.waterfall(exp[i], show=False) - plt.title(r"{0} $y_{{{1}}}=0$".format(modelname, i)) + plt.title(r"{0} $y_{{{1}}}=1$".format(modelname, i)) plt.savefig("{}/pos_{}_{}_{}_{}.png".format(self.name, i, modelname, dataset, seed),dpi=150, bbox_inches='tight') plt.close() for i in range(5, 10): shap.plots.waterfall(exp[i], show=False) - plt.title(r"{0} $y_{{{1}}}=1$".format(modelname, i-5)) + plt.title(r"{0} $y_{{{1}}}=0$".format(modelname, i-5)) plt.savefig("{}/neg_{}_{}_{}_{}.png".format(self.name, i-5, modelname, dataset, seed),dpi=150, bbox_inches='tight') plt.close()