diff --git a/src/libstat/classifiers/bayes.c b/src/libstat/classifiers/bayes.c index dbae98cc2..1d5bb2a6f 100644 --- a/src/libstat/classifiers/bayes.c +++ b/src/libstat/classifiers/bayes.c @@ -331,34 +331,58 @@ bayes_classify_token_multiclass(struct rspamd_classifier *ctx, w = (fw * total_count) / (1.0 + fw * total_count); - /* Apply multinomial model for each class */ - for (j = 0; j < cl->num_classes; j++) { - /* Skip classes with insufficient learns */ - if (ctx->cfg->min_learns > 0 && cl->class_learns[j] < ctx->cfg->min_learns) { - continue; + if (cl->num_classes == 2) { + /* Binary-compatible path: normalize per-token probabilities across the two classes */ + double f0 = (double) class_counts[0] / MAX(1.0, (double) cl->class_learns[0]); + double f1 = (double) class_counts[1] / MAX(1.0, (double) cl->class_learns[1]); + double denom = f0 + f1; + + if (denom > 0.0) { + double p0 = f0 / denom; + double p1 = f1 / denom; + double bp0 = PROB_COMBINE(p0, total_count, w, 0.5); + double bp1 = PROB_COMBINE(p1, total_count, w, 0.5); + + /* Bound and apply min strength (relative to 0.5 for binary) */ + bp0 = MAX(0.0, MIN(1.0, bp0)); + bp1 = MAX(0.0, MIN(1.0, bp1)); + + if (fabs(bp0 - 0.5) >= ctx->cfg->min_prob_strength) { + cl->class_log_probs[0] += log(bp0); + } + if (fabs(bp1 - 0.5) >= ctx->cfg->min_prob_strength) { + cl->class_log_probs[1] += log(bp1); + } } + } + else { + /* General multinomial model for N>2 classes */ + for (j = 0; j < cl->num_classes; j++) { + /* Skip classes with insufficient learns */ + if (ctx->cfg->min_learns > 0 && cl->class_learns[j] < ctx->cfg->min_learns) { + continue; + } - double class_freq = (double) class_counts[j] / MAX(1.0, (double) cl->class_learns[j]); - double class_prob = PROB_COMBINE(class_freq, total_count, w, 1.0 / cl->num_classes); + double class_freq = (double) class_counts[j] / MAX(1.0, (double) cl->class_learns[j]); + double class_prob = PROB_COMBINE(class_freq, total_count, w, 1.0 / cl->num_classes); - /* Ensure probability is properly bounded [0, 1] */ - class_prob = MAX(0.0, MIN(1.0, class_prob)); + /* Ensure probability is properly bounded [0, 1] */ + class_prob = MAX(0.0, MIN(1.0, class_prob)); - /* Skip probabilities too close to uniform (1/num_classes) */ - double uniform_prior = 1.0 / cl->num_classes; - if (fabs(class_prob - uniform_prior) < ctx->cfg->min_prob_strength) { - continue; - } + /* Skip probabilities too close to uniform (1/num_classes) */ + double uniform_prior = 1.0 / cl->num_classes; + if (fabs(class_prob - uniform_prior) < ctx->cfg->min_prob_strength) { + continue; + } - cl->class_log_probs[j] += log(class_prob); + cl->class_log_probs[j] += log(class_prob); + } } cl->processed_tokens++; if (!(tok->flags & RSPAMD_STAT_TOKEN_FLAG_META)) { cl->text_tokens++; } - - /* Per-token debug logging removed to reduce verbosity */ } }