JUC通信工具
CountDownLatch
一般用来是线程A等待指定多个线程结束某种条件,使用较为简单
原理
内部是AQS实现,原理非常简单,内部实现SYNC
CountDownLatch 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 public class CountDownLatch { private static final class Sync extends AbstractQueuedSynchronizer { private static final long serialVersionUID = 4982264981922014374L ; Sync(int count) { setState(count); } int getCount () { return getState(); } protected int tryAcquireShared (int acquires) { return (getState() == 0 ) ? 1 : -1 ; } protected boolean tryReleaseShared (int releases) { for (;;) { int c = getState(); if (c == 0 ) return false ; int nextc = c - 1 ; if (compareAndSetState(c, nextc)) return nextc == 0 ; } } } private final Sync sync; public CountDownLatch (int count) { if (count < 0 ) throw new IllegalArgumentException("count < 0" ); this .sync = new Sync(count); } public void await () throws InterruptedException { sync.acquireSharedInterruptibly(1 ); } public void countDown () { sync.releaseShared(1 ); } }
CyclicBarrier
用来实现多个线程之间相互等待,在某个点时停止,当所有线程到达该点后再开启所有线程,内部是使用Condition.wait实现的,Cyclic
的语义是当一组线程使用完之后,CyclicBarrier
可以复用,代码用例:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 class Solver { final int N; final float [][] data; final CyclicBarrier barrier; class Worker implements Runnable { int myRow; Worker(int row) { myRow = row; } public void run () { while (!done()) { processRow(myRow); try { barrier.await(); } catch (InterruptedException ex) { return ; } catch (BrokenBarrierException ex) { return ; } } } } public Solver (float [][] matrix) { data = matrix; N = matrix.length; Runnable barrierAction = () -> mergeRows(...); barrier = new CyclicBarrier(N, barrierAction); List<Thread> threads = new ArrayList<>(N); for (int i = 0 ; i < N; i++) { Thread thread = new Thread(new Worker(i)); threads.add(thread); thread.start(); } for (Thread thread : threads) thread.join(); } }}
原理
简要来说就是内部使用ReentrantLock
,以及Condition
,当线程获取到锁时,减少计数器,当计时器>0时,线程陷入waiting,当最后一个线程修改计数器为0之后,就会调用signalAll
唤醒所有等待线程
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 public class CyclicBarrier { private static class Generation { Generation() {} boolean broken; } private final ReentrantLock lock = new ReentrantLock(); private final Condition trip = lock.newCondition(); private final int parties; private final Runnable barrierCommand; private Generation generation = new Generation(); private int count; private void nextGeneration () { trip.signalAll(); count = parties; generation = new Generation(); private void breakBarrier () { generation.broken = true ; count = parties; trip.signalAll(); } private int dowait (boolean timed, long nanos) throws InterruptedException, BrokenBarrierException, TimeoutException { final ReentrantLock lock = this .lock; lock.lock(); try { final Generation g = generation; if (g.broken) throw new BrokenBarrierException(); if (Thread.interrupted()) { breakBarrier(); throw new InterruptedException(); } int index = --count; if (index == 0 ) { boolean ranAction = false ; try { final Runnable command = barrierCommand; if (command != null ) command.run(); ranAction = true ; nextGeneration(); return 0 ; } finally { if (!ranAction) breakBarrier(); } } for (;;) { try { if (!timed) trip.await(); else if (nanos > 0L ) nanos = trip.awaitNanos(nanos); } catch (InterruptedException ie) { if (g == generation && ! g.broken) { breakBarrier(); throw ie; } else { Thread.currentThread().interrupt(); } } if (g.broken) throw new BrokenBarrierException(); if (g != generation) return index; if (timed && nanos <= 0L ) { breakBarrier(); throw new TimeoutException(); } } } finally { lock.unlock(); } } } public CyclicBarrier (int parties, Runnable barrierAction) { if (parties <= 0 ) throw new IllegalArgumentException(); this .parties = parties; this .count = parties; this .barrierCommand = barrierAction; } public int await () throws InterruptedException, BrokenBarrierException { try { return dowait(false , 0L ); } catch (TimeoutException toe) { throw new Error(toe); } } }
Semaphore
此工具用来控制同时能够执行的线程,如连接池限制同时并发执行的线程数,内部使用共享锁实现,代码用例如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 class Pool { private static final int MAX_AVAILABLE = 100 ; private final Semaphore available = new Semaphore(MAX_AVAILABLE, true ); public Object getItem () throws InterruptedException { available.acquire(); return getNextAvailableItem(); } public void putItem (Object x) { if (markAsUnused(x)) available.release(); }
原理
这个类的结构和ReentrantLock
一致,不再赘述
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 public class Semaphore implements java .io .Serializable { private final Sync sync; abstract static class Sync extends AbstractQueuedSynchronizer { private static final long serialVersionUID = 1192457210091910933L ; Sync(int permits) { setState(permits); } final int getPermits () { return getState(); } final int nonfairTryAcquireShared (int acquires) { for (;;) { int available = getState(); int remaining = available - acquires; if (remaining < 0 || compareAndSetState(available, remaining)) return remaining; } } protected final boolean tryReleaseShared (int releases) { for (;;) { int current = getState(); int next = current + releases; if (next < current) throw new Error("Maximum permit count exceeded" ); if (compareAndSetState(current, next)) return true ; } } final void reducePermits (int reductions) { for (;;) { int current = getState(); int next = current - reductions; if (next > current) throw new Error("Permit count underflow" ); if (compareAndSetState(current, next)) return ; } } final int drainPermits () { for (;;) { int current = getState(); if (current == 0 || compareAndSetState(current, 0 )) return current; } } } static final class NonfairSync extends Sync { private static final long serialVersionUID = -2694183684443567898L ; NonfairSync(int permits) { super (permits); } protected int tryAcquireShared (int acquires) { return nonfairTryAcquireShared(acquires); } } static final class FairSync extends Sync { private static final long serialVersionUID = 2014338818796000944L ; FairSync(int permits) { super (permits); } protected int tryAcquireShared (int acquires) { for (;;) { if (hasQueuedPredecessors()) return -1 ; int available = getState(); int remaining = available - acquires; if (remaining < 0 || compareAndSetState(available, remaining)) return remaining; } } } public Semaphore (int permits, boolean fair) { sync = fair ? new FairSync(permits) : new NonfairSync(permits); } public void acquire () throws InterruptedException { sync.acquireSharedInterruptibly(1 ); } public void release () { sync.releaseShared(1 ); } }
Phaser
该工具是JDK7提供的同步工具,可以替代CountDownLatch
和CyclicBarrier
,和后者相比它将能够随意注册和取消parties
,在等待栅栏的时候,可以阻塞,也可以不阻塞.相对于CyclicBarrier
,CyclicBarrier
是所有的parties
到达栅栏之后更新计数器,到达下一代Generation
,并且每一代的参与者parties
是不能变化的;对于Phaser
,用PHASE
这个概念表示阶段,并且会记录每个阶段的值(就是递增值),并且在不同的阶段,参与者可以不同.
数据结构分析
Phaser
采用了父子结构,存在一个root
节点,所有的新Phaser
都持有root
,并且指向其Parent
,这样的作用是因为假设所有的操作都集中在通过一个Phaser
,当有大量参与者parties
的情况会导致内部Cas操作竞争激烈,因此采用如此的结构.
Phaser
内部维护了两个单项队列,被称为Treiber stack
无锁栈,由所有父子Phaser
共享,Phaser
内部wait是通过空转onSpinWait
(JDK9),或者通过该结构实现的LockSupport.park
1 2 3 4 5 屏障A 屏障B ThreadA | ThreadA | ThreadB | ThreadB | ThreadC | ThreadC | ThreadD | |
首先多个屏障这个行为并不是Pahser
独特的,CyclicBarrier
也能完成,只是后者每次参与者是一样的,举个例子,假设3个阶段,第一个有4个参与者,第二次一个参与者退出,第三次增加三个参与者.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 public class MultiplyPhasePhaserTest { static class Task implements Runnable { Phaser phaser; Phaser phaserMain; public Task (Phaser phaser, Phaser phaserMain) { this .phaser = phaser; this .phaserMain = phaserMain; } @Override public void run () { System.out.println(Thread.currentThread().getName() + "--执行阶段" + phaser.getPhase()); if (Thread.currentThread().getName().equals("1" )) { phaser.arriveAndDeregister(); phaserMain.arriveAndAwaitAdvance(); return ; } else { phaser.arriveAndAwaitAdvance(); } System.out.println(Thread.currentThread().getName() + "--执行阶段" + phaser.getPhase()); phaser.arriveAndAwaitAdvance(); phaser.register(); Thread thread = new Thread(() -> { System.out.println(Thread.currentThread().getName() + "--执行阶段" + phaser.getPhase()); phaser.arriveAndAwaitAdvance(); }); thread.setName(Thread.currentThread().getName() + "XXX" ); thread.start(); System.out.println(Thread.currentThread().getName() + "--执行阶段" + phaser.getPhase()); phaser.arriveAndAwaitAdvance(); phaserMain.arriveAndAwaitAdvance(); } } private static int TASK_COUNT = 4 ; @Test public void taskTest () { Phaser phaser = new Phaser(TASK_COUNT); Phaser phaser1 = new Phaser(1 ); for (int i = 0 ; i < TASK_COUNT; i++) { phaser1.register(); Thread thread = new Thread(new Task(phaser, phaser1)); thread.setName(i+"" ); thread.start(); } phaser1.arriveAndAwaitAdvance(); } } 3 --执行阶段0 2 --执行阶段0 1 --执行阶段0 0 --执行阶段0 2 --执行阶段1 0 --执行阶段1 3 --执行阶段1 2 --执行阶段2 0 --执行阶段2 3 --执行阶段2 2 XXX--执行阶段2 3 XXX--执行阶段2 0 XXX--执行阶段2
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 public class Phaser { private volatile long state; private static final int MAX_PARTIES = 0xffff ; private static final int MAX_PHASE = Integer.MAX_VALUE; private static final int PARTIES_SHIFT = 16 ; private static final int PHASE_SHIFT = 32 ; private static final int UNARRIVED_MASK = 0xffff ; private static final long PARTIES_MASK = 0xffff0000L ; private static final long COUNTS_MASK = 0xffffffffL ; private static final long TERMINATION_BIT = 1L << 63 ; private static final int ONE_ARRIVAL = 1 ; private static final int ONE_PARTY = 1 << PARTIES_SHIFT; private static final int ONE_DEREGISTER = ONE_ARRIVAL|ONE_PARTY; private static final int EMPTY = 1 ; private static int unarrivedOf (long s) { int counts = (int )s; return (counts == EMPTY) ? 0 : (counts & UNARRIVED_MASK); } private static int partiesOf (long s) { return (int )s >>> PARTIES_SHIFT; } private static int phaseOf (long s) { return (int )(s >>> PHASE_SHIFT); } private static int arrivedOf (long s) { int counts = (int )s; return (counts == EMPTY) ? 0 : (counts >>> PARTIES_SHIFT) - (counts & UNARRIVED_MASK); } private final Phaser parent; private final Phaser root; private final AtomicReference<QNode> evenQ; private final AtomicReference<QNode> oddQ; }
phaser内部结构可以是如此,需要使用父子结构,并且采用QNode作为阻塞方式
运行逻辑
Phaser的运行逻辑基本就是对于任意一个phaser,如果它内部的所有parites到达了屏障,如果是子phaser
,则通知其父,递归执行,最终是root,则修改state
的phase
部分完成advance
操作,并且会使所有阻塞的线程正确执行.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 public Phaser (Phaser parent, int parties) { if (parties >>> PARTIES_SHIFT != 0 ) throw new IllegalArgumentException("Illegal number of parties" ); int phase = 0 ; this .parent = parent; if (parent != null ) { final Phaser root = parent.root; this .root = root; this .evenQ = root.evenQ; this .oddQ = root.oddQ; if (parties != 0 ) phase = parent.doRegister(1 ); } else { this .root = this ; this .evenQ = new AtomicReference<QNode>(); this .oddQ = new AtomicReference<QNode>(); } this .state = (parties == 0 ) ? (long )EMPTY : ((long )phase << PHASE_SHIFT) | ((long )parties << PARTIES_SHIFT) | ((long )parties); }
当子phaser第一次注册(通过构造器或者第一次调用register)时,都会调用一次parent.doRegister
,原因在于Phaser
的机制是,分层结构,当子Phaser
的所有任务达到屏障点时,会递归调用父类,那么从父类的角度来看它仅仅需要知道子Pahser
就可以了,并不需要知道子Pahser
的任务,所以是1.
1 2 3 _______p__________ | | | | | | s1 s2 s3 s4 s5 s6
假设s1
到s6
都是子Pahser
,那么对于p
节点来说,它的parties
数量就是6,s1
到s6
自身的任务或者子phaser
,p并不需要知道,这是一种递归
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 private int doRegister (int registrations) { long adjust = ((long )registrations << PARTIES_SHIFT) | registrations; final Phaser parent = this .parent; int phase; for (;;) { long s = (parent == null ) ? state : reconcileState(); int counts = (int )s; int parties = counts >>> PARTIES_SHIFT; int unarrived = counts & UNARRIVED_MASK; if (registrations > MAX_PARTIES - parties) throw new IllegalStateException(badRegister(s)); phase = (int )(s >>> PHASE_SHIFT); if (phase < 0 ) break ; if (counts != EMPTY) { if (parent == null || reconcileState() == s) { if (unarrived == 0 ) root.internalAwaitAdvance(phase, null ); else if (STATE.compareAndSet(this , s, s + adjust)) break ; } } else if (parent == null ) { long next = ((long )phase << PHASE_SHIFT) | adjust; if (STATE.compareAndSet(this , s, next)) break ; } else { synchronized (this ) { if (state == s) { phase = parent.doRegister(1 ); if (phase < 0 ) break ; while (!STATE.weakCompareAndSet (this , s, ((long )phase << PHASE_SHIFT) | adjust)) { s = state; phase = (int )(root.state >>> PHASE_SHIFT); } break ; } } } } return phase; } private long reconcileState () { final Phaser root = this .root; long s = state; if (root != this ) { int phase, p; while ((phase = (int )(root.state >>> PHASE_SHIFT)) != (int )(s >>> PHASE_SHIFT) && !STATE.weakCompareAndSet (this , s, s = (((long )phase << PHASE_SHIFT) | ((phase < 0 ) ? (s & COUNTS_MASK) : (((p = (int )s >>> PARTIES_SHIFT) == 0 ) ? EMPTY : ((s & PARTIES_MASK) | p)))))) s = state; } return s; }
逻辑如下:
1. 检测并非首次注册,尝试调整pahse(子),若当前phase的所有parites到达临界点,此注册操作需要等待pahse升级,否则正常修改
2. 是root的首次调用则,正常修改
3. 子phaser的首次则加锁,cas修改
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 public int arriveAndAwaitAdvance () { final Phaser root = this .root; for (;;) { long s = (root == this ) ? state : reconcileState(); int phase = (int )(s >>> PHASE_SHIFT); if (phase < 0 ) return phase; int counts = (int )s; int unarrived = (counts == EMPTY) ? 0 : (counts & UNARRIVED_MASK); if (unarrived <= 0 ) throw new IllegalStateException(badArrive(s)); if (STATE.compareAndSet(this , s, s -= ONE_ARRIVAL)) { if (unarrived > 1 ) return root.internalAwaitAdvance(phase, null ); if (root != this ) return parent.arriveAndAwaitAdvance(); long n = s & PARTIES_MASK; int nextUnarrived = (int )n >>> PARTIES_SHIFT; if (onAdvance(phase, nextUnarrived)) n |= TERMINATION_BIT; else if (nextUnarrived == 0 ) n |= EMPTY; else n |= nextUnarrived; int nextPhase = (phase + 1 ) & MAX_PHASE; n |= (long )nextPhase << PHASE_SHIFT; if (!STATE.compareAndSet(this , s, n)) return (int )(state >>> PHASE_SHIFT); releaseWaiters(phase); return nextPhase; } } }
逻辑如下:
1. 任意非phaser到达会修改未到达数量
2. 当当前parties全部到达则通知parent,若无parent则说明是root,则修改root的state的phase以及重置它的paeties部分,子pahser的state重置会延迟到子调用regiester或者wait相关函数通过reconcileState处理
3. 调用releaseWaiters,释放那些由于node阻塞的线程
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 private int internalAwaitAdvance (int phase, QNode node) { releaseWaiters(phase-1 ); boolean queued = false ; int lastUnarrived = 0 ; int spins = SPINS_PER_ARRIVAL; long s; int p; while ((p = (int )((s = state) >>> PHASE_SHIFT)) == phase) { if (node == null ) { int unarrived = (int )s & UNARRIVED_MASK; if (unarrived != lastUnarrived && (lastUnarrived = unarrived) < NCPU) spins += SPINS_PER_ARRIVAL; boolean interrupted = Thread.interrupted(); if (interrupted || --spins < 0 ) { node = new QNode(this , phase, false , false , 0L ); node.wasInterrupted = interrupted; } else Thread.onSpinWait(); } else if (node.isReleasable()) break ; else if (!queued) { AtomicReference<QNode> head = (phase & 1 ) == 0 ? evenQ : oddQ; QNode q = node.next = head.get(); if ((q == null || q.phase == phase) && (int )(state >>> PHASE_SHIFT) == phase) queued = head.compareAndSet(q, node); } else { try { ForkJoinPool.managedBlock(node); } catch (InterruptedException cantHappen) { node.wasInterrupted = true ; } } } if (node != null ) { if (node.thread != null ) node.thread = null ; if (node.wasInterrupted && !node.interruptible) Thread.currentThread().interrupt(); if (p == phase && (p = (int )(state >>> PHASE_SHIFT)) == phase) return abortWait(phase); } releaseWaiters(phase); return p; } private void releaseWaiters (int phase) { QNode q; Thread t; AtomicReference<QNode> head = (phase & 1 ) == 0 ? evenQ : oddQ; while ((q = head.get()) != null && q.phase != (int )(root.state >>> PHASE_SHIFT)) { if (head.compareAndSet(q, q.next) && (t = q.thread) != null ) { q.thread = null ; LockSupport.unpark(t); } } }
逻辑如下:
internalAwaitAdvance
函数在任何一个子phaser都是root.internalAwaitAdvance
调用,为的就是当root.state
的升级之后,可以把那些自选或阻塞的线程解开