fupanc's blog

rate

一个go语言中的limiter,通过这个来实现对速率的限制,下面来了解一下相关的使用方式。

直接下载相关库即可:

go get golang.org/x/time/rate

基本了解

令牌桶算法

这个算法常见于各种限流方式中,经常用于qps限制,以代码形式表示令牌桶的结果大概如下:

type TokenBucket struct {
    capacity       int64     // 桶最大容量
    tokens         float64   // 当前令牌数量
    rate           float64   // 每秒生成令牌速度
    lastRefillTime int64     // 上次补充时间
}

这四个变量就可以用于描述整个限流状态。此算法实现的大致原理是:每次请求到达时,都会根据当前时间-上次补充时间(lastRefillTime),结合rate来进行计算需要更新的令牌并补充(但不会超过capacity),然后如果桶中的令牌足够,就会允许请求进行下去,否则就拒绝/让请求等待token补充,流程大致如下:

sequenceDiagram
    participant Client as 请求
    participant Bucket as TokenBucket

    Client->>Bucket: Allow()

    Bucket->>Bucket: 获取当前时间 now

    Bucket->>Bucket: now - lastRefillTime

    Bucket->>Bucket: 计算新增 token

    Bucket->>Bucket: 更新 tokens

    Bucket->>Bucket: 判断 token 是否足够

    alt token 足够
        Bucket->>Bucket: 扣减 token
        Bucket-->>Client: 放行
    else token 不够
        Bucket-->>Client: 拒绝
    end

    Bucket->>Bucket: 更新 lastRefillTime

是一种非常好的工程实践,用于限流。

库了解

这个rate库很好地实现了令牌桶算法,其源代码也是很好地表示了令牌桶需要的相关变量:

type Limiter struct {
	mu     sync.Mutex
	limit  Limit
	burst  int
	tokens float64
	// last is the last time the limiter's tokens field was updated
	last time.Time
	// lastEvent is the latest time of a rate-limited event (past or future)
	lastEvent time.Time
}

下面来简单看看这个库中一些函数的定义,后续直接用于代码场景中实现。

初始化:

limiter := rate.NewLimiter(2, 5)

参数解析:

相关函数定义:

理解一下下面这个代码就行:

package main

import (
	"context"
	"fmt"
	"golang.org/x/time/rate"
	"time"
)

func main() {

	//  0.5 个 token每秒
	// 桶容量最大 5
	limiter := rate.NewLimiter(0.5, 5)

	ctx := context.Background()

	fmt.Println("====== 初始 burst 测试 ======")

	// 初始桶是满的
	for i := 1; i <= 7; i++ {

		ok := limiter.Allow()

		fmt.Printf(
			"时间=%s 第%d次 Allow() => %v 当前Token≈%.6f\n",
			time.Now().Format("15:04:05.000"),
			i,
			ok,
			limiter.Tokens(),
		)
	}

	fmt.Println()
	fmt.Println("====== 等待 token 自动恢复 ======")

	// 等待 2 秒
	time.Sleep(2 * time.Second)

	fmt.Printf(
		"2秒后 Token≈%.6f\n",
		limiter.Tokens(),
	)

	fmt.Println()

	// 再次消费
	for i := 1; i <= 3; i++ {

		ok := limiter.Allow()

		fmt.Printf(
			"恢复后 第%d次 Allow() => %v Token≈%.6f\n",
			i,
			ok,
			limiter.Tokens(),
		)
	}

	fmt.Println()
	fmt.Println("====== Wait() 测试 ======")

	// 清空 token
	for limiter.Allow() {
	}

	fmt.Printf(
		"清空后 Token≈%.6f\n",
		limiter.Tokens(),
	)

	fmt.Println()

	// Wait 会阻塞直到 token 恢复
	for i := 1; i <= 3; i++ {

		start := time.Now()

		err := limiter.Wait(ctx)
		if err != nil {
			panic(err)
		}

		cost := time.Since(start)

		fmt.Printf(
			"第%d次 Wait() 等待时间=%v Token≈%.6f 时间=%s\n",
			i,
			cost,
			limiter.Tokens(),
			time.Now().Format("15:04:05.000"),
		)
	}
}

运行效果如下:

====== 初始 burst 测试 ======
时间=15:42:43.967 第1次 Allow() => true 当前Token≈4.000147
时间=15:42:43.968 第2次 Allow() => true 当前Token≈3.000177
时间=15:42:43.968 第3次 Allow() => true 当前Token≈2.000178
时间=15:42:43.968 第4次 Allow() => true 当前Token≈1.000179
时间=15:42:43.968 第5次 Allow() => true 当前Token≈0.000180
时间=15:42:43.968 第6次 Allow() => false 当前Token≈0.000181
时间=15:42:43.968 第7次 Allow() => false 当前Token≈0.000182

====== 等待 token 自动恢复 ======
2秒后 Token≈1.000648

恢复后 第1次 Allow() => true Token≈0.000685
恢复后 第2次 Allow() => false Token≈0.000689
恢复后 第3次 Allow() => false Token≈0.000690

====== Wait() 测试 ======
清空后 Token≈0.000695

第1次 Wait() 等待时间=1.998998459s Token≈0.000196 时间=15:42:47.968
第2次 Wait() 等待时间=2.00060725s Token≈0.000546 时间=15:42:49.968
第3次 Wait() 等待时间=1.999912667s Token≈0.000529 时间=15:42:51.968

在wait输出时有明显的顿挫感,并且看token数量能很明显看出每次toekn的更新都是通过计算来的,并不是什么定时器来更新token。

功能场景

主要是用来限制qps的,来看看一些场景下的代码实现方式。

QPS限制

主要是针对dast中的扫描模块的qps限制。下面来分别记录一下dast扫描模块的qps限制的代码实现,扫描模块目前实现了如下几个:

后面来看看具体的个模块的qps限制实现方式,原代码这里不多说,直接去看DAST文章设置的代码就行。

端口扫描

代码实现如下:

package main

import (
	"context"
	"fmt"
	"net"
	"strconv"
	"strings"
	"time"

	"golang.org/x/time/rate"
)

const Timeout = 2 * time.Second

// ---------------- 解析端口 ----------------
func parsePorts(portInput string) ([]int, error) {
	var ports []int

	// 判断是否是范围
	if strings.Contains(portInput, "-") {
		parts := strings.Split(portInput, "-")
		if len(parts) != 2 {
			return nil, fmt.Errorf("端口范围格式错误")
		}

		start, err1 := strconv.Atoi(parts[0])
		end, err2 := strconv.Atoi(parts[1])

		if err1 != nil || err2 != nil || start > end {
			return nil, fmt.Errorf("端口范围非法")
		}

		for i := start; i <= end; i++ {
			ports = append(ports, i)
		}
	} else {
		// 单个端口
		p, err := strconv.Atoi(portInput)
		if err != nil {
			return nil, fmt.Errorf("端口格式错误")
		}
		ports = append(ports, p)
	}

	return ports, nil
}

// ---------------- 扫描函数 ----------------
func scanPorts(ip string, ports []string) {
	// 创建速率限制器:每秒最多 10 个请求
	limiter := rate.NewLimiter(rate.Limit(10), 10)
	ctx := context.Background()

	for _, port := range ports {
		port_final, err := parsePorts(port)
		if err != nil {
			fmt.Println("解析错误:", err)
			return
		}
		for _, port_scan := range port_final {
			// ==================== QPS 限制 ====================
			// 每次扫描端口前等待 token
			if err := limiter.Wait(ctx); err != nil {
				fmt.Printf("速率限制错误: %v\n", err)
				continue
			}

			address := fmt.Sprintf("%s:%d", ip, port_scan)

			conn, err := net.DialTimeout("tcp", address, Timeout)
			if err == nil {
				conn.Close()
				fmt.Printf("✔ %d open\n", port_scan)
			}
		}
	}
}

// ---------------- main ----------------
func main() {
	ip := "www.baidu.com"
	port_field := []string{"80", "443"}

	scanPorts(ip, port_field)
}

测试效果如下:

✔ 80 open
✔ 443 open

很容易加逻辑,自行针对上述加相关运行时间逻辑+修改limiter就可以进行验证,这里就不多说了。

如果是多个目标,直接对目标分组然后并发就行,只需要在goroutine的逻辑里面加上如上代码相关的逻辑就可以针对但目标设置qps了。

服务识别

由于nmap官法并没有提供相关sdk,并且现在市面上大部分都是伪sdk,其实底层还是需要nmap二进制文件,故我就直接通过exec来调用nmap二进制文件,需要注意的是这里要做好对输入的检测,不要被恶意输入导致命令执行。

在前面就说了准备基于nmap的–max-rate选项来实现qps限制,这个选项的官法解释是:每秒发送数据包不超过多少个,很方便做qps限制。

最后测试代码如下:

package main

import (
	"bytes"
	"context"
	"encoding/xml"
	"fmt"
	"os/exec"
	"strings"
	"time"
)

type ServiceResult struct {
	IP      string
	Port    int
	State   string
	Service string
	Product string
	Version string
}

// 只保留最关键的 XML 结构
type nmapResult struct {
	Hosts []struct {
		Ports []struct {
			PortID   int    `xml:"portid,attr"`
			Protocol string `xml:"protocol,attr"`

			State struct {
				State string `xml:"state,attr"`
			} `xml:"state"`

			Service struct {
				Name    string `xml:"name,attr"`
				Product string `xml:"product,attr"`
				Version string `xml:"version,attr"` //将product+version其存入数据库用于前端显示,但这里并不参与后续的逻辑处理
			} `xml:"service"`
		} `xml:"ports>port"`
	} `xml:"host"`
}

func IdentifyService(host string, ports []string) ([]ServiceResult, error) {
	ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) ////给足够长的时间用于程序运行,覆盖dns解析等消耗的时长
	defer cancel()

	portScan := strings.Join(ports, ",")
	if portScan == "" {
		return nil, fmt.Errorf("port不能为空")
	}

	cmd := exec.CommandContext(ctx,
		"nmap",
		"-sV",               //探测开放端口以获取服务/版本信息,对于服务探测非常有必要
		"-Pn",               //将所有主机视为已上线状态,跳过主机发现步骤(因为前面的主机存活以及端口扫描模块已经完成了对应的内容)
		"-n",                //禁止反向域名解析
		"--max-rate", "100", //每秒最多发送多少请求包
		"-p", portScan,
		"-oX", "-", //结果输出为xml格式到stdout中
		host,
	)

	var out bytes.Buffer
	cmd.Stdout = &out

	if err := cmd.Run(); err != nil {
		return nil, fmt.Errorf("nmap 执行失败: %v", err)
	}

	var result nmapResult
	if err := xml.Unmarshal(out.Bytes(), &result); err != nil {
		return nil, fmt.Errorf("解析失败: %v", err)
	}

	if len(result.Hosts) == 0 {
		return nil, fmt.Errorf("无结果")
	}

	h := result.Hosts[0] //从流程来看只需要一个即可,但是看处理方式要不要将多个host的结果合并在一起,就可以使用数组,后面综合起来的时候再看看

	if len(h.Ports) == 0 {
		return nil, fmt.Errorf("无端口信息")
	}

	var results []ServiceResult

	for _, p := range h.Ports {
		results = append(results, ServiceResult{
			IP:      host,
			Port:    p.PortID,
			State:   p.State.State,
			Service: p.Service.Name,
			Product: p.Service.Product,
			Version: p.Service.Version,
		})
	}

	return results, nil
}

func main() {
	ports := []string{"3306", "8888"}

	start := time.Now()

	results, err := IdentifyService("127.0.0.1", ports)
	if err != nil {
		fmt.Println(err)
		return
	}

	for _, r := range results {
		fmt.Printf("IP=%s Port=%d Service=%s State=%s Version=%s\n",
			r.IP, r.Port, r.Service, r.State, r.Product+" "+r.Version)
	}

	end := time.Since(start)
	fmt.Println("整体运行消耗时间为:", end)
}

在命令执行部分加一个选项即可,然后再加了一个时间周期判断。

最后对比消耗时间的周期为:

//未加选项正常运行
6.25863975s
6.251501791s
6.22986275s

如果加上–max-rate选项来设置qps,分别的测试效果如下:

100:6.220211s
50:6.273320042s
30:6.277651709s
20:6.353169666s
10:6.339030583s
5:6.4969995s

可知一般nmap正常发包也就每秒几十到几百左右,故可以直接将qps默认设置为150,然后在实际扫描需求中再修改。

漏洞扫描

这就直接基于nuclei sdk提供的一些选项来进行配置,一般会基于如下几种配置来进行限速:

nuclei.WithGlobalRateLimitCtx(ctx, 100, time.Second) // 先定总上限
nuclei.WithConcurrency(nuclei.Concurrency{
    HostConcurrency:             50,
    TemplateConcurrency:         10,
    TemplatePayloadConcurrency:  10,
    ProbeConcurrency:            50,
})

大致的意思就是并发性控制要发送的请求队列长度,但是还是要受全局速率的限制。基于我到时候准备实现的dast扫描思路去每个扫描器接收的目标数量小于等于10个,所以整体的并发请求队列不会太多,所以这里考虑直接将并发性用作默认情况,然后只需要限制全局速率即可,最后的测试代码如下:

package main

import (
	"context"
	"fmt"
	"time"

	nuclei "github.com/projectdiscovery/nuclei/v3/lib"
	"github.com/projectdiscovery/nuclei/v3/pkg/catalog/disk"
	"github.com/projectdiscovery/nuclei/v3/pkg/output"
	"os"
	"path/filepath"
)

func main() {
	// 扫描目标数组(支持单目标或多目标)
	targets := []string{
		"127.0.0.1:8888",
	}

	// 扫描模板 ID 数组;为空则表示全量扫描
	templateIDs := []string{"local-flask", "multi-request-check"}

	//定义模版目录
	home, err := os.UserHomeDir()
	if err != nil {
		fmt.Println(err)
	}
	templateDir := filepath.Join(home, "nuclei-templates")

	start := time.Now()

	if err := runNucleiScan(context.Background(), targets, templateIDs, templateDir); err != nil {
		fmt.Fprintln(os.Stderr, "scan failed:", err)
		os.Exit(1)
	}

	end := time.Since(start)
	fmt.Println("运行消耗的时间为:", end)

}

func runNucleiScan(ctx context.Context, targets []string, templateIDs []string, templateDir string) error {
	if len(targets) == 0 {
		return fmt.Errorf("targets is empty")
	}

	opts := []nuclei.NucleiSDKOptions{
		nuclei.WithCatalog(disk.NewCatalog(templateDir)), //模版目录
		nuclei.DisableUpdateCheck(),

		// 全局速率限制:200 QPS
		nuclei.WithGlobalRateLimitCtx(ctx, 200, time.Second),
	}

	// 指定模板 ID 时,相当于 CLI 的 -id id1,id2
	if len(templateIDs) > 0 {
		opts = append(opts, nuclei.WithTemplateFilters(nuclei.TemplateFilters{
			IDs: templateIDs,
		}))
	}

	engine, err := nuclei.NewNucleiEngineCtx(ctx, opts...)
	if err != nil {
		return fmt.Errorf("create nuclei engine failed: %w", err)
	}
	defer engine.Close()

	// 全量扫描 / 指定 ID 扫描,注意这里只是加载了全量模版,但是如果指定了filter,就只会扫描指定id的模版文件
	if err := engine.LoadAllTemplates(); err != nil {
		return fmt.Errorf("load templates failed: %w", err)
	}

	// 这里直接加载对应的目标数组即可,另类实现-l
	engine.LoadTargets(targets, true) //如果是false的话就输入目标为ip就必须加上http,只走https了,故设置为true用于自行探测是什么协议

	// 回调里直接输出命中的结果
	return engine.ExecuteCallbackWithCtx(ctx, func(ev *output.ResultEvent) {
		if ev == nil || !ev.MatcherStatus {
			return
		}
		sev := ev.Info.SeverityHolder.Severity.String()
		if sev == "" {
			sev = "unknown"
		}
		fmt.Printf("[VULN] severity=%s template=%s name=%s Matched=%s\n",
			sev,
			ev.TemplateID,
			ev.Info.Name,
			ev.Matched,
		)
	})
}

运行效果如下:

[VULN] severity=info template=multi-request-check name=两个请求关联逻辑判断 POC Matched=http://127.0.0.1:8888
[VULN] severity=info template=local-flask name=两个请求关联逻辑判断 POC Matched=http://127.0.0.1:8888/?name=admin
运行消耗的时间为: 1.272003s

时间对比如下

原正常运行时间为:

运行消耗的时间为: 2.749350167s
运行消耗的时间为: 1.165794875s
运行消耗的时间为: 1.518048625s

加上全局速率限制运行时间为:

200:运行消耗的时间为: 1.324775792s
10:运行消耗的时间为: 2.172227166s
5:运行消耗的时间为: 3.335938666s

有效实现整体速率的限制。

——————————

弱口令扫描

这个很好加qps限制,理清代码逻辑就知道可以加在哪里了,原来的模块文件结构为:

.
├── main.go                      # 主程序入口(包含测试目标)
├── dictionary/
│   └── dictionary.go            # 弱口令字典(用户名和密码列表及排列组合逻辑)
├── model/
│   └── task.go                  # 数据模型(Task 和 Result 结构体定义)
├── checker/
│   ├── checker.go               # Checker 接口定义
│   └── dispatcher.go            # 任务分发器
├── services/
│   ├── ssh.go                   # SSH 弱口令爆破实现
│   ├── mysql.go                 # MySQL 弱口令爆破实现
│   └── redis.go                 # Redis 弱口令爆破实现
.

这里需要改一下main.go和dispatcher.go文件,分别如下:

package main

import (
	"context"
	"fmt"
	"sync"
	"time"

	"Go_ENV/checker"
	"Go_ENV/dictionary"
	"Go_ENV/model"
	"golang.org/x/time/rate"
)

// ServiceInfo 服务信息
type ServiceInfo struct {
	Service string
	Port    int
}

// HostTarget 主机目标(双层结构)
type HostTarget struct {
	Host     string
	Services []ServiceInfo
}

func main() {
	// ==================== 双层数据结构 ====================
	// 第一层:Host
	// 第二层:该 Host 上识别出的服务
	targets := []HostTarget{
		{
			Host: "127.0.0.1",
			Services: []ServiceInfo{
				{"ssh", 22},
				{"mysql", 3306},
				{"redis", 6379},
			},
		},
		// 可以添加更多 Host
		// {
		// 	Host: "192.168.1.100",
		// 	Services: []ServiceInfo{
		// 		{"ssh", 22},
		// 		{"mysql", 3306},
		// 	},
		// },
	}

	// QPS 配置
	const qpsLimit = 1
	const burstSize = 10

	// 使用 WaitGroup 支持多 Host 并发
	var wg sync.WaitGroup

	// 遍历每个 Host(可并发)
	for _, hostTarget := range targets {
		wg.Add(1)
		go func(ht HostTarget) {
			defer wg.Done()
			scanHost(ht, qpsLimit, burstSize)
		}(hostTarget)
	}

	wg.Wait()
	fmt.Println("\n所有主机扫描完成!")
}

// scanHost 扫描单个主机的所有服务(带 QPS 限制)
func scanHost(hostTarget HostTarget, qpsLimit, burstSize int) {
	fmt.Printf("\n========================================\n")
	fmt.Printf("开始扫描主机: %s\n", hostTarget.Host)
	fmt.Printf("服务数量: %d\n", len(hostTarget.Services))
	fmt.Printf("QPS 限制: %d 请求/秒\n", qpsLimit)
	fmt.Printf("========================================\n")

	// ==================== 为该 Host 创建独立的速率限制器 ====================
	limiter := rate.NewLimiter(rate.Limit(qpsLimit), burstSize)
	ctx := context.Background()

	// 遍历该 Host 的所有服务
	for _, svc := range hostTarget.Services {
		fmt.Printf("\n[%s] 开始爆破 %s://%s:%d\n", hostTarget.Host, svc.Service, hostTarget.Host, svc.Port)

		// 获取对应服务的弱口令字典
		credentials := dictionary.GetCredentials(svc.Service)
		fmt.Printf("[%s] 字典大小: %d 组凭证\n", hostTarget.Host, len(credentials))

		// 根据字典生成任务列表
		tasks := make([]model.Task, 0, len(credentials))
		for _, cred := range credentials {
			task := model.Task{
				Service:  svc.Service,
				Host:     hostTarget.Host,
				Port:     svc.Port,
				Username: cred.Username,
				Password: cred.Password,
				Timeout:  3 * time.Second,
			}

			// 针对不同服务设置特定参数
			switch svc.Service {
			case "mysql":
				task.MySQLDB = ""
			case "redis":
				task.RedisDB = 0
			}

			tasks = append(tasks, task)
		}

		// ==================== 应用 QPS 限制 ====================
		dispatcher := checker.NewDispatcher()
		startTime := time.Now()

		for _, task := range tasks {
			if err := limiter.Wait(ctx); err != nil {
				continue
			}

			result := dispatcher.Run(ctx, task)
			if result.OK {
				fmt.Printf("[%s] ✓ service=%s user='%s' pass='%s'\n",
					hostTarget.Host, result.Service, result.Username, result.Password)
			}
		}

		elapsed := time.Since(startTime)
		actualQPS := float64(len(tasks)) / elapsed.Seconds()
		fmt.Printf("[%s] %s 完成: 请求数=%d 耗时=%.2fs 实际QPS=%.2f\n",
			hostTarget.Host, svc.Service, len(tasks), elapsed.Seconds(), actualQPS)
	}

	fmt.Printf("\n[%s] 主机扫描完成!\n", hostTarget.Host)
}
package checker

import (
	"context"
	"fmt"
	"strings"

	"Go_ENV/model"
	"Go_ENV/services"
)

type Dispatcher struct {
	registry map[string]Checker
}

func NewDispatcher() *Dispatcher {
	return &Dispatcher{
		registry: map[string]Checker{
			"ssh":   services.NewSSHChecker(),
			"mysql": services.NewMySQLChecker(),
			"redis": services.NewRedisChecker(),
		},
	}
}

func (d *Dispatcher) Run(ctx context.Context, task model.Task) model.Result {
	task.Service = strings.ToLower(strings.TrimSpace(task.Service))
	checker, ok := d.registry[task.Service]
	if !ok {
		return model.Result{
			Service:  task.Service,
			Target:   fmt.Sprintf("%s:%d", task.Host, task.Port),
			Username: task.Username,
			Password: task.Password,
			OK:       false,
			Err:      fmt.Errorf("unsupported service: %s", task.Service),
		}
	}

	return checker.Check(ctx, task)
}

这里改成了基于host的扫描,更好设置针对不同主机的qps设置,逻辑很正确,最后运行效果如下:

========================================
开始扫描主机: 127.0.0.1
服务数量: 3
QPS 限制: 1 请求/秒
========================================

[127.0.0.1] 开始爆破 ssh://127.0.0.1:22
[127.0.0.1] 字典大小: 66 组凭证
[127.0.0.1] ssh 完成: 请求数=66 耗时=56.00s 实际QPS=1.18

[127.0.0.1] 开始爆破 mysql://127.0.0.1:3306
[127.0.0.1] 字典大小: 32 组凭证
[127.0.0.1] ✓ service=mysql user='root' pass='root'
[127.0.0.1] mysql 完成: 请求数=32 耗时=32.01s 实际QPS=1.00

[127.0.0.1] 开始爆破 redis://127.0.0.1:6379
[127.0.0.1] 字典大小: 5 组凭证
[127.0.0.1] ✓ service=redis user='' pass='redis'
[127.0.0.1] redis 完成: 请求数=5 耗时=5.00s 实际QPS=1.00

[127.0.0.1] 主机扫描完成!

所有主机扫描完成!

成功实现了针对不同主机的qps限流。

如下是针对ncueli扫描模块的qps限制:

opts := []nuclei.NucleiSDKOptions{
		nuclei.WithCatalog(disk.NewCatalog(templateDir)),
		nuclei.DisableUpdateCheck(),

		nuclei.WithGlobalRateLimitCtx(ctx, qpsPerTarget, time.Second),

		nuclei.WithConcurrency(nuclei.Concurrency{
			TemplateConcurrency:           25,
			HostConcurrency:               qpsPerTarget,
			HeadlessHostConcurrency:       qpsPerTarget,
			HeadlessTemplateConcurrency:   10,
			JavascriptTemplateConcurrency: 10,
			TemplatePayloadConcurrency:    25,
			ProbeConcurrency:              50,
		}),
	}