live-forum/server/webapi/LiveForum/LiveForum.Code/MiddlewareExtend/ResponseCacheMiddleware.cs
2026-03-24 11:27:37 +08:00

303 lines
12 KiB
C#
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

using LiveForum.Code.AttributeExtend;
using LiveForum.Code.Base;
using LiveForum.Code.Redis.Contract;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Mvc;
using Microsoft.AspNetCore.Routing;
using Microsoft.Extensions.Logging;
using Newtonsoft.Json;
using Newtonsoft.Json.Linq;
using System;
using System.IO;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace LiveForum.Code.MiddlewareExtend
{
/// <summary>
/// 响应缓存中间件
/// 通过特性标记需要缓存的Action使用Redis进行缓存
/// </summary>
public class ResponseCacheMiddleware
{
private readonly RequestDelegate _next;
private readonly IRedisService _redisService;
private readonly ILogger<ResponseCacheMiddleware> _logger;
private const string CACHE_KEY_PREFIX = "cache:api:";
public ResponseCacheMiddleware(
RequestDelegate next,
IRedisService redisService,
ILogger<ResponseCacheMiddleware> logger)
{
_next = next;
_redisService = redisService;
_logger = logger;
}
public async Task Invoke(HttpContext context)
{
// 1. 仅处理GET请求
if (!string.Equals(context.Request.Method, "GET", StringComparison.OrdinalIgnoreCase))
{
await _next(context);
return;
}
// 2. 获取路由信息(从路径解析)
var path = context.Request.Path.Value ?? "";
// 路径格式:/api/{Controller}/{Action}
if (!path.StartsWith("/api/", StringComparison.OrdinalIgnoreCase))
{
await _next(context);
return;
}
var pathParts = path.TrimStart('/').Split('/', StringSplitOptions.RemoveEmptyEntries);
if (pathParts.Length < 3)
{
await _next(context);
return;
}
var controller = pathParts[1]; // api/{controller}/...
var action = pathParts[2]; // api/{controller}/{action}
// 移除Controller后缀如果有
if (controller.EndsWith("controller", StringComparison.OrdinalIgnoreCase))
{
controller = controller.Substring(0, controller.Length - 10);
}
// 3. 尝试获取Action的ResponseCacheAttribute特性
var cacheAttribute = await GetResponseCacheAttributeAsync(context, controller, action);
if (cacheAttribute == null || !cacheAttribute.Enabled)
{
// 没有缓存特性或已禁用,继续执行
await _next(context);
return;
}
// 4. 生成缓存Key
var cacheKey = GenerateCacheKey(controller, action, context.Request.Query, cacheAttribute);
try
{
// 5. 查询Redis缓存
var cachedResponse = await _redisService.GetAsync(cacheKey);
if (!string.IsNullOrEmpty(cachedResponse))
{
// 缓存命中,直接返回
context.Response.ContentType = "application/json; charset=utf-8";
context.Response.StatusCode = StatusCodes.Status200OK;
await context.Response.WriteAsync(cachedResponse, Encoding.UTF8);
_logger.LogInformation("[ResponseCache] 缓存命中。Key: {CacheKey}", cacheKey);
return;
}
// 6. 缓存未命中,启用响应缓冲以便捕获响应
var originalBodyStream = context.Response.Body;
using (var responseBody = new MemoryStream())
{
context.Response.Body = responseBody;
// 继续执行管道
await _next(context);
// 7. 读取响应内容
responseBody.Seek(0, SeekOrigin.Begin);
var responseBodyText = await new StreamReader(responseBody).ReadToEndAsync();
// 将响应写回原始流
responseBody.Seek(0, SeekOrigin.Begin);
await responseBody.CopyToAsync(originalBodyStream);
// 8. 检查响应状态码和业务代码
if (context.Response.StatusCode == StatusCodes.Status200OK && !string.IsNullOrEmpty(responseBodyText))
{
try
{
// 解析JSON响应检查Code字段
var responseJson = JObject.Parse(responseBodyText);
var codeValue = responseJson["code"]?.Value<int>();
// 只有Code=0Success才缓存
if (codeValue.HasValue && codeValue.Value == (int)ResponseCode.Success)
{
// 9. 写入Redis缓存添加1-100秒随机数防止缓存同时失效
var random = new Random();
var randomSeconds = random.Next(1, 101); // 1-100秒随机数
var actualDuration = cacheAttribute.Duration + randomSeconds;
await _redisService.SetAsync(cacheKey, responseBodyText, TimeSpan.FromSeconds(actualDuration));
_logger.LogInformation(
"[ResponseCache] 响应已缓存。Key: {CacheKey}, 基础Duration: {Duration}秒, 随机数: {RandomSeconds}秒, 实际Duration: {ActualDuration}秒",
cacheKey, cacheAttribute.Duration, randomSeconds, actualDuration);
}
else
{
_logger.LogInformation(
"[ResponseCache] 响应Code不为0不缓存。Key: {CacheKey}, Code: {Code}",
cacheKey, codeValue);
}
}
catch (JsonException ex)
{
_logger.LogWarning(ex, "[ResponseCache] 解析响应JSON失败不缓存。Key: {CacheKey}", cacheKey);
}
}
}
}
catch (Exception ex)
{
_logger.LogError(ex, "[ResponseCache] 缓存处理异常。Key: {CacheKey}", cacheKey);
// 发生异常时继续正常流程
}
}
/// <summary>
/// 获取Action的ResponseCacheAttribute特性
/// </summary>
private async Task<ResponseCacheExtendAttribute> GetResponseCacheAttributeAsync(
HttpContext context,
string controller,
string action)
{
try
{
// 通过Endpoint获取Action描述符
var endpoint = context.GetEndpoint();
if (endpoint?.Metadata != null)
{
var attribute = endpoint.Metadata
.OfType<ResponseCacheExtendAttribute>()
.FirstOrDefault();
if (attribute != null)
{
return attribute;
}
}
// 如果Endpoint中没有尝试通过反射获取备用方案
var controllerType = GetControllerType(controller);
if (controllerType != null)
{
var methodInfo = controllerType.GetMethod(action,
System.Reflection.BindingFlags.Public |
System.Reflection.BindingFlags.Instance |
System.Reflection.BindingFlags.IgnoreCase);
if (methodInfo != null)
{
var attr = methodInfo.GetCustomAttributes(typeof(ResponseCacheExtendAttribute), false)
.FirstOrDefault() as ResponseCacheExtendAttribute;
if (attr != null)
{
return attr;
}
}
}
}
catch (Exception ex)
{
_logger.LogWarning(ex, "[ResponseCache] 获取ResponseCacheAttribute特性失败。Controller: {Controller}, Action: {Action}",
controller, action);
}
return null;
}
/// <summary>
/// 获取Controller类型通过反射
/// </summary>
private Type GetControllerType(string controllerName)
{
try
{
// 尝试查找Controller类型
var controllerFullName = $"LiveForum.WebApi.Controllers.{controllerName}Controller";
var controllerType = Type.GetType(controllerFullName);
if (controllerType == null)
{
// 尝试从所有已加载的程序集中查找
foreach (var assembly in AppDomain.CurrentDomain.GetAssemblies())
{
controllerType = assembly.GetType(controllerFullName);
if (controllerType != null)
break;
}
}
return controllerType;
}
catch
{
return null;
}
}
/// <summary>
/// 生成缓存Key
/// </summary>
private string GenerateCacheKey(
string controller,
string action,
IQueryCollection query,
ResponseCacheExtendAttribute attribute)
{
// 如果指定了自定义前缀,使用自定义前缀
if (!string.IsNullOrWhiteSpace(attribute.CacheKeyPrefix))
{
var baseKey = attribute.CacheKeyPrefix;
// 如果有VaryByQueryKeys追加参数
if (attribute.VaryByQueryKeys != null && attribute.VaryByQueryKeys.Length > 0)
{
var paramPairs = attribute.VaryByQueryKeys
.Where(key => query.ContainsKey(key))
.OrderBy(key => key) // 排序确保一致性
.Select(key => $"{key}={query[key]}")
.ToList();
if (paramPairs.Any())
{
baseKey += ":" + string.Join("&", paramPairs);
}
}
return $"{CACHE_KEY_PREFIX}{baseKey}";
}
// 使用默认格式cache:api:{Controller}:{Action}
var key = $"{CACHE_KEY_PREFIX}{controller}:{action}";
// 如果有VaryByQueryKeys追加参数
if (attribute.VaryByQueryKeys != null && attribute.VaryByQueryKeys.Length > 0)
{
var paramPairs = attribute.VaryByQueryKeys
.Where(key => query.ContainsKey(key))
.OrderBy(key => key) // 排序确保一致性
.Select(key => $"{key}={query[key]}")
.ToList();
if (paramPairs.Any())
{
key += ":" + string.Join("&", paramPairs);
}
}
return key;
}
}
}