From 4b017c205a552d2a9834b6b78b4df4c9876fccb6 Mon Sep 17 00:00:00 2001 From: Vsevolod Stakhov Date: Wed, 23 Jul 2025 14:33:04 +0100 Subject: [PATCH] [Project] Fix binary classification and lua scripts --- lualib/redis_scripts/bayes_classify.lua | 4 +- src/libserver/cfg_file.h | 1 + src/libserver/cfg_rcl.cxx | 64 ++++++++++++++++++------- src/libserver/cfg_utils.cxx | 47 +++++++++++++----- src/libstat/classifiers/bayes.c | 11 ++++- 5 files changed, 96 insertions(+), 31 deletions(-) diff --git a/lualib/redis_scripts/bayes_classify.lua b/lualib/redis_scripts/bayes_classify.lua index 923adcc5a..d6132e631 100644 --- a/lualib/redis_scripts/bayes_classify.lua +++ b/lualib/redis_scripts/bayes_classify.lua @@ -35,7 +35,7 @@ end -- Get token data for all classes (ordered) local token_results = {} -for i, label in ipairs(class_labels) do +for i, _ in ipairs(class_labels) do token_results[i] = {} end @@ -54,7 +54,7 @@ if has_learns then local token_data = redis.call('HMGET', token, unpack(class_labels)) if token_data then - for j, label in ipairs(class_labels) do + for j, _ in ipairs(class_labels) do local count = token_data[j] if count and tonumber(count) > 0 then table.insert(token_results[j], { i, tonumber(count) }) diff --git a/src/libserver/cfg_file.h b/src/libserver/cfg_file.h index 5aaaece35..9f83f8024 100644 --- a/src/libserver/cfg_file.h +++ b/src/libserver/cfg_file.h @@ -142,6 +142,7 @@ struct rspamd_statfile_config { char *class_name; /**< class name for multi-class classification */ unsigned int class_index; /**< class index for O(1) lookup during classification */ gboolean is_spam; /**< DEPRECATED: spam flag - use class_name instead */ + gboolean is_spam_converted; /**< TRUE if class_name was converted from is_spam flag */ struct rspamd_classifier_config *clcf; /**< parent pointer of classifier configuration */ gpointer data; /**< opaque data */ }; diff --git a/src/libserver/cfg_rcl.cxx b/src/libserver/cfg_rcl.cxx index 5afb46745..68b6460d8 100644 --- a/src/libserver/cfg_rcl.cxx +++ b/src/libserver/cfg_rcl.cxx @@ -1215,11 +1215,13 @@ rspamd_rcl_statfile_handler(rspamd_mempool_t *pool, const ucl_object_t *obj, strlen(st->symbol), "spam", 4) != -1) { st->is_spam = TRUE; st->class_name = rspamd_mempool_strdup(pool, "spam"); + st->is_spam_converted = TRUE; } else if (rspamd_substring_search_caseless(st->symbol, strlen(st->symbol), "ham", 3) != -1) { st->is_spam = FALSE; st->class_name = rspamd_mempool_strdup(pool, "ham"); + st->is_spam_converted = TRUE; } else { g_set_error(err, @@ -1242,6 +1244,7 @@ rspamd_rcl_statfile_handler(rspamd_mempool_t *pool, const ucl_object_t *obj, else { st->class_name = rspamd_mempool_strdup(pool, "ham"); } + st->is_spam_converted = TRUE; } /* If class field is present, it was already parsed by the default parser */ return TRUE; @@ -1439,31 +1442,60 @@ rspamd_rcl_classifier_handler(rspamd_mempool_t *pool, cfg->classifiers = g_list_prepend(cfg->classifiers, ccf); - /* Populate class_names array from statfiles */ + /* Populate class_names array from statfiles - only for explicit multiclass configs */ if (ccf->statfiles) { GList *cur = ccf->statfiles; - ccf->class_names = g_ptr_array_new(); + gboolean has_explicit_classes = FALSE; + /* Check if any statfile uses explicit class declaration (not converted from is_spam) */ + cur = ccf->statfiles; while (cur) { struct rspamd_statfile_config *stcf = (struct rspamd_statfile_config *) cur->data; - if (stcf->class_name) { - /* Check if class already exists */ - bool found = false; - for (unsigned int i = 0; i < ccf->class_names->len; i++) { - if (strcmp((char *) g_ptr_array_index(ccf->class_names, i), stcf->class_name) == 0) { - stcf->class_index = i; /* Store the index for O(1) lookup */ - found = true; - break; + msg_debug("checking statfile %s: class_name=%s, is_spam_converted=%s", + stcf->symbol, stcf->class_name ? stcf->class_name : "NULL", + stcf->is_spam_converted ? "true" : "false"); + if (stcf->class_name && !stcf->is_spam_converted) { + has_explicit_classes = TRUE; + break; + } + cur = g_list_next(cur); + } + + msg_debug("has_explicit_classes = %s", has_explicit_classes ? "true" : "false"); + + /* Only populate class_names for explicit multiclass configurations */ + if (has_explicit_classes) { + msg_debug("populating class_names for multiclass configuration"); + } + else { + msg_debug("skipping class_names population for binary configuration"); + } + + if (has_explicit_classes) { + ccf->class_names = g_ptr_array_new(); + + cur = ccf->statfiles; + while (cur) { + struct rspamd_statfile_config *stcf = (struct rspamd_statfile_config *) cur->data; + if (stcf->class_name) { + /* Check if class already exists */ + bool found = false; + for (unsigned int i = 0; i < ccf->class_names->len; i++) { + if (strcmp((char *) g_ptr_array_index(ccf->class_names, i), stcf->class_name) == 0) { + stcf->class_index = i; /* Store the index for O(1) lookup */ + found = true; + break; + } } - } - if (!found) { - /* Add new class */ - stcf->class_index = ccf->class_names->len; - g_ptr_array_add(ccf->class_names, g_strdup(stcf->class_name)); + if (!found) { + /* Add new class */ + stcf->class_index = ccf->class_names->len; + g_ptr_array_add(ccf->class_names, g_strdup(stcf->class_name)); + } } + cur = g_list_next(cur); } - cur = g_list_next(cur); } } diff --git a/src/libserver/cfg_utils.cxx b/src/libserver/cfg_utils.cxx index c8c083439..2533bd65e 100644 --- a/src/libserver/cfg_utils.cxx +++ b/src/libserver/cfg_utils.cxx @@ -3181,18 +3181,41 @@ rspamd_config_validate_class_config(struct rspamd_classifier_config *ccf, GError class_count); } - /* Initialize classifier class tracking */ - if (ccf->class_names) { - g_ptr_array_unref(ccf->class_names); - } - ccf->class_names = g_ptr_array_new_with_free_func(g_free); - - /* Populate class names array */ - GHashTableIter iter; - gpointer key, value; - g_hash_table_iter_init(&iter, seen_classes); - while (g_hash_table_iter_next(&iter, &key, &value)) { - g_ptr_array_add(ccf->class_names, g_strdup((const char *) key)); + /* Initialize classifier class tracking - only for explicit multiclass configurations */ + gboolean has_explicit_classes = FALSE; + + /* Check if any statfile uses explicit class declaration (not converted from is_spam) */ + cur = ccf->statfiles; + while (cur) { + stcf = (struct rspamd_statfile_config *) cur->data; + if (stcf->class_name && !stcf->is_spam_converted) { + has_explicit_classes = TRUE; + break; + } + cur = g_list_next(cur); + } + + /* Only populate class_names for explicit multiclass configurations */ + if (has_explicit_classes) { + if (ccf->class_names) { + g_ptr_array_unref(ccf->class_names); + } + ccf->class_names = g_ptr_array_new_with_free_func(g_free); + + /* Populate class names array */ + GHashTableIter iter; + gpointer key, value; + g_hash_table_iter_init(&iter, seen_classes); + while (g_hash_table_iter_next(&iter, &key, &value)) { + g_ptr_array_add(ccf->class_names, g_strdup((const char *) key)); + } + } + else { + /* Binary configuration - ensure class_names is NULL */ + if (ccf->class_names) { + g_ptr_array_unref(ccf->class_names); + ccf->class_names = nullptr; + } } g_hash_table_destroy(seen_classes); diff --git a/src/libstat/classifiers/bayes.c b/src/libstat/classifiers/bayes.c index 4d070ee20..3fd7190ae 100644 --- a/src/libstat/classifiers/bayes.c +++ b/src/libstat/classifiers/bayes.c @@ -620,18 +620,27 @@ bayes_classify(struct rspamd_classifier *ctx, g_assert(tokens != NULL); /* Check if this is a multi-class classifier */ + msg_debug_bayes("classification check: class_names=%p, len=%uz", + ctx->cfg->class_names, + ctx->cfg->class_names ? ctx->cfg->class_names->len : 0); + if (ctx->cfg->class_names && ctx->cfg->class_names->len >= 2) { /* Verify that at least one statfile has class_name set (indicating new multi-class config) */ gboolean has_class_names = FALSE; for (i = 0; i < ctx->statfiles_ids->len; i++) { int id = g_array_index(ctx->statfiles_ids, int, i); struct rspamd_statfile *st = g_ptr_array_index(ctx->ctx->statfiles, id); + msg_debug_bayes("checking statfile %s: class_name=%s, is_spam_converted=%s", + st->stcf->symbol, + st->stcf->class_name ? st->stcf->class_name : "NULL", + st->stcf->is_spam_converted ? "true" : "false"); if (st->stcf->class_name) { has_class_names = TRUE; - break; } } + msg_debug_bayes("has_class_names=%s", has_class_names ? "true" : "false"); + if (has_class_names) { msg_debug_bayes("using multiclass classification with %ud classes", (unsigned int) ctx->cfg->class_names->len);