|
|
|
@ -396,10 +396,62 @@ lua_tensor_mul (lua_State *L) |
|
|
|
static gint |
|
|
|
lua_tensor_load (lua_State *L) |
|
|
|
{ |
|
|
|
struct rspamd_lua_tensor *t = lua_check_tensor (L, 1); |
|
|
|
const guchar *data; |
|
|
|
gsize sz; |
|
|
|
|
|
|
|
if (t) { |
|
|
|
if (lua_type (L, 1) == LUA_TUSERDATA) { |
|
|
|
struct rspamd_lua_text *t = lua_check_text (L, 1); |
|
|
|
|
|
|
|
if (!t) { |
|
|
|
return luaL_error (L, "invalid argument"); |
|
|
|
} |
|
|
|
|
|
|
|
data = (const guchar *)t->start; |
|
|
|
sz = t->len; |
|
|
|
} |
|
|
|
else { |
|
|
|
data = (const guchar *)lua_tolstring (L, 1, &sz); |
|
|
|
} |
|
|
|
|
|
|
|
if (sz >= sizeof (gint) * 4) { |
|
|
|
int ndims, nelts, dims[2]; |
|
|
|
|
|
|
|
memcpy (&ndims, data, sizeof (int)); |
|
|
|
memcpy (&nelts, data + sizeof (int), sizeof (int)); |
|
|
|
memcpy (dims, data + sizeof (int) * 2, sizeof (int) * 2); |
|
|
|
|
|
|
|
if (sz == nelts * sizeof (rspamd_tensor_num_t) + sizeof (int) * 4) { |
|
|
|
if (ndims == 1) { |
|
|
|
if (nelts == dims[0]) { |
|
|
|
struct rspamd_lua_tensor *t = lua_newtensor (L, ndims, dims, false); |
|
|
|
memcpy (t->data, data + sizeof (int) * 4, nelts * |
|
|
|
sizeof (rspamd_tensor_num_t)); |
|
|
|
} |
|
|
|
else { |
|
|
|
return luaL_error (L, "invalid argument: bad dims: %d x %d != %d", |
|
|
|
dims[0], 1, nelts); |
|
|
|
} |
|
|
|
} |
|
|
|
else if (ndims == 2) { |
|
|
|
if (nelts == dims[0] * dims[1]) { |
|
|
|
struct rspamd_lua_tensor *t = lua_newtensor (L, ndims, dims, false); |
|
|
|
memcpy (t->data, data + sizeof (int) * 4, nelts * |
|
|
|
sizeof (rspamd_tensor_num_t)); |
|
|
|
} |
|
|
|
else { |
|
|
|
return luaL_error (L, "invalid argument: bad dims: %d x %d != %d", |
|
|
|
dims[0], dims[1], nelts); |
|
|
|
} |
|
|
|
} |
|
|
|
else { |
|
|
|
return luaL_error (L, "invalid argument: bad ndims: %d", ndims); |
|
|
|
} |
|
|
|
} |
|
|
|
else { |
|
|
|
return luaL_error (L, "invalid size: %d, %d required, %d elts", (int)sz, |
|
|
|
(int)(nelts * sizeof (rspamd_tensor_num_t) + sizeof (int) * 4), |
|
|
|
nelts); |
|
|
|
} |
|
|
|
} |
|
|
|
else { |
|
|
|
return luaL_error (L, "invalid arguments"); |
|
|
|
|