通过中间件实现简易的 WAF 功能

@zgcwkj  2025年04月26日

分类:

网站 代码 

C# MVC,通过中间件实现简易的 WAF 功能

创建以下三个过滤器

WafSCFiter.cs
using Microsoft.AspNetCore.Mvc;
using Microsoft.AspNetCore.Mvc.Filters;
using System.Text.RegularExpressions;

namespace WAFTest.Extensions
{
    /// <summary>
    /// WAF 安全检查过滤器
    /// </summary>
    public class WafSCFiter : ActionFilterAttribute
    {
        /// <summary>
        /// 日志函数
        /// </summary>
        public LogFunction _LogFunc { get; }

        /// <summary>
        /// 网页代码函数
        /// </summary>
        private HtmlFunction _HtmlFunc { get; }

        /// <summary>
        /// 构造函数
        /// </summary>
        public WafSCFiter(LogFunction logFunc, HtmlFunction htmlFunc)
        {
            this._LogFunc = logFunc;
            this._HtmlFunc = htmlFunc;
        }

        /// <summary>
        /// 允许访问前缀
        /// </summary>
        private static readonly List<string> _allowUrls = new()
        {
            "/api", // API接口
            "/verify", // 验证码
            "/wafverify", // 过验证
            "/img",
            "/lib",
            "/css",
            "/js",
        };

        /// <summary>
        /// 允许访问的IP
        /// </summary>
        private static readonly List<string> _allowIps = new()
        {
            //"::1",
            //"127.0.0.1",
        };

        /// <summary>
        /// 在执行前检查
        /// </summary>
        /// <param name="context">执行上下文</param>
        public override void OnActionExecuting(ActionExecutingContext context)
        {
            var request = context.HttpContext.Request;
            var userIP = context.HttpContext.Connection.RemoteIpAddress?.ToString() ?? "unknown";
            var path = request.Path.ToString().ToLower();
            // 检查过滤器状态
            var isSkip = context.HttpContext.Items.TryGetValue("WAF_Skip", out var skipValue);
            if (isSkip && skipValue?.To<bool>() == true)
            {
                base.OnActionExecuting(context);
                return;
            }
            // 放行请求IP
            if (_allowIps.Any(w => w == userIP))
            {
                context.HttpContext.Items.Add("WAF_Skip", true);
                base.OnActionExecuting(context);
                return;
            }
            // 放行请求前缀
            if (_allowUrls.Any(w => path.StartsWith(w)))
            {
                context.HttpContext.Items.Add("WAF_Skip", true);
                base.OnActionExecuting(context);
                return;
            }
            // 检查请求路径是否包含敏感信息
            if (IsSensitivePath(request.Path))
            {
                var guid = Guid.NewGuid().ToString();
                _LogFunc.BaseLog("warn", "WAF", $"IP {userIP} 访问 {request.Path} 参数包含恶意代码({guid})");
                context.Result = GoBlocked(guid);
                return;
            }
            // 检查URL查询参数是否包含恶意代码
            foreach (var query in request.Query)
            {
                if (ContainsSqlInjection(query.Value) || ContainsXSS(query.Value))
                {
                    var guid = Guid.NewGuid().ToString();
                    _LogFunc.BaseLog("warn", "WAF", $"IP {userIP} 访问 {request.Path} 参数包含恶意代码({guid})");
                    context.Result = GoBlocked(guid);
                    return;
                }
            }
            // 检查表单数据是否包含恶意代码
            if (request.HasFormContentType)
            {
                foreach (var form in request.Form)
                {
                    if (ContainsSqlInjection(form.Value) || ContainsXSS(form.Value))
                    {
                        var guid = Guid.NewGuid().ToString();
                        _LogFunc.BaseLog("warn", "WAF", $"IP {userIP} 访问 {request.Path} 参数包含恶意代码({guid})");
                        context.Result = GoBlocked(guid);
                        return;
                    }
                }
            }
            // 交给下一个过滤器处理
            base.OnActionExecuting(context);
        }

        /// <summary>
        /// SQL注入检测模式,用于匹配常见的SQL注入攻击模式
        /// </summary>
        private readonly string[] _sqlInjectionPatterns = new[]
        {
            @"\b(SELECT|INSERT|UPDATE|DELETE|DROP|UNION|EXEC|ALTER|CREATE|TRUNCATE|DECLARE|WAITFOR|CAST|CONVERT)\b", // 常见SQL关键字
            @"[;]\s*(SELECT|INSERT|UPDATE|DELETE|DROP)", // 分号后跟SQL语句
            @"--[^\r\n]*", // SQL注释
            @"/\*.*?\*/", // 多行注释
            @"\band\b|\bor\b|\bxor\b|\bnot\b", // 逻辑运算符
            @"[+\-*/%%]\s*\d+\s*[=<>]", // 算术运算符
            @"\b(true|false)\b", // 布尔值
            @"'\s*[+\-*/%%]\s*'", // 字符串拼接
        };

        /// <summary>
        /// 检查输入是否包含SQL注入攻击模式
        /// </summary>
        /// <param name="input">需要检查的输入字符串</param>
        /// <returns>如果包含SQL注入模式返回true,否则返回false</returns>
        private bool ContainsSqlInjection(string input)
        {
            if (string.IsNullOrEmpty(input)) return false;
            // URL解码输入
            var decodedInput = System.Web.HttpUtility.UrlDecode(input);
            // 规范化输入(移除多余空格,转换为小写)
            decodedInput = Regex.Replace(decodedInput, @"\s+", " ").Trim().ToLower();
            // 返回状态
            return _sqlInjectionPatterns.Any(pattern => Regex.IsMatch(decodedInput, pattern, RegexOptions.IgnoreCase));
        }

        /// <summary>
        /// XSS攻击检测模式,用于匹配常见的跨站脚本攻击模式
        /// </summary>
        private readonly string[] _xssPatterns = new[]
        {
            @"<script[^>]*>.*?</script>", // 脚本标签
            @"javascript:", // JavaScript协议
            @"vbscript:", // VBScript协议
            @"onload=", // 加载事件
            @"onerror=", // 错误事件
        };

        /// <summary>
        /// 检查输入是否包含XSS攻击模式
        /// </summary>
        /// <param name="input">需要检查的输入字符串</param>
        /// <returns>如果包含XSS攻击模式返回true,否则返回false</returns>
        private bool ContainsXSS(string input)
        {
            if (string.IsNullOrEmpty(input)) return false;
            return _xssPatterns.Any(pattern => Regex.IsMatch(input, pattern, RegexOptions.IgnoreCase));
        }

        /// <summary>
        /// 敏感路径检测模式,用于匹配可能包含敏感信息的URL路径
        /// </summary>
        private readonly string[] _sensitivePathPatterns = new[]
        {
            @"\.config$", // 配置文件
            @"\.conf$", // 配置文件
            @"\.ini$", // INI配置文件
            @"\.env$", // 环境变量文件
            @"\badmin\b", // 管理员路径
            @"\bmanage\b", // 管理路径
        };

        /// <summary>
        /// 检查请求路径是否包含敏感信息
        /// </summary>
        /// <param name="path">需要检查的请求路径</param>
        /// <returns>如果是敏感路径返回true,否则返回false</returns>
        private bool IsSensitivePath(PathString path)
        {
            var pathStr = path.ToString().ToLower();
            return _sensitivePathPatterns.Any(pattern => Regex.IsMatch(pathStr, pattern, RegexOptions.IgnoreCase));
        }

        /// <summary>
        /// 到拦截页面
        /// </summary>
        /// <param name="id">拦截ID</param>
        /// <returns></returns>
        public ContentResult GoBlocked(string id)
        {
            var blockedHtml = @"
<!DOCTYPE html>

<html>
<head>
    <meta charset='UTF-8'>
    <meta name='viewport' content='width=device-width, initial-scale=1.0'>
    <title>安全拦截</title>
    <style>
        * { margin: 0;padding: 0; }
        body { width: 100vw;height: 100vh;background: #ffebee;display: flex;align-items: center;justify-content: center; }
        .content { height: 300px;display: flex;color: #f44336;align-items: center;flex-direction: column; }
        .content > svg { width: 80px;height: 80px;fill: #f44336; }
        .content > div { line-height: 50px;display: flex;align-items: center;flex-direction: column; }
    </style>
</head>
<body>
    <div class='content'>
        <svg viewBox='0 0 24 24'>
            <path d='M12 2L1 21h22L12 2zm0 3.45l8.27 14.32H3.73L12 5.45zm-1.5 8.09v-4h3v4h-3zm0 4h3v-2h-3v2z' />
        </svg>
        <div>
            <h1>访问已被拦截</h1>
            <p>id: " + id + @"</p>
        </div>
    </div>
    <script>console.log('by zgcwkj')</script>
</body>
</html>";
            var result = new ContentResult
            {
                Content = _HtmlFunc.Compress(blockedHtml),
                ContentType = "text /html; charset=utf-8"
            };
            return result;
        }
    }
}
WafJSFiter.cs
using Microsoft.AspNetCore.Mvc;
using Microsoft.AspNetCore.Mvc.Filters;
using System.Collections.Concurrent;
using System.Security.Cryptography;
using System.Text;

namespace WAFTest.Extensions
{
    /// <summary>
    /// WAF 脚本检查过滤器
    /// </summary>
    public class WafJSFiter : ActionFilterAttribute
    {
        /// <summary>
        /// 令牌存储
        /// </summary>
        private static readonly ConcurrentDictionary<string, DateTime> _tokenStore = new();

        /// <summary>
        /// 安全功能
        /// </summary>
        private SecurityFunction _SecurityFunc { get; }

        /// <summary>
        /// 网页代码函数
        /// </summary>
        private HtmlFunction _HtmlFunc { get; }

        /// <summary>
        /// 构造函数
        /// </summary>
        public WafJSFiter(SecurityFunction securityFunc, HtmlFunction htmlFunc)
        {
            this._SecurityFunc = securityFunc;
            this._HtmlFunc = htmlFunc;
        }

        /// <summary>
        /// 在执行前检查
        /// </summary>
        /// <param name="context">执行上下文</param>
        public override void OnActionExecuting(ActionExecutingContext context)
        {
            var request = context.HttpContext.Request;
            var response = context.HttpContext.Response;
            var userIP = context.HttpContext.Connection.RemoteIpAddress?.ToString() ?? "unknown";
            var userAgent = response.Headers.UserAgent;
            // 检查过滤器状态
            var isSkip = context.HttpContext.Items.TryGetValue("WAF_Skip", out var skipValue);
            if (isSkip && skipValue?.To<bool>() == true)
            {
                base.OnActionExecuting(context);
                return;
            }
            // 生成用户唯一值
            var userIDKey = $"{userIP}_{userAgent}".ToMD5().ToLower();
            var userIDSalt = _SecurityFunc.Encrypt($"{DateTime.Now:yyyy-MM-dd}", userIDKey);
            // 检查是否已通过验证
            var wafVerifyKey = ".WafJsVerify.Cookies";
            var wafVerifyValue = request.Cookies[wafVerifyKey];
            if (wafVerifyValue != null)
            {
                // 验证凭据是否存在且未过期
                if (_tokenStore.TryGetValue(wafVerifyValue, out var expirationTime) && DateTime.Now <= expirationTime)
                {
                    // 验证凭据前缀
                    if (wafVerifyValue.Split('_')[0].Equals(userIDSalt))
                    {
                        base.OnActionExecuting(context);
                        return;
                    }
                }
                // 移除过期的凭据
                _tokenStore.TryRemove(wafVerifyValue, out _);
            }
            // 存储凭据
            var verifyToken = GenerateVerifyToken();// 随机验证码
            var verifyTokenExpiration = 30; // 有效期时间
            verifyToken = $"{userIDSalt}_{verifyToken}";// 增加验证前缀
            _tokenStore[verifyToken] = DateTime.Now.AddMinutes(verifyTokenExpiration);
            // 验证脚本
            var verifyCss = @"
.verify-wrapper {text-align: center;padding: 50px 20px;font-family: 'Microsoft YaHei', sans-serif;}
.verify-title {font-size: 24px;color: #333;margin-bottom: 20px;}
.verify-message {font-size: 16px;color: #666;margin-bottom: 30px;line-height: 1.6;}
.loading-spinner {display: inline-block;width: 40px;height: 40px;border: 4px solid #f3f3f3;border-top: 4px solid #3498db;border-radius: 50%;animation: spin 1s linear infinite;}
@keyframes spin {0% {transform: rotate(0deg);}100% {transform: rotate(360deg);}}
.verify-footer {text-align: center;}";
            var verifyJS = @"
(function() {
    try {
        // 基本JS执行验证
        var container = document.getElementById('verify-container');
        if (!container) throw new Error('DOM操作失败');
        // 数组操作验证
        var arr = [1, 2, 3, 4, 5];
        var sum = arr.reduce((a, b) => a + b, 0);
        if (sum !== 15) throw new Error('数组操作失败');
        // Promise验证
        new Promise(resolve => resolve(true))
            .then(() => {
                // 设置验证结果Cookie
                document.cookie = '" + wafVerifyKey + @"=;path=/;expires=Thu, 01 Jan 1970 00:00:00 UTC;';
                document.cookie = '" + wafVerifyKey + @"=' + encodeURIComponent('" + verifyToken + @"') + ';path=/;';
                // 验证通过后刷新页面
                setTimeout(() => window.location.reload(), 500);
            });
    } catch (err) {
        console.error('JS验证失败:', err);
    }
})();";
            var verifyHtml = $@"
<!DOCTYPE html>
<html>
<head>
    <title>安全验证</title>
    <meta charset='UTF-8'>
    <meta http-equiv='X-UA-Compatible' content='IE=edge'>
    <meta name='viewport' content='width=device-width, initial-scale=1.0'>
    <meta http-equiv='Cache-Control' content='no-cache'>
    <style>{verifyCss}</style>
</head>
<body>
    <div id='verify-container' style='display:none'></div>
    <div class='verify-wrapper'>
        <h1 class='verify-title'>安全验证中</h1>
        <p class='verify-message'>系统正在进行安全验证,请稍候...<br>验证通过后将自动跳转到目标页面</p>
        <div class='loading-spinner'></div>
    </div>
    <hr />
    <div class='verify-footer'>
        <p><label>ID:</label>" + userIDKey + "<br><label>IP:</label>" + userIP + $@"</p>
    </div>
    <script>{_HtmlFunc.ObfuscatorJS(verifyJS)}</script>
</body>
</html>";
            // 注入验证脚本
            context.Result = new ContentResult
            {
                Content = _HtmlFunc.Compress(verifyHtml),
                ContentType = "text/html; charset=utf-8"
            };
        }

        /// <summary>
        /// 生成验证令牌
        /// </summary>
        private string GenerateVerifyToken()
        {
            var random = new Random();
            var token = random.Next(100000, 999999).ToString();
            using var sha256 = SHA256.Create();
            var hashBytes = sha256.ComputeHash(Encoding.UTF8.GetBytes(token));
            return Convert.ToBase64String(hashBytes);
        }
    }
}
WafCCFiter.cs
using Microsoft.AspNetCore.Mvc;
using Microsoft.AspNetCore.Mvc.Filters;
using Microsoft.Extensions.Caching.Memory;
using System.Net;

namespace WAFTest.Extensions
{
    /// <summary>
    /// WAF 请求频率过滤器
    /// </summary>
    public class WafCCFiter : ActionFilterAttribute
    {
        /// <summary>
        /// 访问白名单
        /// <para>Key为IP地址,Value为过期时间</para>
        /// </summary>
        private static readonly Dictionary<string, DateTime> _whiteList = new();

        /// <summary>
        /// 日志函数
        /// </summary>
        public LogFunction _LogFunc { get; }

        /// <summary>
        /// 内存缓存
        /// </summary>
        private IMemoryCache _IMemoryCache { get; }

        /// <summary>
        /// 安全功能
        /// </summary>
        private SecurityFunction _SecurityFunc { get; }

        /// <summary>
        /// 网页代码函数
        /// </summary>
        private HtmlFunction _HtmlFunc { get; }

        /// <summary>
        /// 实例化
        /// </summary>
        public WafCCFiter(LogFunction logFunc, IMemoryCache memoryCache, SecurityFunction securityFunc, HtmlFunction htmlFunc)
        {
            this._LogFunc = logFunc;
            this._IMemoryCache = memoryCache;
            this._SecurityFunc = securityFunc;
            this._HtmlFunc = htmlFunc;
        }

        /// <summary>
        /// 在执行前检查
        /// </summary>
        /// <param name="context">执行上下文</param>
        public override void OnActionExecuting(ActionExecutingContext context)
        {
            var request = context.HttpContext.Request;
            var userIP = context.HttpContext.Connection.RemoteIpAddress?.ToString() ?? "unknown";
            var path = request.Path.ToString().ToLower();
            // 检查过滤器状态
            var isSkip = context.HttpContext.Items.TryGetValue("WAF_Skip", out var skipValue);
            if (isSkip && skipValue?.To<bool>() == true)
            {
                base.OnActionExecuting(context);
                return;
            }
            // 检查IP是否在白名单中且未过期
            if (_whiteList.TryGetValue(userIP, out var expireTime) && expireTime > DateTime.Now)
            {
                base.OnActionExecuting(context);
                return;
            }
            // 如果IP已过期,从白名单中移除
            else if (_whiteList.ContainsKey(userIP))
            {
                _whiteList.Remove(userIP);
            }
            // 检查请求频率
            var cacheKey = $"CC_FILTER_{userIP}";
            var requestInfo = _IMemoryCache.GetOrCreate(cacheKey, entry =>
            {
                entry.AbsoluteExpirationRelativeToNow = TimeSpan.FromMinutes(1);
                return new RequestInfo { Count = 0, FirstRequestTime = DateTime.Now };
            })!;
            // 请求频率规则
            var random = new Random();
            var maxRequestsMinute = 1; // 1分钟内
            //var maxRequestsPer = random.Next(50, 100); // 50至100次请求
            var maxRequestsPer = random.Next(10, 30);
            // 加满计数器
            if (requestInfo.Count > maxRequestsPer)
            {
                requestInfo.Count = int.MaxValue;
            }
            // 重置计数器
            else if ((DateTime.Now - requestInfo.FirstRequestTime).TotalMinutes >= maxRequestsMinute)
            {
                requestInfo.Count = 1;
                requestInfo.FirstRequestTime = DateTime.Now;
            }
            // 增加计数器
            else
            {
                requestInfo.Count++;
            }
            // 检查是否超过访问限制
            _IMemoryCache.Set(cacheKey, requestInfo);
            if (requestInfo.Count > maxRequestsPer)
            {
                _LogFunc.BaseLog("warn", "WAFCC", $"IP {userIP} 访问频率过高,已被拦截");
                context.Result = GoVerify(userIP);
                return;
            }
            // 交给下一个过滤器处理
            base.OnActionExecuting(context);
        }

        /// <summary>
        /// 请求信息
        /// </summary>
        private class RequestInfo
        {
            public int Count { get; set; }
            public DateTime FirstRequestTime { get; set; }
        }

        /// <summary>
        /// 到验证页面
        /// </summary>
        /// <param name="ip">请求IP</param>
        /// <returns></returns>
        public ContentResult GoVerify(string ip)
        {
            var code = _SecurityFunc.Encrypt(ip);
            var blockedHtml = @"
<!DOCTYPE html>

<html>
<head>
    <meta charset='UTF-8'>
    <meta name='viewport' content='width=device-width, initial-scale=1.0'>
    <title>安全验证</title>
    <style>
        * { margin: 0;padding: 0; }
        body { width: 100vw;height: 100vh;background: #ffebee;display: flex;align-items: center;justify-content: center; }
        .content { height: 300px;display: flex;color: #f44336;align-items: center;flex-direction: column; }
        .content > svg { width: 80px;height: 80px;fill: #f44336; }
        .content > div { line-height: 50px;display: flex;align-items: center;flex-direction: column; }
        .verify-container { display: flex;align-items: center;gap: 10px;margin: 10px 0; }
        #verify { padding: 8px 12px;border: 2px solid #f44336;border-radius: 4px;outline: none;font-size: 16px; }
        #verify:focus { border-color: #d32f2f;box-shadow: 0 0 0 2px rgba(244, 67, 54, 0.2); }
        #verifyImg { height: 40px;border-radius: 4px;cursor: pointer; }
        button { padding: 8px 24px;background: #f44336;color: white;border: none;border-radius: 4px;font-size: 16px;cursor: pointer; }
        button:hover { background: #d32f2f; }
    </style>
</head>
<body>
    <div class='content'>
        <svg class='logo' viewBox='0 0 24 24'>
            <path d='M12 2L1 21h22L12 2zm0 3.45l8.27 14.32H3.73L12 5.45zm-1.5 8.09v-4h3v4h-3zm0 4h3v-2h-3v2z' />
        </svg>
        <div>
            <h1>请完成验证</h1>
            <h3>检测到您的访问频率较高,需要进行人机验证。</h3>
            <div class='verify-container'>
                <input id='verify' placeholder='请输入验证码'>
                <img id='verifyImg' src='/Verify?code=" + code + @"'>
            </div>
            <button id='submit'>提交</button>
            <p>ip: " + ip + @"</p>
        <div>
    </div>
    <script>
        // 刷新验证码
        let verifyImg = document.querySelector('#verifyImg');
        verifyImg.addEventListener('click', function () {
            this.src = verifyImg.src + '&t=' + Math.random();
        });
        // 提交验证码
        let verify = document.querySelector('#verify');
        let submit = document.querySelector('#submit');
        submit.addEventListener('click', async function () {
            if (verify.value.length == 0) {
                alert('请输入验证码');
                return;
            }
            // 提交验证
            try {
                const formData = new FormData();
                formData.append('code', '" + code + @"');
                formData.append('verify', verify.value);
                const response = await fetch('/WafVerify', {
                    method: 'POST',
                    body: formData
                }).then(res => {
                    if (!res.ok) throw new Error('请求失败');
                    return res.text();
                });
                // 验证结果
                let result = JSON.parse(response);
                if (result.data === true) {
                    window.location.reload();
                } else {
                    alert('验证失败');
                }
            } catch (error) {
                alert(error.message);
            }
        });
    </script>
</body>
</html>";
            var result = new ContentResult
            {
                Content = _HtmlFunc.Compress(blockedHtml),
                ContentType = "text /html; charset=utf-8"
            };
            return result;
        }

        #region 对外函数

        /// <summary>
        /// 添加IP到白名单
        /// </summary>
        /// <param name="ip">IP地址</param>
        /// <param name="expireMinutes">过期时间(分钟),默认5分钟</param>
        public static void AddToWhiteList(string ip, int expireMinutes = 5)
        {
            if (IPAddress.TryParse(ip, out _))
            {
                _whiteList[ip] = DateTime.Now.AddMinutes(expireMinutes);
            }
        }

        /// <summary>
        /// 从白名单移除IP
        /// </summary>
        /// <param name="ip">IP地址</param>
        public static void RemoveFromWhiteList(string ip)
        {
            _whiteList.Remove(ip);
        }

        #endregion 对外函数
    }
}

启用过滤器

//添加 过滤器
builder.Services.AddControllers(options =>
{
    options.Filters.Add<WafSCFiter>();//安全检查
    options.Filters.Add<WafJSFiter>();//脚本检查
    options.Filters.Add<WafCCFiter>();//访问频率检查
});

完整示例源码

WAFTest(NetCode):

内容已隐藏,需要评论并且审核通过后,才能阅读隐藏内容

WAFRunJS(GoLang):WAFRunJS



添加新评论

Top