diff --git a/plugins/openresty/waf/config.lua b/plugins/openresty/waf/config.lua index 8c1b703ff..0e9cb3056 100644 --- a/plugins/openresty/waf/config.lua +++ b/plugins/openresty/waf/config.lua @@ -3,6 +3,7 @@ local lfs = require "lfs" local utils = require "utils" local cjson = require "cjson" + local read_rule = file_utils.read_rule local read_file2string = file_utils.read_file2string local read_file2table = file_utils.read_file2table @@ -81,10 +82,11 @@ local function load_ip_group() ip_group_list[entry] = group_value end end - local waf_dict = ngx.shared.waf - local ok , err = waf_dict:set("ip_group_list", cjson.encode(ip_group_list)) - if not ok then - ngx.log(ngx.ERR, "Failed to set ip_group_list",err) + local ok, err = cache:set("ip_group_list", { + ipc_shm = "ipc_shared_dict", + },ip_group_list) + if not ok then + ngx.log(ngx.ERR, "Failed to set config",err) end end @@ -136,20 +138,22 @@ function _M.load_config_file() init_sites_config() load_ip_group() - local waf_dict = ngx.shared.waf - local ok,err = waf_dict:set("config", cjson.encode(config)) - if not ok then + local ok, err = cache:set("config", { + ipc_shm = "ipc_shared_dict", + },config) + if not ok then ngx.log(ngx.ERR, "Failed to set config",err) end end local function get_config() - local waf_dict = ngx.shared.waf - local cache_config = waf_dict:get("config") + local cache_config = cache:get("config", { + ipc_shm = "ipc_shared_dict", + }) if not cache_config then return config end - return cjson.decode(cache_config) + return cache_config end function _M.get_site_config(website_key) diff --git a/plugins/openresty/waf/init.lua b/plugins/openresty/waf/init.lua index 2d1aa6afa..e856ce256 100644 --- a/plugins/openresty/waf/init.lua +++ b/plugins/openresty/waf/init.lua @@ -1,5 +1,16 @@ local db = require "db" local config = require "config" +local mlcache = require "resty.mlcache" + +local cache, err = mlcache.new("config", "waf", { + lru_size = 1000, + ipc_shm = "ipc_shared_dict", +}) +if not cache then + error("could not create mlcache: " .. err) +end +_G.cache = cache + config.load_config_file() db.init() @@ -7,3 +18,4 @@ db.init() + diff --git a/plugins/openresty/waf/lib/lib.lua b/plugins/openresty/waf/lib/lib.lua index a789be0e2..f1e8c7924 100644 --- a/plugins/openresty/waf/lib/lib.lua +++ b/plugins/openresty/waf/lib/lib.lua @@ -122,13 +122,13 @@ local function match_ip(ip_rule, ip, ipn) if ip_rule.ipGroup == nil or ip_rule.ipGroup == "" then return false end - local waf_dict = ngx.shared.waf - local ip_group_list = waf_dict:get("ip_group_list") + local ip_group_list = cache:get("ip_group_list", { + ipc_shm = "ipc_shared_dict", + }) if ip_group_list == nil then return false end - local ip_group_obj = cjson.decode(ip_group_list) - local ip_group = ip_group_obj[ip_rule.ipGroup] + local ip_group = ip_group_list[ip_rule.ipGroup] if ip_group == nil then return false end diff --git a/plugins/openresty/waf/lib/resty/mlcache.lua b/plugins/openresty/waf/lib/resty/mlcache.lua new file mode 100644 index 000000000..0c1e5b2e2 --- /dev/null +++ b/plugins/openresty/waf/lib/resty/mlcache.lua @@ -0,0 +1,1441 @@ +-- vim: ts=4 sts=4 sw=4 et: + +local new_tab = require "table.new" +local lrucache = require "resty.lrucache" +local resty_lock = require "resty.lock" +local tablepool +do + local pok + pok, tablepool = pcall(require, "tablepool") + if not pok then + -- fallback for OpenResty < 1.15.8.1 + tablepool = { + fetch = function(_, narr, nrec) + return new_tab(narr, nrec) + end, + release = function(_, _, _) + -- nop (obj will be subject to GC) + end, + } + end +end +local codec +do + local pok + pok, codec = pcall(require, "string.buffer") + if not pok then + codec = require "cjson" + end +end + + +local now = ngx.now +local min = math.min +local ceil = math.ceil +local fmt = string.format +local sub = string.sub +local find = string.find +local type = type +local pcall = pcall +local xpcall = xpcall +local traceback = debug.traceback +local error = error +local tostring = tostring +local tonumber = tonumber +local encode = codec.encode +local decode = codec.decode +local thread_spawn = ngx.thread.spawn +local thread_wait = ngx.thread.wait +local setmetatable = setmetatable +local shared = ngx.shared +local ngx_log = ngx.log +local WARN = ngx.WARN +local ERR = ngx.ERR + + +local CACHE_MISS_SENTINEL_LRU = {} +local LOCK_KEY_PREFIX = "lua-resty-mlcache:lock:" +local LRU_INSTANCES = setmetatable({}, { __mode = "v" }) +local SHM_SET_DEFAULT_TRIES = 3 +local BULK_DEFAULT_CONCURRENCY = 3 + + +local TYPES_LOOKUP = { + number = 1, + boolean = 2, + string = 3, + table = 4, +} + + +local SHM_FLAGS = { + stale = 0x00000001, +} + + +local marshallers = { + shm_value = function(str_value, value_type, at, ttl) + return fmt("%d:%f:%f:%s", value_type, at, ttl, str_value) + end, + + shm_nil = function(at, ttl) + return fmt("0:%f:%f:", at, ttl) + end, + + [1] = function(number) -- number + return tostring(number) + end, + + [2] = function(bool) -- boolean + return bool and "true" or "false" + end, + + [3] = function(str) -- string + return str + end, + + [4] = function(t) -- table + local pok, str = pcall(encode, t) + if not pok then + return nil, "could not encode table value: " .. str + end + + return str + end, +} + + +local unmarshallers = { + shm_value = function(marshalled) + -- split our shm marshalled value by the hard-coded ":" tokens + -- "type:at:ttl:value" + -- 1:1501831735.052000:0.500000:123 + local ttl_last = find(marshalled, ":", 21, true) - 1 + + local value_type = sub(marshalled, 1, 1) -- n:... + local at = sub(marshalled, 3, 19) -- n:1501831160 + local ttl = sub(marshalled, 21, ttl_last) + local str_value = sub(marshalled, ttl_last + 2) + + return str_value, tonumber(value_type), tonumber(at), tonumber(ttl) + end, + + [0] = function() -- nil + return nil + end, + + [1] = function(str) -- number + return tonumber(str) + end, + + [2] = function(str) -- boolean + return str == "true" + end, + + [3] = function(str) -- string + return str + end, + + [4] = function(str) -- table + local pok, t = pcall(decode, str) + if not pok then + return nil, "could not decode table value: " .. t + end + + return t + end, +} + + +local function rebuild_lru(self) + if self.lru then + if self.lru.flush_all then + self.lru:flush_all() + return + end + + -- fallback for OpenResty < 1.13.6.2 + -- Invalidate the entire LRU by GC-ing it. + LRU_INSTANCES[self.name] = nil + self.lru = nil + end + + -- Several mlcache instances can have the same name and hence, the same + -- lru instance. We need to GC such LRU instance when all mlcache instances + -- using them are GC'ed. We do this with a weak table. + local lru = LRU_INSTANCES[self.name] + if not lru then + lru = lrucache.new(self.lru_size) + LRU_INSTANCES[self.name] = lru + end + + self.lru = lru +end + + +local _M = { + _VERSION = "2.6.1", + _AUTHOR = "Thibault Charbonnier", + _LICENSE = "MIT", + _URL = "https://github.com/thibaultcha/lua-resty-mlcache", +} +local mt = { __index = _M } + + +function _M.new(name, shm, opts) + if type(name) ~= "string" then + error("name must be a string", 2) + end + + if type(shm) ~= "string" then + error("shm must be a string", 2) + end + + if opts ~= nil then + if type(opts) ~= "table" then + error("opts must be a table", 2) + end + + if opts.lru_size ~= nil and type(opts.lru_size) ~= "number" then + error("opts.lru_size must be a number", 2) + end + + if opts.ttl ~= nil then + if type(opts.ttl) ~= "number" then + error("opts.ttl must be a number", 2) + end + + if opts.ttl < 0 then + error("opts.ttl must be >= 0", 2) + end + end + + if opts.neg_ttl ~= nil then + if type(opts.neg_ttl) ~= "number" then + error("opts.neg_ttl must be a number", 2) + end + + if opts.neg_ttl < 0 then + error("opts.neg_ttl must be >= 0", 2) + end + end + + if opts.resurrect_ttl ~= nil then + if type(opts.resurrect_ttl) ~= "number" then + error("opts.resurrect_ttl must be a number", 2) + end + + if opts.resurrect_ttl < 0 then + error("opts.resurrect_ttl must be >= 0", 2) + end + end + + if opts.resty_lock_opts ~= nil + and type(opts.resty_lock_opts) ~= "table" + then + error("opts.resty_lock_opts must be a table", 2) + end + + if opts.ipc_shm ~= nil and type(opts.ipc_shm) ~= "string" then + error("opts.ipc_shm must be a string", 2) + end + + if opts.ipc ~= nil then + if opts.ipc_shm then + error("cannot specify both of opts.ipc_shm and opts.ipc", 2) + end + + if type(opts.ipc) ~= "table" then + error("opts.ipc must be a table", 2) + end + + if type(opts.ipc.register_listeners) ~= "function" then + error("opts.ipc.register_listeners must be a function", 2) + end + + if type(opts.ipc.broadcast) ~= "function" then + error("opts.ipc.broadcast must be a function", 2) + end + + if opts.ipc.poll ~= nil and type(opts.ipc.poll) ~= "function" then + error("opts.ipc.poll must be a function", 2) + end + end + + if opts.l1_serializer ~= nil + and type(opts.l1_serializer) ~= "function" + then + error("opts.l1_serializer must be a function", 2) + end + + if opts.shm_set_tries ~= nil then + if type(opts.shm_set_tries) ~= "number" then + error("opts.shm_set_tries must be a number", 2) + end + + if opts.shm_set_tries < 1 then + error("opts.shm_set_tries must be >= 1", 2) + end + end + + if opts.shm_miss ~= nil and type(opts.shm_miss) ~= "string" then + error("opts.shm_miss must be a string", 2) + end + + if opts.shm_locks ~= nil and type(opts.shm_locks) ~= "string" then + error("opts.shm_locks must be a string", 2) + end + else + opts = {} + end + + local dict = shared[shm] + if not dict then + return nil, "no such lua_shared_dict: " .. shm + end + + local dict_miss + if opts.shm_miss then + dict_miss = shared[opts.shm_miss] + if not dict_miss then + return nil, "no such lua_shared_dict for opts.shm_miss: " + .. opts.shm_miss + end + end + + if opts.shm_locks then + local dict_locks = shared[opts.shm_locks] + if not dict_locks then + return nil, "no such lua_shared_dict for opts.shm_locks: " + .. opts.shm_locks + end + end + + local self = { + name = name, + dict = dict, + shm = shm, + dict_miss = dict_miss, + shm_miss = opts.shm_miss, + shm_locks = opts.shm_locks or shm, + ttl = opts.ttl or 30, + neg_ttl = opts.neg_ttl or 5, + resurrect_ttl = opts.resurrect_ttl, + lru_size = opts.lru_size or 100, + resty_lock_opts = opts.resty_lock_opts, + l1_serializer = opts.l1_serializer, + shm_set_tries = opts.shm_set_tries or SHM_SET_DEFAULT_TRIES, + debug = opts.debug, + } + + if opts.ipc_shm or opts.ipc then + self.events = { + ["invalidation"] = { + channel = fmt("mlcache:invalidations:%s", name), + handler = function(key) + self.lru:delete(key) + end, + }, + ["purge"] = { + channel = fmt("mlcache:purge:%s", name), + handler = function() + rebuild_lru(self) + end, + } + } + + if opts.ipc_shm then + local mlcache_ipc = require "resty.mlcache.ipc" + + local ipc, err = mlcache_ipc.new(opts.ipc_shm, opts.debug) + if not ipc then + return nil, "failed to initialize mlcache IPC " .. + "(could not instantiate mlcache.ipc): " .. err + end + + for _, ev in pairs(self.events) do + ipc:subscribe(ev.channel, ev.handler) + end + + self.broadcast = function(channel, data) + return ipc:broadcast(channel, data) + end + + self.poll = function(timeout) + return ipc:poll(timeout) + end + + self.ipc = ipc + + else + -- opts.ipc + local ok, err = opts.ipc.register_listeners(self.events) + if not ok and err ~= nil then + return nil, "failed to initialize custom IPC " .. + "(opts.ipc.register_listeners returned an error): " + .. err + end + + self.broadcast = opts.ipc.broadcast + self.poll = opts.ipc.poll + + self.ipc = true + end + end + + if opts.lru then + self.lru = opts.lru + + else + rebuild_lru(self) + end + + return setmetatable(self, mt) +end + + +local function l1_serialize(value, l1_serializer) + if value ~= nil and l1_serializer then + local ok, err + ok, value, err = pcall(l1_serializer, value) + if not ok then + return nil, "l1_serializer threw an error: " .. value + end + + if err then + return nil, err + end + + if value == nil then + return nil, "l1_serializer returned a nil value" + end + end + + return value +end + + +local function set_lru(self, key, value, ttl, neg_ttl, l1_serializer) + local value, err = l1_serialize(value, l1_serializer) + if err then + return nil, err + end + + if value == nil then + value = CACHE_MISS_SENTINEL_LRU + ttl = neg_ttl + end + + if ttl == 0 then + -- indefinite ttl for lua-resty-lrucache is 'nil' + ttl = nil + end + + self.lru:set(key, value, ttl) + + return value +end + + +local function marshall_for_shm(value, ttl, neg_ttl) + local at = now() + + if value == nil then + return marshallers.shm_nil(at, neg_ttl), nil, true -- is_nil + end + + -- serialize insertion time + Lua types for shm storage + + local value_type = TYPES_LOOKUP[type(value)] + + if not marshallers[value_type] then + error("cannot cache value of type " .. type(value)) + end + + local str_marshalled, err = marshallers[value_type](value) + if not str_marshalled then + return nil, "could not serialize value for lua_shared_dict insertion: " + .. err + end + + return marshallers.shm_value(str_marshalled, value_type, at, ttl) +end + + +local function unmarshall_from_shm(shm_v) + local str_serialized, value_type, at, ttl = unmarshallers.shm_value(shm_v) + + local value, err = unmarshallers[value_type](str_serialized) + if err then + return nil, err + end + + return value, nil, at, ttl +end + + +local function set_shm(self, shm_key, value, ttl, neg_ttl, flags, shm_set_tries, + throw_no_mem) + local shm_value, err, is_nil = marshall_for_shm(value, ttl, neg_ttl) + if not shm_value then + return nil, err + end + + local shm = self.shm + local dict = self.dict + + if is_nil then + ttl = neg_ttl + + if self.dict_miss then + shm = self.shm_miss + dict = self.dict_miss + end + end + + -- we will call `set()` N times to work around potential shm fragmentation. + -- when the shm is full, it will only evict about 30 to 90 items (via + -- LRU), which could lead to a situation where `set()` still does not + -- have enough memory to store the cached value, in which case we + -- try again to try to trigger more LRU evictions. + + local tries = 0 + local ok, err + + while tries < shm_set_tries do + tries = tries + 1 + + ok, err = dict:set(shm_key, shm_value, ttl, flags or 0) + if ok or err and err ~= "no memory" then + break + end + end + + if not ok then + if err ~= "no memory" or throw_no_mem then + return nil, "could not write to lua_shared_dict '" .. shm + .. "': " .. err + end + + ngx_log(WARN, "could not write to lua_shared_dict '", + shm, "' after ", tries, " tries (no memory), ", + "it is either fragmented or cannot allocate more ", + "memory, consider increasing 'opts.shm_set_tries'") + end + + return true +end + + +local function set_shm_set_lru(self, key, shm_key, value, ttl, neg_ttl, flags, + shm_set_tries, l1_serializer, throw_no_mem) + + local ok, err = set_shm(self, shm_key, value, ttl, neg_ttl, flags, + shm_set_tries, throw_no_mem) + if not ok then + return nil, err + end + + return set_lru(self, key, value, ttl, neg_ttl, l1_serializer) +end + + +local function get_shm_set_lru(self, key, shm_key, l1_serializer) + local v, shmerr, went_stale = self.dict:get_stale(shm_key) + if v == nil and shmerr then + -- shmerr can be 'flags' upon successful get_stale() calls, so we + -- also check v == nil + return nil, "could not read from lua_shared_dict: " .. shmerr + end + + if self.shm_miss and v == nil then + -- if we cache misses in another shm, maybe it is there + v, shmerr, went_stale = self.dict_miss:get_stale(shm_key) + if v == nil and shmerr then + -- shmerr can be 'flags' upon successful get_stale() calls, so we + -- also check v == nil + return nil, "could not read from lua_shared_dict: " .. shmerr + end + end + + if v ~= nil then + local value, err, at, ttl = unmarshall_from_shm(v) + if err then + return nil, "could not deserialize value after lua_shared_dict " .. + "retrieval: " .. err + end + + if went_stale then + value, err = l1_serialize(value, l1_serializer) + if err then + return nil, err + end + + return value, nil, went_stale + end + + -- 'shmerr' is 'flags' on :get_stale() success + local is_stale = shmerr == SHM_FLAGS.stale + + local remaining_ttl + if ttl == 0 then + -- indefinite ttl, keep '0' as it means 'forever' + remaining_ttl = 0 + + else + -- compute elapsed time to get remaining ttl for LRU caching + remaining_ttl = ttl - (now() - at) + + if remaining_ttl <= 0 then + -- value has less than 1ms of lifetime in the shm, avoid + -- setting it in LRU which would be wasteful and could + -- indefinitely cache the value when ttl == 0 + value, err = l1_serialize(value, l1_serializer) + if err then + return nil, err + end + + return value, nil, nil, is_stale + end + end + + value, err = set_lru(self, key, value, remaining_ttl, remaining_ttl, + l1_serializer) + if err then + return nil, err + end + + return value, nil, nil, is_stale + end +end + + +local function check_opts(self, opts) + local ttl + local neg_ttl + local resurrect_ttl + local l1_serializer + local shm_set_tries + local resty_lock_opts + + if opts ~= nil then + if type(opts) ~= "table" then + error("opts must be a table", 3) + end + + ttl = opts.ttl + if ttl ~= nil then + if type(ttl) ~= "number" then + error("opts.ttl must be a number", 3) + end + + if ttl < 0 then + error("opts.ttl must be >= 0", 3) + end + end + + neg_ttl = opts.neg_ttl + if neg_ttl ~= nil then + if type(neg_ttl) ~= "number" then + error("opts.neg_ttl must be a number", 3) + end + + if neg_ttl < 0 then + error("opts.neg_ttl must be >= 0", 3) + end + end + + resurrect_ttl = opts.resurrect_ttl + if resurrect_ttl ~= nil then + if type(resurrect_ttl) ~= "number" then + error("opts.resurrect_ttl must be a number", 3) + end + + if resurrect_ttl < 0 then + error("opts.resurrect_ttl must be >= 0", 3) + end + end + + l1_serializer = opts.l1_serializer + if l1_serializer ~= nil and type(l1_serializer) ~= "function" then + error("opts.l1_serializer must be a function", 3) + end + + shm_set_tries = opts.shm_set_tries + if shm_set_tries ~= nil then + if type(shm_set_tries) ~= "number" then + error("opts.shm_set_tries must be a number", 3) + end + + if shm_set_tries < 1 then + error("opts.shm_set_tries must be >= 1", 3) + end + end + + resty_lock_opts = opts.resty_lock_opts + if resty_lock_opts ~= nil then + if type(resty_lock_opts) ~= "table" then + error("opts.resty_lock_opts must be a table", 3) + end + end + end + + if not ttl then + ttl = self.ttl + end + + if not neg_ttl then + neg_ttl = self.neg_ttl + end + + if not resurrect_ttl then + resurrect_ttl = self.resurrect_ttl + end + + if not l1_serializer then + l1_serializer = self.l1_serializer + end + + if not shm_set_tries then + shm_set_tries = self.shm_set_tries + end + + if not resty_lock_opts then + resty_lock_opts = self.resty_lock_opts + end + + return ttl, neg_ttl, resurrect_ttl, l1_serializer, shm_set_tries, + resty_lock_opts +end + + +local function unlock_and_ret(lock, res, err, hit_lvl) + local ok, lerr = lock:unlock() + if not ok and lerr ~= "unlocked" then + return nil, "could not unlock callback: " .. lerr + end + + return res, err, hit_lvl +end + + +local function run_callback(self, key, shm_key, data, ttl, neg_ttl, + went_stale, l1_serializer, resurrect_ttl, shm_set_tries, rlock_opts, cb, ...) + + local lock, err = resty_lock:new(self.shm_locks, rlock_opts) + if not lock then + return nil, "could not create lock: " .. err + end + + local elapsed, lerr = lock:lock(LOCK_KEY_PREFIX .. shm_key) + if not elapsed and lerr ~= "timeout" then + return nil, "could not acquire callback lock: " .. lerr + end + + do + -- check for another worker's success at running the callback, but + -- do not return data if it is still the same stale value (this is + -- possible if the value was still not evicted between the first + -- get() and this one) + + local data2, err, went_stale2, stale2 = get_shm_set_lru(self, key, + shm_key, + l1_serializer) + if err then + return unlock_and_ret(lock, nil, err) + end + + if data2 ~= nil and not went_stale2 then + -- we got a fresh item from shm: other worker succeeded in running + -- the callback + if data2 == CACHE_MISS_SENTINEL_LRU then + data2 = nil + end + + return unlock_and_ret(lock, data2, nil, stale2 and 4 or 2) + end + end + + -- we are either the 1st worker to hold the lock, or + -- a subsequent worker whose lock has timed out before the 1st one + -- finished to run the callback + + if lerr == "timeout" then + local errmsg = "could not acquire callback lock: timeout" + + -- no stale data nor desire to resurrect it + if not went_stale or not resurrect_ttl then + return nil, errmsg + end + + -- do not resurrect the value here (another worker is running the + -- callback and will either get the new value, or resurrect it for + -- us if the callback fails) + + ngx_log(WARN, errmsg) + + -- went_stale is true, hence the value cannot be set in the LRU + -- cache, and cannot be CACHE_MISS_SENTINEL_LRU + + return data, nil, 4 + end + + -- still not in shm, we are the 1st worker to hold the lock, and thus + -- responsible for running the callback + + local pok, perr, err, new_ttl = xpcall(cb, traceback, ...) + if not pok then + return unlock_and_ret(lock, nil, "callback threw an error: " .. + tostring(perr)) + end + + if err then + -- callback returned nil + err + + -- be resilient in case callbacks return wrong error type + err = tostring(err) + + -- no stale data nor desire to resurrect it + if not went_stale or not resurrect_ttl then + return unlock_and_ret(lock, perr, err) + end + + -- we got 'data' from the shm, even though it is stale + -- 1. log as warn that the callback returned an error + -- 2. resurrect: insert it back into shm if 'resurrect_ttl' + -- 3. signify the staleness with a high hit_lvl of '4' + + ngx_log(WARN, "callback returned an error (", err, ") but stale ", + "value found in shm will be resurrected for ", + resurrect_ttl, "s (resurrect_ttl)") + + local res_data, res_err = set_shm_set_lru(self, key, shm_key, + data, resurrect_ttl, + resurrect_ttl, + SHM_FLAGS.stale, + shm_set_tries, l1_serializer) + if res_err then + ngx_log(WARN, "could not resurrect stale data (", res_err, ")") + end + + if res_data == CACHE_MISS_SENTINEL_LRU then + res_data = nil + end + + return unlock_and_ret(lock, res_data, nil, 4) + end + + -- successful callback run returned 'data, nil, new_ttl?' + + data = perr + + -- override ttl / neg_ttl + + if type(new_ttl) == "number" then + if new_ttl < 0 then + -- bypass cache + return unlock_and_ret(lock, data, nil, 3) + end + + if data == nil then + neg_ttl = new_ttl + + else + ttl = new_ttl + end + end + + data, err = set_shm_set_lru(self, key, shm_key, data, ttl, neg_ttl, nil, + shm_set_tries, l1_serializer) + if err then + return unlock_and_ret(lock, nil, err) + end + + if data == CACHE_MISS_SENTINEL_LRU then + data = nil + end + + -- unlock and return + + return unlock_and_ret(lock, data, nil, 3) +end + + +function _M:get(key, opts, cb, ...) + if type(key) ~= "string" then + error("key must be a string", 2) + end + + if cb ~= nil and type(cb) ~= "function" then + error("callback must be nil or a function", 2) + end + + -- worker LRU cache retrieval + + local data = self.lru:get(key) + if data == CACHE_MISS_SENTINEL_LRU then + return nil, nil, 1 + end + + if data ~= nil then + return data, nil, 1 + end + + -- not in worker's LRU cache, need shm lookup + + -- restrict this key to the current namespace, so we isolate this + -- mlcache instance from potential other instances using the same + -- shm + local namespaced_key = self.name .. key + + -- opts validation + + local ttl, neg_ttl, resurrect_ttl, l1_serializer, shm_set_tries, + rlock_opts = check_opts(self, opts) + + local err, went_stale, is_stale + data, err, went_stale, is_stale = get_shm_set_lru(self, key, namespaced_key, + l1_serializer) + if err then + return nil, err + end + + if data ~= nil and not went_stale then + if data == CACHE_MISS_SENTINEL_LRU then + data = nil + end + + return data, nil, is_stale and 4 or 2 + end + + -- not in shm either + + if cb == nil then + -- no L3 callback, early exit + return nil, nil, -1 + end + + -- L3 callback, single worker to run it + + return run_callback(self, key, namespaced_key, data, ttl, neg_ttl, + went_stale, l1_serializer, resurrect_ttl, + shm_set_tries, rlock_opts, cb, ...) +end + + +do + local function run_thread(self, ops, from, to) + for i = from, to do + local ctx = ops[i] + + ctx.data, ctx.err, ctx.hit_lvl = run_callback(self, ctx.key, + ctx.shm_key, ctx.data, + ctx.ttl, ctx.neg_ttl, + ctx.went_stale, + ctx.l1_serializer, + ctx.resurrect_ttl, + ctx.shm_set_tries, + ctx.rlock_opts, + ctx.cb, ctx.arg) + end + end + + + local bulk_mt = {} + bulk_mt.__index = bulk_mt + + + function _M.new_bulk(n_ops) + local bulk = new_tab((n_ops or 2) * 4, 1) -- 4 slots per op + bulk.n = 0 + + return setmetatable(bulk, bulk_mt) + end + + + function bulk_mt:add(key, opts, cb, arg) + local i = (self.n * 4) + 1 + self[i] = key + self[i + 1] = opts + self[i + 2] = cb + self[i + 3] = arg + self.n = self.n + 1 + end + + + local function bulk_res_iter(res, i) + local idx = i * 3 + 1 + if idx > res.n then + return + end + + i = i + 1 + + local data = res[idx] + local err = res[idx + 1] + local hit_lvl = res[idx + 2] + + return i, data, err, hit_lvl + end + + + function _M.each_bulk_res(res) + if not res.n then + error("res must have res.n field; is this a get_bulk() result?", 2) + end + + return bulk_res_iter, res, 0 + end + + + function _M:get_bulk(bulk, opts) + if type(bulk) ~= "table" then + error("bulk must be a table", 2) + end + + if not bulk.n then + error("bulk must have n field", 2) + end + + if opts then + if type(opts) ~= "table" then + error("opts must be a table", 2) + end + + if opts.concurrency then + if type(opts.concurrency) ~= "number" then + error("opts.concurrency must be a number", 2) + end + + if opts.concurrency <= 0 then + error("opts.concurrency must be > 0", 2) + end + end + end + + local n_bulk = bulk.n * 4 + local res = new_tab(n_bulk - n_bulk / 4, 1) + local res_idx = 1 + + -- only used if running L3 callbacks + local n_cbs = 0 + local cb_ctxs + + -- bulk + -- { "key", opts, cb, arg } + -- + -- res + -- { data, "err", hit_lvl } + + for i = 1, n_bulk, 4 do + local b_key = bulk[i] + local b_opts = bulk[i + 1] + local b_cb = bulk[i + 2] + + if type(b_key) ~= "string" then + error("key at index " .. i .. " must be a string for operation " .. + ceil(i / 4) .. " (got " .. type(b_key) .. ")", 2) + end + + if type(b_cb) ~= "function" then + error("callback at index " .. i + 2 .. " must be a function " .. + "for operation " .. ceil(i / 4) .. " (got " .. type(b_cb) .. + ")", 2) + end + + -- worker LRU cache retrieval + + local data = self.lru:get(b_key) + if data ~= nil then + if data == CACHE_MISS_SENTINEL_LRU then + data = nil + end + + res[res_idx] = data + --res[res_idx + 1] = nil + res[res_idx + 2] = 1 + + else + local pok, ttl, neg_ttl, resurrect_ttl, l1_serializer, + shm_set_tries, rlock_opts = pcall(check_opts, self, b_opts) + if not pok then + -- strip the stacktrace + local err = ttl:match("mlcache%.lua:%d+:%s(.*)") + error("options at index " .. i + 1 .. " for operation " .. + ceil(i / 4) .. " are invalid: " .. err, 2) + end + + -- not in worker's LRU cache, need shm lookup + -- we will prepare a task for each cache miss + local namespaced_key = self.name .. b_key + + local err, went_stale, is_stale + data, err, went_stale, is_stale = get_shm_set_lru(self, b_key, + namespaced_key, + l1_serializer) + if err then + --res[res_idx] = nil + res[res_idx + 1] = err + --res[res_idx + 2] = nil + + elseif data ~= nil and not went_stale then + if data == CACHE_MISS_SENTINEL_LRU then + data = nil + end + + res[res_idx] = data + --res[res_idx + 1] = nil + res[res_idx + 2] = is_stale and 4 or 2 + + else + -- not in shm either, we have to prepare a task to run the + -- L3 callback + + n_cbs = n_cbs + 1 + + if n_cbs == 1 then + cb_ctxs = tablepool.fetch("bulk_cb_ctxs", 1, 0) + end + + local ctx = tablepool.fetch("bulk_cb_ctx", 0, 15) + ctx.res_idx = res_idx + ctx.cb = b_cb + ctx.arg = bulk[i + 3] -- arg + ctx.key = b_key + ctx.shm_key = namespaced_key + ctx.data = data + ctx.ttl = ttl + ctx.neg_ttl = neg_ttl + ctx.went_stale = went_stale + ctx.l1_serializer = l1_serializer + ctx.resurrect_ttl = resurrect_ttl + ctx.shm_set_tries = shm_set_tries + ctx.rlock_opts = rlock_opts + ctx.data = data + ctx.err = nil + ctx.hit_lvl = nil + + cb_ctxs[n_cbs] = ctx + end + end + + res_idx = res_idx + 3 + end + + if n_cbs == 0 then + -- no callback to run, all items were in L1/L2 + res.n = res_idx - 1 + return res + end + + -- some L3 callbacks have to run + -- schedule threads as per our concurrency settings + -- we will use this thread as well + + local concurrency + if opts then + concurrency = opts.concurrency + end + + if not concurrency then + concurrency = BULK_DEFAULT_CONCURRENCY + end + + local threads + local threads_idx = 0 + + do + -- spawn concurrent threads + local thread_size + local n_threads = min(n_cbs, concurrency) - 1 + + if n_threads > 0 then + threads = tablepool.fetch("bulk_threads", n_threads, 0) + thread_size = ceil(n_cbs / concurrency) + end + + if self.debug then + ngx.log(ngx.DEBUG, "spawning ", n_threads, " threads to run ", + n_cbs, " callbacks") + end + + local from = 1 + local rest = n_cbs + + for i = 1, n_threads do + local to + if rest >= thread_size then + rest = rest - thread_size + to = from + thread_size - 1 + else + rest = 0 + to = from + end + + if self.debug then + ngx.log(ngx.DEBUG, "thread ", i, " running callbacks ", from, + " to ", to) + end + + threads_idx = threads_idx + 1 + threads[i] = thread_spawn(run_thread, self, cb_ctxs, from, to) + + from = from + thread_size + + if rest == 0 then + break + end + end + + if rest > 0 then + -- use this thread as one of our concurrent threads + local to = from + rest - 1 + + if self.debug then + ngx.log(ngx.DEBUG, "main thread running callbacks ", from, + " to ", to) + end + + run_thread(self, cb_ctxs, from, to) + end + end + + -- wait for other threads + + for i = 1, threads_idx do + local ok, err = thread_wait(threads[i]) + if not ok then + -- when thread_wait() fails, we don't get res_idx, and thus + -- cannot populate the appropriate res indexes with the + -- error + ngx_log(ERR, "failed to wait for thread number ", i, ": ", err) + end + end + + for i = 1, n_cbs do + local ctx = cb_ctxs[i] + local ctx_res_idx = ctx.res_idx + + res[ctx_res_idx] = ctx.data + res[ctx_res_idx + 1] = ctx.err + res[ctx_res_idx + 2] = ctx.hit_lvl + + tablepool.release("bulk_cb_ctx", ctx, true) -- no clear tab + end + + tablepool.release("bulk_cb_ctxs", cb_ctxs) + + if threads then + tablepool.release("bulk_threads", threads) + end + + res.n = res_idx - 1 + + return res + end + + +end -- get_bulk() + + +function _M:peek(key, stale) + if type(key) ~= "string" then + error("key must be a string", 2) + end + + -- restrict this key to the current namespace, so we isolate this + -- mlcache instance from potential other instances using the same + -- shm + local namespaced_key = self.name .. key + + local v, err, went_stale = self.dict:get_stale(namespaced_key) + if v == nil and err then + -- err can be 'flags' upon successful get_stale() calls, so we + -- also check v == nil + return nil, "could not read from lua_shared_dict: " .. err + end + + -- if we specified shm_miss, it might be a negative hit cached + -- there + if self.dict_miss and v == nil then + v, err, went_stale = self.dict_miss:get_stale(namespaced_key) + if v == nil and err then + -- err can be 'flags' upon successful get_stale() calls, so we + -- also check v == nil + return nil, "could not read from lua_shared_dict: " .. err + end + end + + if went_stale and not stale then + return nil + end + + if v ~= nil then + local value, err, at, ttl = unmarshall_from_shm(v) + if err then + return nil, "could not deserialize value after lua_shared_dict " .. + "retrieval: " .. err + end + + local remaining_ttl = 0 + + if ttl > 0 then + remaining_ttl = ttl - (now() - at) + + if remaining_ttl == 0 then + -- guarantee a non-zero remaining_ttl if ttl is set + remaining_ttl = 0.001 + end + end + + return remaining_ttl, nil, value, went_stale + end +end + + +function _M:set(key, opts, value) + if not self.broadcast then + error("no ipc to propagate update, specify opts.ipc_shm or opts.ipc", 2) + end + + if type(key) ~= "string" then + error("key must be a string", 2) + end + + do + -- restrict this key to the current namespace, so we isolate this + -- mlcache instance from potential other instances using the same + -- shm + local ttl, neg_ttl, _, l1_serializer, shm_set_tries = check_opts(self, + opts) + local namespaced_key = self.name .. key + + if self.dict_miss then + -- since we specified a separate shm for negative caches, we + -- must make sure that we clear any value that may have been + -- set in the other shm + local dict = value == nil and self.dict or self.dict_miss + + -- TODO: there is a potential race-condition here between this + -- :delete() and the subsequent :set() in set_shm() + local ok, err = dict:delete(namespaced_key) + if not ok then + return nil, "could not delete from shm: " .. err + end + end + + local _, err = set_shm_set_lru(self, key, namespaced_key, value, ttl, + neg_ttl, nil, shm_set_tries, + l1_serializer, true) + if err then + return nil, err + end + end + + local _, err = self.broadcast(self.events.invalidation.channel, key) + if err then + return nil, "could not broadcast update: " .. err + end + + return true +end + + +function _M:delete(key) + if not self.broadcast then + error("no ipc to propagate deletion, specify opts.ipc_shm or opts.ipc", + 2) + end + + if type(key) ~= "string" then + error("key must be a string", 2) + end + + -- delete from shm first + do + -- restrict this key to the current namespace, so we isolate this + -- mlcache instance from potential other instances using the same + -- shm + local namespaced_key = self.name .. key + + local ok, err = self.dict:delete(namespaced_key) + if not ok then + return nil, "could not delete from shm: " .. err + end + + -- instance uses shm_miss for negative caches, since we don't know + -- where the cached value is (is it nil or not?), we must remove it + -- from both + if self.dict_miss then + ok, err = self.dict_miss:delete(namespaced_key) + if not ok then + return nil, "could not delete from shm: " .. err + end + end + end + + -- delete from LRU and propagate + self.lru:delete(key) + + local _, err = self.broadcast(self.events.invalidation.channel, key) + if err then + return nil, "could not broadcast deletion: " .. err + end + + return true +end + + +function _M:purge(flush_expired) + if not self.broadcast then + error("no ipc to propagate purge, specify opts.ipc_shm or opts.ipc", 2) + end + + if not self.lru.flush_all and LRU_INSTANCES[self.name] ~= self.lru then + error("cannot purge when using custom LRU cache with " .. + "OpenResty < 1.13.6.2", 2) + end + + -- clear shm first + self.dict:flush_all() + + -- clear negative caches shm if specified + if self.dict_miss then + self.dict_miss:flush_all() + end + + if flush_expired then + self.dict:flush_expired() + + if self.dict_miss then + self.dict_miss:flush_expired() + end + end + + -- clear LRU content and propagate + rebuild_lru(self) + + local _, err = self.broadcast(self.events.purge.channel, "") + if err then + return nil, "could not broadcast purge: " .. err + end + + return true +end + + +function _M:update(timeout) + if not self.poll then + error("no polling configured, specify opts.ipc_shm or opts.ipc.poll", 2) + end + + local _, err = self.poll(timeout) + if err then + return nil, "could not poll ipc events: " .. err + end + + return true +end + + +return _M \ No newline at end of file diff --git a/plugins/openresty/waf/lib/resty/mlcache/ipc.lua b/plugins/openresty/waf/lib/resty/mlcache/ipc.lua new file mode 100644 index 000000000..55cf0113b --- /dev/null +++ b/plugins/openresty/waf/lib/resty/mlcache/ipc.lua @@ -0,0 +1,257 @@ +-- vim: ts=4 sts=4 sw=4 et: + +local ERR = ngx.ERR +local WARN = ngx.WARN +local INFO = ngx.INFO +local sleep = ngx.sleep +local shared = ngx.shared +local worker_pid = ngx.worker.pid +local ngx_log = ngx.log +local fmt = string.format +local sub = string.sub +local find = string.find +local min = math.min +local type = type +local pcall = pcall +local error = error +local insert = table.insert +local tonumber = tonumber +local setmetatable = setmetatable + + +local INDEX_KEY = "lua-resty-ipc:index" +local FORCIBLE_KEY = "lua-resty-ipc:forcible" +local POLL_SLEEP_RATIO = 2 + + +local function marshall(worker_pid, channel, data) + return fmt("%d:%d:%s%s", worker_pid, #data, channel, data) +end + + +local function unmarshall(str) + local sep_1 = find(str, ":", nil , true) + local sep_2 = find(str, ":", sep_1 + 1, true) + + local pid = tonumber(sub(str, 1 , sep_1 - 1)) + local data_len = tonumber(sub(str, sep_1 + 1, sep_2 - 1)) + + local channel_last_pos = #str - data_len + + local channel = sub(str, sep_2 + 1, channel_last_pos) + local data = sub(str, channel_last_pos + 1) + + return pid, channel, data +end + + +local function log(lvl, ...) + return ngx_log(lvl, "[ipc] ", ...) +end + + +local _M = {} +local mt = { __index = _M } + + +function _M.new(shm, debug) + local dict = shared[shm] + if not dict then + return nil, "no such lua_shared_dict: " .. shm + end + + local self = { + dict = dict, + pid = debug and 0 or worker_pid(), + idx = 0, + callbacks = {}, + } + + return setmetatable(self, mt) +end + + +function _M:subscribe(channel, cb) + if type(channel) ~= "string" then + error("channel must be a string", 2) + end + + if type(cb) ~= "function" then + error("callback must be a function", 2) + end + + if not self.callbacks[channel] then + self.callbacks[channel] = { cb } + + else + insert(self.callbacks[channel], cb) + end +end + + +function _M:broadcast(channel, data) + if type(channel) ~= "string" then + error("channel must be a string", 2) + end + + if type(data) ~= "string" then + error("data must be a string", 2) + end + + local marshalled_event = marshall(worker_pid(), channel, data) + + local idx, err = self.dict:incr(INDEX_KEY, 1, 0) + if not idx then + return nil, "failed to increment index: " .. err + end + + local ok, err, forcible = self.dict:set(idx, marshalled_event) + if not ok then + return nil, "failed to insert event in shm: " .. err + end + + if forcible then + -- take note that eviction has started + -- we repeat this flagging to avoid this key from ever being + -- evicted itself + local ok, err = self.dict:set(FORCIBLE_KEY, true) + if not ok then + return nil, "failed to set forcible flag in shm: " .. err + end + end + + return true +end + + +-- Note: if this module were to be used by users (that is, users can implement +-- their own pub/sub events and thus, callbacks), this method would then need +-- to consider the time spent in callbacks to prevent long running callbacks +-- from penalizing the worker. +-- Since this module is currently only used by mlcache, whose callback is an +-- shm operation, we only worry about the time spent waiting for events +-- between the 'incr()' and 'set()' race condition. +function _M:poll(timeout) + if timeout ~= nil and type(timeout) ~= "number" then + error("timeout must be a number", 2) + end + + local shm_idx, err = self.dict:get(INDEX_KEY) + if err then + return nil, "failed to get index: " .. err + end + + if shm_idx == nil then + -- no events to poll yet + return true + end + + if type(shm_idx) ~= "number" then + return nil, "index is not a number, shm tampered with" + end + + if not timeout then + timeout = 0.3 + end + + if self.idx == 0 then + local forcible, err = self.dict:get(FORCIBLE_KEY) + if err then + return nil, "failed to get forcible flag from shm: " .. err + end + + if forcible then + -- shm lru eviction occurred, we are likely a new worker + -- skip indexes that may have been evicted and resume current + -- polling idx + self.idx = shm_idx - 1 + end + + else + -- guard: self.idx <= shm_idx + self.idx = min(self.idx, shm_idx) + end + + local elapsed = 0 + + for _ = self.idx, shm_idx - 1 do + -- fetch event from shm with a retry policy in case + -- we run our :get() in between another worker's + -- :incr() and :set() + + local v + local idx = self.idx + 1 + + do + local perr + local pok = true + local sleep_step = 0.001 + + while elapsed < timeout do + v, err = self.dict:get(idx) + if v ~= nil or err then + break + end + + if pok then + log(INFO, "no event data at index '", idx, "', ", + "retrying in: ", sleep_step, "s") + + -- sleep is not available in all ngx_lua contexts + -- if we fail once, never retry to sleep + pok, perr = pcall(sleep, sleep_step) + if not pok then + log(WARN, "could not sleep before retry: ", perr, + " (note: it is safer to call this function ", + "in contexts that support the ngx.sleep() ", + "API)") + end + end + + elapsed = elapsed + sleep_step + sleep_step = min(sleep_step * POLL_SLEEP_RATIO, + timeout - elapsed) + end + end + + -- fetch next event on next iteration + -- even if we timeout, we might miss 1 event (we return in timeout and + -- we don't retry that event), but it's better than being stuck forever + -- on an event that might have been evicted from the shm. + self.idx = idx + + if elapsed >= timeout then + return nil, "timeout" + end + + if err then + log(ERR, "could not get event at index '", self.idx, "': ", err) + + elseif type(v) ~= "string" then + log(ERR, "event at index '", self.idx, "' is not a string, ", + "shm tampered with") + + else + local pid, channel, data = unmarshall(v) + + if self.pid ~= pid then + -- coming from another worker + local cbs = self.callbacks[channel] + if cbs then + for j = 1, #cbs do + local pok, perr = pcall(cbs[j], data) + if not pok then + log(ERR, "callback for channel '", channel, + "' threw a Lua error: ", perr) + end + end + end + end + end + end + + return true +end + + +return _M \ No newline at end of file