之前实现过了批量进行insert into,replace into,insert ignore into的方法
replace into方法理论上可以实现更新,不过有两个问题,一是可能会新增不存在的数据,二是不能控制更新指定的字段,会全量覆盖。
最近碰到需要对批量更新进行优化的场景,问了AI发现可以使用 case when then end的方式实现批量更新
sql
UPDATE table
SET
column1 = CASE id
WHEN {id1} THEN {column1value1}
WHEN {id2} THEN {column1value2}
END,
column2 = CASE id
WHEN {id1} THEN {column2value1}
WHEN {id2} THEN {column2value2}
END
WHERE id IN (?,?,?)
代码如下
接口层
java
import org.springframework.data.repository.NoRepositoryBean;
@NoRepositoryBean
public interface BatchSaveRepository<T>{
/**
* UPDATE table
* SET
* column1 = CASE id
* WHEN {id1} THEN {column1value1}
* WHEN {id2} THEN {column1value2}
* END,
* column2 = CASE id
* WHEN {id1} THEN {column2value1}
* WHEN {id2} THEN {column2value2}
* END
* WHERE id IN (?,?,?)
* @param var1
* @param updateFieldNames
* @param batchInt
* @param <S>
* @throws Exception
*/
<S extends T> void batchUpdate(Iterable<S> var1, String updateFieldNames, int batchInt) throws Exception;
}
实现层
java
import java.io.Serializable;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.*;
import java.util.stream.Collectors;
import javax.persistence.Column;
import javax.persistence.EntityManager;
import javax.persistence.Query;
import javax.persistence.Table;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.data.jpa.repository.support.JpaEntityInformation;
import org.springframework.data.jpa.repository.support.SimpleJpaRepository;
import org.springframework.data.repository.NoRepositoryBean;
import org.springframework.transaction.annotation.Transactional;
@NoRepositoryBean
public class BatchSaveRepositoryImpl<T,ID extends Serializable> extends SimpleJpaRepository<T, ID> implements BatchSaveRepository<T> {
private static Logger logger = LoggerFactory.getLogger(BatchSaveRepositoryImpl.class);
private EntityManager em = null;
private static Map<String, SqlUpdateStorage> updateSqlMap = new HashMap<>();
private static final String SQL_UPDATE = "update";
public BatchSaveRepositoryImpl(JpaEntityInformation entityInformation, EntityManager entityManager) {
super(entityInformation, entityManager);
this.em = entityManager;
}
/**
* UPDATE table
* SET
* column1 = CASE id
* WHEN {id1} THEN {column1value1}
* WHEN {id2} THEN {column1value2}
* END,
* column2 = CASE id
* WHEN {id1} THEN {column2value1}
* WHEN {id2} THEN {column2value2}
* END
* WHERE id IN (?,?,?)
* @param var1
* @param updateFieldNames
* @param batchInt
* @param <S>
* @throws Exception
*/
@Transactional
@Override
public <S extends T> void batchUpdate(Iterable<S> var1, String updateFieldNames, int batchInt) throws Exception {
if(StringUtils.isEmpty(updateFieldNames)){
return;
}
if(var1 != null && var1.iterator().hasNext()){
S o1 = var1.iterator().next();
Class<?> clazz = o1.getClass();
SqlUpdateStorage sqlStorage = getSqlUpdateStorage(clazz, updateFieldNames);
String updatePrefix = "update `" + sqlStorage.getTableName() + "` set ";
StringBuilder updateSql = new StringBuilder();
int index = 0;
StringBuilder sqlBuilder = new StringBuilder(updatePrefix);
List<S> tempList = new ArrayList<>();
Iterator<S> iterator = var1.iterator();
while (iterator.hasNext()){
S next = iterator.next();
tempList.add(next);
index++;
if (index % batchInt == 0){
{
executeBatchUpdateSql(sqlStorage, sqlBuilder, tempList);
}
sqlBuilder = new StringBuilder(updatePrefix);
tempList = new ArrayList<>();
}
}
if (index % batchInt != 0){
executeBatchUpdateSql(sqlStorage, sqlBuilder, tempList);
}
}
}
private <S extends T> void executeBatchUpdateSql(SqlUpdateStorage sqlStorage, StringBuilder sqlBuilder, List<S> tempList) throws Exception {
List<List<Object>> parametersList = new ArrayList<>();
List<StringBuilder> setSqlBuilders = new ArrayList<>();
for (String columnName : sqlStorage.getUpdateColumnName()) {
setSqlBuilders.add(new StringBuilder("`"+columnName+"` = case `" + sqlStorage.getIdColumnName()+"`"));
parametersList.add(new ArrayList<>());
}
for (S s : tempList) {
for(int i = 0; i< sqlStorage.getUpdateField().size(); i++){
StringBuilder setSqlBuilder = setSqlBuilders.get(i);
setSqlBuilder.append(" WHEN ? THEN ?");
parametersList.get(i).add(getField(s, sqlStorage.getIdField().getName()));
parametersList.get(i).add(getField(s, sqlStorage.getUpdateField().get(i).getName()));
}
}
List<Object> parameters = new ArrayList<>();
for (List<Object> objects : parametersList) {
parameters.addAll(objects);
}
for (StringBuilder setSqlBuilder : setSqlBuilders) {
setSqlBuilder.append(" end,");
sqlBuilder.append(setSqlBuilder);
}
sqlBuilder.deleteCharAt(sqlBuilder.length() - 1);
sqlBuilder.append(" where `" + sqlStorage.getIdColumnName() + "` in (");
for (S s : tempList) {
sqlBuilder.append("?,");
parameters.add(getField(s, sqlStorage.getIdField().getName()));
}
sqlBuilder.deleteCharAt(sqlBuilder.length() - 1);
sqlBuilder.append(")");
Query query = em.createNativeQuery(sqlBuilder.toString());
int paramIndex = 1;
for (Object parameter : parameters) {
query.setParameter(paramIndex++, parameter);
}
query.executeUpdate();
}
private static SqlUpdateStorage getSqlUpdateStorage(Class<?> clazz, String updateFieldNames) {
Table table = clazz.getAnnotation(Table.class);
String tableName = table.name();
String key = SQL_UPDATE + tableName + ":" + updateFieldNames;
SqlUpdateStorage sqlStorage = updateSqlMap.get(key);
if(sqlStorage == null){
sqlStorage = new SqlUpdateStorage();
List<String> updateFieldName = Arrays.asList(updateFieldNames.split(","));
//判断集合是否为空
sqlStorage.setTableName(tableName);
Field[] fields = clazz.getDeclaredFields();
for(Field f: fields){
boolean id = hasAnnotation(clazz, f, Id.class);
if(id){
sqlStorage.setIdField(f);
sqlStorage.setIdColumnName(getColumnName(clazz, f));
}else if(updateFieldName.contains(f.getName())){
sqlStorage.getUpdateField().add(f);
sqlStorage.getUpdateColumnName().add(getColumnName(clazz, f));
}
}
}
return sqlStorage;
}
public static <T> Object getField(T t, String fieldName) throws Exception {
Field f = t.getClass().getDeclaredField(fieldName);
f.setAccessible(true);
return f.get(t);
}
private static String getColumnName(Class<?> clazz, Field f) {
boolean fieldHasAnno = f.isAnnotationPresent(Column.class);
String fieldName = f.getName();
if(fieldHasAnno) {//先查询field上是否有@Column
Column column = f.getAnnotation(Column.class);
String columnName = column.name();
return columnName;
}else{//再查询field的get方法上是否有@Column
String getName = "get"+fieldName.substring(0, 1).toUpperCase() + fieldName.substring(1);
try {
Method method = clazz.getDeclaredMethod(getName);
Column column = method.getAnnotation(Column.class);
if(column != null){
String columnName = column.name();
return columnName;
}
}catch (Exception e){
}
}
return "";
}
private static boolean hasAnnotation(Class<?> clazz, Field f, Class<? extends Annotation> annotationClass){
boolean fieldHasAnno = f.isAnnotationPresent(annotationClass);
if(fieldHasAnno) {//先查询field上是否有@Column
return true;
}else{//再查询field的get方法上是否有@Column
String fieldName = f.getName();
String getName = "get"+fieldName.substring(0, 1).toUpperCase() + fieldName.substring(1);
try {
Method method = clazz.getDeclaredMethod(getName);
return method.isAnnotationPresent(annotationClass);
}catch (Exception e){
}
}
return false;
}
static class SqlUpdateStorage{
private String tableName;
private String idColumnName;
private List<String> updateColumnName = new ArrayList<>();
private Field idField;
private List<Field> updateField = new ArrayList<>();
public String getTableName() {
return tableName;
}
public void setTableName(String tableName) {
this.tableName = tableName;
}
public String getIdColumnName() {
return idColumnName;
}
public void setIdColumnName(String idColumnName) {
this.idColumnName = idColumnName;
}
public List<String> getUpdateColumnName() {
return updateColumnName;
}
public void setUpdateColumnName(List<String> updateColumnName) {
this.updateColumnName = updateColumnName;
}
public Field getIdField() {
return idField;
}
public void setIdField(Field idField) {
this.idField = idField;
}
public List<Field> getUpdateField() {
return updateField;
}
public void setUpdateField(List<Field> updateField) {
this.updateField = updateField;
}
}
}
使用时用Dao层继承BatchSaveRepository接口皆可以调用这个方法了