Browse Source

[Minor] Lua_tensor: Add deserialisation

pull/3462/head
Vsevolod Stakhov 5 years ago
parent
commit
6a1692499f
  1. 56
      src/lua/lua_tensor.c

56
src/lua/lua_tensor.c

@ -396,10 +396,62 @@ lua_tensor_mul (lua_State *L)
static gint
lua_tensor_load (lua_State *L)
{
struct rspamd_lua_tensor *t = lua_check_tensor (L, 1);
const guchar *data;
gsize sz;
if (t) {
if (lua_type (L, 1) == LUA_TUSERDATA) {
struct rspamd_lua_text *t = lua_check_text (L, 1);
if (!t) {
return luaL_error (L, "invalid argument");
}
data = (const guchar *)t->start;
sz = t->len;
}
else {
data = (const guchar *)lua_tolstring (L, 1, &sz);
}
if (sz >= sizeof (gint) * 4) {
int ndims, nelts, dims[2];
memcpy (&ndims, data, sizeof (int));
memcpy (&nelts, data + sizeof (int), sizeof (int));
memcpy (dims, data + sizeof (int) * 2, sizeof (int) * 2);
if (sz == nelts * sizeof (rspamd_tensor_num_t) + sizeof (int) * 4) {
if (ndims == 1) {
if (nelts == dims[0]) {
struct rspamd_lua_tensor *t = lua_newtensor (L, ndims, dims, false);
memcpy (t->data, data + sizeof (int) * 4, nelts *
sizeof (rspamd_tensor_num_t));
}
else {
return luaL_error (L, "invalid argument: bad dims: %d x %d != %d",
dims[0], 1, nelts);
}
}
else if (ndims == 2) {
if (nelts == dims[0] * dims[1]) {
struct rspamd_lua_tensor *t = lua_newtensor (L, ndims, dims, false);
memcpy (t->data, data + sizeof (int) * 4, nelts *
sizeof (rspamd_tensor_num_t));
}
else {
return luaL_error (L, "invalid argument: bad dims: %d x %d != %d",
dims[0], dims[1], nelts);
}
}
else {
return luaL_error (L, "invalid argument: bad ndims: %d", ndims);
}
}
else {
return luaL_error (L, "invalid size: %d, %d required, %d elts", (int)sz,
(int)(nelts * sizeof (rspamd_tensor_num_t) + sizeof (int) * 4),
nelts);
}
}
else {
return luaL_error (L, "invalid arguments");

Loading…
Cancel
Save