Browse Source

[Feature] Add bidirectional context support for LLM

* Unify context for incoming and outgoing mail
* Same identity used for authenticated/local sender and recipient
* Follows replies module pattern for direction detection
* Make llm_context.lua module-agnostic with debug_module parameter
* Improve userdata handling (use :sub instead of string.sub)
* Add nil-safety to all debug logging calls
* Add cache expiration timestamps to context logs
pull/5647/head
Vsevolod Stakhov 6 days ago
parent
commit
38c48e5a62
No known key found for this signature in database GPG Key ID: 7647B6790081437
  1. 162
      lualib/llm_context.lua

162
lualib/llm_context.lua

@ -2,8 +2,8 @@
Context management for LLM-based spam detection
Provides:
- fetch(task, redis_params, opts, callback): load context JSON from Redis and format prompt snippet
- update_after_classification(task, redis_params, opts, result, sel_part): update context after LLM result
- fetch(task, redis_params, opts, callback, debug_module): load context JSON from Redis and format prompt snippet
- update_after_classification(task, redis_params, opts, result, sel_part, debug_module): update context after LLM result
Opts (all optional, safe defaults applied):
enabled: boolean
@ -17,6 +17,8 @@ Opts (all optional, safe defaults applied):
summary_max_chars: number (truncate stored text)
flagged_phrases: array of strings (case-insensitive match)
last_labels_count: number
debug_module: optional string, module name for debug logging (default: 'llm_context')
]]
local M = {}
@ -56,49 +58,86 @@ local function to_seconds(v)
return tonumber(v) or 0
end
local function get_principal_recipient(task)
return task:get_principal_recipient()
end
local function get_domain_from_addr(addr)
if not addr then return nil end
return string.match(addr, '.*@(.+)')
end
local function compute_identity(task, opts)
local scope = opts.level or DEFAULTS.level
-- Determine our user/domain - same identity for both incoming and outgoing mail
local function get_our_identity(task, scope)
-- For outgoing mail: authenticated user or sender from local network
-- For incoming mail: principal recipient
local user = task:get_user()
local ip = task:get_ip()
local is_outgoing = user or (ip and ip:is_local())
local identity
if scope == 'user' then
identity = task:get_user() or get_principal_recipient(task)
if not identity then
local from = ((task:get_from('smtp') or EMPTY)[1] or EMPTY)['addr']
identity = from
if is_outgoing then
-- Outgoing: use sender (authenticated user or from address)
identity = user or task:get_reply_sender()
if not identity then
local from = ((task:get_from('smtp') or EMPTY)[1] or EMPTY)['addr']
identity = from
end
else
-- Incoming: use recipient
identity = task:get_principal_recipient()
end
elseif scope == 'domain' then
local rcpt = get_principal_recipient(task)
identity = get_domain_from_addr(rcpt)
if not identity then
identity = ((task:get_from('smtp') or EMPTY)[1] or EMPTY)['domain']
if is_outgoing then
-- Outgoing: domain of sender
if user then
identity = get_domain_from_addr(user)
end
if not identity then
identity = ((task:get_from('smtp') or EMPTY)[1] or EMPTY)['domain']
end
else
-- Incoming: domain of recipient
local rcpt = task:get_principal_recipient()
identity = get_domain_from_addr(rcpt)
end
elseif scope == 'esld' then
local rcpt = get_principal_recipient(task)
local d = get_domain_from_addr(rcpt)
if d then
identity = rspamd_util.get_tld(d)
end
if not identity then
local fd = ((task:get_from('smtp') or EMPTY)[1] or EMPTY)['domain']
if fd then identity = rspamd_util.get_tld(fd) end
if is_outgoing then
-- Outgoing: eSLD of sender domain
local d
if user then
d = get_domain_from_addr(user)
end
if not d then
d = ((task:get_from('smtp') or EMPTY)[1] or EMPTY)['domain']
end
if d then identity = rspamd_util.get_tld(d) end
else
-- Incoming: eSLD of recipient domain
local rcpt = task:get_principal_recipient()
local d = get_domain_from_addr(rcpt)
if d then
identity = rspamd_util.get_tld(d)
end
end
else
scope = 'user'
identity = task:get_user() or get_principal_recipient(task)
end
return identity
end
local function compute_identity(task, opts, debug_module)
local N = debug_module or 'llm_context'
local scope = opts.level or DEFAULTS.level
local identity = get_our_identity(task, scope)
if not identity or identity == '' then
return nil
end
-- Log direction for debugging
local user = task:get_user()
local ip = task:get_ip()
local is_outgoing = user or (ip and ip:is_local())
lua_util.debugm(N, task, 'computed identity for %s (%s): %s',
scope, is_outgoing and 'outgoing' or 'incoming', tostring(identity))
local key_prefix = opts.key_prefix or DEFAULTS.key_prefix
local key_suffix = opts.key_suffix or DEFAULTS.key_suffix
local key = string.format('%s:%s:%s', key_prefix, identity, key_suffix)
@ -110,10 +149,17 @@ local function compute_identity(task, opts)
}
end
local function parse_json(str)
if not str or str == '' then return nil end
local function parse_json(data)
if not data then return nil end
-- Redis can return userdata nil or empty string
if type(data) == 'userdata' then
data = tostring(data)
end
if type(data) ~= 'string' or data == '' then
return nil
end
local parser = ucl.parser()
local ok, err = parser:parse_string(str)
local ok, err = parser:parse_text(data)
if not ok then return nil, err end
return parser:get_object()
end
@ -129,7 +175,7 @@ end
local function truncate_text(txt, limit)
if not txt then return '' end
if #txt <= limit then return txt end
return string.sub(txt, 1, limit)
return txt:sub(1, limit)
end
local function has_flag(flags, flag_name)
@ -314,7 +360,8 @@ local function format_context_prompt(ctx)
return table.concat(parts, '\n')
end
function M.fetch(task, redis_params, opts, callback)
function M.fetch(task, redis_params, opts, callback, debug_module)
local N = debug_module or 'llm_context'
opts = lua_util.override_defaults(DEFAULTS, opts or {})
if not opts.enabled then
callback(nil, nil, nil)
@ -325,22 +372,28 @@ function M.fetch(task, redis_params, opts, callback)
return
end
local ident = compute_identity(task, opts)
local ident = compute_identity(task, opts, N)
if not ident then
lua_util.debugm(N, task, 'no identity computed, skipping context')
callback('no identity', nil, nil)
return
end
lua_util.debugm(N, task, 'fetching context for %s: %s',
tostring(ident.scope), tostring(ident.identity))
local function on_get(err, data)
if err then
rspamd_logger.errx(task, 'llm_context: get failed: %s', err)
rspamd_logger.errx(task, 'llm_context: get failed: %s', tostring(err))
callback(err, nil, nil)
return
end
local ctx
if data then
lua_util.debugm(N, task, 'got context data from redis, parsing')
ctx = ensure_defaults(select(1, parse_json(data)) or {})
else
lua_util.debugm(N, task, 'no context data in redis, using empty')
ctx = ensure_defaults({})
end
@ -348,12 +401,14 @@ function M.fetch(task, redis_params, opts, callback)
local min_msgs = opts.min_messages or DEFAULTS.min_messages
local msg_count = #(ctx.recent_messages or {})
if msg_count < min_msgs then
lua_util.debugm('llm_context', task, 'context has only %s messages (min: %s), not injecting into prompt',
msg_count, min_msgs)
lua_util.debugm(N, task, 'context has only %s messages (min: %s), not injecting into prompt',
tostring(msg_count), tostring(min_msgs))
callback(nil, ctx, nil) -- return ctx but no prompt snippet
return
end
lua_util.debugm(N, task, 'context warm-up OK: %s messages, generating snippet',
tostring(msg_count))
local prompt_snippet = format_context_prompt(ctx)
callback(nil, ctx, prompt_snippet)
end
@ -364,19 +419,22 @@ function M.fetch(task, redis_params, opts, callback)
end
end
function M.update_after_classification(task, redis_params, opts, result, sel_part)
function M.update_after_classification(task, redis_params, opts, result, sel_part, debug_module)
local N = debug_module or 'llm_context'
opts = lua_util.override_defaults(DEFAULTS, opts or {})
if not opts.enabled then return end
if not redis_params then return end
local ident = compute_identity(task, opts)
local ident = compute_identity(task, opts, N)
if not ident then return end
local function on_get(err, data)
if err then
rspamd_logger.errx(task, 'llm_context: get for update failed: %s', err)
rspamd_logger.errx(task, 'llm_context: get for update failed: %s', tostring(err))
return
end
lua_util.debugm(N, task, 'updating context for %s: %s',
tostring(ident.scope), tostring(ident.identity))
local ctx = ensure_defaults(select(1, parse_json(data)) or {})
local msg = build_message_summary(task, sel_part, opts)
@ -413,9 +471,33 @@ function M.update_after_classification(task, redis_params, opts, result, sel_par
local payload = encode_json(ctx)
local ttl = to_seconds(opts.ttl)
local expire_at = now() + ttl
-- Log what we're storing in context
lua_util.debugm(N, task,
'storing context for %s: %s messages, labels=%s, top_senders=%s, flagged=%s, payload_size=%s bytes, expiring at %s',
tostring(ident.identity or '(none)'),
tostring(#ctx.recent_messages),
table.concat(ctx.last_spam_labels or {}, ','),
table.concat(ctx.top_senders or {}, ','),
table.concat(ctx.flagged_phrases or {}, ','),
tostring(#payload),
os.date('%Y-%m-%d %H:%M:%S', expire_at))
if msg then
lua_util.debugm(N, task,
'added message: from=%s, subject=%s, keywords=%s',
tostring(msg.from or '(none)'),
tostring(msg.subject or '(none)'),
table.concat(msg.keywords or {}, ','))
end
local function on_set(set_err)
if set_err then
rspamd_logger.errx(task, 'llm_context: set failed: %s', set_err)
rspamd_logger.errx(task, 'llm_context: set failed: %s', tostring(set_err))
else
lua_util.debugm(N, task, 'context saved to redis: key=%s, ttl=%s, expiring at %s',
tostring(ident.key), tostring(ttl), os.date('%Y-%m-%d %H:%M:%S', expire_at))
end
end
local ok = lua_redis.redis_make_request(task, redis_params, ident.key, true, on_set, 'SETEX',

Loading…
Cancel
Save