开发高并发 高扩展的ai WAF尝试
字数 1129 2025-08-24 07:48:33
高并发高扩展AI WAF系统开发教学文档
1. 系统概述
本教学文档详细讲解如何开发一个基于AI的高并发、高扩展性Web应用防火墙(WAF)系统。该系统采用微服务架构,结合机器学习技术,实现了对恶意URL请求的实时检测和防护。
2. 系统架构
2.1 整体架构
系统由以下核心组件构成:
- Web服务器:Go语言编写,处理用户请求
- WAF中间件:Go语言编写,负责请求过滤和分发
- AI服务器:Python编写,执行机器学习检测
- 消息队列:RabbitMQ实现组件间通信
- 缓存系统:Redis存储黑名单和URL缓存
- 数据库:MySQL存储用户凭证
2.2 数据流设计
- 用户请求 → Web服务器
- Web服务器 → RabbitMQ队列1 (携带URL和Session)
- WAF消费队列1 → 检查Redis缓存
- 缓存命中:直接返回结果
- 缓存未命中:发送到RabbitMQ队列2
- AI服务器消费队列2 → 机器学习检测 → 结果写入Redis
- 检测到恶意请求 → 用户加入黑名单
3. 核心组件实现
3.1 Web服务器实现
3.1.1 用户认证
// 用户结构体
type User struct {
Username string `json:"username"`
Password string `json:"password"`
}
// JWT Claims结构体
type Claims struct {
Username string `json:"username"`
jwt.StandardClaims
}
// 登录处理
http.HandleFunc("/login", func(w http.ResponseWriter, r *http.Request) {
// 解析请求体
var user User
err := json.NewDecoder(r.Body).Decode(&user)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
// 验证用户凭证
var user1 User
_ = db.QueryRow("SELECT username, password FROM users WHERE username = ?",
user.Username).Scan(&user1.Username, &user1.Password)
if user != user1 {
http.Error(w, "Invalid username or password", http.StatusUnauthorized)
return
}
// 生成JWT Token
claims := &Claims{
Username: user.Username,
StandardClaims: jwt.StandardClaims{
ExpiresAt: jwt.TimeFunc().Add(time.Hour * 24).Unix(),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, err := token.SignedString([]byte("secret"))
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// 返回Token
w.Header().Set("Authorization", fmt.Sprintf("Bearer %s", tokenString))
w.Write([]byte("Hello World"))
})
3.1.2 请求处理
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
// 验证Token
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
http.Error(w, "Missing authorization header", http.StatusUnauthorized)
return
}
tokenString := authHeader[len("Bearer "):]
token, err := jwt.ParseWithClaims(tokenString, &Claims{},
func(token *jwt.Token) (interface{}, error) {
return []byte("secret"), nil
})
// 检查黑名单
claims, ok := token.Claims.(*Claims)
if !ok || redisClient.SIsMember("blacklist", claims.Username).Val() {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
// 发送到消息队列
err = ch.Publish("", q.Name, false, false, amqp.Publishing{
ContentType: "text/plain",
Body: []byte(fmt.Sprintf("Session: %s, URL: %s", claims.Username, r.URL)),
})
w.Write([]byte("Hello World"))
})
3.2 WAF实现
// 消费消息队列
for msg := range msgs {
// 解析Session和URL
var Session, URL string
fields := strings.Split(string(msg.Body), ",")
for _, field := range fields {
if strings.HasPrefix(field, "Session:") {
Session = strings.TrimSpace(strings.TrimPrefix(field, "Session:"))
}
}
re := regexp.MustCompile(`URL:\s*(.*)`)
match := re.FindStringSubmatch(string(msg.Body))
if len(match) > 1 {
URL = match[1]
}
// 检查缓存
val, _ := redisClient.HGet("cache", URL).Result()
if val == "" {
// 未命中缓存,发送到AI服务器
err = ch.Publish("", "newurls", false, false, amqp.Publishing{
ContentType: "text/plain",
Body: []byte(fmt.Sprintf("Session: %s, URL: %s", Session, URL)),
})
} else if val == "1" {
// 恶意URL,加入黑名单
redisClient.SAdd("blacklist", Session)
}
}
3.3 AI服务器实现
3.3.1 机器学习模型
import joblib
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
# 加载训练数据
def load(name):
with open(name, 'r') as f:
alldata = f.readlines()
return [str(urllib.parse.unquote(i)) for i in alldata]
badqueries = load('badqueries.txt')
goodqueries = load('goodqueries.txt')
# 特征提取和模型训练
vectorizer = TfidfVectorizer()
X = vectorizer.fit_transform(badqueries + goodqueries)
lgs = LogisticRegression(class_weight='balanced')
lgs.fit(X, [1]*len(badqueries) + [0]*len(goodqueries))
# 保存模型
joblib.dump(lgs, 'lgs.model')
3.3.2 队列消费
def callback(ch, method, properties, body):
# 解析Session和URL
session, url = None, None
for field in body.decode().split(", "):
if field.startswith("Session:"):
session = field.split(":")[1].strip()
elif field.startswith("URL:"):
url = field.split(":")[1].strip()
# 预测URL
tmp = int(check(url))
# 更新缓存
if not r.hexists('cache', url):
r.hset('cache', url, tmp)
print(f"{url} 已加入cache数据库")
if tmp: # 恶意URL
r.sadd('blacklist', session)
print(f"{session} 已拉黑")
# 启动消费者
channel.basic_consume(queue='newurls', on_message_callback=callback, auto_ack=True)
channel.start_consuming()
4. 环境部署
4.1 Docker容器部署
# Redis
docker run -d -p 6379:6379 redis
# MySQL
docker run -e MYSQL_ROOT_PASSWORD=123456 -p 3306:3306 -d mysql
# RabbitMQ
docker run -it --rm --name rabbitmq -p 5672:5672 -p 15672:15672 rabbitmq:3.11-management
4.2 服务启动顺序
- 启动基础设施服务(Redis, MySQL, RabbitMQ)
- 启动AI服务器(python)
- 启动WAF服务(go)
- 启动Web服务器(go)
5. 关键设计决策
5.1 架构设计优势
- 异步处理:通过消息队列解耦,避免同步阻塞
- 缓存优化:Redis缓存减少重复计算
- 扩展性:可轻松添加更多AI服务器实例
- 性能隔离:机器学习处理不影响Web请求响应
5.2 技术选型理由
- Go语言:高性能Web服务器和中间件
- Python:丰富的机器学习生态
- RabbitMQ:数据可靠性高于Kafka
- Redis:低延迟缓存和黑名单管理
6. 性能优化建议
-
模型优化:
- 使用更高效的模型(如XGBoost)
- 实现模型热更新
- 添加特征工程
-
架构扩展:
- 将队列2改为交换机模式,支持多类型检测
- 添加负载均衡和多AI服务器实例
- 实现分布式缓存
-
安全增强:
- 添加IP黑名单功能
- 实现速率限制
- 增加更多检测维度(如请求头、请求体)
7. 总结
本系统通过结合微服务架构和机器学习技术,实现了高性能、高扩展性的AI WAF解决方案。关键创新点包括:
- 异步处理流水线设计
- 机器学习实时检测
- 多层级缓存优化
- 模块化可扩展架构
该系统可作为企业级Web安全防护的基础框架,通过持续迭代优化,可满足各类高并发场景下的安全防护需求。