【并发编程】CountDownLatch 源码分析

喜欢ヅ旅行 2022-03-20 14:16 391阅读 0赞

前言

Github:https://github.com/yihonglei/jdk-source-code-reading(java-concurrent)

一 CountDownLatch 概述

1、介绍

CountDownLatch (同步工具类) 允许一个或多个线程等待其他线程完成操作后才继续执行。

20190213184542781.jpg

使用 CountDownLatch 时,需要指定一个整数值 N,此值是线程将要等待的操作数。当线程 M 为了要等待操作 A 完成时,

线程 M 需要调用 await() 方法。await() 方法让线程 M 进入等待状态直到所有操作 A 完成为止,M 才被唤醒继续执行。

当操作 A 执行完成(每一个处理),调用 countDown() 方法来减少 CountDownLatch 类的内部计数器,N 每次减少 1。

当内部计数器递减为 0 时,CountDownLatch 会唤醒所有调用 await() 方法挂起的线程,即会唤醒 M,

从而实现 M 等待操作 A 执行完成后再继续执行 M 操作的效果。

2、原理

CountDownLatch 的构造函数接收一个 int 类型的参数作为计数器构造参数,如果你想等待 N 个点完成,这里就传入 N。

当我们调用 CountDownLatch 的 countDown() 方法时,N 就会减 1,CountDownLatch 的 await() 方法会阻塞当前线程,

直到 N 变成零被唤醒继续执行。由于 countDown() 方法可以用在任何地方,所以这里说的 N 个点,可以是 N 个线程,

也可以是 1 个线程里的 N 个执行步骤。用在多个线程时,只需要把这个 CountDownLatch 的引用传递到线程里即可。

3、核心方法

countDown():用于减少计数器次数,每调用一次就默认会减少 1,当锁释放完时,将等待线程唤醒。

await():负责线程的阻塞,当 CountDownLatch 计数的值为 0 时,获取到锁,才返回主线程执行。

4、典型场景

CountDownLatch 使用场景主要用于控制主线程等待所有子线程全部执行完成然后恢复主线程执行。

二 CountDownLatch 实例

1、实例场景

我们需要批量的从数据库查询出数据进行处理。一般会想到用多线程去处理,但是,有一个问题就是我们如何保证每一次查询的

数据不是正在处理的数据?方法有很多种,可以在每一批数据处理完之后再去数据库取下一批数据,每一批数据采取多线程处理的方式。

我们也可以采用别的方案,这里只针对使用 CountDownLatch 来实现批量处理。CountDownLatch 控制主线程必须等待线程池子线程

执行完才恢复执行主线程。

2、实例代码

  1. package com.jpeony.concurrent.countdownlatch;
  2. import java.util.ArrayList;
  3. import java.util.List;
  4. import java.util.concurrent.CompletableFuture;
  5. import java.util.concurrent.CountDownLatch;
  6. import java.util.concurrent.ExecutorService;
  7. import java.util.concurrent.Executors;
  8. /**
  9. * 多线程+CountDownLatch演示
  10. *
  11. * @author yihonglei
  12. */
  13. public class CountDownLatchTest {
  14. // 线程池
  15. private static ExecutorService executorService = Executors.newFixedThreadPool(10);
  16. public static void main(String[] args) {
  17. int counterBatch = 1;
  18. try {
  19. // 数据循环处理
  20. while (true) {
  21. // 模拟数据库查询出的List
  22. List<String> list = new ArrayList<>();
  23. for (int i = 0; i < 10; i++) {
  24. list.add("user" + i);
  25. }
  26. // 计数器大小定义为集合大小,避免处理不一致导致主线程无限等待
  27. CountDownLatch countDownLatch = new CountDownLatch(list.size());
  28. // 循环处理List
  29. list.parallelStream().forEach(userId -> {
  30. // 任务提交线程池
  31. CompletableFuture.supplyAsync(() -> {
  32. try {
  33. // 处理用户数据
  34. dealUser(userId);
  35. } finally {
  36. countDownLatch.countDown();
  37. }
  38. return 1;
  39. }, executorService);
  40. });
  41. // 主线程等待所有子线程都执行完成时,恢复执行主线程
  42. countDownLatch.await();
  43. System.out.println("========================恢复主线程执行==========================");
  44. // 数据批次计数器
  45. counterBatch++;
  46. // 模拟执行5批
  47. if (counterBatch > 5) {
  48. break;
  49. }
  50. }
  51. System.out.println("循环退出,程序执行完成,counterBatch=" + counterBatch);
  52. // 关闭线程池
  53. executorService.shutdown();
  54. } catch (Exception e) {
  55. System.out.println("异常日志");
  56. }
  57. }
  58. /**
  59. * 模拟根据用户Id处理用户数据的逻辑
  60. */
  61. public static void dealUser(String userId) {
  62. System.out.println("ThreadName:" + Thread.currentThread().getName() + ", userId:" + userId + " 处理完成!");
  63. }
  64. }

运行结果:

  1. ThreadName:pool-1-thread-3, userId:user4 处理完成!
  2. ThreadName:pool-1-thread-7, userId:user9 处理完成!
  3. ThreadName:pool-1-thread-2, userId:user7 处理完成!
  4. ThreadName:pool-1-thread-9, userId:user3 处理完成!
  5. ThreadName:pool-1-thread-5, userId:user2 处理完成!
  6. ThreadName:pool-1-thread-4, userId:user8 处理完成!
  7. ThreadName:pool-1-thread-1, userId:user0 处理完成!
  8. ThreadName:pool-1-thread-8, userId:user6 处理完成!
  9. ThreadName:pool-1-thread-10, userId:user5 处理完成!
  10. ThreadName:pool-1-thread-6, userId:user1 处理完成!
  11. ========================恢复主线程执行==========================
  12. ThreadName:pool-1-thread-3, userId:user8 处理完成!
  13. ThreadName:pool-1-thread-7, userId:user1 处理完成!
  14. ThreadName:pool-1-thread-9, userId:user9 处理完成!
  15. ThreadName:pool-1-thread-2, userId:user4 处理完成!
  16. ThreadName:pool-1-thread-7, userId:user6 处理完成!
  17. ThreadName:pool-1-thread-9, userId:user0 处理完成!
  18. ThreadName:pool-1-thread-4, userId:user3 处理完成!
  19. ThreadName:pool-1-thread-5, userId:user2 处理完成!
  20. ThreadName:pool-1-thread-3, userId:user7 处理完成!
  21. ThreadName:pool-1-thread-1, userId:user5 处理完成!
  22. ========================恢复主线程执行==========================
  23. ThreadName:pool-1-thread-2, userId:user5 处理完成!
  24. ThreadName:pool-1-thread-7, userId:user8 处理完成!
  25. ThreadName:pool-1-thread-10, userId:user0 处理完成!
  26. ThreadName:pool-1-thread-4, userId:user1 处理完成!
  27. ThreadName:pool-1-thread-8, userId:user2 处理完成!
  28. ThreadName:pool-1-thread-5, userId:user7 处理完成!
  29. ThreadName:pool-1-thread-9, userId:user6 处理完成!
  30. ThreadName:pool-1-thread-6, userId:user4 处理完成!
  31. ThreadName:pool-1-thread-2, userId:user9 处理完成!
  32. ThreadName:pool-1-thread-7, userId:user3 处理完成!
  33. ========================恢复主线程执行==========================
  34. ThreadName:pool-1-thread-3, userId:user1 处理完成!
  35. ThreadName:pool-1-thread-1, userId:user8 处理完成!
  36. ThreadName:pool-1-thread-8, userId:user2 处理完成!
  37. ThreadName:pool-1-thread-9, userId:user6 处理完成!
  38. ThreadName:pool-1-thread-2, userId:user3 处理完成!
  39. ThreadName:pool-1-thread-1, userId:user4 处理完成!
  40. ThreadName:pool-1-thread-5, userId:user0 处理完成!
  41. ThreadName:pool-1-thread-10, userId:user5 处理完成!
  42. ThreadName:pool-1-thread-3, userId:user9 处理完成!
  43. ThreadName:pool-1-thread-4, userId:user7 处理完成!
  44. ========================恢复主线程执行==========================
  45. ThreadName:pool-1-thread-6, userId:user0 处理完成!
  46. ThreadName:pool-1-thread-7, userId:user8 处理完成!
  47. ThreadName:pool-1-thread-2, userId:user3 处理完成!
  48. ThreadName:pool-1-thread-5, userId:user5 处理完成!
  49. ThreadName:pool-1-thread-8, userId:user1 处理完成!
  50. ThreadName:pool-1-thread-10, userId:user6 处理完成!
  51. ThreadName:pool-1-thread-1, userId:user7 处理完成!
  52. ThreadName:pool-1-thread-7, userId:user2 处理完成!
  53. ThreadName:pool-1-thread-6, userId:user4 处理完成!
  54. ThreadName:pool-1-thread-9, userId:user9 处理完成!
  55. ========================恢复主线程执行==========================
  56. 循环退出,程序执行完成,counterBatch=6

程序分析:

1)模拟从数据库每一次取出一批数据,每批数据为 10 条;

2)CountDownLatch 计数器大小设定与数据条数相同,这里就为 10;

3)然后循环 List,每一条数据创建一个线程,然后提交线程池,每一个线程处理完要调 countDown(),每次减 1。

4)主线程也就是这里的 main 线程,调用了 await() 方法,await() 方法表示等待线程池的线程执行完成,恢复主线程执行,

即 CountDownLatch 计数器为 0 时恢复主线程,进行下一次的循环取批数据处理。

从而我们可以实现每一批数据取出后,交由线程池多线程处理,并且主线程会等待子线程都执行完成,

然后才恢复执行,进行下一次的循环取批处理,就不会出现取批次时取到正在处理的数据。

三 CountDownLatch 源码分析(jdk8)

1、CountDownLatch(int count) 构造函数

  1. public CountDownLatch(int count) {
  2. if (count < 0) throw new IllegalArgumentException("count < 0");
  3. this.sync = new Sync(count);
  4. }

先从构造函数看起,传入 int 的 count,对 count 进行校验,然后 new Sync(count)。

  1. Sync(int count) {
  2. setState(count);
  3. }

Sync 为 AQS 的子类,在构造函数里面,通过 setState 设置 state 的值为 count,state 为 volatile 变量,保证多线程可见性。

2、CountDownLatch#await()

  1. public void await() throws InterruptedException {
  2. sync.acquireSharedInterruptibly(1);
  3. }

调用 CountDownLatch 内部类 Sync 父类 AbstractQueuedSynchronizer 的模板方法 acquireSharedInterruptibly() 尝试获取共享锁。

  1. public final void acquireSharedInterruptibly(int arg)
  2. throws InterruptedException {
  3. // 判断线程是否中断
  4. if (Thread.interrupted())
  5. throw new InterruptedException();
  6. // 尝试获取共享锁
  7. if (tryAcquireShared(arg) < 0)
  8. // 添加到阻塞队列,挂起线程,等待唤醒获取锁
  9. doAcquireSharedInterruptibly(arg);
  10. }

第1步:Thread.interrupted() 判断线程是否中断,中断则抛出线程中断异常;

第2步:tryAcquireShared(arg) 方法尝试获取共享锁,当 state 为 0 时,返回 1 才能获取锁,主线程会继续执行

否则返回 -1,获取不到锁,则调用 await 的线程(主线程)通过 doAcquireSharedInterruptibly(arg)方 法进行阻塞操作;

这里可以结合实例理解为 main 主线程被阻塞,那么主线程在哪里被唤醒的?在 countDown() 方法里进行主线程唤醒。

  1. protected int tryAcquireShared(int acquires) {
  2. return (getState() == 0) ? 1 : -1;
  3. }

第3步:doAcquireSharedInterruptibly(arg) 如何阻塞主线程?

  1. /**
  2. * Acquires in shared interruptible mode.
  3. * @param arg the acquire argument
  4. */
  5. private void doAcquireSharedInterruptibly(int arg)
  6. throws InterruptedException {
  7. // 基于共享模式创建节点
  8. final Node node = addWaiter(Node.SHARED);
  9. boolean failed = true;
  10. try {
  11. for (;;) {
  12. // 获取当前节点的前驱节点
  13. final Node p = node.predecessor();
  14. if (p == head) {
  15. // 尝试获取共享锁
  16. int r = tryAcquireShared(arg);
  17. if (r >= 0) {
  18. setHeadAndPropagate(node, r);
  19. p.next = null; // help GC
  20. failed = false;
  21. return;
  22. }
  23. }
  24. // 获取锁失败后暂停线程
  25. if (shouldParkAfterFailedAcquire(p, node) &&
  26. parkAndCheckInterrupt())
  27. throw new InterruptedException();
  28. }
  29. } finally {
  30. // 如果获取锁失败,线程已经被暂停了,取消尝试获取锁的操作
  31. if (failed)
  32. cancelAcquire(node);
  33. }
  34. }

addWaiter(Node mode):初始化队里,并基于当前线程构建节点添加到队列尾部。

  1. private Node addWaiter(Node mode) {
  2. // 基于当前线程构建共享模式 Node
  3. Node node = new Node(Thread.currentThread(), mode);
  4. // Try the fast path of enq; backup to full enq on failure
  5. // 先尝试通过 compareAndSetTail 快速添加队列节点,不行再通过 enq 入队。
  6. Node pred = tail;
  7. // 添加第一个队列节点时,尾节点是空的,不会走快速添加,之后才会走CAS快速添加
  8. if (pred != null) {
  9. node.prev = pred;
  10. if (compareAndSetTail(pred, node)) {
  11. pred.next = node;
  12. return node;
  13. }
  14. }
  15. // 第一次添加节点,走这里,完成队列的初始化和元素的添加
  16. enq(node);
  17. return node;
  18. }

enq(final Node node):初始化队列并添加当前线程构建的节点到队尾。

  1. private Node enq(final Node node) {
  2. for (;;) {
  3. // 获取尾节点
  4. Node t = tail;
  5. // 第一次循环,t是null,会进入if判断,compareAndSetHead设置new Node()到队列,
  6. // 这个时候队列只有一个节点,就是头结点,也是尾节点
  7. if (t == null) { // Must initialize
  8. if (compareAndSetHead(new Node()))
  9. tail = head;
  10. } else {// 节点插入队尾
  11. // 第二次循环时,当前节点的前驱节点
  12. node.prev = t;
  13. // 节点添加到队尾
  14. if (compareAndSetTail(t, node)) {
  15. // t的下一个节点指向node,跟头结点建立引用,形成链表
  16. t.next = node;
  17. // 返回t(这个时候队列的头结点是new Node(),尾节点是我们传进来的node,队列里只有两个节点)
  18. return t;
  19. }
  20. }
  21. }
  22. }

shouldParkAfterFailedAcquire(Node pred, Node node):设置节点的状态为等待唤醒状态。

  1. private static boolean shouldParkAfterFailedAcquire(Node pred, Node node) {
  2. int ws = pred.waitStatus;
  3. if (ws == Node.SIGNAL)
  4. /*
  5. * This node has already set status asking a release
  6. * to signal it, so it can safely park.
  7. */
  8. return true;
  9. if (ws > 0) {
  10. /*
  11. * Predecessor was cancelled. Skip over predecessors and
  12. * indicate retry.
  13. */
  14. do {
  15. node.prev = pred = pred.prev;
  16. } while (pred.waitStatus > 0);
  17. pred.next = node;
  18. } else {
  19. /*
  20. * waitStatus must be 0 or PROPAGATE. Indicate that we
  21. * need a signal, but don't park yet. Caller will need to
  22. * retry to make sure it cannot acquire before parking.
  23. */
  24. compareAndSetWaitStatus(pred, ws, Node.SIGNAL);
  25. }
  26. return false;
  27. }

boolean parkAndCheckInterrupt():调用 LockSupport.park 暂停当前线程,并返回线程是否中断的状态。

  1. private final boolean parkAndCheckInterrupt() {
  2. // 暂停当前的线程
  3. LockSupport.park(this);
  4. // 获取线程是否中断的状态
  5. return Thread.interrupted();
  6. }

调用 await() 方法的现在在这里被暂停的,后期通过 countDown() 里面的逻辑进行唤醒。

3、CountDownLatch#countDown()

调用 countDown() 方法,每调用一次 state 就会减1。

  1. public void countDown() {
  2. sync.releaseShared(1);
  3. }

调用 CountDownLatch 内部类 Sync 的 releaseShared() 方法,arg 传值为 1。

  1. public final boolean releaseShared(int arg) {
  2. // 1
  3. if (tryReleaseShared(arg)) {
  4. // 2
  5. doReleaseShared();
  6. return true;
  7. }
  8. return false;
  9. }

第1步:执行tryReleaseShared(arg)方法,返回 true 或 false,尝试去释放共享锁。

  1. protected boolean tryReleaseShared(int releases) {
  2. // Decrement count; signal when transition to zero
  3. for (;;) {// 自旋减1
  4. int c = getState();
  5. if (c == 0)
  6. return false;
  7. int nextc = c-1;
  8. if (compareAndSetState(c, nextc))
  9. return nextc == 0;// nextc减到为0时,返回true
  10. }
  11. }

即当最后一次进行 countDown() 操作时 state 为 1,即 c 为 1,则 nextc 为 0,进行 CAS 操作后,state 变为 0,返回 true,

则执行 doReleaseShared() 方法。

第2步:执行doReleaseShared():方法释放共享锁,唤醒调用 await() 等待线程。

  1. private void doReleaseShared() {
  2. /*
  3. * Ensure that a release propagates, even if there are other
  4. * in-progress acquires/releases. This proceeds in the usual
  5. * way of trying to unparkSuccessor of head if it needs
  6. * signal. But if it does not, status is set to PROPAGATE to
  7. * ensure that upon release, propagation continues.
  8. * Additionally, we must loop in case a new node is added
  9. * while we are doing this. Also, unlike other uses of
  10. * unparkSuccessor, we need to know if CAS to reset status
  11. * fails, if so rechecking.
  12. */
  13. for (;;) {
  14. // 获取头结点
  15. Node h = head;
  16. // 判断头结点不为空,并且不是尾节点,则进入if逻辑
  17. if (h != null && h != tail) {
  18. int ws = h.waitStatus;
  19. if (ws == Node.SIGNAL) {// 头结点的状态为Node.SIGNAL
  20. if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0))
  21. continue; // loop to recheck cases
  22. unparkSuccessor(h);// 唤醒头节点的后续节点线程
  23. }
  24. else if (ws == 0 &&
  25. !compareAndSetWaitStatus(h, 0, Node.PROPAGATE))
  26. continue; // loop on failed CAS
  27. }
  28. // 队列里面只有头结点时,退出锁的循环释放
  29. if (h == head) // loop if head changed
  30. break;
  31. }
  32. }

unparkSuccessor(Node node):唤醒后续节点线程。

  1. /**
  2. * Wakes up node's successor, if one exists.
  3. *
  4. * @param node the node
  5. */
  6. private void unparkSuccessor(Node node) {
  7. /*
  8. * If status is negative (i.e., possibly needing signal) try
  9. * to clear in anticipation of signalling. It is OK if this
  10. * fails or if status is changed by waiting thread.
  11. */
  12. int ws = node.waitStatus;
  13. if (ws < 0)
  14. compareAndSetWaitStatus(node, ws, 0);
  15. /*
  16. * Thread to unpark is held in successor, which is normally
  17. * just the next node. But if cancelled or apparently null,
  18. * traverse backwards from tail to find the actual
  19. * non-cancelled successor.
  20. */
  21. // node是外层传入的头节点,s为头节点的后继节点
  22. Node s = node.next;
  23. if (s == null || s.waitStatus > 0) {
  24. s = null;
  25. for (Node t = tail; t != null && t != node; t = t.prev)
  26. if (t.waitStatus <= 0)
  27. s = t;
  28. }
  29. if (s != null)
  30. // 唤醒线程
  31. LockSupport.unpark(s.thread);
  32. }

主线程一开始被构建在 Node 节点中作为成员变量,被 LockSupport.park 暂停了,这里当 state 为 0 时获取锁到锁,

通过 LockSupport.unpark 唤醒主线程,当线程唤醒后,调用 await() 的线程会继续执行,去获取到锁,继续执行代码。

四 CountDownLatch 总结

1、CountDownLatch 主要用于主线程等待子线程执行完成,恢复主线程继续执行的场景。

2、A 操作调用 countDown() 减少计数器数值,M 调用 await() 一直等待,直到 countDown() 将 state 减为 0 时恢复主线程执行。

发表评论

表情:
评论列表 (有 0 条评论,391人围观)

还没有评论,来说两句吧...

相关阅读

    相关 并发工具类CountDownLatch分析

    > 同步工具类可以使任何一种对象,只要该对象可以根据自身的状态来协调控制线程的控制流。阻塞队列可以作为同步工具类,其他类型的同步工具类还包括:信号量(Semaphore)、栅栏