Commit 8800476f authored by liaozan's avatar liaozan 🏀

Polish request wrapper

parent 0a348e81
...@@ -3,6 +3,7 @@ package com.schbrain.common.web; ...@@ -3,6 +3,7 @@ package com.schbrain.common.web;
import com.schbrain.common.web.properties.WebProperties; import com.schbrain.common.web.properties.WebProperties;
import com.schbrain.common.web.servlet.CharacterEncodingServletContextInitializer; import com.schbrain.common.web.servlet.CharacterEncodingServletContextInitializer;
import com.schbrain.common.web.servlet.RequestLoggingFilter; import com.schbrain.common.web.servlet.RequestLoggingFilter;
import com.schbrain.common.web.servlet.RequestWrapperFilter;
import com.schbrain.common.web.servlet.TraceIdInitializeServletListener; import com.schbrain.common.web.servlet.TraceIdInitializeServletListener;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
...@@ -20,19 +21,25 @@ public class ServletComponentConfiguration { ...@@ -20,19 +21,25 @@ public class ServletComponentConfiguration {
@Bean @Bean
@ConditionalOnMissingBean @ConditionalOnMissingBean
public TraceIdInitializeServletListener traceIdInitializeServletListener() { public TraceIdInitializeServletListener defaultTraceIdInitializeServletListener() {
return new TraceIdInitializeServletListener(); return new TraceIdInitializeServletListener();
} }
@Bean @Bean
@ConditionalOnMissingBean @ConditionalOnMissingBean
public CharacterEncodingServletContextInitializer characterEncodingServletContextInitializer(WebProperties webProperties) { public CharacterEncodingServletContextInitializer defaultCharacterEncodingServletContextInitializer(WebProperties webProperties) {
return new CharacterEncodingServletContextInitializer(webProperties.getEncoding()); return new CharacterEncodingServletContextInitializer(webProperties.getEncoding());
} }
@Bean @Bean
@ConditionalOnMissingBean @ConditionalOnMissingBean
public RequestContextFilter requestContextFilter() { public RequestWrapperFilter defaukltRequestWrapperFilter() {
return new RequestWrapperFilter();
}
@Bean
@ConditionalOnMissingBean
public RequestContextFilter defaultRequestContextFilter() {
OrderedRequestContextFilter requestContextFilter = new OrderedRequestContextFilter(); OrderedRequestContextFilter requestContextFilter = new OrderedRequestContextFilter();
requestContextFilter.setThreadContextInheritable(true); requestContextFilter.setThreadContextInheritable(true);
return requestContextFilter; return requestContextFilter;
...@@ -41,7 +48,7 @@ public class ServletComponentConfiguration { ...@@ -41,7 +48,7 @@ public class ServletComponentConfiguration {
@Bean @Bean
@ConditionalOnMissingBean @ConditionalOnMissingBean
@ConditionalOnProperty(value = "schbrain.web.enable-request-logging", havingValue = "true", matchIfMissing = true) @ConditionalOnProperty(value = "schbrain.web.enable-request-logging", havingValue = "true", matchIfMissing = true)
public RequestLoggingFilter requestLoggingFilter() { public RequestLoggingFilter defaultRequestLoggingFilter() {
return new RequestLoggingFilter(); return new RequestLoggingFilter();
} }
......
...@@ -57,10 +57,10 @@ public class BodyParamMethodArgumentResolver extends AbstractNamedValueMethodArg ...@@ -57,10 +57,10 @@ public class BodyParamMethodArgumentResolver extends AbstractNamedValueMethodArg
if (value == null || value.isNull()) { if (value == null || value.isNull()) {
return null; return null;
} }
return objectMapper.convertValue(value, getJavaType(parameter)); return objectMapper.convertValue(value, toJavaType(parameter));
} }
private JavaType getJavaType(MethodParameter parameter) { private JavaType toJavaType(MethodParameter parameter) {
Type parameterType = parameter.getNestedGenericParameterType(); Type parameterType = parameter.getNestedGenericParameterType();
return objectMapper.constructType(parameterType); return objectMapper.constructType(parameterType);
} }
...@@ -68,11 +68,15 @@ public class BodyParamMethodArgumentResolver extends AbstractNamedValueMethodArg ...@@ -68,11 +68,15 @@ public class BodyParamMethodArgumentResolver extends AbstractNamedValueMethodArg
private JsonNode getRequestBody(NativeWebRequest nativeWebRequest) throws IOException { private JsonNode getRequestBody(NativeWebRequest nativeWebRequest) throws IOException {
JsonNode requestBody = (JsonNode) nativeWebRequest.getAttribute(REQUEST_BODY_CACHE, SCOPE_REQUEST); JsonNode requestBody = (JsonNode) nativeWebRequest.getAttribute(REQUEST_BODY_CACHE, SCOPE_REQUEST);
if (requestBody == null) { if (requestBody == null) {
ContentCachingRequestWrapper request = wrapRequestIfRequired(nativeWebRequest.getNativeRequest(HttpServletRequest.class)); ContentCachingRequestWrapper request = wrapRequest(nativeWebRequest);
requestBody = objectMapper.readTree(request.getInputStream()); requestBody = objectMapper.readTree(request.getInputStream());
nativeWebRequest.setAttribute(REQUEST_BODY_CACHE, requestBody, SCOPE_REQUEST); nativeWebRequest.setAttribute(REQUEST_BODY_CACHE, requestBody, SCOPE_REQUEST);
} }
return requestBody; return requestBody;
} }
private ContentCachingRequestWrapper wrapRequest(NativeWebRequest nativeWebRequest) {
return wrapRequestIfRequired(nativeWebRequest.getNativeRequest(HttpServletRequest.class));
}
} }
package com.schbrain.common.web.servlet; package com.schbrain.common.web.servlet;
import cn.hutool.core.text.CharPool; import cn.hutool.core.text.CharPool;
import com.schbrain.common.web.utils.ContentCachingServletUtils;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.springframework.boot.web.servlet.filter.OrderedFilter; import org.springframework.boot.web.servlet.filter.OrderedFilter;
...@@ -15,6 +14,7 @@ import javax.servlet.http.HttpServletRequest; ...@@ -15,6 +14,7 @@ import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
import java.io.IOException; import java.io.IOException;
import static com.schbrain.common.web.utils.ContentCachingServletUtils.getRequestBody;
import static com.schbrain.common.web.utils.ContentCachingServletUtils.wrapRequestIfRequired; import static com.schbrain.common.web.utils.ContentCachingServletUtils.wrapRequestIfRequired;
/** /**
...@@ -55,14 +55,14 @@ public class RequestLoggingFilter extends OncePerRequestFilter implements Ordere ...@@ -55,14 +55,14 @@ public class RequestLoggingFilter extends OncePerRequestFilter implements Ordere
String method = request.getMethod(); String method = request.getMethod();
String requestUri = request.getRequestURI(); String requestUri = request.getRequestURI();
String queryString = request.getQueryString(); String queryString = request.getQueryString();
String body = ContentCachingServletUtils.getRequestBody(request, false); String requestBody = getRequestBody(request, false);
StringBuilder builder = new StringBuilder(); StringBuilder builder = new StringBuilder();
builder.append("requestUri: ").append(method).append(CharPool.SPACE).append(requestUri); builder.append("requestUri: ").append(method).append(CharPool.SPACE).append(requestUri);
if (StringUtils.isNotBlank(queryString)) { if (StringUtils.isNotBlank(queryString)) {
builder.append(", queryString: ").append(queryString); builder.append(", queryString: ").append(queryString);
} }
if (StringUtils.isNotBlank(body)) { if (StringUtils.isNotBlank(requestBody)) {
builder.append(", body: ").append(body); builder.append(", body: ").append(requestBody);
} }
builder.append(", startTime: ").append(startTime); builder.append(", startTime: ").append(startTime);
builder.append(", endTime: ").append(endTime); builder.append(", endTime: ").append(endTime);
......
package com.schbrain.common.web.servlet;
import org.springframework.boot.web.servlet.filter.OrderedFilter;
import org.springframework.core.Ordered;
import org.springframework.web.filter.OncePerRequestFilter;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import static com.schbrain.common.web.utils.ContentCachingServletUtils.wrapRequestIfRequired;
/**
* @author liaozan
* @since 2023/8/20
*/
public class RequestWrapperFilter extends OncePerRequestFilter implements OrderedFilter {
@Override
public int getOrder() {
return Ordered.HIGHEST_PRECEDENCE;
}
@Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain chain) throws ServletException, IOException {
chain.doFilter(wrapRequestIfRequired(request), response);
}
}
package com.schbrain.common.web.support.signature; package com.schbrain.common.web.support.signature;
import cn.hutool.crypto.digest.DigestUtil; import cn.hutool.crypto.digest.DigestUtil;
import com.google.common.base.Joiner;
import com.schbrain.common.web.support.BaseHandlerInterceptor; import com.schbrain.common.web.support.BaseHandlerInterceptor;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.springframework.web.method.HandlerMethod; import org.springframework.web.method.HandlerMethod;
import org.springframework.web.util.ContentCachingRequestWrapper;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
import java.util.Objects;
import static cn.hutool.core.text.StrPool.UNDERLINE; import static cn.hutool.core.text.StrPool.UNDERLINE;
import static com.schbrain.common.web.utils.ContentCachingServletUtils.getRequestBody; import static com.schbrain.common.web.utils.ContentCachingServletUtils.getRequestBody;
import static com.schbrain.common.web.utils.ContentCachingServletUtils.wrapRequestIfRequired; import static org.springframework.web.util.WebUtils.getNativeRequest;
public abstract class AbstractSignatureValidationInterceptor<T extends SignatureContext> extends BaseHandlerInterceptor { public abstract class AbstractSignatureValidationInterceptor<T extends SignatureContext> extends BaseHandlerInterceptor {
private static final Joiner JOINER = Joiner.on(UNDERLINE).skipNulls();
private static final String SCH_APP_KEY = "Sch-App-Key"; private static final String SCH_APP_KEY = "Sch-App-Key";
private static final String SCH_TIMESTAMP = "Sch-Timestamp"; private static final String SCH_TIMESTAMP = "Sch-Timestamp";
private static final String SCH_SIGNATURE = "Sch-Signature"; private static final String SCH_SIGNATURE = "Sch-Signature";
private static final String SCH_EXPIRE_TIME = "Sch-Expire-Time"; private static final String SCH_EXPIRE_TIME = "Sch-Expire-Time";
@Override @Override
protected boolean preHandle(HttpServletRequest request, HttpServletResponse response, HandlerMethod handlerMethod) { protected boolean preHandle(HttpServletRequest request, HttpServletResponse response, HandlerMethod handler) {
String appKey = request.getHeader(SCH_APP_KEY); ContentCachingRequestWrapper wrappedRequest = getWrappedRequest(request);
String timestamp = request.getHeader(SCH_TIMESTAMP);
String signature = request.getHeader(SCH_SIGNATURE); String appKey = wrappedRequest.getHeader(SCH_APP_KEY);
String expireTime = request.getHeader(SCH_EXPIRE_TIME); String timestamp = wrappedRequest.getHeader(SCH_TIMESTAMP);
String signature = wrappedRequest.getHeader(SCH_SIGNATURE);
// 空校验 // 空校验
if (StringUtils.isAnyBlank(appKey, timestamp, signature)) { if (StringUtils.isAnyBlank(appKey, timestamp, signature)) {
...@@ -32,6 +38,7 @@ public abstract class AbstractSignatureValidationInterceptor<T extends Signature ...@@ -32,6 +38,7 @@ public abstract class AbstractSignatureValidationInterceptor<T extends Signature
} }
// 过期校验 // 过期校验
String expireTime = wrappedRequest.getHeader(SCH_EXPIRE_TIME);
if (StringUtils.isNotBlank(expireTime) && System.currentTimeMillis() > Long.parseLong(expireTime)) { if (StringUtils.isNotBlank(expireTime) && System.currentTimeMillis() > Long.parseLong(expireTime)) {
throw new SignatureValidationException("请求信息已过期!"); throw new SignatureValidationException("请求信息已过期!");
} }
...@@ -39,16 +46,16 @@ public abstract class AbstractSignatureValidationInterceptor<T extends Signature ...@@ -39,16 +46,16 @@ public abstract class AbstractSignatureValidationInterceptor<T extends Signature
// 获取appSecret // 获取appSecret
SignatureContext context = getSignatureContext(appKey); SignatureContext context = getSignatureContext(appKey);
if (null == context || StringUtils.isBlank(context.getAppSecret())) { if (null == context || StringUtils.isBlank(context.getAppSecret())) {
throw new SignatureValidationException(); throw new SignatureValidationException("appSecret不存在!");
} }
request = wrapRequestIfRequired(request); String requestUri = wrappedRequest.getRequestURI();
String queryString = wrappedRequest.getQueryString();
String requestBody = getRequestBody(wrappedRequest, true);
// 校验签名 // 校验签名
String requestUri = request.getRequestURI(); String calculatedSignature = signParams(requestUri, queryString, requestBody, timestamp, appKey, context.getAppSecret());
String queryString = request.getQueryString(); if (!Objects.equals(signature, calculatedSignature)) {
String body = getRequestBody(request, true);
String compareSignature = signParams(requestUri, queryString, body, timestamp, appKey, context.getAppSecret());
if (!signature.equals(compareSignature)) {
throw new SignatureValidationException(); throw new SignatureValidationException();
} }
...@@ -61,19 +68,19 @@ public abstract class AbstractSignatureValidationInterceptor<T extends Signature ...@@ -61,19 +68,19 @@ public abstract class AbstractSignatureValidationInterceptor<T extends Signature
SignatureContextUtil.clear(); SignatureContextUtil.clear();
} }
protected abstract T getSignatureContext(String appKey);
protected String signParams(String requestUri, String queryString, String bodyString, String timestamp, String appKey, String appSecret) { protected String signParams(String requestUri, String queryString, String bodyString, String timestamp, String appKey, String appSecret) {
StringBuilder toSign = new StringBuilder(requestUri); String toSign = JOINER.join(requestUri, queryString, bodyString, timestamp, appKey, appSecret);
if (StringUtils.isNotBlank(queryString)) { return DigestUtil.sha256Hex(toSign);
toSign.append(UNDERLINE).append(queryString);
}
if (StringUtils.isNotBlank(bodyString)) {
toSign.append(UNDERLINE).append(bodyString);
} }
toSign.append(UNDERLINE).append(timestamp).append(UNDERLINE).append(appKey).append(UNDERLINE).append(appSecret);
return DigestUtil.sha256Hex(toSign.toString()); protected abstract T getSignatureContext(String appKey);
private ContentCachingRequestWrapper getWrappedRequest(HttpServletRequest request) {
ContentCachingRequestWrapper wrapper = getNativeRequest(request, ContentCachingRequestWrapper.class);
if (wrapper == null) {
throw new SignatureValidationException("请求异常");
}
return wrapper;
} }
} }
...@@ -5,13 +5,15 @@ import org.springframework.util.Assert; ...@@ -5,13 +5,15 @@ import org.springframework.util.Assert;
import org.springframework.util.StreamUtils; import org.springframework.util.StreamUtils;
import org.springframework.web.util.ContentCachingRequestWrapper; import org.springframework.web.util.ContentCachingRequestWrapper;
import org.springframework.web.util.ContentCachingResponseWrapper; import org.springframework.web.util.ContentCachingResponseWrapper;
import org.springframework.web.util.WebUtils;
import javax.annotation.Nullable;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
import java.io.IOException; import java.io.IOException;
import java.nio.charset.Charset; import java.nio.charset.Charset;
import static org.springframework.web.util.WebUtils.getNativeRequest;
/** /**
* @author liaozan * @author liaozan
* @since 2023-05-08 * @since 2023-05-08
...@@ -43,8 +45,12 @@ public class ContentCachingServletUtils { ...@@ -43,8 +45,12 @@ public class ContentCachingServletUtils {
} }
} }
/**
* Get request body content
*/
@Nullable
public static String getRequestBody(HttpServletRequest request, boolean readFromInputStream) { public static String getRequestBody(HttpServletRequest request, boolean readFromInputStream) {
ContentCachingRequestWrapper nativeRequest = WebUtils.getNativeRequest(request, ContentCachingRequestWrapper.class); ContentCachingRequestWrapper nativeRequest = getNativeRequest(request, ContentCachingRequestWrapper.class);
if (nativeRequest == null) { if (nativeRequest == null) {
return null; return null;
} }
...@@ -53,7 +59,7 @@ public class ContentCachingServletUtils { ...@@ -53,7 +59,7 @@ public class ContentCachingServletUtils {
try { try {
return StreamUtils.copyToString(request.getInputStream(), charset); return StreamUtils.copyToString(request.getInputStream(), charset);
} catch (IOException e) { } catch (IOException e) {
log.warn("Failed to read body content from request inputStream"); log.warn("Failed to read body content from request inputStream", e);
return null; return null;
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment