OpenResty使用Lua大全(十二)实战: 动手实现一个网关框架

@[TOC]

系列文章索引

OpenResty使用Lua大全(一)Lua语法入门实战 OpenResty使用Lua大全(二)在OpenResty中使用Lua OpenResty使用Lua大全(三)OpenResty使用Json模块解析json OpenResty使用Lua大全(四)OpenResty中使用Redis OpenResty使用Lua大全(五)OpenResty中使用MySQL OpenResty使用Lua大全(六)OpenResty发送http请求 OpenResty使用Lua大全(七)OpenResty使用全局缓存 OpenResty使用Lua大全(八)OpenResty执行流程与阶段详解 OpenResty使用Lua大全(九)实战:nginx-lua-redis实现访问频率控制 OpenResty使用Lua大全(十)实战: Lua + Redis 实现动态封禁 IP OpenResty使用Lua大全(十一)实战: nginx实现接口签名安全认证 OpenResty使用Lua大全(十二)实战: 动手实现一个网关框架

一、网关基本介绍

1、网关常用功能

统一入口; 安全:黑名单、权限身份认证; 限流:实现微服务访问流量计算,基于流量计算分析进行限流,可以定义多种限流规则; 缓存:数据缓存; 日志:记录日志; 监控:记录请求响应数据,api耗时分析,性能监控; 重试:异常重试; 熔断:降级。

2、本次实战目标

利用openresty+lua实现网关的部分需求;主要搭建好框架,可以在此基础上扩展。

自己了解深入网关 基于openresty开发小型网关框架,主要介绍逻辑,和实现方式;有些生产环境上面去使用,还是需要大家进行改造。

3、orange介绍

官网:http://orange.sumory.com

安装依赖 1)OpenResty: 版本应在1.9.7.3+ 需要编译OpenResty时添加--with-http_stub_status_module

2)lor框架

bash 复制代码
git clone https://github.com/sumory/lor
cd lor
make install

3)MySQL 数据库名:orange 安装mysql数据库 执行相应的版本的数据库脚本

4、安装orange

git clone https://github.com/sumory/orange.git 使用的是v0.5.0及以上的版本, 可以通过make install将Orange安装到系统中。 执行此命令后, 以下两部分将被安装: /usr/local/orange #orange运行时需要的文件 /usr/local/bin/orange #orange命令行工具

注意: 报错/usr/bin/env: 'resty': No such file or directory 在/usr/bin/目录下创建resty的软链接 #sudo ln -s /usr/local/openresty/bin/resty /usr/bin/resty

执行 orange help

Orange v0.6.4, OpenResty/Nginx API Gateway.

Usage: orange COMMAND [OPTIONS]

The commands are:

stop Stop current Orange version Show the version of Orange restart Restart Orange reload Reload the config of Orange store Init/Update/Backup Orange store help Show help tips start Start the Orange Gateway

根据以上两种方式选择使用sh start.sh或orange start来启动Orange。 Orange启动成功后, Dashboard和API server也随之启动:

内置的Dashboard可通过http://localhost:9999访问 API Server默认在7777端口监听,如不需要API Server可删除nginx.conf里对应的配置

二、动手实现网关

1、主入口

(1)nginx的conf配置

bash 复制代码
#user  nobody;
worker_processes  1;

error_log  logs/error.log;
error_log  logs/debug.log  debug;

#pid        logs/nginx.pid;

events {
    worker_connections  1024;
}


http {
    include       mime.types;
    default_type  text/html;
	charset utf-8;
    sendfile        on;
    keepalive_timeout  65;
	
	resolver 8.8.8.8;
	
	upstream default_upstream {
        server localhost:8080;
    }
	
	#----------------------------nginx gateway configuration-----------------------------
	# lua文件地址
    lua_package_path '/usr/local/lua/?.lua;;';
    lua_code_cache on;
		
	lua_shared_dict shared_ip_blacklist 1m; #定义ip_blacklist 本地缓存变量
		
	# 引用我们的 网关lua文件
	init_by_lua_block {
		
        local gateway = require("gateway.gateway")

        context = {
            gateway = gateway
        }
    }

    init_worker_by_lua_block {
        local gateway = context.gateway
        gateway.init_worker()
    }
	
    server {
        listen       80;
        #server_name  www.server1.com;
		
		location = /favicon.ico {
            log_not_found off;
            access_log off;
        }
		
		location / {
            set $upstream_host $host;
            set $upstream_url 'http://default_upstream';

            rewrite_by_lua_block {
                local gateway = context.gateway
                gateway.redirect()
                gateway.rewrite()
            }

            access_by_lua_block {
                local gateway = context.gateway
                gateway.access()
            }

            # proxy
            proxy_set_header X-Real-IP $remote_addr;
            proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
            proxy_set_header X-Forwarded-Scheme $scheme;
            proxy_set_header Host $upstream_host;
            proxy_pass $upstream_url;


            header_filter_by_lua_block {
                local gateway = context.gateway
                gateway.header_filter()
            }

            body_filter_by_lua_block {
                local gateway = context.gateway
                gateway.body_filter()
            }

            log_by_lua_block {
                local gateway = context.gateway
                gateway.log()
            }
        }
    }
	
	server {  
        listen 8080;  
  
        location /world {  
            echo "hello world";  
        }  
    }  
}

(2)网关lua:gateway.lua

创建/usr/local/lua/gateway/gateway.lua文件

bash 复制代码
mkdir -p /usr/local/lua/gateway
vi /usr/local/lua/gateway/gateway.lua
lua 复制代码
local gateway = {}

function gateway.init()
	ngx.log(ngx.DEBUG, "===========gateway.init============");
end

function gateway.init_worker()
	ngx.log(ngx.DEBUG, "===========gateway.init_worker============");
end

function gateway.redirect()
	ngx.log(ngx.DEBUG, "===========gateway.redirect============");
end

function gateway.rewrite()
	ngx.log(ngx.DEBUG, "===========gateway.rewrite============");
end

function gateway.access()
	ngx.log(ngx.DEBUG, "===========gateway.access============");
end

function gateway.header_filter()
	ngx.log(ngx.DEBUG, "===========gateway.header_filter============");
end

function gateway.body_filter()
	ngx.log(ngx.DEBUG, "===========gateway.body_filter============");
end

function gateway.log()
	ngx.log(ngx.DEBUG, "===========gateway.log============");
end

return gateway;

(3)启动

启动nginx,访问 看debug日志:

这样,我们从nginx的入口,就调用到了我们的lua脚本。 后续我们开发网关的功能,只需要扩展gateway.lua就可以了。

2、gateway网关的插件管理

(1)实现思路

我们需要设计一个方案,可以动态的插拔插件,每个插件代表不同的功能

设计一个方便可插拔的插件,在文档目录结构方面进行这样的设计

--gateway #网关gateway主目录 ---lib #存放第三方类库 ---common #存放一些工具类库 ---conf #存放配置文件 -----gateway.conf #定义gateway网关的一些相关配置 ---plugins #存放所有插件的目录 -----base_plugin.lua #插件基类 -----XXXX #某个插件的目录名 -------handler.lua #某个插件的实现类,统一命名为handler,这个是用目录名区分插件,这样的设计有特别用途

bash 复制代码
mkdir -p /usr/local/lua/gateway/plugins
mkdir -p /usr/local/lua/gateway/lib
mkdir -p /usr/local/lua/gateway/common
mkdir -p /usr/local/lua/gateway/conf

(2)基础父类

存放了公共的类库 classic脚本使用了object的基类,类似java的object类; 以后我们所设计的类,只要在此类上面进行扩展即可,此基类object,大家不需要完全了解,只需要知道此脚本类似于object基类就行。

bash 复制代码
vi /usr/local/lua/gateway/lib/classic.lua
lua 复制代码
local Object = {}
Object.__index = Object


function Object:new()
end


function Object:extend()
  local cls = {}
  for k, v in pairs(self) do
    if k:find("__") == 1 then
      cls[k] = v
    end
  end
  cls.__index = cls
  cls.super = self
  setmetatable(cls, self)
  return cls
end


function Object:implement(...)
  for _, cls in pairs({...}) do
    for k, v in pairs(cls) do
      if self[k] == nil and type(v) == "function" then
        self[k] = v
      end
    end
  end
end


function Object:is(T)
  local mt = getmetatable(self)
  while mt do
    if mt == T then
      return true
    end
    mt = getmetatable(mt)
  end
  return false
end


function Object:__tostring()
  return "Object"
end


function Object:__call(...)
  local obj = setmetatable({}, self)
  obj:new(...)
  return obj
end


return Object

(3)工具类

在common目录下创建的是一些公共方法脚本类

bash 复制代码
vi /usr/local/lua/gateway/common/utils.lua
lua 复制代码
local string_find = string.find

local _M = {}

function _M.debug_log(msg)
    ngx.log(ngx.DEBUG, msg);
end

function _M.warn_log(msg)
    ngx.log(ngx.WARN, msg);
end

function _M.error_log(msg)
    ngx.log(ngx.ERR, msg);
end

function _M.load_module_if_exists(module_name)
    local status, res = pcall(require, module_name)
    if status then
        return true, res
        -- Here we match any character because if a module has a dash '-' in its name, we would need to escape it.
    elseif type(res) == "string" and string_find(res, "module '"..module_name.."' not found", nil, true) then
        return false
    else
        error(res)
    end
end

return _M

(4)基础插件类

bash 复制代码
vi /usr/local/lua/gateway/plugins/base_plugin.lua
lua 复制代码
local utils = require("gateway.common.utils")
local Object = require("gateway.lib.classic")

local BasePlugin = Object:extend()

function BasePlugin:new(name)
    self._name = name
    utils.debug_log("BasePlugin executing plugin \""..self._name.."\": new")
end

function BasePlugin:get_name()
    return self._name
end

function BasePlugin:init_worker()
    utils.debug_log("BasePlugin executing plugin \""..self._name.."\": init_worker")
end

function BasePlugin:redirect()
    utils.debug_log("BasePlugin executing plugin \""..self._name.."\": redirect")
end

function BasePlugin:rewrite()
    utils.debug_log("BasePlugin executing plugin \""..self._name.."\": rewrite")
end

function BasePlugin:access()
    utils.debug_log("BasePlugin executing plugin \""..self._name.."\": access")
end

function BasePlugin:header_filter()
    utils.debug_log("BasePlugin executing plugin \""..self._name.."\": header_filter")
end

function BasePlugin:body_filter()
    utils.debug_log("BasePlugin executing plugin \""..self._name.."\": body_filter")
end

function BasePlugin:log()
    utils.debug_log("BasePlugin executing plugin \""..self._name.."\": log")
end

return BasePlugin

(5)实战:签名验证的插件

在plugins目录下新建sign_auth目录 在sign_auth目录下新建handler.lua

bash 复制代码
mkdir -p /usr/local/lua/gateway/plugins/sign_auth
vi /usr/local/lua/gateway/plugins/sign_auth/handler.lua
lua 复制代码
local utils = require("gateway.common.utils")
local BasePlugin = require("gateway.plugins.base_plugin")

local SignAuthHandler = BasePlugin:extend()
SignAuthHandler.PRIORITY = 0

function SignAuthHandler:new()
    SignAuthHandler.super.new(self, "sign_auth-plugin")
    utils.debug_log("===========SignAuthHandler.new============");
end

function SignAuthHandler:access()
    SignAuthHandler.super.access(self)
    utils.debug_log("===========SignAuthHandler.access============");
end

return SignAuthHandler;

(6)设置配置文件

到此插件已经设计完成,现在我们需要设计配置文件,有哪些可以配置

bash 复制代码
vi /usr/local/lua/gateway/conf/gateway.conf
bash 复制代码
{
    "plugins": [       #插件属性,拥有哪些插件
        "sign_auth"    #此名称就是 插件的目录名
    ]
}

注意,不能带注释!!!

设计到这里,我们基本雏形已经完成

3、根据配置文件加载插件

(1)编写读取配置文件内容的类库

bash 复制代码
vi /usr/local/lua/gateway/common/io.lua
lua 复制代码
--- 
-- :P some origin code is from https://github.com/Mashape/kong/blob/master/kong/tools/io.lua
-- modified by sumory.wu

local stringy = require("gateway.common.stringy")

local _M = {}

---
-- Checks existence of a file.
-- @param path path/file to check
-- @return `true` if found, `false` + error message otherwise
function _M.file_exists(path)
    local f, err = io.open(path, "r")
    if f ~= nil then
        io.close(f)
        return true
    else
        return false, err
    end
end

---
-- Execute an OS command and catch the output.
-- @param command OS command to execute
-- @return string containing command output (both stdout and stderr)
-- @return exitcode
function _M.os_execute(command, preserve_output)
    local n = os.tmpname() -- get a temporary file name to store output
    local f = os.tmpname() -- get a temporary file name to store script
    _M.write_to_file(f, command)
    local exit_code = os.execute("/bin/bash "..f.." > "..n.." 2>&1")
    local result = _M.read_file(n)
    os.remove(n)
    os.remove(f)
    return preserve_output and result or string.gsub(string.gsub(result, "^"..f..":[%s%w]+:%s*", ""), "[%\r%\n]", ""), exit_code / 256
end

---
-- Check existence of a command.
-- @param cmd command being searched for
-- @return `true` of found, `false` otherwise
function _M.cmd_exists(cmd)
    local _, code = _M.os_execute("hash "..cmd)
    return code == 0
end

--- Kill a process by PID.
-- Kills the process and waits until it's terminated
-- @param pid_file the file containing the pid to kill
-- @param signal the signal to use
-- @return `os_execute` results, see os_execute.
function _M.kill_process_by_pid_file(pid_file, signal)
    if _M.file_exists(pid_file) then
        local pid = stringy.strip(_M.read_file(pid_file))
        return _M.os_execute("while kill -0 "..pid.." >/dev/null 2>&1; do kill "..(signal and "-"..tostring(signal).." " or "")..pid.."; sleep 0.1; done")
    end
end

--- Read file contents.
-- @param path filepath to read
-- @return file contents as string, or `nil` if not succesful
function _M.read_file(path)
    local contents
    local file = io.open(path, "rb")
    if file then
        contents = file:read("*all")
        file:close()
    end
    return contents
end

--- Write file contents.
-- @param path filepath to write to
-- @return `true` upon success, or `false` + error message on failure
function _M.write_to_file(path, value)
    local file, err = io.open(path, "w")
    if err then
        return false, err
    end

    file:write(value)
    file:close()
    return true
end


--- Get the filesize.
-- @param path path to file to check
-- @return size of file, or `nil` on failure
function _M.file_size(path)
    local size
    local file = io.open(path, "rb")
    if file then
        size = file:seek("end")
        file:close()
    end
    return size
end

return _M

(2)针对字符串的工具类

bash 复制代码
vi /usr/local/lua/gateway/common/stringy.lua
lua 复制代码
local string_gsub = string.gsub
local string_find = string.find
local table_insert = table.insert

local _M = {}

function _M.trim_all(str)
    if not str or str == "" then return "" end
    local result = string_gsub(str, " ", "")
    return result
end

function _M.strip(str)
    if not str or str == "" then return "" end
    local result = string_gsub(str, "^ *", "")
    result = string_gsub(result, "( *)$", "")
    return result
end


function _M.split(str, delimiter)
    if not str or str == "" then return {} end
    if not delimiter or delimiter == "" then return { str } end

    local result = {}
    for match in (str .. delimiter):gmatch("(.-)" .. delimiter) do
        table_insert(result, match)
    end
    return result
end

function _M.startswith(str, substr)
    if str == nil or substr == nil then
        return false
    end
    if string_find(str, substr) ~= 1 then
        return false
    else
        return true
    end
end

function _M.endswith(str, substr)
    if str == nil or substr == nil then
        return false
    end
    local str_reverse = string.reverse(str)
    local substr_reverse = string.reverse(substr)
    if string.find(str_reverse, substr_reverse) ~= 1 then
        return false
    else
        return true
    end
end

return _M

(3)字符串与json对象编码的工具类

bash 复制代码
vi /usr/local/lua/gateway/common/json.lua
lua 复制代码
local cjson = require("cjson.safe")

local _M = {}

function _M.encode(data, empty_table_as_object)
    if not data then return nil end

    if cjson.encode_empty_table_as_object then
        -- empty table default is arrya
        cjson.encode_empty_table_as_object(empty_table_as_object or false)
    end

    if require("ffi").os ~= "Windows" then
        cjson.encode_sparse_array(true)
    end

    return cjson.encode(data)
end


function _M.decode(data)
    if not data then return nil end

    return cjson.decode(data)
end


return _M

(4)编写配置文件加载类库

bash 复制代码
vi /usr/local/lua/gateway/common/config_loader.lua
lua 复制代码
local json = require("gateway.common.json")
local IO = require("gateway.common.io")

local _M = {}

function _M.load(config_path)
    config_path = config_path or "/usr/local/lua/gateway/conf/gateway.conf"
    local config_contents = IO.read_file(config_path)

    if not config_contents then
        ngx.log(ngx.ERR, "No configuration file at: ", config_path)
        os.exit(1)
    end

    local config = json.decode(config_contents)
    return config, config_path
end

return _M

(5)改造入口gateway.lua

bash 复制代码
vi /usr/local/lua/gateway/gateway.lua
lua 复制代码
local utils = require("gateway.common.utils")
local config_loader = require("gateway.common.config_loader")

local function load_node_plugins(config)
  utils.debug_log("===========load_node_plugins============");
  local plugins = config.plugins --插件列表
  local sorted_plugins = {} --按照优先级的插件集合
  for _, v in ipairs(plugins) do
    local loaded, plugin_handler = utils.load_module_if_exists("gateway.plugins." .. v .. ".handler")
    if not loaded then
            utils.warn_log("The following plugin is not installed or has no handler: " .. v)
        else
            utils.debug_log("Loading plugin: " .. v)
            table.insert(sorted_plugins, {
                name = v,
                handler = plugin_handler(), --插件
            })
        end
  end
  --表按照优先级排序
  table.sort(sorted_plugins, function(a, b)
        local priority_a = a.handler.PRIORITY or 0
        local priority_b = b.handler.PRIORITY or 0
        return priority_a > priority_b
    end)
  
  return sorted_plugins
end


local gateway= {}

function gateway.init(options)
  options = options or {}
  local config
  local status, err = pcall(function()
    --gateway的配置文件路径
        local conf_file_path = options.config
    utils.debug_log("Loading gateway conf : " .. conf_file_path)
        config = config_loader.load(conf_file_path)
    --加载配置的插件
        loaded_plugins = load_node_plugins(config)
    end)
  
  if not status or err then
        utils.error_log("Startup error: " .. err)
        return ngx.exit(ngx.HTTP_INTERNAL_SERVER_ERROR)  
    end

  ngx.log(ngx.DEBUG, "===========gateway.init============");
end

function gateway.init_worker()
        ngx.log(ngx.DEBUG, "===========gateway.init_worker============");
end

function gateway.redirect()
        ngx.log(ngx.DEBUG, "===========gateway.redirect============");
end

function gateway.rewrite()
        ngx.log(ngx.DEBUG, "===========gateway.rewrite============");
end

function gateway.access()
        ngx.log(ngx.DEBUG, "===========gateway.access============");
end

function gateway.header_filter()
        ngx.log(ngx.DEBUG, "===========gateway.header_filter============");
end

function gateway.body_filter()
        ngx.log(ngx.DEBUG, "===========gateway.body_filter============");
end

function gateway.log()
        ngx.log(ngx.DEBUG, "===========gateway.log============");
end

return gateway;

(6)更改nginx.conf

bash 复制代码
init_by_lua_block {

    local gateway = require("gateway.gateway")

    local config_file = "/usr/local/lua/gateway/conf/gateway.conf"

    gateway.init({
      config = config_file
    })

    context = {
        gateway = gateway
    }
}

(7)重启nginx,查看nginx日志

bash 复制代码
nginx -s reload

访问一下:

4、签名验证插件完善

(1)handler.lua进行改造

lua 复制代码
vi /usr/local/lua/gateway/plugins/sign_auth/handler.lua
lua 复制代码
local utils = require("gateway.common.utils")
local BasePlugin = require("gateway.plugins.base_plugin")
local redis = require "resty.redis"  --引入redis模块

local function close_redis(red)  
    if not red then  
        return
    end  
    --释放连接(连接池实现)  
    local pool_max_idle_time = 10000 --毫秒  
    local pool_size = 100 --连接池大小  
    local ok, err = red:set_keepalive(pool_max_idle_time, pool_size)  
    if not ok then  
        utils.error_log("set keepalive error : "..err)  
    end  
end

--检验请求的sign签名是否正确
--params:传入的参数值组成的table
--secret:项目secret,根据appid找到secret
local function signcheck(params,secret)
    --判断参数是否为空,为空报异常
    if utils.isTableEmpty(params) then
        local mess="params table is empty"
        utils.error_log(mess)
        return false,mess
    end
    
    --判断是否有签名参数
    local sign = params["sign"]
    if sign == nil then
        local mess="params sign is nil"
        utils.error_log(mess)
        return false,mess
    end
    
    --是否存在时间戳的参数
    local timestamp = params["time"]
    if timestamp == nil then
        local mess="params timestamp is nil"
        utils.error_log(mess)
        return false,mess
    end
    --时间戳有没有过期,10秒过期
    local now_mill = ngx.now() * 1000 --毫秒级
    if now_mill - timestamp > 10000 then
        local mess="params timestamp is 过期"
        utils.error_log(mess)
        return false,mess
    end
    
    local keys, tmp = {}, {}

    --提出所有的键名并按字符顺序排序
    for k, _ in pairs(params) do 
        if k ~= "sign" then --去除掉
            keys[#keys+1]= k
        end
    end
    table.sort(keys)
    --根据排序好的键名依次读取值并拼接字符串成key=value&key=value
    for _, k in pairs(keys) do
        if type(params[k]) == "string" or type(params[k]) == "number" then 
            tmp[#tmp+1] = k .. "=" .. tostring(params[k])
        end
    end
    --将salt添加到最后,计算正确的签名sign值并与传入的sign签名对比,
    local signchar = table.concat(tmp, "&") .."&"..secret
    local rightsign = ngx.md5(signchar);
    if sign ~= rightsign then
        --如果签名错误返回错误信息并记录日志,
        local mess="sign error: sign,"..sign .. " right sign:" ..rightsign.. " sign_char:" .. signchar
        utils.error_log(mess)
        return false,mess
    end
    return true
end

local SignAuthHandler = BasePlugin:extend()
SignAuthHandler.PRIORITY = 0

function SignAuthHandler:new()
    SignAuthHandler.super.new(self, "sign_auth-plugin")
    utils.debug_log("===========SignAuthHandler.new============");
end

function SignAuthHandler:access()
    SignAuthHandler.super.access(self)
    utils.debug_log("===========SignAuthHandler.access============");
    local params = {}

    local get_args = ngx.req.get_uri_args();
    
    local appid = get_args["appid"];
    
    if appid == nil then
        ngx.say("appid is empty,非法请求");
        return ngx.exit(ngx.HTTP_FORBIDDEN) --直接返回403
    end
    
    ngx.req.read_body()
    local post_args = ngx.req.get_post_args();

    utils.union(params,get_args)
    params = utils.union(params,post_args)
    
    local red = redis:new()  --创建一个对象,注意是用冒号调用的
    
    --设置超时(毫秒)  
    red:set_timeout(1000) 
    --建立连接  
    local host = "127.0.0.1"  
    local port = 6379
    local ok, err = red:connect(host, port)
    if not ok then  
        close_redis(red)
        utils.error_log("Cannot connect");
        return ngx.exit(ngx.HTTP_INTERNAL_SERVER_ERROR)   
    end  
    
    --得到此appid对应的secret
    local resp, err = red:hget("apphash",appid)
    if not resp or (resp == ngx.null) then  
        close_redis(red)
        return ngx.exit(ngx.HTTP_INTERNAL_SERVER_ERROR) --redis 获取值失败
    end 
    --resp存放着就是appid对应的secret       
    local checkResult,mess = signcheck(params,resp)

    if not checkResult then
        ngx.say(mess);
        return ngx.exit(ngx.HTTP_FORBIDDEN) --直接返回403
    end
end

return SignAuthHandler;

(2)gateway.lua中access阶段进行改造

bash 复制代码
vi /usr/local/lua/gateway/gateway.lua
lua 复制代码
function gateway.access()
    ngx.log(ngx.DEBUG, "===========gateway.access============");
    for _, plugin in ipairs(loaded_plugins) do
        ngx.log(ngx.DEBUG, "==gateway.access name==" .. plugin.name);
        plugin.handler:access()
    end
end

(3)在utils工具类加入对table的操作

bash 复制代码
vi /usr/local/lua/gateway/common/utils.lua
lua 复制代码
--判断table是否为空
function _M.isTableEmpty(t)
    return t == nil or next(t) == nil
end

--两个table合并
function _M.union(table1,table2)
    for k, v in pairs(table2) do
        table1[k] = v
    end
    return table1
end

(4)启动nginx,利用之前编写的java代码,模拟请求

5、实现黑名单插件

(1)使用封装的redis库

bash 复制代码
vi /usr/local/lua/gateway/lib/redis.lua
lua 复制代码
local redis_c = require "resty.redis"

local ok, new_tab = pcall(require, "table.new")
if not ok or type(new_tab) ~= "function" then
    new_tab = function (narr, nrec) return {} end
end

local _M = new_tab(0, 155)
_M._VERSION = '0.01'

local commands = {
    "append",            "auth",              "bgrewriteaof",
    "bgsave",            "bitcount",          "bitop",
    "blpop",             "brpop",
    "brpoplpush",        "client",            "config",
    "dbsize",
    "debug",             "decr",              "decrby",
    "del",               "discard",           "dump",
    "echo",
    "eval",              "exec",              "exists",
    "expire",            "expireat",          "flushall",
    "flushdb",           "get",               "getbit",
    "getrange",          "getset",            "hdel",
    "hexists",           "hget",              "hgetall",
    "hincrby",           "hincrbyfloat",      "hkeys",
    "hlen",
    "hmget",              "hmset",      "hscan",
    "hset",
    "hsetnx",            "hvals",             "incr",
    "incrby",            "incrbyfloat",       "info",
    "keys",
    "lastsave",          "lindex",            "linsert",
    "llen",              "lpop",              "lpush",
    "lpushx",            "lrange",            "lrem",
    "lset",              "ltrim",             "mget",
    "migrate",
    "monitor",           "move",              "mset",
    "msetnx",            "multi",             "object",
    "persist",           "pexpire",           "pexpireat",
    "ping",              "psetex",            "psubscribe",
    "pttl",
    "publish",      --[[ "punsubscribe", ]]   "pubsub",
    "quit",
    "randomkey",         "rename",            "renamenx",
    "restore",
    "rpop",              "rpoplpush",         "rpush",
    "rpushx",            "sadd",              "save",
    "scan",              "scard",             "script",
    "sdiff",             "sdiffstore",
    "select",            "set",               "setbit",
    "setex",             "setnx",             "setrange",
    "shutdown",          "sinter",            "sinterstore",
    "sismember",         "slaveof",           "slowlog",
    "smembers",          "smove",             "sort",
    "spop",              "srandmember",       "srem",
    "sscan",
    "strlen",       --[[ "subscribe",  ]]     "sunion",
    "sunionstore",       "sync",              "time",
    "ttl",
    "type",         --[[ "unsubscribe", ]]    "unwatch",
    "watch",             "zadd",              "zcard",
    "zcount",            "zincrby",           "zinterstore",
    "zrange",            "zrangebyscore",     "zrank",
    "zrem",              "zremrangebyrank",   "zremrangebyscore",
    "zrevrange",         "zrevrangebyscore",  "zrevrank",
    "zscan",
    "zscore",            "zunionstore",       "evalsha"
}

local mt = { __index = _M }

local function is_redis_null( res )
    if type(res) == "table" then
        for k,v in pairs(res) do
            if v ~= ngx.null then
                return false
            end
        end
        return true
    elseif res == ngx.null then
        return true
    elseif res == nil then
        return true
    end

    return false
end

function _M.close_redis(self, redis)  
    if not redis then  
        return  
    end  
    --释放连接(连接池实现)
    local pool_max_idle_time = self.pool_max_idle_time --最大空闲时间 毫秒  
    local pool_size = self.pool_size --连接池大小  
    
    local ok, err = redis:set_keepalive(pool_max_idle_time, pool_size)  
    if not ok then  
        ngx.say("set keepalive error : ", err)  
    end  
end  

-- change connect address as you need
function _M.connect_mod( self, redis )
    redis:set_timeout(self.timeout)
        
    local ok, err = redis:connect(self.ip, self.port)
    if not ok then  
        ngx.say("connect to redis error : ", err)  
        return self:close_redis(redis)  
    end

    if self.password ~= "" then ----密码认证
    
        local count, err = redis:get_reused_times()
        if 0 == count then ----新建连接,需要认证密码
            ok, err = redis:auth(self.password)
            if not ok then
                ngx.say("failed to auth: ", err)
                return
            end
        elseif err then  ----从连接池中获取连接,无需再次认证密码
            ngx.say("failed to get reused times: ", err)
            return
        end
    end

    return ok,err;
end

function _M.init_pipeline( self )
    self._reqs = {}
end

function _M.commit_pipeline( self )
    local reqs = self._reqs

    if nil == reqs or 0 == #reqs then
        return {}, "no pipeline"
    else
        self._reqs = nil
    end

    local redis, err = redis_c:new()
    if not redis then
        return nil, err
    end

    local ok, err = self:connect_mod(redis)
    if not ok then
        return {}, err
    end

    redis:init_pipeline()
    for _, vals in ipairs(reqs) do
        local fun = redis[vals[1]]
        table.remove(vals , 1)

        fun(redis, unpack(vals))
    end

    local results, err = redis:commit_pipeline()
    if not results or err then
        return {}, err
    end

    if is_redis_null(results) then
        results = {}
        ngx.log(ngx.WARN, "is null")
    end
    -- table.remove (results , 1)

    --self.set_keepalive_mod(redis)
    self:close_redis(redis)  

    for i,value in ipairs(results) do
        if is_redis_null(value) then
            results[i] = nil
        end
    end

    return results, err
end

local function do_command(self, cmd, ... )
    if self._reqs then
        table.insert(self._reqs, {cmd, ...})
        return
    end

    local redis, err = redis_c:new()
    if not redis then
        return nil, err
    end

    local ok, err = self:connect_mod(redis)
    if not ok or err then
        return nil, err
    end

    redis:select(self.db_index)
    
    local fun = redis[cmd]
    local result, err = fun(redis, ...)
    if not result or err then
        -- ngx.log(ngx.ERR, "pipeline result:", result, " err:", err)
        return nil, err
    end

    if is_redis_null(result) then
        result = nil
    end

    --self.set_keepalive_mod(redis)
    self:close_redis(redis)  

    return result, err
end

for i = 1, #commands do
    local cmd = commands[i]
    _M[cmd] =
            function (self, ...)
                return do_command(self, cmd, ...)
            end
end

function _M.new(self, opts)
    opts = opts or {}
    local timeout = (opts.timeout and opts.timeout * 1000) or 1000
    local db_index= opts.db_index or 0
    local ip = opts.ip or '127.0.0.1'
    local port = opts.port or 6379
    local password = opts.password or ""
    local pool_max_idle_time = opts.pool_max_idle_time or 60000
    local pool_size = opts.pool_size or 100

    return setmetatable({
            timeout = timeout,
            db_index = db_index,
            ip = ip,
            port = port,
            password = password,
            pool_max_idle_time = pool_max_idle_time,
            pool_size = pool_size,
            _reqs = nil }, mt)
end

return _M

(2)创建插件

bash 复制代码
mkdir -p /usr/local/lua/gateway/plugins/limit_ip
vi /usr/local/lua/gateway/plugins/limit_ip/handler.lua
lua 复制代码
local utils = require("gateway.common.utils")
local redis = require("gateway.lib.redis")  --引入redis模块
local BasePlugin = require("gateway.plugins.base_plugin")

local opts = {
    ip = "127.0.0.1",
    port = "6379",
    -- password = "123456",
    nil,
    db_index = 0
}

local LimitIpHandler = BasePlugin:extend()
LimitIpHandler.PRIORITY = 2

function LimitIpHandler:new()
    LimitIpHandler.super.new(self, "limit_ip-plugin")
    utils.debug_log("===========LimitIpHandler.new============");
end

function LimitIpHandler:access()
    LimitIpHandler.super.access(self)
    utils.debug_log("===========LimitIpHandler.access============");
    
    local key = "limit:ip:blacklist";
    local user_ip = utils.get_ip();
    local shared_ip_blacklist = ngx.shared.shared_ip_blacklist;
    
    --获得本地缓存的最新刷新时间
    local last_update_time = shared_ip_blacklist:get("last_update_time");
    
    if last_update_time ~= nil then 
        local dif_time = ngx.now() - last_update_time 
        if dif_time < 60 then --缓存1分钟,没有过期
            if shared_ip_blacklist:get(user_ip) then
                return ngx.exit(ngx.HTTP_FORBIDDEN) --直接返回403
            end
        end
    end
    
    local red = redis:new(opts)  --创建一个对象,注意是用冒号调用的
    
    local ip_blacklist, err = red:smembers(key);
    if err then
        utils.error_log("limit ip smembers");
    else
        --刷新本地缓存,重新设置
        shared_ip_blacklist:flush_all();
        
        if ip_blacklist ~= nil then            
            for i,bip in ipairs(ip_blacklist) do
                --本地缓存redis中的黑名单
                shared_ip_blacklist:set(bip,true);
            end
        end
        
        --设置本地缓存的最新更新时间
        shared_ip_blacklist:set("last_update_time",ngx.now());
    end
        
    if shared_ip_blacklist:get(ip) then
        return ngx.exit(ngx.HTTP_FORBIDDEN) --直接返回403
    end
    
end

return LimitIpHandler;

(3)完善工具类

bash 复制代码
vi /usr/local/lua/gateway/common/utils.lua
lua 复制代码
function _M.get_ip()
    local myIP = ngx.req.get_headers()["X-Real-IP"]
    if myIP == nil then
        myIP = ngx.req.get_headers()["x_forwarded_for"]
    end
    if myIP == nil then
        myIP = ngx.var.remote_addr
    end
    return myIP;
end

(4)配置文件加入插件

bash 复制代码
vi /usr/local/lua/gateway/conf/gateway.conf
bash 复制代码
{
    "plugins": [
        "limit_ip"
    ]
}

(4)测试

访问:

加入黑名单:

bash 复制代码
sadd limit:ip:blacklist 192.168.56.1

过一段时间,生效,发现被限制了:

6、实现限流插件

(1)新建插件

bash 复制代码
mkdir -p /usr/local/lua/gateway/plugins/limit_frequency
vi /usr/local/lua/gateway/plugins/limit_frequency/handler.lua
lua 复制代码
local utils = require("gateway.common.utils")
local redis = require("gateway.lib.redis")  --引入redis模块
local BasePlugin = require("gateway.plugins.base_plugin")

local opts = {
    ip = "127.0.0.1",
    port = "6379",
    db_index = 0
}

local LimitFrequencyHandler = BasePlugin:extend()
LimitFrequencyHandler.PRIORITY = 1

function LimitFrequencyHandler:new()
    LimitFrequencyHandler.super.new(self, "LimitFrequency-plugin")
    utils.debug_log("===========LimitFrequencyHandler.new============");
end


function LimitFrequencyHandler:access()
    LimitFrequencyHandler.super.access(self)
    utils.debug_log("===========LimitFrequencyHandler.access============");
    
    local user_ip = utils.get_ip();
    
    local key = "limit:frequency:"..user_ip;

    local red = redis:new(opts)  --创建一个对象,注意是用冒号调用的
    
    --得到此客户端IP的频次
    local resp, err = red:get(key)
    
    if err then
        return ngx.exit(ngx.HTTP_INTERNAL_SERVER_ERROR) --redis 获取值失败
    end 

    if resp == nil then   
        utils.debug_log("===========key set ============");
        local result,err = red:set(key, 1) -- 单位时间 第一次访问
        utils.debug_log("===========key expire ============");
        result,err = red:expire(key, 10) --10秒时间 过期
    end  

    if type(resp) == "string" then 
        if tonumber(resp) > 10 then -- 超过10次
            return ngx.exit(ngx.HTTP_FORBIDDEN) --直接返回403
        end
    end

    --调用API设置key  
    local ok, err = red:incr(key)  
    if err then
        return ngx.exit(ngx.HTTP_INTERNAL_SERVER_ERROR) --redis 报错 
    end 

end

return LimitFrequencyHandler;

(2)配置文件加入插件

bash 复制代码
vi /usr/local/lua/gateway/conf/gateway.conf
bash 复制代码
{
    "plugins": [
        "limit_ip",
        "limit_frequency"
    ]
}

(3)测试

重启nginx:

10秒内访问超过10次,就会出现403 Forbidden。

三、总结

其实没必要自己手动实现一个网关,因为现有的开源网关功能很强大也很好用,比如kong,orange开源网关框架。

相关推荐
啦啦右一1 小时前
Spring Boot | (一)Spring开发环境构建
spring boot·后端·spring
森屿Serien1 小时前
Spring Boot常用注解
java·spring boot·后端
盛派网络小助手3 小时前
微信 SDK 更新 Sample,NCF 文档和模板更新,更多更新日志,欢迎解锁
开发语言·人工智能·后端·架构·c#
∝请叫*我简单先生3 小时前
java如何使用poi-tl在word模板里渲染多张图片
java·后端·poi-tl
zquwei4 小时前
SpringCloudGateway+Nacos注册与转发Netty+WebSocket
java·网络·分布式·后端·websocket·网络协议·spring
dessler4 小时前
Docker-run命令详细讲解
linux·运维·后端·docker
Q_19284999065 小时前
基于Spring Boot的九州美食城商户一体化系统
java·spring boot·后端
ZSYP-S5 小时前
Day 15:Spring 框架基础
java·开发语言·数据结构·后端·spring
Yuan_o_6 小时前
Linux 基本使用和程序部署
java·linux·运维·服务器·数据库·后端