package Middlewares
import (
"github.com/gin-gonic/gin"
"strconv"
"time"
"voteapi/pkg/app/response"
"voteapi/pkg/gredis"
"voteapi/pkg/util"
)
const IP_LIMIT_NUM_KEY = "ipLimit:ipLimitNum"
const IP_BLACK_LIST_KEY = "ipLimit:ipBlackList"
var prefix = "{gateway}"
var delaySeconds int64 = 60 // 观察时间跨度,秒
var maxAttempts int64 = 10000 // 限制请求数
var blackSeconds int64 = 0 // 封禁时长,秒,0-不封禁
func GateWayPlus() gin.HandlerFunc {
return func(c *gin.Context) {
path := c.FullPath()
clientIp := c.ClientIP()
// redis配置集群时必须
param := make(map[string]string)
param["path"] = path
param["clientIp"] = clientIp
if !main(param) {
c.Abort()
response.JsonResponseError(c, "当前IP请求过于频繁,暂时被封禁~")
}
}
}
func main(param map[string]string) bool {
// 预知的IP黑名单
var blackList []string
if util.InStringArray(param["clientIp"], blackList) {
return false
}
// 预知的IP白名单
var whiteList []string
if util.InStringArray(param["clientIp"], whiteList) {
return false
}
blackKey := prefix + ":" + IP_BLACK_LIST_KEY
limitKey := prefix + ":" + IP_LIMIT_NUM_KEY
curr := time.Now().Unix()
item := util.Md5(param["path"] + "|" + param["clientIp"])
return normal(blackKey, limitKey, item, curr)
}
// 普通模式
func normal(blackKey string, limitKey string, item string, time int64) (res bool) {
if blackSeconds > 0 {
timeout, _ := gredis.RawCommand("HGET", blackKey, item)
if timeout != nil {
to, _ := strconv.Atoi(string(timeout.([]uint8)))
if int64(to) > time {
// 未解封
return false
}
// 已解封,移除黑名单
gredis.RawCommand("HDEL", blackKey, item)
}
}
l, _ := gredis.RawCommand("HGET", limitKey, item)
if l != nil {
last, _ := strconv.Atoi(string(l.([]uint8)))
if int64(last) >= maxAttempts {
return false
}
}
num, _ := gredis.RawCommand("HINCRBY", limitKey, item, 1)
if ttl, _ := gredis.TTLKey(limitKey); ttl == int64(-1) {
gredis.Expire(limitKey, int64(delaySeconds))
}
if num.(int64) >= maxAttempts blackSeconds > 0 {
// 加入黑名单
gredis.RawCommand("HSET", blackKey, item, time+blackSeconds)
// 删除记录
gredis.RawCommand("HDEL", limitKey, item)
}
return true
}
// LUA脚本模式
// 支持redis集群部署
func luaScript(blackKey string, limitKey string, item string, time int64) (res bool) {
script := `
local blackSeconds = tonumber(ARGV[5])
if(blackSeconds > 0)
then
local timeout = redis.call('hget', KEYS[1], ARGV[1])
if(timeout ~= false)
then
if(tonumber(timeout) > tonumber(ARGV[2]))
then
return false
end
redis.call('hdel', KEYS[1], ARGV[1])
end
end
local last = redis.call('hget', KEYS[2], ARGV[1])
if(last ~= false and tonumber(last) >= tonumber(ARGV[3]))
then
return false
end
local num = redis.call('hincrby', KEYS[2], ARGV[1], 1)
local ttl = redis.call('ttl', KEYS[2])
if(ttl == -1)
then
redis.call('expire', KEYS[2], ARGV[4])
end
if(tonumber(num) >= tonumber(ARGV[3]) and blackSeconds > 0)
then
redis.call('hset', KEYS[1], ARGV[1], ARGV[2] + ARGV[5])
redis.call('hdel', KEYS[2], ARGV[1])
end
return true
`
result, err := gredis.RawCommand("EVAL", script, 2, blackKey, limitKey, item, time, maxAttempts, delaySeconds, blackSeconds)
if err != nil {
return false
}
if result == int64(1) {
return true
} else {
return false
}
}
// 加锁限频,输出次数大概率小于最大值
func ExecLimit(lastExecTime *time.Time, l *sync.RWMutex ,maxTimes int, perDuration time.Duration, f func()) {
l.Lock()
defer l.Unlock()
// per times cost time(s)
SecondsPerTimes := float64(perDuration) / float64(time.Second) / float64(maxTimes)
now := time.Now()
interval := now.Sub(*lastExecTime).Seconds()
if interval SecondsPerTimes {
time.Sleep(time.Duration(int64((SecondsPerTimes-interval)*1000000000)) * time.Nanosecond)
}
f()
*lastExecTime = time.Now()
}
// 官方的,需要引用 "golang.org/x/time/rate"
// 基本上可以达到满值,比自己写的更优
func ExecLimit2(l *rate.Limiter, f func()) {
go func() {
l.Wait(context.Background())
f()
}()
}
func TestExecLimit(t *testing.T) {
runtime.GOMAXPROCS(runtime.NumCPU())
go func() {
var lastExecTime time.Time
var l sync.RWMutex
for {
ExecLimit(lastExecTime, l, 10, time.Second, func() {
fmt.Println("do")
})
}
}()
select {
case -time.After(1 * time.Second):
fmt.Println("1秒到时")
}
}
func TestExecLimit2(t *testing.T) {
runtime.GOMAXPROCS(runtime.NumCPU())
l := rate.NewLimiter(1, 30)
go func() {
for {
ExecLimit2(l, func() {
fmt.Println("do")
})
}
}()
select {
case -time.After(1 * time.Second):
fmt.Println("1秒到时")
}
}
上述使用,定义在某个服务节点的全局变量lastExecTime仅仅会对该服务的函数f()操作限频,如果在负载均衡后,多个相同服务的节点,对第三方的接口累计限频,比如三个服务共同拉取第三方接口,合计限频为30次/s.