一,需求点描述:
1,支持大文件(5G、10G等)下载 且页面不蹦
2,在资源紧张时如何进行大文件下载
二,对于大文件下载需要考虑的点
1,大文件(5G、10G等)下载时,内存溢出、超时等问题
2,定义异常信息、分块传输等
3,并发下载时的并发数量限制
4,前后端对于大文件下载的瓶颈点(如浏览器内存限制等)
三,代码编写
核心代码类
java
private void doDownloadFile(String filePath, HttpServletResponse response) {
log.info("The down load file is:【{}】", filePath);
File file = new File(filePath);
if (!FileUtil.exist(file)) {
throw new CimException(-1, TipsCodeEnum.FILE_DOWNLOAD_FAIL.getTipsCode(), FILE_NOT_EXISTS);
}
long length = file.length();
double fileLength = NumberUtils.divideDouble(length, 1024 * 1024);
log.info("file length is :{} KB --> {} MB --> {} GB",length,fileLength,NumberUtils.divideDouble(length,1024*1024*1024,4));
log.info("maxLength:{}",maxLength);
if(fileLength > maxLength){
throw new CimException(-1, TipsCodeEnum.FILE_DOWNLOAD_FAIL.getTipsCode(), FILE_MORE_THAN_MAX_LENGTH);
}
try (ServletOutputStream sos = response.getOutputStream();
FileInputStream fis = new FileInputStream(file);
BufferedInputStream bis = new BufferedInputStream(fis)) {
// 获取文件名并进行URL编码
String fileName = file.getName();
// 设置HTTP响应头
response.setHeader(HttpHeaders.CONTENT_DISPOSITION, "attachment; fileName=" + URLEncoder.encode(fileName, StandardCharsets.UTF_8.toString()));
response.setHeader("Content-Type", MediaType.APPLICATION_OCTET_STREAM_VALUE);
// 禁用缓存
// response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate"); // HTTP 1.1
// response.setHeader("Pragma", "no-cache"); // HTTP 1.0
// response.setDateHeader("Expires", 0); // Proxies
// 创建一个缓冲区
byte[] buffer = new byte[4096]; // 可以根据需要进行调整
int bytesRead;
// 循环读取文件并写入响应流
while ((bytesRead = bis.read(buffer)) != -1) {
sos.write(buffer, 0, bytesRead);
}
} catch (IOException e) {
log.warn("The down load file is fail:【{}】", filePath, e);
}
log.info("doDownloadFile completed!");
}
注:这里采用分块传输(Http1.1以上版本支持),禁止一次性读取到内存导致内存溢出,需要进行流式读取并且边读边向客户端(浏览器)进行输出
NumberUtils
java
import lombok.extern.slf4j.Slf4j;
import java.math.BigDecimal;
import java.math.RoundingMode;
@Slf4j
public class NumberUtils {
public static int divide(double a,double b){
return divide(a,b,0, RoundingMode.HALF_UP);
}
public static int divide(double a,double b,int scale,RoundingMode roundingMode){
return divideBigDecimal( a, b,scale, roundingMode).intValue();
}
public static double divideDouble(double a,double b){
return divideBigDecimal( a, b,2, RoundingMode.HALF_UP).doubleValue();
}
public static double divideDouble(double a,double b,int scale){
return divideBigDecimal( a, b,scale, RoundingMode.HALF_UP).doubleValue();
}
public static double divideDouble(double a,double b,int scale,RoundingMode roundingMode){
return divideBigDecimal( a, b,scale, roundingMode).doubleValue();
}
public static BigDecimal divideBigDecimal(double a,double b,int scale,RoundingMode roundingMode){
if(b == 0){
log.error("divide -> b = 0");
throw new RuntimeException("数据异常,请联系管理员!");
}
BigDecimal aBigDecimal = new BigDecimal(a);
BigDecimal bigDecimal = new BigDecimal(b);
return aBigDecimal.divide(bigDecimal,scale, roundingMode);
}
}
以上即完成了大文件下载功能,这时候还需要考虑并发读取数量的限制,这里采用信号量+拦截器来进行实现
自定义拦截器RequestInterceptor
java
import com.taia.yms.config.ThreadLocalConf;
import com.taia.yms.exception.CustomException;
import com.taia.yms.exception.TipsCodeEnum;
import com.taia.yms.util.MyStopWatch;
import lombok.extern.slf4j.Slf4j;
import org.springframework.web.servlet.HandlerInterceptor;
import org.springframework.web.servlet.ModelAndView;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.util.Optional;
import java.util.concurrent.Semaphore;
/**
* @ClassName RequestInterceptor
* Description 拦截配置
* Date 2021/3/2 8:59
* Version 1.0
**/
@Slf4j
public class RequestInterceptor implements HandlerInterceptor {
private final Semaphore semaphore;
// 假设我们允许同时有5个下载请求
public RequestInterceptor(int count) {
this.semaphore = new Semaphore(count);
}
public RequestInterceptor() {
this.semaphore = new Semaphore(Integer.MAX_VALUE);
}
/**
* 拦截器,调用公共验证接口
*/
@Override
public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
response.addHeader("request-id",request.getHeader("request-id"));
// 尝试获取一个许可
if (!semaphore.tryAcquire()) {
throw new CustomException(TipsCodeEnum.TOO_MANY_DOWNLOAD.getMessage(),
Integer.valueOf(TipsCodeEnum.TOO_MANY_DOWNLOAD.getTipsCode()));
}
return true;
}
// 后处理回调方法,实现处理器的后处理
@Override
public void postHandle(HttpServletRequest request, HttpServletResponse response, Object handler, ModelAndView modelAndView) throws Exception {
// 未定义 暂时不实现
}
// 整个请求处理完毕后回调方法,即在视图渲染完毕时回调
@Override
public void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object handler, Exception ex) throws Exception {
// 清空本地线程-上下文
if (Optional.ofNullable(ThreadLocalConf.get()).isPresent()){
ThreadLocalConf.remove();
}
// 清空本地线程-秒表
if (Optional.ofNullable(MyStopWatch.get()).isPresent()){
MyStopWatch.remove();
}
// 在请求处理完成后释放许可
semaphore.release();
}
}
自定义异常类信息
java
public class CustomException extends RuntimeException {
private int errorCode;
public CustomException(String message, int errorCode) {
super(message);
this.errorCode = errorCode;
}
public int getErrorCode() {
return errorCode;
}
}
在全局异常拦截中进行引用
java
@RestControllerAdvice
public class GlobalExceptionHandler {
@ExceptionHandler(CustomException.class)
public JsonResult handleCustomException(CustomException e) {
return JsonResult.err(e.getErrorCode(), e.getMessage());
}
}
加载拦截器
java
import com.taia.yms.interceptor.RequestInterceptor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.servlet.config.annotation.CorsRegistry;
import org.springframework.web.servlet.config.annotation.InterceptorRegistry;
import org.springframework.web.servlet.config.annotation.PathMatchConfigurer;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;
@Configuration
@Slf4j
public class WebConfig implements WebMvcConfigurer {
@Value("${spring.download.file.maxCount}")
private int maxCount;
/**
* 跨域支持
* @param corsRegistry
*/
@Override
public void addCorsMappings(CorsRegistry corsRegistry) {
corsRegistry.addMapping("/**")
.allowCredentials(true)
.allowedHeaders("*")
.allowedMethods("GET","POST", "PUT", "DELETE")
.allowedOriginPatterns("*")
.exposedHeaders("Header1", "Header2");
}
/**
* 配置拦截策略
* addInterceptor(RequestInterceptor)为具体拦截逻辑的执行类 实现了HandlerInterceptor接口
* addPathPatterns("/test/**") 意义是访问路径下/test 下所有的访问路径都需要被RequestInterceptor拦截
* excludePathPatterns("/test/exception") 这个访问路径/test/exception则不在被RequestInterceptor拦截的范围
* @param registry
*/
@Override
public void addInterceptors(InterceptorRegistry registry) {
registry.addInterceptor(new RequestInterceptor())
.addPathPatterns("/**");
/**
* 限制 下载并发数
*/
log.info("maxCount:{}",maxCount);
registry.addInterceptor(new RequestInterceptor(maxCount)) // 假设构造函数接受一个int参数作为Semaphore的初始许可数
.addPathPatterns("/**/download");
}
/**
* 修改访问路径
* @param configurer
* configurer.setUseTrailingSlashMatch(true);
* 设置为true后,访问路径后加/ 也能正常访问 /user == /user/
*/
@Override
public void configurePathMatch(PathMatchConfigurer configurer) {
// 设置为true后,访问路径后加/ 也能正常访问 /user == /user/
}
}
至此,大文件下载功能即完成了!
编写单元测试
java
import cn.hutool.json.JSONUtil;
import com.taia.yms.YmsApplication;
import com.taia.yms.entity.reqbody.FileDownReqBody;
import lombok.extern.slf4j.Slf4j;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.http.MediaType;
import org.springframework.test.context.ActiveProfiles;
import org.springframework.test.context.junit4.SpringRunner;
import org.springframework.test.web.servlet.request.MockMvcRequestBuilders;
import org.springframework.test.web.servlet.result.MockMvcResultHandlers;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
@ActiveProfiles("dev")
@RunWith(SpringRunner.class)
@SpringBootTest(classes = YmsApplication.class)
@Slf4j
public class CommonApiControllerTest extends AbstractControllerTest{
@Test
public void download() throws Exception {
FileDownReqBody fileDownReqBody = new FileDownReqBody();
fileDownReqBody.setAbsolutePath("");
fileDownReqBody.setQualityViewId(2155509586L);
String requestBody = JSONUtil.toJsonStr(fileDownReqBody);
mockMvc.perform(
MockMvcRequestBuilders.post("/v2/commonApi/download")
.contentType(MediaType.APPLICATION_JSON).content(requestBody)
.header("Authorization",token)
).andDo(MockMvcResultHandlers.print());
}
@Test
public void testConcurrentDownloads() {
// 假设我们模拟10个并发下载请求
ExecutorService executorService = Executors.newFixedThreadPool(10);
for (int i = 0; i < 2; i++) {
executorService.submit(() -> {
try {
download();
} catch (Exception e) {
log.error("业务执行异常:{}",e.getMessage());
throw new RuntimeException(e);
}
});
}
// 关闭线程池,等待所有任务完成
executorService.shutdown();
try {
if (!executorService.awaitTermination(60, TimeUnit.SECONDS)) {
log.warn("超时未执行");
executorService.shutdownNow();
}
} catch (InterruptedException e) {
log.error("异常了:{}",e.getMessage());
executorService.shutdownNow();
Thread.currentThread().interrupt();
}
}
}
AbstractControllerTest核心代码
java
@Autowired
private WebApplicationContext webApplicationContext;
@Autowired
private ApiGlobalVarsFilter apiGlobalVarsFilter;
public MockMvc mockMvc;
@Autowired
RestTemplate restTemplate;
public String token;
@Before
public void setupMockMvc(){
//获取登录token
LoginRequest request = new LoginRequest();
request.setUserNo(USER_NO);
request.setPassword(PASSWORD);
ResponseEntity<LoginResponse> loginResponseResponseEntity =
restTemplate.postForEntity(LOGIN_URL_DEV, request, LoginResponse.class);
token = loginResponseResponseEntity.getBody().getToken();
mockMvc = MockMvcBuilders.webAppContextSetup(webApplicationContext)
.apply(SecurityMockMvcConfigurers.springSecurity())
.addFilters(apiGlobalVarsFilter).build();
}