作者简介:大家好,我是smart哥,前中兴通讯、美团架构师,现某互联网公司CTO
联系qq:184480602,加我进群,大家一起学习,一起进步,一起对抗互联网寒冬
最核心的内容前两篇已经讲完了,这一篇只有代码:
先看demo目录下的三个文件:
DemoApplication.java
package com.example.demo;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
@SpringBootApplication
public class DemoApplication {
public static void main(String[] args) {
SpringApplication.run(DemoApplication.class, args);
}
}
User.java
package com.example.demo;
import com.example.demo.mybatisplus.annotations.TableName;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.util.Date;
/**
* @author mx
*/
@Data
@TableName("t_user")
@AllArgsConstructor
@NoArgsConstructor
public class User {
private Long id;
private String name;
private Integer age;
private Date birthday;
}
UserMapper.java
package com.example.demo;
import com.example.demo.mybatisplus.AbstractBaseMapper;
/**
* @author mx
*/
public class UserMapper extends AbstractBaseMapper<User> {
}
mybatisplus下AbstractBaseMapper.java
package com.example.demo.mybatisplus;
import com.example.demo.mybatisplus.annotations.TableName;
import com.example.demo.mybatisplus.core.JdbcTemplate;
import com.example.demo.mybatisplus.query.QueryWrapper;
import com.example.demo.mybatisplus.query.SqlParam;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.lang.reflect.Field;
import java.lang.reflect.ParameterizedType;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
/**
* Mapper基类
*
* @author mx
*/
public abstract class AbstractBaseMapper<T> {
private static Logger logger = LoggerFactory.getLogger(AbstractBaseMapper.class);
private JdbcTemplate<T> jdbcTemplate = new JdbcTemplate<T>();
private Class<T> beanClass;
private final String TABLE_NAME;
private static final String DEFAULT_LOGICAL_TYPE = " and ";
public AbstractBaseMapper() {
// DO对象的Class
beanClass = (Class<T>) ((ParameterizedType) this.getClass()
.getGenericSuperclass())
.getActualTypeArguments()[0];
// DO对应的表名 TODO 非空判断及默认处理
TABLE_NAME = beanClass.getAnnotation(TableName.class).value();
}
public T select(QueryWrapper<T> queryWrapper) {
List<T> list = this.list(queryWrapper);
if (!list.isEmpty()) {
return list.get(0);
}
return null;
}
public List<T> list(QueryWrapper<T> queryWrapper) {
StringBuilder sqlBuilder = new StringBuilder("SELECT * FROM ").append(TABLE_NAME).append(" WHERE ");
List<Object> paramList = new ArrayList<>();
Map<String, SqlParam> conditionMap = queryWrapper.build();
conditionMap.forEach((operator, param) -> {
sqlBuilder.append(param.getColumnName()).append(operator).append("?").append(DEFAULT_LOGICAL_TYPE);
paramList.add(param.getValue());
});
// 删除最后一个 and
String sql = sqlBuilder.replace(sqlBuilder.length() - DEFAULT_LOGICAL_TYPE.length(), sqlBuilder.length(), ";").toString();
try {
logger.info("sql: {}", sql);
logger.info("params: {}", paramList);
return jdbcTemplate.queryForList(sql, paramList, beanClass);
} catch (Exception e) {
e.printStackTrace();
logger.error("query failed", e);
}
return Collections.emptyList();
}
public int insert(T bean) {
// 得到DO对象的所有字段
Field[] declaredFields = beanClass.getDeclaredFields();
// 拼接sql语句,表名来自DO的TableName注解value
StringBuilder sqlBuilder = new StringBuilder()
.append("INSERT INTO ")
.append(TABLE_NAME)
.append(" VALUES(");
for (int i = 0; i < declaredFields.length; i++) {
sqlBuilder.append("?");
if (i < declaredFields.length - 1) {
sqlBuilder.append(",");
}
}
sqlBuilder.append(")");
// 收集sql参数
ArrayList<Object> paramList = new ArrayList<>();
try {
for (Field declaredField : declaredFields) {
declaredField.setAccessible(true);
Object o = declaredField.get(bean);
paramList.add(o);
}
} catch (IllegalAccessException e) {
e.printStackTrace();
}
int affectedRows = 0;
try {
logger.info("sql: {}", sqlBuilder.toString());
logger.info("params: {}", paramList);
affectedRows = jdbcTemplate.update(sqlBuilder.toString(), paramList);
logger.info("insert success, affectedRows: {}", affectedRows);
return affectedRows;
} catch (SQLException e) {
e.printStackTrace();
logger.error("insert failed", e);
}
return 0;
}
public int updateSelective(T bean, QueryWrapper<T> queryWrapper) {
// 得到DO对象的所有字段
Field[] declaredFields = beanClass.getDeclaredFields();
// 拼接sql语句,表名来自DO的TableName注解value
StringBuilder sqlSetBuilder = new StringBuilder()
.append("UPDATE ")
.append(TABLE_NAME)
.append(" SET ");
List<Object> paramList = new ArrayList<>();
// 先拼接要SET的字段占位符 SET name=?, age=?
try {
for (int i = 0; i < declaredFields.length; i++) {
Field declaredField = declaredFields[i];
declaredField.setAccessible(true);
Object fieldValue = declaredField.get(bean);
if (fieldValue != null) {
sqlSetBuilder.append(declaredField.getName()).append(" = ").append("?").append(", ");
paramList.add(fieldValue);
}
}
} catch (IllegalAccessException e) {
e.printStackTrace();
}
// 删除最后一个 ,
sqlSetBuilder = sqlSetBuilder.delete(sqlSetBuilder.length() - 2, sqlSetBuilder.length());
// 再拼接WHERE条件占位符
StringBuilder sqlWhereBuilder = new StringBuilder(" WHERE ");
Map<String, SqlParam> conditionMap = queryWrapper.build();
for (Map.Entry<String, SqlParam> stringSqlParamEntry : conditionMap.entrySet()) {
String operator = stringSqlParamEntry.getKey();
SqlParam param = stringSqlParamEntry.getValue();
sqlWhereBuilder.append(param.getColumnName()).append(operator).append("?").append(DEFAULT_LOGICAL_TYPE);
paramList.add(param.getValue());
}
// 删除最后一个 and
sqlWhereBuilder = sqlWhereBuilder.replace(sqlWhereBuilder.length() - DEFAULT_LOGICAL_TYPE.length(), sqlWhereBuilder.length(), ";");
String sql = sqlSetBuilder.append(sqlWhereBuilder).toString();
int affectedRows = 0;
try {
logger.info("sql: {}", sqlSetBuilder.toString());
logger.info("params: {}", paramList);
affectedRows = jdbcTemplate.update(sql, paramList);
logger.info("update success, affectedRows: {}", affectedRows);
return affectedRows;
} catch (SQLException e) {
e.printStackTrace();
logger.error("update failed", e);
}
return 0;
}
}
annotations下的TableName.java
package com.example.demo.mybatisplus.annotations;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
/**
* @author mx
*/
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.TYPE)
public @interface TableName {
String value();
}
core下的
JdbcTemplate.java
package com.example.demo.mybatisplus.core;
import java.lang.reflect.Field;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;
/**
* JdbcTemplate,简化jdbc操作
*
* @author mx
*/
public class JdbcTemplate<T> {
public List<T> queryForList(String sql, List<Object> params, RowMapper<T> rowMapper) throws SQLException {
return query(sql, params, rowMapper);
}
public T queryForObject(String sql, List<Object> params, RowMapper<T> rowMapper) throws SQLException {
List<T> result = query(sql, params, rowMapper);
return result.isEmpty() ? null : result.get(0);
}
public List<T> queryForList(String sql, List<Object> params, Class<T> clazz) throws Exception {
return query(sql, params, clazz);
}
public T queryForObject(String sql, List<Object> params, Class<T> clazz) throws Exception {
List<T> result = query(sql, params, clazz);
return result.isEmpty() ? null : result.get(0);
}
public int update(String sql, List<Object> params) throws SQLException {
// 1.获取Connection
Connection conn = getConnection();
// 2.传入sql模板、sql参数,得到PreparedStatement
PreparedStatement ps = getPreparedStatement(sql, params, conn);
// 3.执行更新(增删改)
int affectedRows = ps.executeUpdate();
// 4.释放资源
closeConnection(conn, ps, null);
return affectedRows;
}
// ************************* private methods **************************
private List<T> query(String sql, List<Object> params, RowMapper<T> rowMapper) throws SQLException {
// 外部传入rowMapper(手写规则)
return baseQuery(sql, params, rowMapper);
}
private List<T> query(String sql, List<Object> params, Class<T> clazz) throws Exception {
// 自己创建rowMapper(反射)后传入
BeanHandler<T> beanHandler = new BeanHandler<>(clazz);
return baseQuery(sql, params, beanHandler);
}
/**
* 基础查询方法,必须传入Bean的映射规则
*
* @param sql
* @param params
* @param rowMapper
* @return
* @throws SQLException
*/
private List<T> baseQuery(String sql, List<Object> params, RowMapper<T> rowMapper) throws SQLException {
// TODO 参数非空校验
// 1.获取Connection
Connection conn = getConnection();
// 2.传入sql模板、sql参数,得到PreparedStatement
PreparedStatement ps = getPreparedStatement(sql, params, conn);
// 3.执行查询
ResultSet rs = ps.executeQuery();
// 4.处理结果
List<T> result = new ArrayList<>();
while (rs.next()) {
T obj = rowMapper.mapRow(rs);
result.add(obj);
}
// 5.释放资源
closeConnection(conn, ps, rs);
return result;
}
/**
* 内部类,实现了RowMapper接口,底层使用反射
*
* @param <R>
*/
private static class BeanHandler<R> implements RowMapper<R> {
// clazz表示最终封装的bean类型
private Class<R> clazz;
public BeanHandler(Class<R> clazz) {
this.clazz = clazz;
}
@Override
public R mapRow(ResultSet rs) {
try {
if (rs.next()) {
// 1.获取表数据
ResultSetMetaData metaData = rs.getMetaData();
// 2.反射创建bean
R bean = clazz.newInstance();
// 3.利用反射,把表数据设置到bean中
for (int i = 0; i < metaData.getColumnCount(); i++) {
String name = metaData.getColumnName(i + 1);
Object value = rs.getObject(name);
Field field = clazz.getDeclaredField(name);
field.setAccessible(true);
field.set(bean, value);
}
// 4.返回bean
return bean;
} else {
return null;
}
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}
private PreparedStatement getPreparedStatement(String sql, List<Object> params, Connection conn) throws SQLException {
// 1.传入sql模板,得到PreparedStatement
PreparedStatement ps = conn.prepareStatement(sql);
// 2.为sql模板设置参数
for (int i = 0; i < params.size(); i++) {
ps.setObject(i + 1, params.get(i));
}
return ps;
}
private Connection getConnection() throws SQLException {
// TODO 可以抽取配置到properties文件
String url = "jdbc:mysql://localhost:3306/demo";
String user = "root";
String password = "123456";
return DriverManager.getConnection(url, user, password);
}
private void closeConnection(Connection conn, PreparedStatement preparedStatement, ResultSet rs) throws SQLException {
if (rs != null) {
rs.close();
}
if (preparedStatement != null) {
preparedStatement.close();
}
if (conn != null) {
conn.close();
}
}
}
RowMapper.java
package com.example.demo.mybatisplus.core;
import java.sql.ResultSet;
/**
* 结果集映射器
*
* @author mx
*/
@FunctionalInterface
public interface RowMapper<T> {
/**
* 将结果集转为指定的Bean
*
* @param resultSet
* @return
*/
T mapRow(ResultSet resultSet);
}
query下的
QueryWrapper.java
package com.example.demo.mybatisplus.query;
import com.example.demo.mybatisplus.utils.ConditionFunction;
import com.example.demo.mybatisplus.utils.Reflections;
import java.util.HashMap;
import java.util.Map;
/**
* 模拟MyBatis-Plus的LambdaQueryWrapper(思路完全不同,仅仅是形似)
*
* @author mx
*/
public class QueryWrapper<T> {
// conditionMap,收集查询条件
// {
// " LIKE ": {
// "name": "bravo1988"
// },
// " = ": {
// "age": 18
// }
// }
private final Map<String, SqlParam> conditionMap = new HashMap<>();
// 操作符类型,比如 name like 'bravo' 中的 LIKE
private static final String OPERATOR_EQ = " = ";
private static final String OPERATOR_GT = " > ";
private static final String OPERATOR_LT = " < ";
private static final String OPERATOR_LIKE = " LIKE ";
public QueryWrapper<T> eq(ConditionFunction<T, ?> fn, Object value) {
String columnName = Reflections.fnToColumnName(fn);
conditionMap.put(OPERATOR_EQ, new SqlParam(columnName, value));
return this;
}
public QueryWrapper<T> gt(ConditionFunction<T, ?> fn, Object value) {
String columnName = Reflections.fnToColumnName(fn);
conditionMap.put(OPERATOR_GT, new SqlParam(columnName, value));
return this;
}
public QueryWrapper<T> lt(ConditionFunction<T, ?> fn, Object value) {
String columnName = Reflections.fnToColumnName(fn);
conditionMap.put(OPERATOR_LT, new SqlParam(columnName, value));
return this;
}
public QueryWrapper<T> like(ConditionFunction<T, ?> fn, Object value) {
String columnName = Reflections.fnToColumnName(fn);
conditionMap.put(OPERATOR_LIKE, new SqlParam(columnName, "%" + value + "%"));
return this;
}
public Map<String, SqlParam> build() {
return conditionMap;
}
}
SqlParam.java
package com.example.demo.mybatisplus.query;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
/**
* @author mx
*/
@Data
@NoArgsConstructor
@AllArgsConstructor
public class SqlParam {
private String columnName;
private Object value;
}
utils下的
ConditionFunction.java
package com.example.demo.mybatisplus.utils;
import java.io.Serializable;
import java.util.function.Function;
/**
* 扩展java.util.function包下的Function接口:支持Serializable
* 搭配Reflections工具类一起使用,用于获取Lambda表达式的方法名
*
* @author mx
*/
@FunctionalInterface
public interface ConditionFunction<T, R> extends Function<T, R>, Serializable {
}
Reflections.java
package com.example.demo.mybatisplus.utils;
import java.beans.Introspector;
import java.lang.invoke.SerializedLambda;
import java.lang.reflect.Method;
import java.util.regex.Pattern;
/**
* 获取Lambda入参的方法名
*
* @author mx
*/
public class Reflections {
private static final Pattern GET_PATTERN = Pattern.compile("^get[A-Z].*");
private static final Pattern IS_PATTERN = Pattern.compile("^is[A-Z].*");
/**
* 注意: 非标准变量(非小驼峰)调用这个方法可能会有问题
*
* @param fn
* @param <T>
* @return
*/
public static <T> String fnToColumnName(ConditionFunction<T, ?> fn) {
try {
Method method = fn.getClass().getDeclaredMethod("writeReplace");
method.setAccessible(Boolean.TRUE);
SerializedLambda serializedLambda = (SerializedLambda) method.invoke(fn);
String getter = serializedLambda.getImplMethodName();
// 对于非标准变量生成的Get方法这里可以直接抛出异常,或者打印异常日志
if (GET_PATTERN.matcher(getter).matches()) {
getter = getter.substring(3);
} else if (IS_PATTERN.matcher(getter).matches()) {
getter = getter.substring(2);
}
return Introspector.decapitalize(getter);
} catch (ReflectiveOperationException e) {
throw new RuntimeException(e);
}
}
}
其实第一篇的内容是最难的,不只是从0到1,而是从0到90,后面两篇其实只是90到100,在这基础稍微扩展了一下而已。
AbstractBaseMapper代码还有冗余,有兴趣的同学可以自行完善。但还是那句话,如果你的目的是为了锻炼封装能力,可以精益求精,但我们的AbstractBaseMapper注定不能用于生产,即使要优化,点到为止即可。
学习必须往深处挖,挖的越深,基础越扎实!