死磕 java线程系列之ForkJoinPool深入解析

喜欢ヅ旅行 2023-05-30 09:13 100阅读 0赞

forkjoinpool

(手机横屏看源码更方便)

#

注:java源码分析部分如无特殊说明均基于 java8 版本。

注:本文基于ForkJoinPool分治线程池类。

简介

随着在硬件上多核处理器的发展和广泛使用,并发编程成为程序员必须掌握的一门技术,在面试中也经常考查面试者并发相关的知识。

今天,我们就来看一道面试题:

如何充分利用多核CPU,计算很大数组中所有整数的和?

剖析

  • 单线程相加?

我们最容易想到就是单线程相加,一个for循环搞定。

  • 线程池相加?

如果进一步优化,我们会自然而然地想到使用线程池来分段相加,最后再把每个段的结果相加。

  • 其它?

Yes,就是我们今天的主角——ForkJoinPool,但是它要怎么实现呢?似乎没怎么用过哈^^

三种实现

OK,剖析完了,我们直接来看三种实现,不墨迹,直接上菜。

  1. /**
  2. * 计算1亿个整数的和
  3. */
  4. public class ForkJoinPoolTest01 {
  5. public static void main(String[] args) throws ExecutionException, InterruptedException {
  6. // 构造数据
  7. int length = 100000000;
  8. long[] arr = new long[length];
  9. for (int i = 0; i < length; i ) {
  10. arr[i] = ThreadLocalRandom.current().nextInt(Integer.MAX_VALUE);
  11. }
  12. // 单线程
  13. singleThreadSum(arr);
  14. // ThreadPoolExecutor线程池
  15. multiThreadSum(arr);
  16. // ForkJoinPool线程池
  17. forkJoinSum(arr);
  18. }
  19. private static void singleThreadSum(long[] arr) {
  20. long start = System.currentTimeMillis();
  21. long sum = 0;
  22. for (int i = 0; i < arr.length; i ) {
  23. // 模拟耗时,本文由公从号“彤哥读源码”原创
  24. sum = (arr[i]/3*3/3*3/3*3/3*3/3*3);
  25. }
  26. System.out.println("sum: " sum);
  27. System.out.println("single thread elapse: " (System.currentTimeMillis() - start));
  28. }
  29. private static void multiThreadSum(long[] arr) throws ExecutionException, InterruptedException {
  30. long start = System.currentTimeMillis();
  31. int count = 8;
  32. ExecutorService threadPool = Executors.newFixedThreadPool(count);
  33. List<Future<Long>> list = new ArrayList<>();
  34. for (int i = 0; i < count; i ) {
  35. int num = i;
  36. // 分段提交任务
  37. Future<Long> future = threadPool.submit(() -> {
  38. long sum = 0;
  39. for (int j = arr.length / count * num; j < (arr.length / count * (num 1)); j ) {
  40. try {
  41. // 模拟耗时
  42. sum = (arr[j]/3*3/3*3/3*3/3*3/3*3);
  43. } catch (Exception e) {
  44. e.printStackTrace();
  45. }
  46. }
  47. return sum;
  48. });
  49. list.add(future);
  50. }
  51. // 每个段结果相加
  52. long sum = 0;
  53. for (Future<Long> future : list) {
  54. sum = future.get();
  55. }
  56. System.out.println("sum: " sum);
  57. System.out.println("multi thread elapse: " (System.currentTimeMillis() - start));
  58. }
  59. private static void forkJoinSum(long[] arr) throws ExecutionException, InterruptedException {
  60. long start = System.currentTimeMillis();
  61. ForkJoinPool forkJoinPool = ForkJoinPool.commonPool();
  62. // 提交任务
  63. ForkJoinTask<Long> forkJoinTask = forkJoinPool.submit(new SumTask(arr, 0, arr.length));
  64. // 获取结果
  65. Long sum = forkJoinTask.get();
  66. forkJoinPool.shutdown();
  67. System.out.println("sum: " sum);
  68. System.out.println("fork join elapse: " (System.currentTimeMillis() - start));
  69. }
  70. private static class SumTask extends RecursiveTask<Long> {
  71. private long[] arr;
  72. private int from;
  73. private int to;
  74. public SumTask(long[] arr, int from, int to) {
  75. this.arr = arr;
  76. this.from = from;
  77. this.to = to;
  78. }
  79. @Override
  80. protected Long compute() {
  81. // 小于1000的时候直接相加,可灵活调整
  82. if (to - from <= 1000) {
  83. long sum = 0;
  84. for (int i = from; i < to; i ) {
  85. // 模拟耗时
  86. sum = (arr[i]/3*3/3*3/3*3/3*3/3*3);
  87. }
  88. return sum;
  89. }
  90. // 分成两段任务,本文由公从号“彤哥读源码”原创
  91. int middle = (from to) / 2;
  92. SumTask left = new SumTask(arr, from, middle);
  93. SumTask right = new SumTask(arr, middle, to);
  94. // 提交左边的任务
  95. left.fork();
  96. // 右边的任务直接利用当前线程计算,节约开销
  97. Long rightResult = right.compute();
  98. // 等待左边计算完毕
  99. Long leftResult = left.join();
  100. // 返回结果
  101. return leftResult rightResult;
  102. }
  103. }
  104. }

彤哥偷偷地告诉你,实际上计算1亿个整数相加,单线程是最快的,我的电脑大概是100ms左右,使用线程池反而会变慢。

所以,为了演示ForkJoinPool的牛逼之处,我把每个数都/3*3/3*3/3*3/3*3/3*3了一顿操作,用来模拟计算耗时。

来看结果:

  1. sum: 107352457433800662
  2. single thread elapse: 789
  3. sum: 107352457433800662
  4. multi thread elapse: 228
  5. sum: 107352457433800662
  6. fork join elapse: 189

可以看到,ForkJoinPool相对普通线程池还是有很大提升的。

问题:普通线程池能否实现ForkJoinPool这种计算方式呢,即大任务拆中任务,中任务拆小任务,最后再汇总?

forkjoinpool

你可以试试看(-᷅_-᷄)

OK,下面我们正式进入ForkJoinPool的解析。

分治法

  • 基本思想

把一个规模大的问题划分为规模较小的子问题,然后分而治之,最后合并子问题的解得到原问题的解。

  • 步骤

(1)分割原问题:

(2)求解子问题:

(3)合并子问题的解为原问题的解。

在分治法中,子问题一般是相互独立的,因此,经常通过递归调用算法来求解子问题。

  • 典型应用场景

(1)二分搜索

(2)大整数乘法

(3)Strassen矩阵乘法

(4)棋盘覆盖

(5)归并排序

(6)快速排序

(7)线性时间选择

(8)汉诺塔

ForkJoinPool继承体系

ForkJoinPool是 java 7 中新增的线程池类,它的继承体系如下:

forkjoinpool

ForkJoinPool和ThreadPoolExecutor都是继承自AbstractExecutorService抽象类,所以它和ThreadPoolExecutor的使用几乎没有多少区别,除了任务变成了ForkJoinTask以外。

这里又运用到了一种很重要的设计原则——开闭原则——对修改关闭,对扩展开放。

可见整个线程池体系一开始的接口设计就很好,新增一个线程池类,不会对原有的代码造成干扰,还能利用原有的特性。

ForkJoinTask

两个主要方法

  • fork()

fork()方法类似于线程的Thread.start()方法,但是它不是真的启动一个线程,而是将任务放入到工作队列中。

  • join()

join()方法类似于线程的Thread.join()方法,但是它不是简单地阻塞线程,而是利用工作线程运行其它任务。当一个工作线程中调用了join()方法,它将处理其它任务,直到注意到目标子任务已经完成了。

三个子类

  • RecursiveAction

无返回值任务。

  • RecursiveTask

有返回值任务。

  • CountedCompleter

无返回值任务,完成任务后可以触发回调。

ForkJoinPool内部原理

ForkJoinPool内部使用的是“工作窃取”算法实现的。

forkjoinpool

(1)每个工作线程都有自己的工作队列WorkQueue;

(2)这是一个双端队列,它是线程私有的;

(3)ForkJoinTask中fork的子任务,将放入运行该任务的工作线程的队头,工作线程将以LIFO的顺序来处理工作队列中的任务;

(4)为了最大化地利用CPU,空闲的线程将从其它线程的队列中“窃取”任务来执行;

(5)从工作队列的尾部窃取任务,以减少竞争;

(6)双端队列的操作:push()/pop()仅在其所有者工作线程中调用,poll()是由其它线程窃取任务时调用的;

(7)当只剩下最后一个任务时,还是会存在竞争,是通过CAS来实现的;

forkjoinpool

ForkJoinPool最佳实践

(1)最适合的是计算密集型任务,本文由公从号“彤哥读源码”原创;

(2)在需要阻塞工作线程时,可以使用ManagedBlocker;

(3)不应该在RecursiveTask 的内部使用ForkJoinPool.invoke()/invokeAll();

总结

(1)ForkJoinPool特别适合于“分而治之”算法的实现;

(2)ForkJoinPool和ThreadPoolExecutor是互补的,不是谁替代谁的关系,二者适用的场景不同;

(3)ForkJoinTask有两个核心方法——fork()和join(),有三个重要子类——RecursiveAction、RecursiveTask和CountedCompleter;

(4)ForkjoinPool内部基于“工作窃取”算法实现;

(5)每个线程有自己的工作队列,它是一个双端队列,自己从队列头存取任务,其它线程从尾部窃取任务;

(6)ForkJoinPool最适合于计算密集型任务,但也可以使用ManagedBlocker以便用于阻塞型任务;

(7)RecursiveTask内部可以少调用一次fork(),利用当前线程处理,这是一种技巧;

彩蛋

ManagedBlocker怎么使用?

答:ManagedBlocker相当于明确告诉ForkJoinPool框架要阻塞了,ForkJoinPool就会启另一个线程来运行任务,以最大化地利用CPU。

请看下面的例子,自己琢磨哈^^。

  1. /**
  2. * 斐波那契数列
  3. * 一个数是它前面两个数之和
  4. * 1,1,2,3,5,8,13,21
  5. */
  6. public class Fibonacci {
  7. public static void main(String[] args) {
  8. long time = System.currentTimeMillis();
  9. Fibonacci fib = new Fibonacci();
  10. int result = fib.f(1_000).bitCount();
  11. time = System.currentTimeMillis() - time;
  12. System.out.println("result,本文由公从号“彤哥读源码”原创 = " result);
  13. System.out.println("test1_000() time = " time);
  14. }
  15. public BigInteger f(int n) {
  16. Map<Integer, BigInteger> cache = new ConcurrentHashMap<>();
  17. cache.put(0, BigInteger.ZERO);
  18. cache.put(1, BigInteger.ONE);
  19. return f(n, cache);
  20. }
  21. private final BigInteger RESERVED = BigInteger.valueOf(-1000);
  22. public BigInteger f(int n, Map<Integer, BigInteger> cache) {
  23. BigInteger result = cache.putIfAbsent(n, RESERVED);
  24. if (result == null) {
  25. int half = (n 1) / 2;
  26. RecursiveTask<BigInteger> f0_task = new RecursiveTask<BigInteger>() {
  27. @Override
  28. protected BigInteger compute() {
  29. return f(half - 1, cache);
  30. }
  31. };
  32. f0_task.fork();
  33. BigInteger f1 = f(half, cache);
  34. BigInteger f0 = f0_task.join();
  35. long time = n > 10_000 ? System.currentTimeMillis() : 0;
  36. try {
  37. if (n % 2 == 1) {
  38. result = f0.multiply(f0).add(f1.multiply(f1));
  39. } else {
  40. result = f0.shiftLeft(1).add(f1).multiply(f1);
  41. }
  42. synchronized (RESERVED) {
  43. cache.put(n, result);
  44. RESERVED.notifyAll();
  45. }
  46. } finally {
  47. time = n > 10_000 ? System.currentTimeMillis() - time : 0;
  48. if (time > 50)
  49. System.out.printf("f(%d) took %d%n", n, time);
  50. }
  51. } else if (result == RESERVED) {
  52. try {
  53. ReservedFibonacciBlocker blocker = new ReservedFibonacciBlocker(n, cache);
  54. ForkJoinPool.managedBlock(blocker);
  55. result = blocker.result;
  56. } catch (InterruptedException e) {
  57. throw new CancellationException("interrupted");
  58. }
  59. }
  60. return result;
  61. // return f(n - 1).add(f(n - 2));
  62. }
  63. private class ReservedFibonacciBlocker implements ForkJoinPool.ManagedBlocker {
  64. private BigInteger result;
  65. private final int n;
  66. private final Map<Integer, BigInteger> cache;
  67. public ReservedFibonacciBlocker(int n, Map<Integer, BigInteger> cache) {
  68. this.n = n;
  69. this.cache = cache;
  70. }
  71. @Override
  72. public boolean block() throws InterruptedException {
  73. synchronized (RESERVED) {
  74. while (!isReleasable()) {
  75. RESERVED.wait();
  76. }
  77. }
  78. return true;
  79. }
  80. @Override
  81. public boolean isReleasable() {
  82. return (result = cache.get(n)) != RESERVED;
  83. }
  84. }
  85. }

#

欢迎关注我的公众号“彤哥读源码”,查看更多源码系列文章, 与彤哥一起畅游源码的海洋。

qrcode

发表评论

表情:
评论列表 (有 0 条评论,100人围观)

还没有评论,来说两句吧...

相关阅读