diff --git a/TrainerClass.py b/TrainerClass.py index 237ac41..e8d0389 100644 --- a/TrainerClass.py +++ b/TrainerClass.py @@ -372,7 +372,7 @@ class eNoseTrainer: "mse": mse, "mae": mae, "rmse": rmse, - "num_params": sum(t.count("\n") for t in optimized_model.get_booster().get_dump(dump_format="json")) + "num_params": None if model_params.get('multi_strategy') == 'multi_output_tree' else sum(t.count("\n") for t in optimized_model.get_booster().get_dump()) }] ) self.ledger = pd.concat([self.ledger, newrow], ignore_index=True) self.bar.update() @@ -430,7 +430,7 @@ class eNoseTrainer: "mse": mse, "mae": mae, "rmse": rmse, - "num_params": sum(t.count("\n") for t in optimized_model.get_booster().get_dump(dump_format="json")) + "num_params": None if model_params.get('multi_strategy') == 'multi_output_tree' else sum(t.count("\n") for t in optimized_model.get_booster().get_dump()) }] ) self.ledger = pd.concat([self.ledger, newrow], ignore_index=True)