文章目录
-
- 主键生成策略源码
-
- KeyGenerateAlgorithm
- 源码入口
- 实现
- [扩展 自定义分布式主键生成策略](#扩展 自定义分布式主键生成策略)
- 分片算法
-
- ShardingAlgorithm
- 实现
- [扩展 自定义分片算法](#扩展 自定义分片算法)
- 踩的坑
主键生成策略源码
KeyGenerateAlgorithm
全限定类名org.apache.shardingsphere.sharding.spi.KeyGenerateAlgorithm
分布式主键生成算法,已知实现
配置标识 | 详细说明 | 全限定类名 |
---|---|---|
SNOWFLAKE | 基于雪花算法的分布式主键生成算法 | org.apache.shardingsphere.sharding.algorithm.keygen.SnowflakeKeyGenerateAlgorithm |
UUID | 基于 UUID 的分布式主键生成算法 | org.apache.shardingsphere.sharding.algorithm.keygen.UUIDKeyGenerateAlgorithm |
NANOID | 基于 NanoId 的分布式主键生成算法 | org.apache.shardingsphere.sharding.nanoid.algorithm.keygen.NanoIdKeyGenerateAlgorithm |
COSID | 基于 CosId 的分布式主键生成算法 | org.apache.shardingsphere.sharding.cosid.algorithm.keygen.CosIdKeyGenerateAlgorithm |
COSID_SNOWFLAKE | 基于 CosId 的雪花算法分布式主键生成算法 | org.apache.shardingsphere.sharding.cosid.algorithm.keygen.CosIdSnowflakeKeyGenerateAlgorithm |
源码入口
java
package org.apache.shardingsphere.sharding.factory;
/**
* Key generate algorithm factory.
*/
@NoArgsConstructor(access = AccessLevel.PRIVATE)
public final class KeyGenerateAlgorithmFactory {
//加载所有的主键生成策略
static {
ShardingSphereServiceLoader.register(KeyGenerateAlgorithm.class);
}
/**
* 根据配置的主键生成策略,获取一个主键生成算法
* 例如:spring.shardingsphere.rules.sharding.key-generators.usercourse_keygen.type=SNOWFLAKE
*/
public static KeyGenerateAlgorithm newInstance(final AlgorithmConfiguration keyGenerateAlgorithmConfig) {
return ShardingSphereAlgorithmFactory.createAlgorithm(keyGenerateAlgorithmConfig, KeyGenerateAlgorithm.class);
}
/**
* 判断是否包含配置的算法
*/
public static boolean contains(final String keyGenerateAlgorithmType) {
return TypedSPIRegistry.findRegisteredService(KeyGenerateAlgorithm.class, keyGenerateAlgorithmType).isPresent();
}
}
先来看主键生成策略是如何加载的:ShardingSphereServiceLoader.register(KeyGenerateAlgorithm.class);
java
public final class ShardingSphereServiceLoader {
//线程安全Map,缓存所有主键生成器
private static final Map<Class<?>, Collection<Object>> SERVICES = new ConcurrentHashMap<>();
// 进入到register()方法中
public static void register(final Class<?> serviceInterface) {
if (!SERVICES.containsKey(serviceInterface)) {
// 调用下方的load()方法
SERVICES.put(serviceInterface, load(serviceInterface));
}
}
//使用java的SPI机制加载接口的所有实现类
private static <T> Collection<Object> load(final Class<T> serviceInterface) {
Collection<Object> result = new LinkedList<>();
for (T each : ServiceLoader.load(serviceInterface)) {
result.add(each);
}
return result;
}
}
实现
ShardingJDBC是通过SPI机制,加载org.apache.shardingsphere.sharding.spi.KeyGenerateAlgorithm
接口的实现类,也就是上方表格中的内容
我们就可以直接在yml配置文件中进行配置分布式主键生成算法
接下来就以SNOWFLAKE雪花算法举例,下方就列举出了几个关键方法
java
// 实现了KeyGenerateAlgorithm接口
public final class SnowflakeKeyGenerateAlgorithm implements KeyGenerateAlgorithm, InstanceContextAware {
// 在init方法中,会把我们yml配置文件中定义的props的配置项,保存在下面方法的形参中,并赋值给props成员属性
// 其他地方再用props对象获取我们的配置项
@Override
public void init(final Properties props) {
this.props = props;
maxVibrationOffset = getMaxVibrationOffset(props);
maxTolerateTimeDifferenceMilliseconds = getMaxTolerateTimeDifferenceMilliseconds(props);
}
// 实现KeyGenerateAlgorithm接口中的抽象方法generateKey()
// 也就是在这个方法中具体生成分布式主键值的
@Override
public synchronized Long generateKey() {
long currentMilliseconds = timeService.getCurrentMillis();
if (waitTolerateTimeDifferenceIfNeed(currentMilliseconds)) {
currentMilliseconds = timeService.getCurrentMillis();
}
if (lastMilliseconds == currentMilliseconds) {
if (0L == (sequence = (sequence + 1) & SEQUENCE_MASK)) {
currentMilliseconds = waitUntilNextTime(currentMilliseconds);
}
} else {
vibrateSequenceOffset();
sequence = sequenceOffset;
}
lastMilliseconds = currentMilliseconds;
return ((currentMilliseconds - EPOCH) << TIMESTAMP_LEFT_SHIFT_BITS) | (getWorkerId() << WORKER_ID_LEFT_SHIFT_BITS) | sequence;
}
// getType() 方法中返回的字符串就是我们上方yml配置文件中type配置项填写的值
@Override
public String getType() {
return "SNOWFLAKE";
}
}
其他几个实现类也是一样的格式
扩展 自定义分布式主键生成策略
java
package com.hs.sharding.algorithm;
import com.google.common.base.Preconditions;
import org.apache.shardingsphere.infra.instance.InstanceContext;
import org.apache.shardingsphere.infra.instance.InstanceContextAware;
import org.apache.shardingsphere.sharding.algorithm.keygen.TimeService;
import org.apache.shardingsphere.sharding.spi.KeyGenerateAlgorithm;
import java.util.Calendar;
import java.util.Properties;
/**
* 改进雪花算法,让他能够 %4 均匀分布。
* @auth hs
*/
public final class MySnowFlakeAlgorithm implements KeyGenerateAlgorithm, InstanceContextAware {
public static final long EPOCH;
private static final String MAX_VIBRATION_OFFSET_KEY = "max-vibration-offset";
private static final String MAX_TOLERATE_TIME_DIFFERENCE_MILLISECONDS_KEY = "max-tolerate-time-difference-milliseconds";
private static final long SEQUENCE_BITS = 12L;
private static final long WORKER_ID_BITS = 10L;
private static final long SEQUENCE_MASK = (1 << SEQUENCE_BITS) - 1;
private static final long WORKER_ID_LEFT_SHIFT_BITS = SEQUENCE_BITS;
private static final long TIMESTAMP_LEFT_SHIFT_BITS = WORKER_ID_LEFT_SHIFT_BITS + WORKER_ID_BITS;
private static final int DEFAULT_VIBRATION_VALUE = 1;
private static final int MAX_TOLERATE_TIME_DIFFERENCE_MILLISECONDS = 10;
private static final long DEFAULT_WORKER_ID = 0;
private static TimeService timeService = new TimeService();
public static void setTimeService(TimeService timeService) {
MySnowFlakeAlgorithm.timeService = timeService;
}
private Properties props;
@Override
public Properties getProps() {
return props;
}
private int maxVibrationOffset;
private int maxTolerateTimeDifferenceMilliseconds;
private volatile int sequenceOffset = -1;
private volatile long sequence;
private volatile long lastMilliseconds;
private volatile InstanceContext instanceContext;
static {
Calendar calendar = Calendar.getInstance();
calendar.set(2016, Calendar.NOVEMBER, 1);
calendar.set(Calendar.HOUR_OF_DAY, 0);
calendar.set(Calendar.MINUTE, 0);
calendar.set(Calendar.SECOND, 0);
calendar.set(Calendar.MILLISECOND, 0);
EPOCH = calendar.getTimeInMillis();
}
@Override
public void init(final Properties props) {
this.props = props;
maxVibrationOffset = getMaxVibrationOffset(props);
maxTolerateTimeDifferenceMilliseconds = getMaxTolerateTimeDifferenceMilliseconds(props);
}
@Override
public void setInstanceContext(final InstanceContext instanceContext) {
this.instanceContext = instanceContext;
if (null != instanceContext) {
instanceContext.generateWorkerId(props);
}
}
private int getMaxVibrationOffset(final Properties props) {
int result = Integer.parseInt(props.getOrDefault(MAX_VIBRATION_OFFSET_KEY, DEFAULT_VIBRATION_VALUE).toString());
Preconditions.checkArgument(result >= 0 && result <= SEQUENCE_MASK, "Illegal max vibration offset.");
return result;
}
private int getMaxTolerateTimeDifferenceMilliseconds(final Properties props) {
return Integer.parseInt(props.getOrDefault(MAX_TOLERATE_TIME_DIFFERENCE_MILLISECONDS_KEY, MAX_TOLERATE_TIME_DIFFERENCE_MILLISECONDS).toString());
}
@Override
public synchronized Long generateKey() {
long currentMilliseconds = timeService.getCurrentMillis();
if (waitTolerateTimeDifferenceIfNeed(currentMilliseconds)) {
currentMilliseconds = timeService.getCurrentMillis();
}
if (lastMilliseconds == currentMilliseconds) {
// if (0L == (sequence = (sequence + 1) & SEQUENCE_MASK)) {
currentMilliseconds = waitUntilNextTime(currentMilliseconds);
// }
} else {
vibrateSequenceOffset();
// sequence = sequenceOffset;
sequence = sequence >= SEQUENCE_MASK ? 0:sequence+1;
}
lastMilliseconds = currentMilliseconds;
return ((currentMilliseconds - EPOCH) << TIMESTAMP_LEFT_SHIFT_BITS) | (getWorkerId() << WORKER_ID_LEFT_SHIFT_BITS) | sequence;
}
private boolean waitTolerateTimeDifferenceIfNeed(final long currentMilliseconds) {
if (lastMilliseconds <= currentMilliseconds) {
return false;
}
long timeDifferenceMilliseconds = lastMilliseconds - currentMilliseconds;
Preconditions.checkState(timeDifferenceMilliseconds < maxTolerateTimeDifferenceMilliseconds,
"Clock is moving backwards, last time is %d milliseconds, current time is %d milliseconds", lastMilliseconds, currentMilliseconds);
try {
Thread.sleep(timeDifferenceMilliseconds);
} catch (InterruptedException e) {
}
return true;
}
private long waitUntilNextTime(final long lastTime) {
long result = timeService.getCurrentMillis();
while (result <= lastTime) {
result = timeService.getCurrentMillis();
}
return result;
}
@SuppressWarnings("NonAtomicOperationOnVolatileField")
private void vibrateSequenceOffset() {
sequenceOffset = sequenceOffset >= maxVibrationOffset ? 0 : sequenceOffset + 1;
}
private long getWorkerId() {
return null == instanceContext ? DEFAULT_WORKER_ID : instanceContext.getWorkerId();
}
@Override
public String getType() {
return "MYSNOWFLAKE";
}
@Override
public boolean isDefault() {
return true;
}
}
使用spi机制加载我们上方定义的类
yml配置文件中使用我们自己定义的类
分片算法
ShardingAlgorithm
全限定类名org.apache.shardingsphere.sharding.spi.ShardingAlgorithm
分片算法,已知实现
配置标识 | 自动分片算法 | 详细说明 | 类名 |
---|---|---|---|
MOD | Y | 基于取模的分片算法 | ModShardingAlgorithm |
HASH_MOD | Y | 基于哈希取模的分片算法 | HashModShardingAlgorithm |
BOUNDARY_RANGE | Y | 基于分片边界的范围分片算法 | BoundaryBasedRangeShardingAlgorithm |
VOLUME_RANGE | Y | 基于分片容量的范围分片算法 | VolumeBasedRangeShardingAlgorithm |
AUTO_INTERVAL | Y | 基于可变时间范围的分片算法 | AutoIntervalShardingAlgorithm |
INTERVAL | N | 基于固定时间范围的分片算法 | IntervalShardingAlgorithm |
CLASS_BASED | N | 基于自定义类的分片算法 | ClassBasedShardingAlgorithm |
INLINE | N | 基于行表达式的分片算法 | InlineShardingAlgorithm |
COMPLEX_INLINE | N | 基于行表达式的复合分片算法 | ComplexInlineShardingAlgorithm |
HINT_INLINE | N | 基于行表达式的 Hint 分片算法 | HintInlineShardingAlgorithm |
COSID_MOD | N | 基于 CosId 的取模分片算法 | CosIdModShardingAlgorithm |
COSID_INTERVAL | N | 基于 CosId 的固定时间范围的分片算法 | CosIdIntervalShardingAlgorithm |
COSID_INTERVAL_SNOWFLAKE | N | 基于 CosId 的雪花ID固定时间范围的分片算法 | CosIdSnowflakeIntervalShardingAlgorithm |
实现
这里就拿CLASS_BASED自定义分片策略来举例。我们之前的配置项如下所示。
这里就有一个问题,props的值我怎么知道写什么,我又怎么知道我自定义的类需要实现什么接口?
我们现在进入到CLASS_BASED分片算法的实现类中ClassBasedShardingAlgorithm
去看看它的源码
java
public final class ClassBasedShardingAlgorithm implements StandardShardingAlgorithm<Comparable<?>>, ComplexKeysShardingAlgorithm<Comparable<?>>, HintShardingAlgorithm<Comparable<?>> {
// 定义两个常量,我们会发现这里就是props中我们进行配置的值
private static final String STRATEGY_KEY = "strategy";
private static final String ALGORITHM_CLASS_NAME_KEY = "algorithmClassName";
@Getter
private Properties props;
private ClassBasedShardingAlgorithmStrategyType strategy;
private String algorithmClassName;
private StandardShardingAlgorithm standardShardingAlgorithm;
private ComplexKeysShardingAlgorithm complexKeysShardingAlgorithm;
private HintShardingAlgorithm hintShardingAlgorithm;
// init()方法中会获取到props对象,props对象中保存了我们yml配置文件中的配置内容
// 这里就会取出来,赋值给 strategy 和 algorithmClassName 成员属性
@Override
public void init(final Properties props) {
this.props = props;
strategy = getStrategy(props);
algorithmClassName = getAlgorithmClassName(props);
initAlgorithmInstance(props);
}
private ClassBasedShardingAlgorithmStrategyType getStrategy(final Properties props) {
String strategy = props.getProperty(STRATEGY_KEY);
Preconditions.checkNotNull(strategy, "Properties `%s` can not be null when uses class based sharding strategy.", STRATEGY_KEY);
return ClassBasedShardingAlgorithmStrategyType.valueOf(strategy.toUpperCase().trim());
}
private String getAlgorithmClassName(final Properties props) {
String result = props.getProperty(ALGORITHM_CLASS_NAME_KEY);
Preconditions.checkNotNull(result, "Properties `%s` can not be null when uses class based sharding strategy.", ALGORITHM_CLASS_NAME_KEY);
return result;
}
// 这里就会判断 strategy 属性是哪一个 STANDARD、COMPLEX、HINT
// 然后在进行具体的实例 StandardShardingAlgorithm、ComplexKeysShardingAlgorithm、HintShardingAlgorithm
private void initAlgorithmInstance(final Properties props) {
switch (strategy) {
case STANDARD:
standardShardingAlgorithm = ClassBasedShardingAlgorithmFactory.newInstance(algorithmClassName, StandardShardingAlgorithm.class, props);
break;
case COMPLEX:
complexKeysShardingAlgorithm = ClassBasedShardingAlgorithmFactory.newInstance(algorithmClassName, ComplexKeysShardingAlgorithm.class, props);
break;
case HINT:
hintShardingAlgorithm = ClassBasedShardingAlgorithmFactory.newInstance(algorithmClassName, HintShardingAlgorithm.class, props);
break;
default:
break;
}
}
// doSharding()方法,具体的分片算法逻辑
@SuppressWarnings("unchecked")
@Override
public String doSharding(final Collection<String> availableTargetNames, final PreciseShardingValue<Comparable<?>> shardingValue) {
return standardShardingAlgorithm.doSharding(availableTargetNames, shardingValue);
}
@SuppressWarnings("unchecked")
@Override
public Collection<String> doSharding(final Collection<String> availableTargetNames, final RangeShardingValue<Comparable<?>> shardingValue) {
return standardShardingAlgorithm.doSharding(availableTargetNames, shardingValue);
}
@SuppressWarnings("unchecked")
@Override
public Collection<String> doSharding(final Collection<String> availableTargetNames, final ComplexKeysShardingValue<Comparable<?>> shardingValue) {
return complexKeysShardingAlgorithm.doSharding(availableTargetNames, shardingValue);
}
@SuppressWarnings("unchecked")
@Override
public Collection<String> doSharding(final Collection<String> availableTargetNames, final HintShardingValue<Comparable<?>> shardingValue) {
return hintShardingAlgorithm.doSharding(availableTargetNames, shardingValue);
}
// 返回trye为 CLASS_BASED 这里也就是和yml配置文件中的type对应上了
@Override
public String getType() {
return "CLASS_BASED";
}
}
其他的分片算法也是类似的实现
扩展 自定义分片算法
自定义一个java类,实现ShardingAlgorithm接口,或者是它的子接口 StandardShardingAlgorithm、ComplexKeysShardingAlgorithm、HintShardingAlgorithm都行,重写其中的doSharding()
方法,我们自己指定分片逻辑
重写getType()
方法,返回一个字符串,能够让我们在yml配置文件中进行配置
java
@Override
public String getType() {
return "MY_COMPLEX_ALGORITHM";
}
例如我现在自定义的分片类如下
java
package com.hs.sharding.algorithm;
import com.google.common.base.Preconditions;
import com.google.common.collect.Range;
import org.apache.shardingsphere.sharding.api.sharding.standard.PreciseShardingValue;
import org.apache.shardingsphere.sharding.api.sharding.standard.RangeShardingValue;
import org.apache.shardingsphere.sharding.api.sharding.standard.StandardShardingAlgorithm;
import java.util.*;
/**
* 自定义分片策略 , 我们这里实现标准的分片算法接口StandardShardingAlgorithm
* 我这里是分片逻辑就是按照个数取模,在分发到sys_user1 sys_user2数据表中
*/
public class HsComplexAlgorithm implements StandardShardingAlgorithm<Long> {
/**
* 数据库个数
*/
private final String DB_COUNT = "db-count";
/**
* 数据表个数
*/
private final String TAB_COUNT = "tab-count";
/**
* 真实数据表前缀
*/
private final String PERTAB = "pertab";
private Integer dbCount;
private Integer tabCount;
private String pertab;
private Properties props;
@Override
public void init(Properties props) {
this.props = props;
this.dbCount = getDbCount(props);
this.tabCount = getTabCount(props);
this.pertab = getPertab(props);
// 校验条件
Preconditions.checkState(null != pertab && !pertab.isEmpty(),
"Inline hsComplex algorithm expression cannot be null or empty.");
}
/**
* 精确查询分片执行接口(对应的sql是where ??=值)
* @param collection 可用的分片名集合(分库就是库名,分表就是表名)
* @param preciseShardingValue 分片键
*/
@Override
public String doSharding(Collection<String> collection, PreciseShardingValue<Long> preciseShardingValue) {
Long uid = preciseShardingValue.getValue();
String resultTableName = pertab + ((uid + 1) % (dbCount * tabCount) / tabCount + 1);
if (collection.contains(resultTableName)){
return resultTableName;
}
throw new UnsupportedOperationException("route: " + resultTableName + " is not supported, please check your config");
}
/**
* 范围分片规则(对应的是where ??>='XXX' and ??<='XXX')
* 范围查询分片算法(分片键涉及区间查询时会进入该方法进行分片计算)
*/
@Override
public Collection<String> doSharding(Collection<String> collection, RangeShardingValue<Long> rangeShardingValue) {
List<String> result = new ArrayList<>();
Range<Long> valueRange = rangeShardingValue.getValueRange();
Long upperEndpoint = valueRange.upperEndpoint();
Long aLong = valueRange.lowerEndpoint();
// TODO 进行相应的分片判断
// return result;
return collection;
}
private String getPertab(Properties props) {
return props.getProperty(PERTAB);
}
private Integer getDbCount(Properties props) {
String count = props.getProperty(DB_COUNT);
return count == null || count.isEmpty() ? 0 : Integer.valueOf(count);
}
private Integer getTabCount(Properties props) {
String count = props.getProperty(TAB_COUNT);
return count == null || count.isEmpty() ? 0 : Integer.valueOf(count);
}
@Override
public Properties getProps() {
return props;
}
@Override
public String getType() {
return "HS";
}
}
需要添加一个SPI的配置文件org.apache.shardingsphere.sharding.spi.ShardingAlgorithm
,在该文件中指定我们上方创建的java类
yml配置文件中进行相应的更改
踩的坑
我先是自定义的类实现的是ComplexKeysShardingAlgorithm接口,但是我们yml配置类中还是一直按照standard
的配置,导致我自定义的类中的doSharding()
方法所以就一直没有调用到
之后我修改了complex就能调用了
在配置还是standard时,我通过debug,发现init()
和 getType()
方法都能够调用,证明API机制相关的文件没问题。
我就想会不会是单分片键、精确查询、范围查询相关问题导致的?
我修改了实现接口,改为了StandardShardingAlgorithm,然后就进入了其中单分片键的doSharding()
方法。最后就一点一点的排查,再到了这上面的配置