让spring服务器做为一个反向代理,将一些请求转发给其他的服务来完成响应。实现类似于nginx的功能。
思路:
1、写一个Filter来判断路径来转发符合规则的请求(只转发后端请求且符合特定规则的请求)
2、需要判断后端的服务是否存活
3、转发需要将表单之间的&
转为最初的&
HTML中的&用&
来表示,转发过程中需要用
StringEscapeUtils.unescapeHtml3(queryString)
来反解,将&
还原为&
代码:
java
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.net.URI;
import java.net.URISyntaxException;
import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.List;
import org.apache.commons.text.StringEscapeUtils;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.core.annotation.Order;
import org.springframework.http.ContentDisposition;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.http.HttpStatusCode;
import org.springframework.http.RequestEntity;
import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Component;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StreamUtils;
import org.springframework.web.client.RestTemplate;
import jakarta.annotation.Resource;
import jakarta.servlet.Filter;
import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
import jakarta.servlet.ServletOutputStream;
import jakarta.servlet.ServletRequest;
import jakarta.servlet.ServletResponse;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import lombok.extern.slf4j.Slf4j;
@Component
@Order(2)
@Slf4j
public class ForwardFilter implements Filter {
@Value("${ha.service.host:}")
private String masterServiceHost;
@Value("${server.servlet.context-path}")
private String apiBasePath;
@Resource
private HighAvailableService highAvailableService;
@Override
public void doFilter(ServletRequest servletRequest,
ServletResponse servletResponse,
FilterChain filterChain)
throws IOException, ServletException {
HttpServletRequest request = (HttpServletRequest) servletRequest;
HttpServletResponse response = (HttpServletResponse) servletResponse;
if (request.getRequestURI().startsWith(apiBasePath) &&
highAvailableService.isMasterAlive()) {
// 转发请求
try {
final ResponseEntity<byte[]> responseEntity =
forward(request, response, masterServiceHost);
HttpStatusCode statusCode = responseEntity.getStatusCode();
if (statusCode.isError()) {
// get from self
filterChain.doFilter(servletRequest, servletResponse);
} else {
// transfer response information
response.setContentType(
responseEntity.getHeaders().getContentType()
.toString());
response.setContentLengthLong(
responseEntity.getHeaders().getContentLength());
response.setCharacterEncoding(
StandardCharsets.UTF_8.name());
response.setStatus(responseEntity.getStatusCodeValue());
final ContentDisposition contentDisposition =
responseEntity.getHeaders().getContentDisposition();
response.setHeader("Content-Disposition",
contentDisposition.toString());
ServletOutputStream outputStream = response.getOutputStream();
outputStream.write(responseEntity.getBody());
outputStream.flush();
outputStream.close();
}
} catch (Exception e) {
log.error("{}", e);
response.setStatus(400);
PrintWriter writer = response.getWriter();
writer.write(e.getClass().getName());
writer.flush();
}
} else {
filterChain.doFilter(servletRequest, servletResponse);
}
}
public ResponseEntity<byte[]> forward(HttpServletRequest request,
HttpServletResponse response,
String routeUrl) {
try {
// build up the forward URL
String forwardUrl = createForwardUrl(request, routeUrl);
RequestEntity requestEntity =
createRequestEntity(request, forwardUrl);
return route(requestEntity);
} catch (Exception e) {
return new ResponseEntity("FORWARD ERROR",
HttpStatus.INTERNAL_SERVER_ERROR);
}
}
private String createForwardUrl(HttpServletRequest request,
String routeUrl) {
String queryString = request.getQueryString();
final String decode = StringEscapeUtils.unescapeHtml3(queryString);
String requestURI = request.getRequestURI();
return routeUrl + requestURI + (decode != null ?
"?" + decode : "");
}
private RequestEntity createRequestEntity(HttpServletRequest request,
String url)
throws URISyntaxException, IOException {
String method = request.getMethod();
HttpMethod httpMethod = HttpMethod.valueOf(method);
MultiValueMap<String, String> headers = parseRequestHeader(request);
byte[] body = parseRequestBody(request);
return new RequestEntity<>(body, headers, httpMethod, new URI(url));
}
private ResponseEntity<byte[]> route(RequestEntity requestEntity) {
RestTemplate restTemplate = new RestTemplate();
return restTemplate.exchange(requestEntity, byte[].class);
}
private byte[] parseRequestBody(HttpServletRequest request)
throws IOException {
InputStream inputStream = request.getInputStream();
return StreamUtils.copyToByteArray(inputStream);
}
private MultiValueMap<String, String> parseRequestHeader(
HttpServletRequest request) {
HttpHeaders headers = new HttpHeaders();
List<String> headerNames = Collections.list(request.getHeaderNames());
for (String headerName : headerNames) {
List<String> headerValues =
Collections.list(request.getHeaders(headerName));
for (String headerValue : headerValues) {
headers.add(headerName, headerValue);
}
}
return headers;
}
}