ChouBox/ChouBox.Code/MiddlewareExtend/SignatureVerifyMiddleware.cs
2025-05-18 15:20:26 +08:00

395 lines
24 KiB
C#
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

using ChouBox.Code.AppExtend;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Caching.Distributed;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.Logging;
using Newtonsoft.Json;
using StackExchange.Redis;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Security.Cryptography;
using System.Text;
using System.Text.RegularExpressions;
using System.Threading.Tasks;
namespace ChouBox.Code.MiddlewareExtend
{
/// <summary>
/// 请求签名验证中间件
/// </summary>
public class SignatureVerifyMiddleware
{
private readonly RequestDelegate _next;
private readonly IConfiguration _configuration;
private readonly ILogger<SignatureVerifyMiddleware> _logger;
private readonly IConnectionMultiplexer _redisConnection;
/// <summary>
/// Redis键前缀
/// </summary>
private const string REDIS_KEY_PREFIX = "api_nonce:";
/// <summary>
/// Nonce过期时间
/// </summary>
private const int NONCE_EXPIRE_TIME = 600; // 10分钟
/// <summary>
/// 时间戳允许的误差(秒)
/// </summary>
private const int TIMESTAMP_TOLERANCE = 60; // 1分钟
// 修改构造函数
public SignatureVerifyMiddleware(
RequestDelegate next,
IConfiguration configuration,
ILogger<SignatureVerifyMiddleware> logger,
IConnectionMultiplexer redisConnection)
{
_next = next;
_configuration = configuration;
_logger = logger;
_redisConnection = redisConnection;
}
// 添加属性获取Redis数据库
private IDatabase Redis => _redisConnection.GetDatabase();
/// <summary>
/// 处理请求签名验证
/// </summary>
public async Task InvokeAsync(HttpContext context)
{
// 获取当前请求路径
var path = context.Request.Path.Value?.TrimStart('/') ?? string.Empty;
// 检查是否在白名单内
if (IsWhitelistedPath(path, context))
{
await _next(context);
return;
}
// 根据请求方法进行签名验证
var method = context.Request.Method.ToUpper();
Dictionary<string, string> parameters = new Dictionary<string, string>();
if (method == "GET")
{
foreach (var item in context.Request.Query)
{
parameters[item.Key] = item.Value.ToString();
}
}
else if (method == "POST")
{
// 读取表单数据
if (context.Request.HasFormContentType)
{
foreach (var item in context.Request.Form)
{
parameters[item.Key] = item.Value.ToString();
}
}
else
{
// 读取JSON数据
context.Request.EnableBuffering();
using var reader = new System.IO.StreamReader(context.Request.Body, Encoding.UTF8, true, 1024, true);
var bodyText = await reader.ReadToEndAsync();
context.Request.Body.Position = 0;
if (!string.IsNullOrEmpty(bodyText))
{
try
{
var jsonData = JsonConvert.DeserializeObject<Dictionary<string, object>>(bodyText);
if (jsonData != null)
{
foreach (var item in jsonData)
{
parameters[item.Key] = item.Value?.ToString() ?? string.Empty;
}
}
}
catch (Exception ex)
{
_logger.LogError(ex, "解析请求JSON失败");
await Error(context, "无效的请求格式");
return;
}
}
}
}
try
{
await VerifySignature(context, parameters);
}
catch (Exception ex)
{
_logger.LogError(ex, "签名验证失败");
await Error(context, ex.Message);
return;
}
// 继续执行下一个中间件
await _next(context);
}
/// <summary>
/// 检查请求路径是否在白名单中
/// </summary>
private bool IsWhitelistedPath(string path, HttpContext context)
{
// 检查是否有内部标识
if (context.Request.Query.TryGetValue("is_test", out var isTest) && isTest == "true")
{
return true;
}
// 检查IP白名单
var ipWhitelist = GetIpWhitelist();
var clientIp = context.Connection.RemoteIpAddress?.ToString() ?? string.Empty;
if (ipWhitelist.Contains(clientIp))
{
return true;
}
// 获取白名单路径
var whitelistPaths = GetWhitelistPaths();
// 检查路径是否在白名单内
foreach (var whitePath in whitelistPaths)
{
if (PathMatch(whitePath, path))
{
return true;
}
}
return false;
}
/// <summary>
/// 验证请求签名
/// </summary>
private async Task VerifySignature(HttpContext context, Dictionary<string, string> parameters)
{
// 检查是否有必要的签名参数
if (!parameters.ContainsKey("timestamp") || !parameters.ContainsKey("sign") || !parameters.ContainsKey("nonce"))
{
throw new Exception("缺少必要的签名参数");
}
// 检查时间戳是否在允许范围内1分钟误差
if (!long.TryParse(parameters["timestamp"], out long timestamp))
{
throw new Exception("无效的时间戳格式");
}
var now = DateTimeOffset.UtcNow.ToUnixTimeSeconds();
if (Math.Abs(now - timestamp) > TIMESTAMP_TOLERANCE)
{
throw new Exception("请求时间戳超出允许范围");
}
// 检查nonce是否被使用过防重放攻击
var nonce = parameters["nonce"];
var nonceKey = REDIS_KEY_PREFIX + nonce;
var existingNonce = await Redis.KeyExistsAsync(nonceKey);
if (existingNonce)
{
throw new Exception("无效的请求nonce已被使用");
}
// 记录nonce到Redis有效期10分钟足够覆盖时间戳可接受的误差范围
await Redis.StringSetAsync(
nonceKey,
"1",
TimeSpan.FromSeconds(NONCE_EXPIRE_TIME)
);
// 从请求中获取签名
var requestSign = parameters["sign"];
// 拷贝参数,移除不需要的参数
var signParams = new Dictionary<string, string>(parameters);
if (signParams.ContainsKey("s"))
signParams.Remove("s"); // 移除URL参数
signParams.Remove("sign"); // 移除签名参数
// 按照键名对参数进行排序
var sortedParams = signParams.OrderBy(x => x.Key).ToDictionary(x => x.Key, x => x.Value);
// 组合参数为字符串
var signStr = new StringBuilder();
foreach (var param in sortedParams)
{
var value = param.Value;
// 处理复杂类型的参数
if (value.StartsWith("{") || value.StartsWith("["))
{
try
{
// 尝试解析为JSON格式化处理
var jsonObj = JsonConvert.DeserializeObject(value);
value = JsonConvert.SerializeObject(jsonObj, Formatting.None);
}
catch
{
// 如果解析失败,使用原始值
}
}
signStr.Append(param.Key).Append("=").Append(value).Append("&");
}
// 获取当前请求的域名和时间戳,组合为密钥
var host = context.Request.Host.Value;
var appSecret = host + timestamp;
// 添加密钥
if (signStr.Length > 0)
{
signStr.Length--; // 删除末尾的&符号
}
signStr.Append(appSecret);
// 生成本地签名使用MD5签名算法
var localSign = GetMd5Hash(signStr.ToString());
// 比对签名
if (requestSign != localSign)
{
throw new Exception("签名验证失败");
}
}
/// <summary>
/// 返回错误信息
/// </summary>
private async Task Error(HttpContext context, string message, int code = 0)
{
var result = new
{
status = code,
msg = message,
data = (object)null
};
context.Response.StatusCode = 200;
context.Response.ContentType = "application/json; charset=utf-8";
await context.Response.WriteAsync(JsonConvert.SerializeObject(result));
}
/// <summary>
/// 获取路径白名单
/// </summary>
private IEnumerable<string> GetWhitelistPaths()
{
// 1. 默认白名单路径(如支付回调通知等)
var defaultWhitelist = new List<string>
{
"notify/*", // 支付回调等通知
"health", // 健康检查
"debug", // 调试接口
"generate_urllinks",
"webhook/*", // webhook路径
"internal/*", // 内部接口
};
// 2. 从配置文件中获取白名单路径
try
{
var configWhitelist = _configuration.GetSection("Api:WhitelistPaths").Get<List<string>>();
if (configWhitelist != null && configWhitelist.Any())
{
return defaultWhitelist.Concat(configWhitelist);
}
}
catch (Exception ex)
{
_logger.LogError(ex, "获取API白名单路径配置失败");
}
return defaultWhitelist;
}
/// <summary>
/// 获取IP白名单
/// </summary>
private IEnumerable<string> GetIpWhitelist()
{
// 默认IP白名单
var defaultIpWhitelist = new List<string>
{
"127.0.0.1", // 本地回环地址
"::1", // IPv6本地回环地址
};
// 从配置文件中获取IP白名单
try
{
var configIpWhitelist = _configuration.GetSection("Api:IpWhitelist").Get<List<string>>();
if (configIpWhitelist != null && configIpWhitelist.Any())
{
return defaultIpWhitelist.Concat(configIpWhitelist);
}
}
catch (Exception ex)
{
_logger.LogError(ex, "获取API白名单IP配置失败");
}
return defaultIpWhitelist;
}
/// <summary>
/// 路径匹配检查
/// </summary>
private bool PathMatch(string pattern, string path)
{
// 完全匹配
if (pattern == path)
{
return true;
}
// 通配符匹配 (例如: 'notify/*')
if (pattern.Contains("*"))
{
var regex = "^" + Regex.Escape(pattern).Replace("\\*", ".*") + "$";
return Regex.IsMatch(path, regex, RegexOptions.IgnoreCase);
}
return false;
}
/// <summary>
/// 计算MD5哈希
/// </summary>
private string GetMd5Hash(string input)
{
using (var md5 = MD5.Create())
{
var inputBytes = Encoding.UTF8.GetBytes(input);
var hashBytes = md5.ComputeHash(inputBytes);
var sb = new StringBuilder();
for (int i = 0; i < hashBytes.Length; i++)
{
sb.Append(hashBytes[i].ToString("x2"));
}
return sb.ToString();
}
}
}
}