mirror of https://github.com/rspamd/rspamd.git
Rapid spam filtering system
https://rspamd.com/
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
533 lines
16 KiB
533 lines
16 KiB
--[[
|
|
Context management for LLM-based spam detection
|
|
|
|
Provides:
|
|
- fetch(task, redis_params, opts, callback, debug_module): load context JSON from Redis and format prompt snippet
|
|
- update_after_classification(task, redis_params, opts, result, sel_part, debug_module): update context after LLM result
|
|
|
|
Opts (all optional, safe defaults applied):
|
|
enabled: boolean
|
|
level: 'user' | 'domain' | 'esld' (scope for context key)
|
|
key_prefix: string (prefix before scope)
|
|
key_suffix: string (suffix after identity)
|
|
max_messages: number (sliding window size)
|
|
message_ttl: seconds
|
|
ttl: seconds (Redis key TTL)
|
|
top_senders: number (how many to keep in top_senders)
|
|
summary_max_chars: number (truncate stored text)
|
|
flagged_phrases: array of strings (case-insensitive match)
|
|
last_labels_count: number
|
|
|
|
debug_module: optional string, module name for debug logging (default: 'llm_context')
|
|
]]
|
|
|
|
local M = {}
|
|
|
|
local lua_redis = require "lua_redis"
|
|
local lua_util = require "lua_util"
|
|
local rspamd_logger = require "rspamd_logger"
|
|
local ucl = require "ucl"
|
|
local rspamd_util = require "rspamd_util"
|
|
local llm_common = require "llm_common"
|
|
|
|
local EMPTY = {}
|
|
|
|
local DEFAULTS = {
|
|
enabled = false,
|
|
level = 'user',
|
|
key_prefix = 'user',
|
|
key_suffix = 'mail_context',
|
|
max_messages = 40,
|
|
min_messages = 5, -- minimum messages in context before injecting into prompt
|
|
message_ttl = 14 * 24 * 3600,
|
|
ttl = 30 * 24 * 3600,
|
|
top_senders = 5,
|
|
summary_max_chars = 512,
|
|
flagged_phrases = {
|
|
'reset your password',
|
|
'click here to verify',
|
|
'confirm your account',
|
|
'urgent invoice',
|
|
'wire transfer',
|
|
},
|
|
last_labels_count = 10,
|
|
}
|
|
|
|
local function to_seconds(v)
|
|
if type(v) == 'number' then return v end
|
|
return tonumber(v) or 0
|
|
end
|
|
|
|
local function get_domain_from_addr(addr)
|
|
if not addr then return nil end
|
|
return string.match(addr, '.*@(.+)')
|
|
end
|
|
|
|
-- Determine our user/domain - same identity for both incoming and outgoing mail
|
|
local function get_our_identity(task, scope)
|
|
-- For outgoing mail: authenticated user or sender from local network
|
|
-- For incoming mail: principal recipient
|
|
local user = task:get_user()
|
|
local ip = task:get_ip()
|
|
local is_outgoing = user or (ip and ip:is_local())
|
|
|
|
local identity
|
|
if scope == 'user' then
|
|
if is_outgoing then
|
|
-- Outgoing: use sender (authenticated user or from address)
|
|
identity = user or task:get_reply_sender()
|
|
if not identity then
|
|
local from = ((task:get_from('smtp') or EMPTY)[1] or EMPTY)['addr']
|
|
identity = from
|
|
end
|
|
else
|
|
-- Incoming: use recipient
|
|
identity = task:get_principal_recipient()
|
|
end
|
|
elseif scope == 'domain' then
|
|
if is_outgoing then
|
|
-- Outgoing: domain of sender
|
|
if user then
|
|
identity = get_domain_from_addr(user)
|
|
end
|
|
if not identity then
|
|
identity = ((task:get_from('smtp') or EMPTY)[1] or EMPTY)['domain']
|
|
end
|
|
else
|
|
-- Incoming: domain of recipient
|
|
local rcpt = task:get_principal_recipient()
|
|
identity = get_domain_from_addr(rcpt)
|
|
end
|
|
elseif scope == 'esld' then
|
|
if is_outgoing then
|
|
-- Outgoing: eSLD of sender domain
|
|
local d
|
|
if user then
|
|
d = get_domain_from_addr(user)
|
|
end
|
|
if not d then
|
|
d = ((task:get_from('smtp') or EMPTY)[1] or EMPTY)['domain']
|
|
end
|
|
if d then identity = rspamd_util.get_tld(d) end
|
|
else
|
|
-- Incoming: eSLD of recipient domain
|
|
local rcpt = task:get_principal_recipient()
|
|
local d = get_domain_from_addr(rcpt)
|
|
if d then
|
|
identity = rspamd_util.get_tld(d)
|
|
end
|
|
end
|
|
end
|
|
|
|
return identity
|
|
end
|
|
|
|
local function compute_identity(task, opts, debug_module)
|
|
local N = debug_module or 'llm_context'
|
|
local scope = opts.level or DEFAULTS.level
|
|
local identity = get_our_identity(task, scope)
|
|
|
|
if not identity or identity == '' then
|
|
return nil
|
|
end
|
|
|
|
-- Log direction for debugging
|
|
local user = task:get_user()
|
|
local ip = task:get_ip()
|
|
local is_outgoing = user or (ip and ip:is_local())
|
|
lua_util.debugm(N, task, 'computed identity for %s (%s): %s',
|
|
scope, is_outgoing and 'outgoing' or 'incoming', tostring(identity))
|
|
|
|
local key_prefix = opts.key_prefix or DEFAULTS.key_prefix
|
|
local key_suffix = opts.key_suffix or DEFAULTS.key_suffix
|
|
local key = string.format('%s:%s:%s', key_prefix, identity, key_suffix)
|
|
|
|
return {
|
|
scope = scope,
|
|
identity = identity,
|
|
key = key,
|
|
}
|
|
end
|
|
|
|
local function parse_json(data)
|
|
if not data then return nil end
|
|
-- Redis can return userdata nil or empty string
|
|
if type(data) == 'userdata' then
|
|
data = tostring(data)
|
|
end
|
|
if type(data) ~= 'string' or data == '' then
|
|
return nil
|
|
end
|
|
local parser = ucl.parser()
|
|
local ok, err = parser:parse_text(data)
|
|
if not ok then return nil, err end
|
|
return parser:get_object()
|
|
end
|
|
|
|
local function encode_json(obj)
|
|
return ucl.to_format(obj, 'json-compact', true)
|
|
end
|
|
|
|
local function now()
|
|
return os.time()
|
|
end
|
|
|
|
local function truncate_text(txt, limit)
|
|
if not txt then return '' end
|
|
if #txt <= limit then return txt end
|
|
return txt:sub(1, limit)
|
|
end
|
|
|
|
local function has_flag(flags, flag_name)
|
|
if type(flags) ~= 'table' then return false end
|
|
for _, f in ipairs(flags) do
|
|
if f == flag_name then return true end
|
|
end
|
|
return false
|
|
end
|
|
|
|
local function extract_keywords(text_part, limit)
|
|
if not text_part then return {} end
|
|
local words = text_part:get_words('full')
|
|
if not words or #words == 0 then return {} end
|
|
|
|
local counts = {}
|
|
for _, w in ipairs(words) do
|
|
local norm_word = w[2] or '' -- normalized
|
|
local flags = w[4] or {}
|
|
-- Skip stop words, too short, or non-text
|
|
if not has_flag(flags, 'stop_word') and #norm_word > 2 and has_flag(flags, 'text') then
|
|
counts[norm_word] = (counts[norm_word] or 0) + 1
|
|
end
|
|
end
|
|
|
|
local arr = {}
|
|
for word, cnt in pairs(counts) do
|
|
table.insert(arr, { w = word, c = cnt })
|
|
end
|
|
table.sort(arr, function(a, b)
|
|
if a.c == b.c then return a.w < b.w end
|
|
return a.c > b.c
|
|
end)
|
|
|
|
local res = {}
|
|
for i = 1, math.min(limit or 12, #arr) do
|
|
table.insert(res, arr[i].w)
|
|
end
|
|
return res
|
|
end
|
|
|
|
local function safe_array(arr)
|
|
if type(arr) ~= 'table' then return {} end
|
|
return arr
|
|
end
|
|
|
|
local function build_message_summary(task, sel_part, opts)
|
|
local model_cfg = { max_tokens = 256 }
|
|
local content_tbl
|
|
if sel_part then
|
|
local itbl = llm_common.build_llm_input(task, { max_tokens = model_cfg.max_tokens })
|
|
content_tbl = itbl
|
|
else
|
|
content_tbl = llm_common.build_llm_input(task, { max_tokens = model_cfg.max_tokens })
|
|
end
|
|
if type(content_tbl) ~= 'table' then
|
|
return nil
|
|
end
|
|
local txt = content_tbl.text or ''
|
|
local summary_max = opts.summary_max_chars or DEFAULTS.summary_max_chars
|
|
local msg = {
|
|
from = content_tbl.from or ((task:get_from('smtp') or EMPTY)[1] or EMPTY)['addr'],
|
|
subject = content_tbl.subject or '',
|
|
ts = now(),
|
|
keywords = extract_keywords(sel_part, 12),
|
|
}
|
|
if txt and #txt > 0 then
|
|
msg.text = truncate_text(txt, summary_max)
|
|
end
|
|
return msg
|
|
end
|
|
|
|
local function trim_messages(recent_messages, max_messages, min_ts)
|
|
local res = {}
|
|
for _, m in ipairs(recent_messages) do
|
|
if not min_ts or (m.ts and m.ts >= min_ts) then
|
|
table.insert(res, m)
|
|
end
|
|
end
|
|
table.sort(res, function(a, b)
|
|
local ta = a.ts or 0
|
|
local tb = b.ts or 0
|
|
return ta > tb
|
|
end)
|
|
while #res > max_messages do
|
|
table.remove(res)
|
|
end
|
|
return res
|
|
end
|
|
|
|
local function recompute_top_senders(sender_counts, limit_n)
|
|
local arr = {}
|
|
for s, c in pairs(sender_counts or {}) do
|
|
table.insert(arr, { s = s, c = c })
|
|
end
|
|
table.sort(arr, function(a, b)
|
|
if a.c == b.c then return a.s < b.s end
|
|
return a.c > b.c
|
|
end)
|
|
local res = {}
|
|
for i = 1, math.min(limit_n, #arr) do
|
|
table.insert(res, arr[i].s)
|
|
end
|
|
return res
|
|
end
|
|
|
|
local function ensure_defaults(ctx)
|
|
if type(ctx) ~= 'table' then ctx = {} end
|
|
ctx.recent_messages = safe_array(ctx.recent_messages)
|
|
ctx.top_senders = safe_array(ctx.top_senders)
|
|
ctx.flagged_phrases = safe_array(ctx.flagged_phrases)
|
|
ctx.last_spam_labels = safe_array(ctx.last_spam_labels)
|
|
ctx.sender_counts = ctx.sender_counts or {}
|
|
return ctx
|
|
end
|
|
|
|
local function contains_ci(haystack, needle)
|
|
if not haystack or not needle then return false end
|
|
return string.find(string.lower(haystack), string.lower(needle), 1, true) ~= nil
|
|
end
|
|
|
|
local function update_flagged_phrases(ctx, text_part, opts)
|
|
local phrases = opts.flagged_phrases or DEFAULTS.flagged_phrases
|
|
if not text_part then return end
|
|
local words = text_part:get_words('norm')
|
|
if not words or #words == 0 then return end
|
|
local text_lower = table.concat(words, ' ')
|
|
for _, p in ipairs(phrases) do
|
|
if contains_ci(text_lower, p) then
|
|
local present = false
|
|
for _, e in ipairs(ctx.flagged_phrases) do
|
|
if string.lower(e) == string.lower(p) then
|
|
present = true
|
|
break
|
|
end
|
|
end
|
|
if not present then
|
|
table.insert(ctx.flagged_phrases, p)
|
|
end
|
|
end
|
|
end
|
|
end
|
|
|
|
local function to_bullets_recent(recent_messages, limit_n)
|
|
local lines = {}
|
|
local n = math.min(limit_n, #recent_messages)
|
|
for i = 1, n do
|
|
local m = recent_messages[i]
|
|
local from = m.from or m.sender or ''
|
|
local subj = m.subject or ''
|
|
table.insert(lines, string.format('- %s: %s', from, subj))
|
|
end
|
|
return table.concat(lines, '\n')
|
|
end
|
|
|
|
local function join_list(arr)
|
|
if not arr or #arr == 0 then return '' end
|
|
return table.concat(arr, ', ')
|
|
end
|
|
|
|
local function format_context_prompt(ctx, task)
|
|
local bullets = to_bullets_recent(ctx.recent_messages or {}, 5)
|
|
local top_senders = join_list(ctx.top_senders or {})
|
|
local flagged = join_list(ctx.flagged_phrases or {})
|
|
local spam_types = join_list(ctx.last_spam_labels or {})
|
|
|
|
-- Check if current sender is known
|
|
local sender_frequency = 'new'
|
|
if task then
|
|
local from = ((task:get_from('smtp') or EMPTY)[1] or EMPTY)['addr']
|
|
if from and ctx.sender_counts and ctx.sender_counts[from] then
|
|
local count = ctx.sender_counts[from]
|
|
if count >= 10 then
|
|
sender_frequency = 'frequent'
|
|
elseif count >= 3 then
|
|
sender_frequency = 'known'
|
|
else
|
|
sender_frequency = 'occasional'
|
|
end
|
|
end
|
|
end
|
|
|
|
local parts = {}
|
|
table.insert(parts, 'User recent correspondence summary:')
|
|
if bullets ~= '' then
|
|
table.insert(parts, bullets)
|
|
else
|
|
table.insert(parts, '- (no recent messages)')
|
|
end
|
|
table.insert(parts, string.format('Top senders in mailbox: %s', top_senders))
|
|
if flagged ~= '' then
|
|
table.insert(parts, string.format('Recently flagged suspicious phrases: %s', flagged))
|
|
end
|
|
if spam_types ~= '' then
|
|
table.insert(parts, string.format('Last detected spam types: %s', spam_types))
|
|
end
|
|
table.insert(parts, string.format('Current sender: %s', sender_frequency))
|
|
|
|
return table.concat(parts, '\n')
|
|
end
|
|
|
|
function M.fetch(task, redis_params, opts, callback, debug_module)
|
|
local N = debug_module or 'llm_context'
|
|
opts = lua_util.override_defaults(DEFAULTS, opts or {})
|
|
if not opts.enabled then
|
|
callback(nil, nil, nil)
|
|
return
|
|
end
|
|
if not redis_params then
|
|
callback('no redis', nil, nil)
|
|
return
|
|
end
|
|
|
|
local ident = compute_identity(task, opts, N)
|
|
if not ident then
|
|
lua_util.debugm(N, task, 'no identity computed, skipping context')
|
|
callback('no identity', nil, nil)
|
|
return
|
|
end
|
|
|
|
lua_util.debugm(N, task, 'fetching context for %s: %s',
|
|
tostring(ident.scope), tostring(ident.identity))
|
|
|
|
local function on_get(err, data)
|
|
if err then
|
|
rspamd_logger.errx(task, 'llm_context: get failed: %s', tostring(err))
|
|
callback(err, nil, nil)
|
|
return
|
|
end
|
|
local ctx
|
|
if data then
|
|
lua_util.debugm(N, task, 'got context data from redis, parsing')
|
|
ctx = ensure_defaults(select(1, parse_json(data)) or {})
|
|
else
|
|
lua_util.debugm(N, task, 'no context data in redis, using empty')
|
|
ctx = ensure_defaults({})
|
|
end
|
|
|
|
-- Check if context has enough messages for warm-up
|
|
local min_msgs = opts.min_messages or DEFAULTS.min_messages
|
|
local msg_count = #(ctx.recent_messages or {})
|
|
if msg_count < min_msgs then
|
|
lua_util.debugm(N, task, 'context has only %s messages (min: %s), not injecting into prompt',
|
|
tostring(msg_count), tostring(min_msgs))
|
|
callback(nil, ctx, nil) -- return ctx but no prompt snippet
|
|
return
|
|
end
|
|
|
|
lua_util.debugm(N, task, 'context warm-up OK: %s messages, generating snippet',
|
|
tostring(msg_count))
|
|
local prompt_snippet = format_context_prompt(ctx, task)
|
|
callback(nil, ctx, prompt_snippet)
|
|
end
|
|
|
|
local ok = lua_redis.redis_make_request(task, redis_params, ident.key, false, on_get, 'GET', { ident.key })
|
|
if not ok then
|
|
callback('request not scheduled', nil, nil)
|
|
end
|
|
end
|
|
|
|
function M.update_after_classification(task, redis_params, opts, result, sel_part, debug_module)
|
|
local N = debug_module or 'llm_context'
|
|
opts = lua_util.override_defaults(DEFAULTS, opts or {})
|
|
if not opts.enabled then return end
|
|
if not redis_params then return end
|
|
|
|
local ident = compute_identity(task, opts, N)
|
|
if not ident then return end
|
|
|
|
local function on_get(err, data)
|
|
if err then
|
|
rspamd_logger.errx(task, 'llm_context: get for update failed: %s', tostring(err))
|
|
return
|
|
end
|
|
lua_util.debugm(N, task, 'updating context for %s: %s',
|
|
tostring(ident.scope), tostring(ident.identity))
|
|
local ctx = ensure_defaults(select(1, parse_json(data)) or {})
|
|
|
|
local msg = build_message_summary(task, sel_part, opts)
|
|
if msg then
|
|
table.insert(ctx.recent_messages, 1, msg)
|
|
local sender = msg.from or ''
|
|
if sender ~= '' then
|
|
ctx.sender_counts[sender] = (ctx.sender_counts[sender] or 0) + 1
|
|
end
|
|
update_flagged_phrases(ctx, sel_part, opts)
|
|
end
|
|
|
|
local min_ts = now() - to_seconds(opts.message_ttl)
|
|
ctx.recent_messages = trim_messages(ctx.recent_messages, opts.max_messages, min_ts)
|
|
ctx.top_senders = recompute_top_senders(ctx.sender_counts, opts.top_senders)
|
|
|
|
local labels = {}
|
|
if result then
|
|
if result.categories and type(result.categories) == 'table' then
|
|
for _, c in ipairs(result.categories) do table.insert(labels, tostring(c)) end
|
|
end
|
|
if result.probability then
|
|
if result.probability > 0.5 then
|
|
table.insert(labels, 'spam')
|
|
else
|
|
table.insert(labels, 'ham')
|
|
end
|
|
end
|
|
end
|
|
for _, l in ipairs(labels) do table.insert(ctx.last_spam_labels, 1, l) end
|
|
while #ctx.last_spam_labels > opts.last_labels_count do table.remove(ctx.last_spam_labels) end
|
|
|
|
ctx.updated_at = now()
|
|
|
|
local payload = encode_json(ctx)
|
|
local ttl = to_seconds(opts.ttl)
|
|
local expire_at = now() + ttl
|
|
|
|
-- Log what we're storing in context
|
|
lua_util.debugm(N, task,
|
|
'storing context for %s: %s messages, labels=%s, top_senders=%s, flagged=%s, payload_size=%s bytes, expiring at %s',
|
|
tostring(ident.identity or '(none)'),
|
|
tostring(#ctx.recent_messages),
|
|
table.concat(ctx.last_spam_labels or {}, ','),
|
|
table.concat(ctx.top_senders or {}, ','),
|
|
table.concat(ctx.flagged_phrases or {}, ','),
|
|
tostring(#payload),
|
|
os.date('%Y-%m-%d %H:%M:%S', expire_at))
|
|
|
|
if msg then
|
|
lua_util.debugm(N, task,
|
|
'added message: from=%s, subject=%s, keywords=%s',
|
|
tostring(msg.from or '(none)'),
|
|
tostring(msg.subject or '(none)'),
|
|
table.concat(msg.keywords or {}, ','))
|
|
end
|
|
|
|
local function on_set(set_err)
|
|
if set_err then
|
|
rspamd_logger.errx(task, 'llm_context: set failed: %s', tostring(set_err))
|
|
else
|
|
lua_util.debugm(N, task, 'context saved to redis: key=%s, ttl=%s, expiring at %s',
|
|
tostring(ident.key), tostring(ttl), os.date('%Y-%m-%d %H:%M:%S', expire_at))
|
|
end
|
|
end
|
|
local ok = lua_redis.redis_make_request(task, redis_params, ident.key, true, on_set, 'SETEX',
|
|
{ ident.key, tostring(ttl), payload })
|
|
if not ok then
|
|
rspamd_logger.errx(task, 'llm_context: set request was not scheduled')
|
|
end
|
|
end
|
|
|
|
local ok = lua_redis.redis_make_request(task, redis_params, ident.key, false, on_get, 'GET', { ident.key })
|
|
if not ok then
|
|
rspamd_logger.errx(task, 'llm_context: initial get request was not scheduled')
|
|
end
|
|
end
|
|
|
|
return M
|