Browse Source

[Minor] Add specific calculations for binary classification case

pull/5569/head
Vsevolod Stakhov 2 months ago
parent
commit
b6a3d5c9a6
No known key found for this signature in database GPG Key ID: 7647B6790081437
  1. 58
      src/libstat/classifiers/bayes.c

58
src/libstat/classifiers/bayes.c

@ -331,34 +331,58 @@ bayes_classify_token_multiclass(struct rspamd_classifier *ctx,
w = (fw * total_count) / (1.0 + fw * total_count);
/* Apply multinomial model for each class */
for (j = 0; j < cl->num_classes; j++) {
/* Skip classes with insufficient learns */
if (ctx->cfg->min_learns > 0 && cl->class_learns[j] < ctx->cfg->min_learns) {
continue;
if (cl->num_classes == 2) {
/* Binary-compatible path: normalize per-token probabilities across the two classes */
double f0 = (double) class_counts[0] / MAX(1.0, (double) cl->class_learns[0]);
double f1 = (double) class_counts[1] / MAX(1.0, (double) cl->class_learns[1]);
double denom = f0 + f1;
if (denom > 0.0) {
double p0 = f0 / denom;
double p1 = f1 / denom;
double bp0 = PROB_COMBINE(p0, total_count, w, 0.5);
double bp1 = PROB_COMBINE(p1, total_count, w, 0.5);
/* Bound and apply min strength (relative to 0.5 for binary) */
bp0 = MAX(0.0, MIN(1.0, bp0));
bp1 = MAX(0.0, MIN(1.0, bp1));
if (fabs(bp0 - 0.5) >= ctx->cfg->min_prob_strength) {
cl->class_log_probs[0] += log(bp0);
}
if (fabs(bp1 - 0.5) >= ctx->cfg->min_prob_strength) {
cl->class_log_probs[1] += log(bp1);
}
}
}
else {
/* General multinomial model for N>2 classes */
for (j = 0; j < cl->num_classes; j++) {
/* Skip classes with insufficient learns */
if (ctx->cfg->min_learns > 0 && cl->class_learns[j] < ctx->cfg->min_learns) {
continue;
}
double class_freq = (double) class_counts[j] / MAX(1.0, (double) cl->class_learns[j]);
double class_prob = PROB_COMBINE(class_freq, total_count, w, 1.0 / cl->num_classes);
double class_freq = (double) class_counts[j] / MAX(1.0, (double) cl->class_learns[j]);
double class_prob = PROB_COMBINE(class_freq, total_count, w, 1.0 / cl->num_classes);
/* Ensure probability is properly bounded [0, 1] */
class_prob = MAX(0.0, MIN(1.0, class_prob));
/* Ensure probability is properly bounded [0, 1] */
class_prob = MAX(0.0, MIN(1.0, class_prob));
/* Skip probabilities too close to uniform (1/num_classes) */
double uniform_prior = 1.0 / cl->num_classes;
if (fabs(class_prob - uniform_prior) < ctx->cfg->min_prob_strength) {
continue;
}
/* Skip probabilities too close to uniform (1/num_classes) */
double uniform_prior = 1.0 / cl->num_classes;
if (fabs(class_prob - uniform_prior) < ctx->cfg->min_prob_strength) {
continue;
}
cl->class_log_probs[j] += log(class_prob);
cl->class_log_probs[j] += log(class_prob);
}
}
cl->processed_tokens++;
if (!(tok->flags & RSPAMD_STAT_TOKEN_FLAG_META)) {
cl->text_tokens++;
}
/* Per-token debug logging removed to reduce verbosity */
}
}

Loading…
Cancel
Save