协程的协作等待:替代 CountDownLatch
在 Java 线程中,当一个线程【A】需要等待多个其他线程【B】的完成才能继续时,我们会使用 CountDownLatch(Latch 的意思是门闩)。
使用方式:首先需要初始化它的计数值,在线程【A】中调用 await(),然后在每个线程【B】完成时调用 countdown() 即可。
只有当调用次数达到预设值时,await() 才会返回。这就实现了只有所有线程【B】执行完毕后,线程【A】才能接着执行的效果。
kotlin
fun main() {
val count = 2
val latch = CountDownLatch(count)
val b1 = thread {
Thread.sleep(1000)
latch.countDown()
}
val b2 = thread {
Thread.sleep(3000)
latch.countDown()
}
val a = thread {
latch.await()
println("Thread A finished")
}
// 等待线程 A 完成
a.join()
}
Job.join() 与 Deferred.await()
在协程中,我们当然会想到 Job.join() 与 Deferred.await(),这两个函数都可用于等待协程的完成。
kotlin
fun main(): Unit = runBlocking {
// 协程的等待
val job1 = launch {
delay(2000)
}
val job2 = launch {
delay(3000)
}
val job3 = launch {
// 等待 job1 和 job2 结束
job1.join()
job2.join()
delay(1000)
}
job3.join()
// 协程的等待
val deferred1 = async {
delay(500)
return@async "async1"
}
val deferred2 = async {
delay(3500)
return@async "async2"
}
launch {
// 获取 deferred1 和 deferred2 的结果
val result1 = deferred1.await()
val result2 = deferred2.await()
delay(300)
println("launch4 result: $result1 $result2")
}
}
Job.join() 的写法在线程中非常少见,因为线程的管理成本很高,且没有协程中的结构化取消。
Channel
在形式上,只关心完成的次数,Channel 与 CountDownLatch 更接近:
kotlin
fun main(): Unit = runBlocking {
val count = 2
val channel = Channel<Unit>(capacity = count)
launch { // 协程A
repeat(count) {
channel.receive() // 挂起等待 count 次
}
delay(300)
println("A Done")
}
launch { // 协程B1
delay(1000)
println("B1 Done")
channel.send(Unit)
}
launch { // 协程B2
delay(2000)
println("B2 Done")
channel.send(Unit)
}
}
select 表达式:先到先得
对于多个任务,只想获得最快完成的结果时,就可以使用 select 表达式。
比如说我们可以给 Job 设置 onJoin 回调:
kotlin
fun main(): Unit = runBlocking {
val job1 = launch {
delay(1000)
println("job1 done")
}
val job2 = launch {
delay(2000)
println("job2 done")
}
val result = select {
job1.onJoin {
"job1"
}
job2.onJoin {
"job2"
}
}
println("The first completed coroutine is $result")
}
这个回调会监听 Job 的状态,在 Job 结束(无论是正常完成还是被取消)后,会执行回调的代码块,并将这个代码块的返回值作为 select 表达式的返回值。
另外,Deferred 也有类似的 onAwait,它的回调的参数是 async 函数的返回值。
Channel 中有着 onSend、onReceive、onReceiveCatching,分别代表着 Channel 发送一条数据并发送成功、Channel 成功接收到一条数据、Channel 通道被关闭或成功收到一条数据。
kotlin
val channel1 = Channel<String>()
val channel2 = Channel<String>()
val channel3 = Channel<String>()
select {
channel1.onSend("message") { sendChannel: SendChannel<String> ->
println("The message was successfully sent")
}
channel2.onReceive { message: String ->
println("The message was successfully received")
}
channel3.onReceiveCatching { result: ChannelResult<String> ->
println("The message was successfully received")
}
}
注意: 在同一个
select中,同一个对象最多只能设置一个监听。
最后,我们可以用 onTimeout 设置一个总超时的监听,以防无限等待。
kotlin
import kotlinx.coroutines.selects.onTimeout
select {
onTimeout(5.seconds) {
println("Timeout")
}
}
共享变量与互斥锁:Mutex 与 synchronized
如果对 Java 多线程不了解,可以看我的这篇博客:Java 多线程指南:从基础用法到线程安全
竞争条件 (Race Condition),又称竞态条件,指的是多个线程访问或修改共享资源时,由于执行顺序的不确定性,导致程序的行为不一致或出现错误。
比如运行这段代码,发现 count 最终的值竟然不是零。
kotlin
fun main(): Unit = runBlocking {
var count = 0
val scope = CoroutineScope(Dispatchers.Default) // 使用多线程调度器
val job1 = scope.launch {
repeat(100_000_000) {
count++
}
}
val job2 = scope.launch {
repeat(100_000_000) {
count--
}
}
job1.join()
job2.join()
println("the final count is $count")
}
这时,我们可以使用 Java 的 synchronized 关键字(在 Kotlin 中是一个函数),这样临界区将会被多个线程所互斥。
kotlin
fun main(): Unit = runBlocking {
var count = 0
val lock = Any()
val scope = CoroutineScope(Dispatchers.Default) // 使用多线程调度器
val job1 = scope.launch {
repeat(1_000_000) {
synchronized(lock = lock) {
count++
}
}
}
val job2 = scope.launch {
repeat(1_000_000) {
synchronized(lock = lock) {
count--
}
}
}
job1.join()
job2.join()
println("the final count is $count")
}
虽然可以 在协程用 synchronized,因为 Kotlin 协程的本质是线程,最终协程的代码还是运行在线程中的。synchronized 会锁定当前的线程,自然也会卡住在该线程上运行的协程。
但并不推荐使用。因为你使用 synchronized,协程在等待锁时,会阻塞整个线程,而不是让出当前线程,导致了资源浪费。
使用 Mutex 互斥锁
在协程中,我们会使用 Mutex 互斥锁,相当于 Java 中的 Lock。
kotlin
fun main(): Unit = runBlocking {
var count = 0
val mutex = Mutex()
val scope = CoroutineScope(Dispatchers.Default) // 使用多线程调度器
val job1 = scope.launch {
repeat(1_000_000) {
mutex.lock() // 挂起函数
try {
count++
} catch (e: Exception) {
e.printStackTrace()
} finally {
mutex.unlock() // 保证锁的释放
}
}
}
val job2 = scope.launch {
repeat(1_000_000) {
mutex.lock() // 挂起函数
try {
count--
} catch (e: Exception) {
e.printStackTrace()
} finally {
mutex.unlock() // 保证锁的释放
}
}
}
job1.join()
job2.join()
println("the final count is $count")
}
也可以使用便捷函数 withLock,它会自动上锁和解锁。
Mutex 与 synchronized 的区别在于:当锁被占用时,synchronized 会阻塞 线程;而 mutex.lock() 会挂起协程,释放所占用的线程,让线程可以去干别的事。
锁的选择
实际中,该如何选择?
-
在纯协程环境中,永远使用
Mutex,因为不卡线程。 -
在纯线程环境中,就使用
synchronized或ReentrantLock。 -
如果共享变量既会被协程访问,又会被线程访问,就统一使用
synchronized或ReentrantLock,因为mutex.lock()是挂起函数,无法在线程中使用。
协程中的 ThreadLocal 陷阱与解决方案
我们都知道 ThreadLocal Java 中线程的"局部变量"。
如果在协程中直接使用 ThreadLocal,可能会导致数据丢失等问题:
kotlin
val myThreadLocal = ThreadLocal<String>()
fun main(): Unit = runBlocking {
val scope = CoroutineScope(Dispatchers.Default)
val job = scope.launch {
myThreadLocal.set("hello")
println("Before: ${myThreadLocal.get()} on ${Thread.currentThread().name}")
delay(300)
// 可能切换到了另一个线程上
println("After: ${myThreadLocal.get()} on ${Thread.currentThread().name}")
}
// 尽可能抢占线程
repeat(30) {
scope.launch {
delay(500)
}
}
job.join()
}
运行结果:
vbnet
Before: hello on DefaultDispatcher-worker-1
After: null on DefaultDispatcher-worker-19
所以,我们不能直接使用 ThreadLocal。
在协程中的局部变量机制是 CoroutineContext,所以当要和 ThreadLocal 交互时,可以使用 asContextElement。
kotlin
val myThreadLocal = ThreadLocal<String>()
fun main(): Unit = runBlocking {
val scope = CoroutineScope(Dispatchers.Default)
// 使用 asContextElement 将 ThreadLocal 包装成 CoroutineContext.Element
val job = scope.launch(myThreadLocal.asContextElement(value = "hello")) {
println("Before: ${myThreadLocal.get()} on ${Thread.currentThread().name}")
delay(1000)
// 无论切换到哪个线程,ThreadLocal 都会保持正确的值
println("After: ${myThreadLocal.get()} on ${Thread.currentThread().name}")
}
// 尽可能抢占线程
repeat(30) {
scope.launch {
delay(500)
}
}
job.join()
}
运行结果:
vbnet
Before: hello on DefaultDispatcher-worker-1
After: hello on DefaultDispatcher-worker-17
其实 asContextElement 只是保证了协程所运行的线程会自动同步 ThreadLocal 的值,每当协程进行线程切换时,它会自动设置和恢复 ThreadLocal 的值。