Commit f249c534 authored by liaozan's avatar liaozan 🏀

Refactor request wrapper

parent e8e69aab
......@@ -33,22 +33,22 @@ public class ServletComponentConfiguration {
@Bean
@ConditionalOnMissingBean
public RequestWrapperFilter defaukltRequestWrapperFilter() {
return new RequestWrapperFilter();
public RequestContextFilter requestContextFilter() {
OrderedRequestContextFilter requestContextFilter = new OrderedRequestContextFilter();
requestContextFilter.setThreadContextInheritable(true);
return requestContextFilter;
}
@Bean
@ConditionalOnMissingBean
public RequestContextFilter defaultRequestContextFilter() {
OrderedRequestContextFilter requestContextFilter = new OrderedRequestContextFilter();
requestContextFilter.setThreadContextInheritable(true);
return requestContextFilter;
public RequestWrapperFilter requestWrapperFilter() {
return new RequestWrapperFilter();
}
@Bean
@ConditionalOnMissingBean
@ConditionalOnProperty(value = "schbrain.web.enable-request-logging", havingValue = "true", matchIfMissing = true)
public RequestLoggingFilter defaultRequestLoggingFilter() {
public RequestLoggingFilter requestLoggingFilter() {
return new RequestLoggingFilter();
}
......
......@@ -5,18 +5,18 @@ import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.schbrain.common.util.JacksonUtils;
import com.schbrain.common.web.annotation.BodyParam;
import com.schbrain.common.web.servlet.ContentCachingRequest;
import lombok.Setter;
import org.springframework.core.MethodParameter;
import org.springframework.util.Assert;
import org.springframework.web.context.request.NativeWebRequest;
import org.springframework.web.method.annotation.AbstractNamedValueMethodArgumentResolver;
import org.springframework.web.util.ContentCachingRequestWrapper;
import javax.servlet.http.HttpServletRequest;
import java.io.IOException;
import java.lang.reflect.Type;
import static com.schbrain.common.web.utils.ContentCachingServletUtils.wrapRequestIfRequired;
import static com.schbrain.common.web.utils.RequestContentCachingUtils.wrapIfRequired;
import static org.springframework.web.context.request.RequestAttributes.SCOPE_REQUEST;
/**
......@@ -68,15 +68,15 @@ public class BodyParamMethodArgumentResolver extends AbstractNamedValueMethodArg
private JsonNode getRequestBody(NativeWebRequest nativeWebRequest) throws IOException {
JsonNode requestBody = (JsonNode) nativeWebRequest.getAttribute(REQUEST_BODY_CACHE, SCOPE_REQUEST);
if (requestBody == null) {
ContentCachingRequestWrapper request = wrapRequest(nativeWebRequest);
ContentCachingRequest request = wrapRequest(nativeWebRequest);
requestBody = objectMapper.readTree(request.getInputStream());
nativeWebRequest.setAttribute(REQUEST_BODY_CACHE, requestBody, SCOPE_REQUEST);
}
return requestBody;
}
private ContentCachingRequestWrapper wrapRequest(NativeWebRequest nativeWebRequest) {
return wrapRequestIfRequired(nativeWebRequest.getNativeRequest(HttpServletRequest.class));
private ContentCachingRequest wrapRequest(NativeWebRequest request) {
return wrapIfRequired(request.getNativeRequest(HttpServletRequest.class));
}
}
package com.schbrain.common.web.servlet;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.StreamUtils;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.IOException;
/**
* @author liaozan
* @since 2023/8/22
*/
@Slf4j
public class ContentCachingRequest extends HttpServletRequestWrapper {
private WrappedByteArrayInputStream inputStream;
public ContentCachingRequest(HttpServletRequest request) {
super(request);
}
@Override
public WrappedByteArrayInputStream getInputStream() throws IOException {
if (inputStream == null) {
byte[] bytes = StreamUtils.copyToByteArray(super.getInputStream());
this.inputStream = new WrappedByteArrayInputStream(bytes);
}
return inputStream;
}
/**
* Return the cached request content as a String. The Charset used to decode the cached content is the same as returned by getCharacterEncoding.
*/
public String getContentAsString() {
return getContentAsString(getCharacterEncoding());
}
/**
* Return the cached request content as a String
*/
public String getContentAsString(String charset) {
try {
return getInputStream().getContentAsString(charset);
} catch (IOException e) {
throw new IllegalStateException(e.getMessage(), e);
}
}
}
......@@ -14,8 +14,8 @@ import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import static com.schbrain.common.web.utils.ContentCachingServletUtils.getRequestBody;
import static com.schbrain.common.web.utils.ContentCachingServletUtils.wrapRequestIfRequired;
import static com.schbrain.common.web.utils.RequestContentCachingUtils.getRequestBody;
import static com.schbrain.common.web.utils.RequestContentCachingUtils.wrapIfRequired;
/**
* 请求日志拦截器
......@@ -35,7 +35,7 @@ public class RequestLoggingFilter extends OncePerRequestFilter implements Ordere
return;
}
request = wrapRequestIfRequired(request);
request = wrapIfRequired(request);
long startTime = System.currentTimeMillis();
try {
......@@ -55,7 +55,7 @@ public class RequestLoggingFilter extends OncePerRequestFilter implements Ordere
String method = request.getMethod();
String requestUri = request.getRequestURI();
String queryString = request.getQueryString();
String requestBody = getRequestBody(request, false);
String requestBody = getRequestBody(request);
StringBuilder builder = new StringBuilder();
builder.append("requestUri: ").append(method).append(CharPool.SPACE).append(requestUri);
if (StringUtils.isNotBlank(queryString)) {
......
......@@ -10,7 +10,7 @@ import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import static com.schbrain.common.web.utils.ContentCachingServletUtils.wrapRequestIfRequired;
import static com.schbrain.common.web.utils.RequestContentCachingUtils.wrapIfRequired;
/**
* @author liaozan
......@@ -25,7 +25,7 @@ public class RequestWrapperFilter extends OncePerRequestFilter implements Ordere
@Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain chain) throws ServletException, IOException {
chain.doFilter(wrapRequestIfRequired(request), response);
chain.doFilter(wrapIfRequired(request), response);
}
}
package com.schbrain.common.web.servlet;
import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import java.io.ByteArrayInputStream;
import java.nio.charset.Charset;
/**
* @author liaozan
* @since 2023/8/22
*/
public class WrappedByteArrayInputStream extends ServletInputStream {
private final ByteArrayInputStreamWrapper delegate;
public WrappedByteArrayInputStream(byte[] bytes) {
this.delegate = new ByteArrayInputStreamWrapper(bytes);
}
@Override
public boolean isFinished() {
return delegate.available() == 0;
}
@Override
public boolean isReady() {
return true;
}
@Override
public void setReadListener(ReadListener ignore) {
}
@Override
public int read() {
return delegate.read();
}
/**
* Return the cached request content as a String
*/
public String getContentAsString(String charset) {
return new String(delegate.getBytes(), Charset.forName(charset));
}
/**
* Simple {@link ByteArrayInputStream} wrapper that exposes the underlying byte array.
*/
private static class ByteArrayInputStreamWrapper extends ByteArrayInputStream {
public ByteArrayInputStreamWrapper(byte[] bytes) {
super(bytes);
}
public byte[] getBytes() {
return buf;
}
}
}
......@@ -14,7 +14,7 @@ import java.util.List;
import java.util.Objects;
import static cn.hutool.core.text.StrPool.UNDERLINE;
import static com.schbrain.common.web.utils.ContentCachingServletUtils.getRequestBody;
import static com.schbrain.common.web.utils.RequestContentCachingUtils.getRequestBody;
import static org.springframework.web.util.WebUtils.getNativeRequest;
public abstract class AbstractSignatureValidationInterceptor<T extends SignatureContext> extends BaseHandlerInterceptor {
......@@ -51,7 +51,7 @@ public abstract class AbstractSignatureValidationInterceptor<T extends Signature
String requestUri = wrappedRequest.getRequestURI();
String queryString = wrappedRequest.getQueryString();
String requestBody = getRequestBody(wrappedRequest, true);
String requestBody = getRequestBody(wrappedRequest);
// 校验签名
String calculatedSignature = signParams(requestUri, queryString, requestBody, timestamp, appKey, context.getAppSecret());
......
package com.schbrain.common.web.utils;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.Assert;
import org.springframework.util.StreamUtils;
import org.springframework.web.util.ContentCachingRequestWrapper;
import org.springframework.web.util.ContentCachingResponseWrapper;
import javax.annotation.Nullable;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.nio.charset.Charset;
import static org.springframework.web.util.WebUtils.getNativeRequest;
/**
* @author liaozan
* @since 2023-05-08
*/
@Slf4j
public class ContentCachingServletUtils {
/**
* Make request content cacheable to avoid stream closed error after inputStream closed
*/
public static ContentCachingRequestWrapper wrapRequestIfRequired(HttpServletRequest request) {
Assert.notNull(request, "request must not be null");
if (request instanceof ContentCachingRequestWrapper) {
return (ContentCachingRequestWrapper) request;
} else {
return new ContentCachingRequestWrapper(request);
}
}
/**
* Make response content cacheable to avoid stream closed error after outputStream closed
*/
public static ContentCachingResponseWrapper wrapResponseIfRequired(HttpServletResponse response) {
Assert.notNull(response, "response must not be null");
if (response instanceof ContentCachingResponseWrapper) {
return (ContentCachingResponseWrapper) response;
} else {
return new ContentCachingResponseWrapper(response);
}
}
/**
* Get request body content
*/
@Nullable
public static String getRequestBody(HttpServletRequest request, boolean readFromInputStream) {
ContentCachingRequestWrapper nativeRequest = getNativeRequest(request, ContentCachingRequestWrapper.class);
if (nativeRequest == null) {
return null;
}
Charset charset = Charset.forName(nativeRequest.getCharacterEncoding());
if (readFromInputStream) {
try {
return StreamUtils.copyToString(request.getInputStream(), charset);
} catch (IOException e) {
log.warn("Failed to read body content from request inputStream", e);
return null;
}
}
return new String(nativeRequest.getContentAsByteArray(), charset);
}
}
package com.schbrain.common.web.utils;
import com.schbrain.common.web.servlet.ContentCachingRequest;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.Assert;
import javax.annotation.Nullable;
import javax.servlet.http.HttpServletRequest;
import static org.springframework.web.util.WebUtils.getNativeRequest;
/**
* @author liaozan
* @since 2023-05-08
*/
@Slf4j
public class RequestContentCachingUtils {
/**
* Make request content cacheable to avoid stream closed error after inputStream closed
*/
public static ContentCachingRequest wrapIfRequired(HttpServletRequest request) {
Assert.notNull(request, "request must not be null");
if (request instanceof ContentCachingRequest) {
return (ContentCachingRequest) request;
} else {
return new ContentCachingRequest(request);
}
}
/**
* Get request body content
*/
@Nullable
public static String getRequestBody(HttpServletRequest request) {
ContentCachingRequest requestToUse = getNativeRequest(request, ContentCachingRequest.class);
if (requestToUse == null) {
log.warn("request is not an instance of {}", ContentCachingRequest.class.getSimpleName());
return null;
}
return requestToUse.getContentAsString();
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment