创建签名验证自定义中间件 ,优化代码结构

This commit is contained in:
2021-02-26 16:40:51 +08:00
parent 7c3ce74a3b
commit d6617f47c0
8 changed files with 128 additions and 54 deletions

View File

@@ -10,6 +10,7 @@ using System.Threading.Tasks;
using MediatR; using MediatR;
using QRCodeService.Application.Queries; using QRCodeService.Application.Queries;
using QRCodeService.Application.Commands; using QRCodeService.Application.Commands;
using QRCodeService.Infrastructure.Middlewares;
namespace QRCodeService.Controllers.Api namespace QRCodeService.Controllers.Api
{ {
@@ -19,32 +20,28 @@ namespace QRCodeService.Controllers.Api
{ {
readonly IMediator mediator; readonly IMediator mediator;
readonly ILinkQueries linkQueries; readonly ILinkQueries linkQueries;
readonly IAppQueries appQueries;
public LinkController(IMediator mediator, ILinkQueries queries, IAppQueries appQueries) public LinkController(IMediator mediator, ILinkQueries queries)
{ {
this.mediator = mediator; this.mediator = mediator;
this.linkQueries = queries; this.linkQueries = queries;
this.appQueries = appQueries; }
[Route("{shortCode}")]
[HttpGet]
public async Task<IActionResult> Get(string shortCode)
{
var link = await linkQueries.GetLinkAsync(shortCode);
if (link == null)
{
return NotFound();
}
return Ok(link);
} }
[HttpGet] [CheckSign(typeof(CreateLinkModel))]
public IActionResult Get()
{
return Ok();
}
[HttpPost] [HttpPost]
public async Task<IActionResult> Create(CreateLinkModel input) public async Task<IActionResult> Create(CreateLinkModel input)
{ {
var app = await appQueries.GetAppAsync(input.AppId);
if (app == null)
{
return BadRequest();
}
if (! await input.CheckValidAsync(app.Appkey))
{
return BadRequest();
}
var command = new CreateLinkCommand(input.SuffixUrl,1); var command = new CreateLinkCommand(input.SuffixUrl,1);
var link = await mediator.Send(command); var link = await mediator.Send(command);
if (link==null) if (link==null)

View File

@@ -9,6 +9,7 @@ using System.IO;
using QRCodeService.Application.Queries; using QRCodeService.Application.Queries;
using QRCodeService.Models; using QRCodeService.Models;
using QRCodeService.Options; using QRCodeService.Options;
using QRCodeService.Infrastructure.Middlewares;
namespace QRCodeService.Controllers namespace QRCodeService.Controllers
{ {
@@ -25,20 +26,11 @@ namespace QRCodeService.Controllers
this.appQueries = appQueries; this.appQueries = appQueries;
this.option = option; this.option = option;
} }
[CheckSign(typeof(GetQRCodeModel))]
[Route("qrcode")] [Route("qrcode")]
[HttpPost] [HttpPost]
public async Task<IActionResult> GetImage(GetQRCodeModel input) public async Task<IActionResult> GetImage(GetQRCodeModel input)
{ {
var app = await appQueries.GetAppAsync(input.AppId);
if (app == null)
{
return BadRequest();
}
if (!await input.CheckValidAsync(app.Appkey))
{
return BadRequest();
}
var link = await linkQueries.GetLinkAsync(input.ShortCode); var link = await linkQueries.GetLinkAsync(input.ShortCode);
if (link.AppId != input.AppId) if (link.AppId != input.AppId)
{ {

View File

@@ -0,0 +1,13 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace QRCodeService.Extensions
{
public static class ControllerBaseExtension
{
}
}

View File

@@ -6,7 +6,7 @@ using System.Threading.Tasks;
namespace QRCodeService.Extensions namespace QRCodeService.Extensions
{ {
public static class GenericTypeExtensions public static class GenericTypeExtension
{ {
public static string GetGenericTypeName(this Type type) public static string GetGenericTypeName(this Type type)
{ {

View File

@@ -0,0 +1,18 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace QRCodeService.Infrastructure.Middlewares
{
public class CheckSignAttribute:Attribute
{
public Type ModelType { get; set; }
public CheckSignAttribute(Type type)
{
ModelType = type;
}
}
}

View File

@@ -0,0 +1,78 @@
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features;
using QRCodeService.Application.Queries;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using System.Text.Json;
using System.Text.Json.Serialization;
using QRCodeService.Extensions;
namespace QRCodeService.Infrastructure.Middlewares
{
/// <summary>
/// 报文签名验证中间件
/// </summary>
public class CheckSignMiddleware
{
readonly RequestDelegate next;
public CheckSignMiddleware(RequestDelegate next)
{
this.next = next;
}
public async Task Invoke(HttpContext context, IAppQueries appQueries)
{
if (context == null)
{
throw new ArgumentNullException(nameof(context));
}
var endpoint = context.GetEndpoint();
if (endpoint != null)
{
var attribute = endpoint.Metadata.GetMetadata<CheckSignAttribute>();
if (attribute != null)
{
context.Request.EnableBuffering();
var requestReader = new StreamReader(context.Request.Body);
var requestContent = await requestReader.ReadToEndAsync();
var param = JsonSerializer.Deserialize(requestContent,attribute.ModelType);
var props = attribute.ModelType.GetProperties();
var appid = (int)props.Single(p=>p.Name=="AppId").GetValue(param);
var time = props.Single(p => p.Name == "Time").GetValue(param) as string;
var sign = props.Single(p => p.Name == "Sign").GetValue(param) as string;
var timeDate = time.ToDate("yyyyMMddHHmmss");
if (timeDate == null||Math.Abs((timeDate.Value - DateTime.Now).TotalSeconds) > 60)//时间不同步
{
context.Response.StatusCode = StatusCodes.Status400BadRequest;
await context.Response.BodyWriter.WriteAsync(Encoding.UTF8.GetBytes("check sign failed"));
return;
}
var app = await appQueries.GetAppAsync(appid);
if (app == null)
{
context.Response.StatusCode = StatusCodes.Status400BadRequest;
await context.Response.BodyWriter.WriteAsync(Encoding.UTF8.GetBytes("check sign failed"));
return;
}
var appKey = app.Appkey;
var signStr = string.Join(null, props.Where(p => p.Name != "Sign").OrderBy(p => p.Name).Select(p => p.GetValue(param).ToString()));
var checkSign = BitConverter.ToString( (signStr + appKey).ToMD5()).Replace("-","");
if(checkSign.ToLower() != sign.ToLower())
{
context.Response.StatusCode = StatusCodes.Status400BadRequest;
await context.Response.BodyWriter.WriteAsync(Encoding.UTF8.GetBytes("check sign failed"));
return;
}
context.Request.Body.Position = 0;
}
}
await next(context);
}
}
}

View File

@@ -14,30 +14,5 @@ namespace QRCodeService.Models
public int AppId { get; set; } public int AppId { get; set; }
public string Time { get; set; } public string Time { get; set; }
public string Sign { get; set; } public string Sign { get; set; }
public async Task<bool> CheckSignAsync(string appkey)
{
//除Sign字段以外按key名排序
var props = this.GetType().GetProperties();
var signStr = string.Join(null,props.Where(p=>p.Name!="Sign").OrderBy(p => p.Name).Select(p=>p.GetValue(this).ToString()));
var sign = Convert.ToBase64String((signStr + appkey).ToMD5());
return sign == Sign;
}
public bool CheckTime()
{
var timeDate = Time.ToDate("yyyyMMddhhmmss");
if (timeDate == null)
{
return false;
}
if (Math.Abs((timeDate.Value - DateTime.Now).TotalSeconds) > 60)
{
return false;
}
return true;
}
public async Task<bool> CheckValidAsync(string appkey)
{
return CheckTime() && await CheckSignAsync(appkey);
}
} }
} }

View File

@@ -22,6 +22,7 @@ using QRCodeService.Application.Behaviors;
using QRCodeService.Application.Commands; using QRCodeService.Application.Commands;
using QRCodeService.Application.Queries; using QRCodeService.Application.Queries;
using QRCodeService.Application.Validations; using QRCodeService.Application.Validations;
using QRCodeService.Infrastructure.Middlewares;
using QRCodeService.Options; using QRCodeService.Options;
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
@@ -104,7 +105,7 @@ namespace QRCodeService
app.UseRouting(); app.UseRouting();
app.UseAuthorization(); app.UseAuthorization();
app.UseMiddleware<CheckSignMiddleware>();
app.UseEndpoints(endpoints => app.UseEndpoints(endpoints =>
{ {
endpoints.MapControllers(); endpoints.MapControllers();