Browse Source

[Feature] Improve GPT module with uncertain caching and server timeout

* Add GPT_UNCERTAIN symbol for caching uncertain classifications
  - Cache results even when no consensus is reached
  - Avoid repeated expensive LLM queries for borderline cases
  - Set X-GPT-Reason header with detailed vote statistics
* Add server-side timeout support for OpenAI API requests
  - New request_timeout parameter (optional, multiplied by 0.95)
  - Only sent if explicitly configured (not all APIs support this)
  - Accounts for connection setup and data transfer overhead
* Fix max_ham_prob initialization (was 0, now correctly 1.0)
* Add pcall protection for fold_header_with_encoding with raw fallback
* Improve error messages for token limit exceeded
* Add detailed logging for context snippets and consensus decisions
* Pass debug_module parameter to llm_context functions
pull/5647/head
Vsevolod Stakhov 1 week ago
parent
commit
74f4837503
No known key found for this signature in database GPG Key ID: 7647B6790081437
  1. 123
      src/plugins/lua/gpt.lua

123
src/plugins/lua/gpt.lua

@ -62,6 +62,13 @@ if confighelp then
reason_header = "X-GPT-Reason";
# Use JSON format for response
json = false;
# Optional: pass request timeout to the server (in seconds)
# WARNING: Not all API implementations support this parameter (e.g., standard OpenAI API doesn't)
# Only enable if your API endpoint/proxy specifically supports max_completion_time parameter
# If not set, this parameter will not be sent to the server
# Note: the actual value sent to server is multiplied by 0.95 to account for
# connection setup, SSL handshake, and data transfer overhead
# request_timeout = 8;
# Optional user/domain context in Redis
context = {
@ -133,6 +140,11 @@ local default_extra_symbols = {
description = 'GPT model detected malware content',
category = 'malware',
},
GPT_UNCERTAIN = {
score = 0.0,
description = 'GPT model was uncertain about classification',
category = 'uncertain',
},
}
-- Should be filled from extra symbols
@ -172,6 +184,7 @@ local settings = {
json = false,
extra_symbols = nil,
cache_prefix = REDIS_PREFIX,
request_timeout = nil, -- Optional: pass request timeout to server (in seconds)
-- user/domain context options (nested table forwarded to llm_context)
context = {
enabled = false,
@ -432,12 +445,12 @@ local function default_openai_json_conversion(task, input)
elseif reply.probability == "low" then
spam_score = 0.1
else
rspamd_logger.infox(task, "cannot convert to spam probability: %s", reply.probability)
lua_util.debugm(N, task, "cannot convert to spam probability: %s", reply.probability)
end
end
if type(reply.usage) == 'table' then
rspamd_logger.infox(task, 'usage: %s tokens', reply.usage.total_tokens)
lua_util.debugm(N, task, 'usage: %s tokens', reply.usage.total_tokens)
end
return spam_score, reply.reason, {}
@ -475,9 +488,22 @@ local function default_openai_plain_conversion(task, input)
end
local first_message = reply.choices[1].message.content
local finish_reason = reply.choices[1].finish_reason or 'unknown'
if not first_message or first_message == "" then
rspamd_logger.errx(task, 'no content in the first message')
if finish_reason == 'length' then
-- Token limit exceeded - provide helpful error message
local usage = reply.usage or {}
local completion_tokens = usage.completion_tokens or 0
local reasoning_tokens = usage.completion_tokens_details and usage.completion_tokens_details.reasoning_tokens or 0
rspamd_logger.errx(task, 'LLM response truncated: token limit exceeded. ' ..
'Used %s completion tokens (including %s reasoning tokens). ' ..
'Increase max_completion_tokens in model_parameters config for this model.',
completion_tokens, reasoning_tokens)
else
rspamd_logger.errx(task, 'no content in the first message (finish_reason: %s, usage: %s)',
finish_reason, reply.usage and ucl.to_format(reply.usage, 'json-compact') or 'none')
end
return
end
@ -491,7 +517,7 @@ local function default_openai_plain_conversion(task, input)
local categories = lua_util.str_split(clean_reply_line(lines[3]), ',')
if type(reply.usage) == 'table' then
rspamd_logger.infox(task, 'usage: %s tokens', reply.usage.total_tokens)
lua_util.debugm(N, task, 'usage: %s tokens', reply.usage.total_tokens)
end
if spam_score then
@ -592,12 +618,12 @@ local function default_ollama_json_conversion(task, input)
elseif reply.probability == "low" then
spam_score = 0.1
else
rspamd_logger.infox(task, "cannot convert to spam probability: %s", reply.probability)
lua_util.debugm(N, task, "cannot convert to spam probability: %s", reply.probability)
end
end
if type(reply.usage) == 'table' then
rspamd_logger.infox(task, 'usage: %s tokens', reply.usage.total_tokens)
lua_util.debugm(N, task, 'usage: %s tokens', reply.usage.total_tokens)
end
return spam_score, reply.reason, {}
@ -647,7 +673,7 @@ local function insert_results(task, result, sel_part)
if result.categories then
process_categories(task, result.categories)
end
else
elseif result.probability < 0.5 then
task:insert_result('GPT_HAM', (0.5 - result.probability) * 2, tostring(result.probability))
if settings.autolearn then
task:set_flag("learn_ham")
@ -655,12 +681,27 @@ local function insert_results(task, result, sel_part)
if result.categories then
process_categories(task, result.categories)
end
else
-- probability == 0.5, uncertain result, don't set GPT_SPAM/GPT_HAM
if result.categories then
process_categories(task, result.categories)
end
end
if result.reason and settings.reason_header then
local v = lua_util.fold_header_with_encoding(task, settings.reason_header,
tostring(result.reason), { encode = 'auto' })
lua_mime.modify_headers(task,
{ add = { [settings.reason_header] = { value = v, order = 1 } } })
if type(settings.reason_header) == 'string' and #result.reason > 0 then
local ok, v = pcall(lua_util.fold_header_with_encoding, task, settings.reason_header,
result.reason, { encode = false, structured = false })
if ok and v then
lua_mime.modify_headers(task,
{ add = { [settings.reason_header] = { value = v, order = 1 } } })
else
rspamd_logger.warnx(task, 'cannot fold header %s: %s; using raw value', settings.reason_header,
v)
-- Fallback: use raw value without encoding
lua_mime.modify_headers(task,
{ add = { [settings.reason_header] = { value = result.reason, order = 1 } } })
end
end
end
if cache_context then
@ -669,7 +710,7 @@ local function insert_results(task, result, sel_part)
-- Update long-term user/domain context after classification
if redis_params and settings.context then
llm_context.update_after_classification(task, redis_params, settings.context, result, sel_part)
llm_context.update_after_classification(task, redis_params, settings.context, result, sel_part, N)
end
end
@ -681,21 +722,21 @@ local function check_consensus_and_insert_results(task, results, sel_part)
end
local nspam, nham = 0, 0
local max_spam_prob, max_ham_prob = 0, 0
local max_spam_prob, max_ham_prob = 0, 1.0
local reasons = {}
for _, result in ipairs(results) do
if result.success then
if result.success and result.probability then
if result.probability > 0.5 then
nspam = nspam + 1
max_spam_prob = math.max(max_spam_prob, result.probability)
lua_util.debugm(N, task, "model: %s; spam: %s; reason: '%s'",
result.model, result.probability, result.reason)
result.model or 'unknown', result.probability, result.reason or 'no reason')
else
nham = nham + 1
max_ham_prob = math.min(max_ham_prob, result.probability)
lua_util.debugm(N, task, "model: %s; ham: %s; reason: '%s'",
result.model, result.probability, result.reason)
result.model or 'unknown', result.probability, result.reason or 'no reason')
end
if result.reason then
@ -724,8 +765,20 @@ local function check_consensus_and_insert_results(task, results, sel_part)
},
sel_part)
else
-- No consensus
lua_util.debugm(N, task, "no consensus")
-- No consensus - still cache and set uncertain symbol to avoid re-querying LLM
lua_util.debugm(N, task, "no consensus: nspam=%s, nham=%s, max_spam_prob=%s, max_ham_prob=%s",
nspam, nham, max_spam_prob, max_ham_prob)
-- Use 0.5 (neutral) probability with uncertain marker
local uncertain_reason = reason_text or string.format(
"Uncertain classification: spam votes=%d (max %.2f), ham votes=%d (min %.2f)",
nspam, max_spam_prob, nham, max_ham_prob)
insert_results(task, {
probability = 0.5,
reason = uncertain_reason,
categories = { 'uncertain' },
},
sel_part)
task:insert_result('GPT_UNCERTAIN', 1.0)
end
end
@ -747,7 +800,7 @@ local function check_llm_cached(task, content, sel_part, context_snippet)
end
if data then
rspamd_logger.infox(task, 'found cached response %s', cache_key)
lua_util.debugm(N, task, 'found cached response %s', cache_key)
insert_results(task, data, sel_part)
else
check_llm_uncached(task, content, sel_part, context_snippet)
@ -757,6 +810,11 @@ end
local function openai_check(task, content, sel_part, context_snippet)
lua_util.debugm(N, task, "sending content to gpt: %s", content)
if context_snippet then
lua_util.debugm(N, task, "with context snippet (%s chars): %s", #context_snippet, context_snippet)
else
lua_util.debugm(N, task, "no context snippet")
end
local upstream
local results = {}
@ -851,6 +909,13 @@ local function openai_check(task, content, sel_part, context_snippet)
body.response_format = { type = "json_object" }
end
-- Optionally add request timeout for server-side timeout control
-- Only pass if explicitly configured (not all API implementations support this)
-- Multiply by 0.95 to account for connection setup, SSL handshake, and data transfer time
if settings.request_timeout then
body.max_completion_time = settings.request_timeout * 0.95
end
body.model = model
upstream = settings.upstreams:get_upstream_round_robin()
@ -883,6 +948,11 @@ end
local function ollama_check(task, content, sel_part, context_snippet)
lua_util.debugm(N, task, "sending content to gpt: %s", content)
if context_snippet then
lua_util.debugm(N, task, "with context snippet (%s chars): %s", #context_snippet, context_snippet)
else
lua_util.debugm(N, task, "no context snippet")
end
local upstream
local results = {}
@ -975,6 +1045,13 @@ local function ollama_check(task, content, sel_part, context_snippet)
body.response_format = { type = "json_object" }
end
-- Optionally add request timeout for server-side timeout control
-- Only pass if explicitly configured (not all API implementations support this)
-- Multiply by 0.95 to account for connection setup, SSL handshake, and data transfer time
if settings.request_timeout then
body.max_completion_time = settings.request_timeout * 0.95
end
body.model = model
upstream = settings.upstreams:get_upstream_round_robin()
@ -1024,14 +1101,14 @@ local function gpt_check(task)
inferred_result = { probability = 0.1, reason = 'ham by filters', categories = {} }
end
end
llm_context.update_after_classification(task, redis_params, settings.context, inferred_result, sel_part)
llm_context.update_after_classification(task, redis_params, settings.context, inferred_result, sel_part, N)
end
rspamd_logger.info(task, "skip checking gpt as the condition is not met: %s; context updated", content)
lua_util.debugm(N, task, "skip checking gpt as the condition is not met: %s; context updated", content)
return
end
if not ret then
rspamd_logger.info(task, "skip checking gpt as the condition is not met: %s", content)
lua_util.debugm(N, task, "skip checking gpt as the condition is not met: %s", content)
return
end
@ -1052,7 +1129,7 @@ local function gpt_check(task)
if context_enabled then
llm_context.fetch(task, redis_params, settings.context, function(_, _, snippet)
proceed(snippet)
end)
end, N)
else
proceed(nil)
end

Loading…
Cancel
Save