基于安全解码防御越狱攻击
字数 1330 2025-08-20 18:18:39
基于安全解码防御大模型越狱攻击的教学文档
1. 前言
本教学文档详细介绍了基于安全解码(SafeDecoding)防御大型语言模型(LLMs)越狱攻击的方法。该方法通过设计一种安全意识解码策略,在不牺牲模型对良性查询响应能力的同时,有效降低越狱攻击成功率。
2. 背景知识
2.1 大模型中的解码机制
在LLMs中,解码是生成响应序列的关键步骤。给定自回归模型θ,对于令牌序列x₁:n-₁,第n个令牌xₙ的概率表示为:
p(xₙ|x₁:n-₁) = softmax(f(x₁:n-₁))
常用解码策略包括:
- 贪婪解码:选择最高概率令牌
- 束搜索:维护多个可能序列
- Top-k解码:仅考虑概率最高的k个令牌
- Top-p(Nucleus)解码:选择概率总和达到p的令牌子集
2.2 越狱攻击目标
越狱攻击旨在诱导LLM产生不安全的响应,其成功度通过攻击成功率(ASR)衡量:
ASR = (#成功攻击)/(#总攻击尝试)
攻击者通过解决以下优化问题构造攻击序列:
argmax p(xₙ:|x₁:n-₁), s.t. xₙ: ∈ H
其中H表示与攻击目标一致的提示集合。
3. SafeDecoding方法
3.1 核心思想
关键观察:
- 攻击成功源于攻击目标令牌序列概率占主导
- 安全免责声明(如"抱歉,我不能...")仍存在于令牌样本空间中
SafeDecoding通过:
- 减弱攻击目标一致的令牌概率
- 增强安全相关令牌概率
3.2 方案总览
3.2.1 训练阶段
- 收集有害查询数据集
- 使用原始模型生成拒绝响应
- 用GPT-4过滤有效拒绝响应
- 使用LoRA等方法微调原始模型,创建专家模型
3.2.2 推理阶段
- 用户查询同时发送给原始模型和专家模型
- 构建新的令牌分布
- 基于新分布采样生成响应
3.3 详细实现
3.3.1 样本空间构建
对于第n步解码:
- 获取原始模型和专家模型的前k个令牌Vₖⁿ和V'ₖⁿ
- 构建样本空间Vⁿ(c) = Vₖⁿ ∩ V'ₖⁿ
- 参数c控制样本空间大小
3.3.2 概率函数定义
对于x ∈ Vⁿ(c),定义概率:
Pₙ(x) ∝ exp(log pθ(x|x₁:n-₁) + α·(log pθ'(x|x₁:n-₁) - log pθ(x|x₁:n-₁)))
其中α ≥ 0是权重参数
3.3.3 优化策略
- 仅在前m个解码步骤应用SafeDecoding
- 后续步骤使用常规解码方法
- 平衡安全性和计算效率
4. 代码实现关键点
4.1 类定义
class SafeDecoding:
def __init__(self, model, tokenizer, adapter_names, alpha=1.0,
first_m=3, top_k=50, num_common_tokens=10, verbose=False):
# 初始化参数
self.model = model
self.tokenizer = tokenizer
self.adapter_names = adapter_names
self.alpha = alpha
self.first_m = first_m
self.top_k = top_k
self.num_common_tokens = num_common_tokens
self.verbose = verbose
4.2 安全解码核心逻辑
-
生成配置设置:
- max_new_tokens=1
- do_sample=False (使用贪婪解码)
-
样本空间构建:
# 获取基础模型和专家模型的前k个令牌 topk_base = torch.topk(output_base.scores[0][0], self.top_k) topk_expert = torch.topk(output_expert.scores[0][0], self.top_k) # 寻找共享令牌 common_tokens = set() iter_range = 1 while len(common_tokens) < self.num_common_tokens: current_indices_base = range(iter_range * self.top_k) current_indices_expert = range(iter_range * self.top_k) common_in_iteration = set(topk_base.indices[current_indices_base].tolist()) & set(topk_expert.indices[current_indices_expert].tolist()) common_tokens.update(common_in_iteration) iter_range += 1 -
得分更新与采样:
# 计算更新后的得分 updated_scores = [] for token_id in intersection_indices: p_base = torch.softmax(output_base.scores[0][0], dim=-1)[token_id] p_expert = torch.softmax(output_expert.scores[0][0], dim=-1)[token_id] updated_p = p_base * (p_expert / p_base) ** self.alpha updated_scores.append(updated_p.item()) # 归一化得分 probs = torch.softmax(torch.tensor(updated_scores), dim=-1) # 采样策略 if not gen_config.do_sample: next_token_id = intersection_indices[torch.argmax(probs)] elif gen_config.top_p is not None: # Top-p采样实现 sorted_probs, sorted_indices = torch.sort(probs, descending=True) cumulative_probs = torch.cumsum(sorted_probs, dim=-1) sorted_token_ids = intersection_indices[sorted_indices] top_p_mask = cumulative_probs <= gen_config.top_p next_token_id = sorted_token_ids[top_p_mask][torch.multinomial( sorted_probs[top_p_mask], 1)].item()
5. 实验结果
在AdvBench测试集上:
- 原始ASR: ~60%
- 应用SafeDecoding后ASR降至16%
- 对良性查询响应质量无明显影响
典型日志输出示例:
[安全解码] 基础模型Top-5:
ID: 1234, Token: Sure, LogProb: -0.2, Prob: 0.82
ID: 5678, Token: I, LogProb: -1.5, Prob: 0.22
[安全解码] 专家模型Top-5:
ID: 5678, Token: I, LogProb: -0.8, Prob: 0.45
ID: 9101, Token: Sorry, LogProb: -1.2, Prob: 0.30
[安全解码] 选择Token: Sorry (ID: 9101)
6. 总结
SafeDecoding通过:
- 构建安全专家模型
- 在解码阶段融合原始模型和专家模型输出
- 动态调整令牌概率分布
实现了在不影响正常使用的前提下有效防御越狱攻击的目标。该方法计算效率高,易于与现有解码策略结合,是LLM安全防护的有效解决方案。