Commit 8800476f authored by liaozan's avatar liaozan 🏀

Polish request wrapper

parent 0a348e81
......@@ -3,6 +3,7 @@ package com.schbrain.common.web;
import com.schbrain.common.web.properties.WebProperties;
import com.schbrain.common.web.servlet.CharacterEncodingServletContextInitializer;
import com.schbrain.common.web.servlet.RequestLoggingFilter;
import com.schbrain.common.web.servlet.RequestWrapperFilter;
import com.schbrain.common.web.servlet.TraceIdInitializeServletListener;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
......@@ -20,19 +21,25 @@ public class ServletComponentConfiguration {
@Bean
@ConditionalOnMissingBean
public TraceIdInitializeServletListener traceIdInitializeServletListener() {
public TraceIdInitializeServletListener defaultTraceIdInitializeServletListener() {
return new TraceIdInitializeServletListener();
}
@Bean
@ConditionalOnMissingBean
public CharacterEncodingServletContextInitializer characterEncodingServletContextInitializer(WebProperties webProperties) {
public CharacterEncodingServletContextInitializer defaultCharacterEncodingServletContextInitializer(WebProperties webProperties) {
return new CharacterEncodingServletContextInitializer(webProperties.getEncoding());
}
@Bean
@ConditionalOnMissingBean
public RequestContextFilter requestContextFilter() {
public RequestWrapperFilter defaukltRequestWrapperFilter() {
return new RequestWrapperFilter();
}
@Bean
@ConditionalOnMissingBean
public RequestContextFilter defaultRequestContextFilter() {
OrderedRequestContextFilter requestContextFilter = new OrderedRequestContextFilter();
requestContextFilter.setThreadContextInheritable(true);
return requestContextFilter;
......@@ -41,8 +48,8 @@ public class ServletComponentConfiguration {
@Bean
@ConditionalOnMissingBean
@ConditionalOnProperty(value = "schbrain.web.enable-request-logging", havingValue = "true", matchIfMissing = true)
public RequestLoggingFilter requestLoggingFilter() {
public RequestLoggingFilter defaultRequestLoggingFilter() {
return new RequestLoggingFilter();
}
}
\ No newline at end of file
}
......@@ -57,10 +57,10 @@ public class BodyParamMethodArgumentResolver extends AbstractNamedValueMethodArg
if (value == null || value.isNull()) {
return null;
}
return objectMapper.convertValue(value, getJavaType(parameter));
return objectMapper.convertValue(value, toJavaType(parameter));
}
private JavaType getJavaType(MethodParameter parameter) {
private JavaType toJavaType(MethodParameter parameter) {
Type parameterType = parameter.getNestedGenericParameterType();
return objectMapper.constructType(parameterType);
}
......@@ -68,11 +68,15 @@ public class BodyParamMethodArgumentResolver extends AbstractNamedValueMethodArg
private JsonNode getRequestBody(NativeWebRequest nativeWebRequest) throws IOException {
JsonNode requestBody = (JsonNode) nativeWebRequest.getAttribute(REQUEST_BODY_CACHE, SCOPE_REQUEST);
if (requestBody == null) {
ContentCachingRequestWrapper request = wrapRequestIfRequired(nativeWebRequest.getNativeRequest(HttpServletRequest.class));
ContentCachingRequestWrapper request = wrapRequest(nativeWebRequest);
requestBody = objectMapper.readTree(request.getInputStream());
nativeWebRequest.setAttribute(REQUEST_BODY_CACHE, requestBody, SCOPE_REQUEST);
}
return requestBody;
}
private ContentCachingRequestWrapper wrapRequest(NativeWebRequest nativeWebRequest) {
return wrapRequestIfRequired(nativeWebRequest.getNativeRequest(HttpServletRequest.class));
}
}
package com.schbrain.common.web.servlet;
import cn.hutool.core.text.CharPool;
import com.schbrain.common.web.utils.ContentCachingServletUtils;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.boot.web.servlet.filter.OrderedFilter;
......@@ -15,6 +14,7 @@ import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import static com.schbrain.common.web.utils.ContentCachingServletUtils.getRequestBody;
import static com.schbrain.common.web.utils.ContentCachingServletUtils.wrapRequestIfRequired;
/**
......@@ -55,14 +55,14 @@ public class RequestLoggingFilter extends OncePerRequestFilter implements Ordere
String method = request.getMethod();
String requestUri = request.getRequestURI();
String queryString = request.getQueryString();
String body = ContentCachingServletUtils.getRequestBody(request, false);
String requestBody = getRequestBody(request, false);
StringBuilder builder = new StringBuilder();
builder.append("requestUri: ").append(method).append(CharPool.SPACE).append(requestUri);
if (StringUtils.isNotBlank(queryString)) {
builder.append(", queryString: ").append(queryString);
}
if (StringUtils.isNotBlank(body)) {
builder.append(", body: ").append(body);
if (StringUtils.isNotBlank(requestBody)) {
builder.append(", body: ").append(requestBody);
}
builder.append(", startTime: ").append(startTime);
builder.append(", endTime: ").append(endTime);
......
package com.schbrain.common.web.servlet;
import org.springframework.boot.web.servlet.filter.OrderedFilter;
import org.springframework.core.Ordered;
import org.springframework.web.filter.OncePerRequestFilter;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import static com.schbrain.common.web.utils.ContentCachingServletUtils.wrapRequestIfRequired;
/**
* @author liaozan
* @since 2023/8/20
*/
public class RequestWrapperFilter extends OncePerRequestFilter implements OrderedFilter {
@Override
public int getOrder() {
return Ordered.HIGHEST_PRECEDENCE;
}
@Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain chain) throws ServletException, IOException {
chain.doFilter(wrapRequestIfRequired(request), response);
}
}
package com.schbrain.common.web.support.signature;
import cn.hutool.crypto.digest.DigestUtil;
import com.google.common.base.Joiner;
import com.schbrain.common.web.support.BaseHandlerInterceptor;
import org.apache.commons.lang3.StringUtils;
import org.springframework.web.method.HandlerMethod;
import org.springframework.web.util.ContentCachingRequestWrapper;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.util.Objects;
import static cn.hutool.core.text.StrPool.UNDERLINE;
import static com.schbrain.common.web.utils.ContentCachingServletUtils.getRequestBody;
import static com.schbrain.common.web.utils.ContentCachingServletUtils.wrapRequestIfRequired;
import static org.springframework.web.util.WebUtils.getNativeRequest;
public abstract class AbstractSignatureValidationInterceptor<T extends SignatureContext> extends BaseHandlerInterceptor {
private static final Joiner JOINER = Joiner.on(UNDERLINE).skipNulls();
private static final String SCH_APP_KEY = "Sch-App-Key";
private static final String SCH_TIMESTAMP = "Sch-Timestamp";
private static final String SCH_SIGNATURE = "Sch-Signature";
private static final String SCH_EXPIRE_TIME = "Sch-Expire-Time";
@Override
protected boolean preHandle(HttpServletRequest request, HttpServletResponse response, HandlerMethod handlerMethod) {
String appKey = request.getHeader(SCH_APP_KEY);
String timestamp = request.getHeader(SCH_TIMESTAMP);
String signature = request.getHeader(SCH_SIGNATURE);
String expireTime = request.getHeader(SCH_EXPIRE_TIME);
protected boolean preHandle(HttpServletRequest request, HttpServletResponse response, HandlerMethod handler) {
ContentCachingRequestWrapper wrappedRequest = getWrappedRequest(request);
String appKey = wrappedRequest.getHeader(SCH_APP_KEY);
String timestamp = wrappedRequest.getHeader(SCH_TIMESTAMP);
String signature = wrappedRequest.getHeader(SCH_SIGNATURE);
// 空校验
if (StringUtils.isAnyBlank(appKey, timestamp, signature)) {
......@@ -32,6 +38,7 @@ public abstract class AbstractSignatureValidationInterceptor<T extends Signature
}
// 过期校验
String expireTime = wrappedRequest.getHeader(SCH_EXPIRE_TIME);
if (StringUtils.isNotBlank(expireTime) && System.currentTimeMillis() > Long.parseLong(expireTime)) {
throw new SignatureValidationException("请求信息已过期!");
}
......@@ -39,16 +46,16 @@ public abstract class AbstractSignatureValidationInterceptor<T extends Signature
// 获取appSecret
SignatureContext context = getSignatureContext(appKey);
if (null == context || StringUtils.isBlank(context.getAppSecret())) {
throw new SignatureValidationException();
throw new SignatureValidationException("appSecret不存在!");
}
request = wrapRequestIfRequired(request);
String requestUri = wrappedRequest.getRequestURI();
String queryString = wrappedRequest.getQueryString();
String requestBody = getRequestBody(wrappedRequest, true);
// 校验签名
String requestUri = request.getRequestURI();
String queryString = request.getQueryString();
String body = getRequestBody(request, true);
String compareSignature = signParams(requestUri, queryString, body, timestamp, appKey, context.getAppSecret());
if (!signature.equals(compareSignature)) {
String calculatedSignature = signParams(requestUri, queryString, requestBody, timestamp, appKey, context.getAppSecret());
if (!Objects.equals(signature, calculatedSignature)) {
throw new SignatureValidationException();
}
......@@ -61,19 +68,19 @@ public abstract class AbstractSignatureValidationInterceptor<T extends Signature
SignatureContextUtil.clear();
}
protected String signParams(String requestUri, String queryString, String bodyString, String timestamp, String appKey, String appSecret) {
String toSign = JOINER.join(requestUri, queryString, bodyString, timestamp, appKey, appSecret);
return DigestUtil.sha256Hex(toSign);
}
protected abstract T getSignatureContext(String appKey);
protected String signParams(String requestUri, String queryString, String bodyString, String timestamp, String appKey, String appSecret) {
StringBuilder toSign = new StringBuilder(requestUri);
if (StringUtils.isNotBlank(queryString)) {
toSign.append(UNDERLINE).append(queryString);
private ContentCachingRequestWrapper getWrappedRequest(HttpServletRequest request) {
ContentCachingRequestWrapper wrapper = getNativeRequest(request, ContentCachingRequestWrapper.class);
if (wrapper == null) {
throw new SignatureValidationException("请求异常");
}
if (StringUtils.isNotBlank(bodyString)) {
toSign.append(UNDERLINE).append(bodyString);
}
toSign.append(UNDERLINE).append(timestamp).append(UNDERLINE).append(appKey).append(UNDERLINE).append(appSecret);
return DigestUtil.sha256Hex(toSign.toString());
return wrapper;
}
}
......@@ -5,13 +5,15 @@ import org.springframework.util.Assert;
import org.springframework.util.StreamUtils;
import org.springframework.web.util.ContentCachingRequestWrapper;
import org.springframework.web.util.ContentCachingResponseWrapper;
import org.springframework.web.util.WebUtils;
import javax.annotation.Nullable;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.nio.charset.Charset;
import static org.springframework.web.util.WebUtils.getNativeRequest;
/**
* @author liaozan
* @since 2023-05-08
......@@ -43,8 +45,12 @@ public class ContentCachingServletUtils {
}
}
/**
* Get request body content
*/
@Nullable
public static String getRequestBody(HttpServletRequest request, boolean readFromInputStream) {
ContentCachingRequestWrapper nativeRequest = WebUtils.getNativeRequest(request, ContentCachingRequestWrapper.class);
ContentCachingRequestWrapper nativeRequest = getNativeRequest(request, ContentCachingRequestWrapper.class);
if (nativeRequest == null) {
return null;
}
......@@ -53,7 +59,7 @@ public class ContentCachingServletUtils {
try {
return StreamUtils.copyToString(request.getInputStream(), charset);
} catch (IOException e) {
log.warn("Failed to read body content from request inputStream");
log.warn("Failed to read body content from request inputStream", e);
return null;
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment