基于AI的WAF小demo
字数 788 2025-08-24 07:48:33

基于AI的WAF实现教学文档

1. 项目概述

本项目实现了一个基于AI的Web应用防火墙(WAF)原型系统,主要功能是通过机器学习模型自动识别恶意URL请求。系统由两部分组成:

  • Go语言实现的Web服务器
  • Python实现的AI检测服务

2. 系统架构

2.1 整体架构

用户请求 → Go服务器 → Python AI服务 → 返回检测结果 → Go服务器响应

2.2 工作流程

  1. 用户访问Go服务器上的URL
  2. Go服务器将URL发送给Python AI服务进行检测
  3. AI服务返回检测结果(0=正常/1=恶意)
  4. Go服务器根据结果决定是否拦截请求

3. 数据集准备

3.1 恶意URL数据集生成

工具准备:

  • Xray漏洞扫描工具
  • Pikachu漏洞测试环境(Docker)

步骤:

  1. 启动Pikachu环境:
docker run --name pikachu -d -p 8000:80 area39/pikachu
  1. 使用Xray扫描生成恶意请求:
./xray_darwin_arm64 webscan --basic-crawler http://127.0.0.1:8000/ --html-output xx.html
  1. 使用Go程序捕获网络流量(lo0网卡):
package main

import (
    "bufio"
    "bytes"
    "fmt"
    "log"
    "net/http"
    "os"
    "github.com/google/gopacket"
    "github.com/google/gopacket/pcap"
)

func main() {
    handle, err := pcap.OpenLive("lo0", 65536, true, pcap.BlockForever)
    if err != nil {
        log.Fatal(err)
    }
    defer handle.Close()

    filter := "tcp and port 8000 and tcp[((tcp[12:1] & 0xf0) >> 2):4] = 0x47455420"
    if err := handle.SetBPFFilter(filter); err != nil {
        log.Fatal(err)
    }

    file, err := os.Create("url_info.txt")
    if err != nil {
        log.Fatal(err)
    }
    defer file.Close()

    packetSource := gopacket.NewPacketSource(handle, handle.LinkType())
    for packet := range packetSource.Packets() {
        if applicationLayer := packet.ApplicationLayer(); applicationLayer != nil {
            if request, err := http.ReadRequest(bufio.NewReader(bytes.NewBuffer(applicationLayer.Payload()))); err == nil {
                url := request.URL.String()
                if _, err := file.WriteString(url + "\n"); err != nil {
                    log.Println(err)
                }
                fmt.Println(url)
            }
        }
    }
}

3.2 良性URL数据集生成

方法一: 模拟点击爬取

from requests_html import HTMLSession
import time

def save_links_to_file(links):
    with open('urls.txt', 'a') as f:
        for link in links:
            f.write(link + '\n')

def crawl(url, end_time):
    session = HTMLSession()
    r = session.get(url)
    links = r.html.absolute_links
    urls = []
    for link in links:
        if link.startswith('http'):
            urls.append(link)
    save_links_to_file(urls)
    for url in urls:
        if time.time() > end_time:
            return
        crawl(url, end_time)

start_url = 'http://127.0.0.1:8000/'
end_time = time.time() + 3
crawl(start_url, end_time)

方法二: 浏览器插件收集

// ==UserScript==
// @name Append URL to Local File
// @namespace http://tampermonkey
// @version 1.0
// @description Append current URL to local file on page load
// @match *://*/*
// @grant GM_xmlhttpRequest
// ==/UserScript==

(function() {
    'use strict';
    const currentUrl = window.location.href;
    GM_xmlhttpRequest({
        method: 'POST',
        url: 'http://your-server:7000',
        data: currentUrl,
        headers: {'Content-Type': 'text/plain'},
        onload: function(response) {
            console.log('URL sent to server');
        },
        onerror: function(error) {
            console.error('Error sending URL to server:', error);
        }
    });
})();

服务器端(Node.js):

const http = require('http');
const fs = require('fs');

const server = http.createServer((req, res) => {
    if (req.method === 'POST') {
        let body = '';
        req.on('data', chunk => {
            body += chunk.toString();
        });
        req.on('end', () => {
            fs.appendFile('urls.txt', body + '\n', err => {
                if (err) {
                    console.error('Error appending URL to file:', err);
                    res.statusCode = 500;
                    res.end('Error appending URL to file');
                } else {
                    console.log('URL appended to file:', body);
                    res.statusCode = 200;
                    res.end('URL appended to file');
                }
            });
        });
    } else {
        res.statusCode = 404;
        res.end('Not found');
    }
});

server.listen(7000, () => {
    console.log('Server running on port 7000');
});

4. 机器学习模型训练

4.1 数据预处理

import os
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
import urllib.parse

def load(name):
    filepath = os.path.join(str(os.getcwd()), name)
    with open(filepath, 'r') as f:
        alldata = f.readlines()
    ans = []
    for i in alldata:
        i = str(urllib.parse.unquote(i))
        ans.append(i)
    return ans

badqueries = load('badqueries.txt')
goodqueries = load('goodqueries.txt')

4.2 特征工程与模型训练

# 标签设置: 1=恶意, 0=良性
Y = [1 for i in range(0, len(badqueries))] + [0 for i in range(0, len(goodqueries))]

# 使用TF-IDF将URL文本转换为特征向量
vectorizer = TfidfVectorizer()
X = vectorizer.fit_transform(badqueries + goodqueries)

# 划分训练集和测试集
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2, random_state=42)

# 使用逻辑回归模型(带类别权重平衡)
lgs = LogisticRegression(class_weight='balanced')
lgs.fit(X_train, Y_train)

# 评估模型
print(lgs.score(X_test, Y_test))  # 示例输出: 0.9941273293313274

5. 服务实现

5.1 Python AI服务(Flask)

from flask import Flask, request
import os
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
import urllib.parse

# ... (数据加载和模型训练代码同上)

app = Flask(__name__)

@app.route('/check_url', methods=['POST'])
def check_url():
    url = request.form['url']
    X_predict = vectorizer.transform([url])
    res = lgs.predict(X_predict)
    if res == 1:
        return '1\n'
    return '0\n'

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=8889)

5.2 Go Web服务器

package main

import (
	"fmt"
	"io/ioutil"
	"net/http"
	"strings"
)

func main() {
	http.HandleFunc("/", handleRequest)
	http.ListenAndServe(":8080", nil)
}

func handleRequest(w http.ResponseWriter, r *http.Request) {
	url := r.URL.Path[1:]
	result, err := checkURL(url)
	if err != nil {
		http.Error(w, "500 Internal Server Error", http.StatusInternalServerError)
		return
	}
	if result == 1 {
		http.Error(w, "403 Forbidden", http.StatusForbidden)
	} else {
		fmt.Fprintf(w, "Hello, %s!", url)
	}
}

func checkURL(url string) (int, error) {
	data := fmt.Sprintf("url=%s", url)
	req, err := http.NewRequest("POST", "http://localhost:8889/check_url", strings.NewReader(data))
	if err != nil {
		return 0, err
	}
	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
	
	client := &http.Client{}
	resp, err := client.Do(req)
	if err != nil {
		return 0, err
	}
	defer resp.Body.Close()
	
	body, err := ioutil.ReadAll(resp.Body)
	if err != nil {
		return 0, err
	}
	
	result := string(body)
	if result == "1\n" {
		return 1, nil
	}
	return 0, nil
}

6. 改进方向

  1. 模型优化:

    • 尝试CNN将URL转化为图像进行分类
    • 使用更复杂的NLP模型进行多分类
  2. 系统健壮性:

    • 实现模型持久化,避免每次启动都重新训练
    • 添加Python服务宕机时的备用方案
  3. 安全增强:

    • 对频繁触发检测的用户进行IP封禁
    • 记录用户行为用于后续分析
  4. 功能扩展:

    • 支持检测请求体和请求头中的恶意内容
    • 实现更复杂的动态网站防护
  5. 性能优化:

    • 使用更高效的序列化协议替代HTTP
    • 实现批量检测减少请求次数

7. 总结

本系统展示了如何将机器学习应用于Web安全防护,核心思路是通过分析URL特征自动识别恶意请求。虽然当前实现较为简单,但提供了完整的AI WAF原型,可作为进一步开发的基础。

基于AI的WAF实现教学文档 1. 项目概述 本项目实现了一个基于AI的Web应用防火墙(WAF)原型系统,主要功能是通过机器学习模型自动识别恶意URL请求。系统由两部分组成: Go语言实现的Web服务器 Python实现的AI检测服务 2. 系统架构 2.1 整体架构 2.2 工作流程 用户访问Go服务器上的URL Go服务器将URL发送给Python AI服务进行检测 AI服务返回检测结果(0=正常/1=恶意) Go服务器根据结果决定是否拦截请求 3. 数据集准备 3.1 恶意URL数据集生成 工具准备 : Xray漏洞扫描工具 Pikachu漏洞测试环境(Docker) 步骤 : 启动Pikachu环境: 使用Xray扫描生成恶意请求: 使用Go程序捕获网络流量(lo0网卡): 3.2 良性URL数据集生成 方法一: 模拟点击爬取 方法二: 浏览器插件收集 服务器端(Node.js) : 4. 机器学习模型训练 4.1 数据预处理 4.2 特征工程与模型训练 5. 服务实现 5.1 Python AI服务(Flask) 5.2 Go Web服务器 6. 改进方向 模型优化 : 尝试CNN将URL转化为图像进行分类 使用更复杂的NLP模型进行多分类 系统健壮性 : 实现模型持久化,避免每次启动都重新训练 添加Python服务宕机时的备用方案 安全增强 : 对频繁触发检测的用户进行IP封禁 记录用户行为用于后续分析 功能扩展 : 支持检测请求体和请求头中的恶意内容 实现更复杂的动态网站防护 性能优化 : 使用更高效的序列化协议替代HTTP 实现批量检测减少请求次数 7. 总结 本系统展示了如何将机器学习应用于Web安全防护,核心思路是通过分析URL特征自动识别恶意请求。虽然当前实现较为简单,但提供了完整的AI WAF原型,可作为进一步开发的基础。