Java並發編程的藝術——並發工具類CountDownLatch詳解

一個即將被退役的碼農 發佈 2022-10-02T08:19:48.912596+00:00

CountDownLatch底層也是由AQS,用來同步一個或多個任務的常用並發工具類,強制它們等待由其他任務執行的一組操作完成帶著BAT大廠的面試問題去理解請帶著這些問題繼續後文,會很大程度上幫助你更好地理解相關知識點。什麼是CountDownLatch?

CountDownLatch底層也是由AQS,用來同步一個或多個任務的常用並發工具類,強制它們等待由其他任務執行的一組操作完成

帶著BAT大廠的面試問題去理解

請帶著這些問題繼續後文,會很大程度上幫助你更好地理解相關知識點。

  • 什麼是CountDownLatch?
  • CountDownLatch底層實現原理?
  • CountDownLatch一次可以喚醒幾個任務? 多個
  • CountDownLatch有哪些主要方法? await(),countDown()
  • CountDownLatch適用於什麼場景?
  • 寫道題:實現一個容器,提供兩個方法,add,size 寫兩個線程,線程1添加10個元素到容器中,線程2實現監控元素的個數,當個數到5個時,線程2給出提示並結束? 使用CountDownLatch 代替wait notify 好處。

CountDownLatch介紹

從源碼可知,其底層是由AQS提供支持,所以其數據結構可以參考AQS的數據結構,而AQS的數據結構核心就是兩個虛擬隊列: 同步隊列sync queue 和條件隊列condition queue,不同的條件會有不同的條件隊列。CountDownLatch典型的用法是將一個程序分為n個互相獨立的可解決任務,並創建值為n的CountDownLatch。當每一個任務完成時,都會在這個鎖存器上調用countDown,等待問題被解決的任務調用這個鎖存器的await,將他們自己攔住,直至鎖存器計數結束。

CountDownLatch源碼分析

類的繼承關係

CountDownLatch沒有顯示繼承哪個父類或者實現哪個父接口, 它底層是AQS是通過內部類Sync來實現的。

public class CountDownLatch {

類的內部類

CountDownLatch類存在一個內部類Sync,繼承自AbstractQueuedSynchronizer,其原始碼如下。

private static final class Sync extends AbstractQueuedSynchronizer {
    // 版本號
    private static final long serialVersionUID = 4982264981922014374L;
    
    // 構造器
    Sync(int count) {
        setstate(count);
    }
    
    // 返回當前計數
    int getCount() {
        return getState();
    }

    // 試圖在共享模式下獲取對象狀態
    protected int tryAcquireShared(int acquires) {
        return (getState() == 0) ? 1 : -1;
    }

    // 試圖設置狀態來反映共享模式下的一個釋放
    protected boolean tryReleaseShared(int releases) {
        // Decrement count; signal when transition to zero
        // 無限循環
        for (;;) {
            // 獲取狀態
            int c = getState();
            if (c == 0) // 沒有被線程占有
                return false;
            // 下一個狀態
            int nextc = c-1;
            if (compareAndSetState(c, nextc)) // 比較並且設置成功
                return nextc == 0;
        }
    }
}
    

說明: 對CountDownLatch方法的調用會轉發到對Sync或AQS的方法的調用,所以,AQS對CountDownLatch提供支持。

類的屬性

可以看到CountDownLatch類的內部只有一個Sync類型的屬性:

public class CountDownLatch {
    // 同步隊列
    private final Sync sync;
}

類的構造函數

public CountDownLatch(int count) {
    if (count < 0) throw new IllegalArgumentException("count < 0");
    // 初始化狀態數
    this.sync = new Sync(count);
}

說明: 該構造函數可以構造一個用給定計數初始化的CountDownLatch,並且構造函數內完成了sync的初始化,並設置了狀態數。

核心函數 - await函數

此函數將會使當前線程在鎖存器倒計數至零之前一直等待,除非線程被中斷。其源碼如下

public void await() throws InterruptedException {
    // 轉發到sync對象上
    sync.acquireSharedInterruptibly(1);
}

說明: 由源碼可知,對CountDownLatch對象的await的調用會轉發為對Sync的acquireSharedInterruptibly(從AQS繼承的方法)方法的調用。

  • acquireSharedInterruptibly源碼如下:
public final void acquireSharedInterruptibly(int arg)
        throws InterruptedException {
    if (Thread.interrupted())
        throw new InterruptedException();
    if (tryAcquireShared(arg) < 0)
        doAcquireSharedInterruptibly(arg);
}

說明: 從源碼中可知,acquireSharedInterruptibly又調用了CountDownLatch的內部類Sync的tryAcquireShared和AQS的doAcquireSharedInterruptibly函數。

  • tryAcquireShared函數的源碼如下:
protected int tryAcquireShared(int acquires) {
    return (getState() == 0) ? 1 : -1;
}

說明: 該函數只是簡單的判斷AQS的state是否為0,為0則返回1,不為0則返回-1。

  • doAcquireSharedInterruptibly函數的源碼如下:
private void doAcquireSharedInterruptibly(int arg) throws InterruptedException {
    // 添加節點至等待隊列
    final Node node = addWaiter(Node.SHARED);
    boolean failed = true;
    try {
        for (;;) { // 無限循環
            // 獲取node的前驅節點
            final Node p = node.predecessor();
            if (p == head) { // 前驅節點為頭節點
                // 試圖在共享模式下獲取對象狀態
                int r = tryAcquireShared(arg);
                if (r >= 0) { // 獲取成功
                    // 設置頭節點並進行繁殖
                    setHeadAndPropagate(node, r);
                    // 設置節點next域
                    p.next = null; // help GC
                    failed = false;
                    return;
                }
            }
            if (shouldParkAfterFailedAcquire(p, node) &&
                parkAndCheckInterrupt()) // 在獲取失敗後是否需要禁止線程並且進行中斷檢查
                // 拋出異常
                throw new InterruptedException();
        }
    } finally {
        if (failed)
            cancelAcquire(node);
    }
}

說明: 在AQS的doAcquireSharedInterruptibly中可能會再次調用CountDownLatch的內部類Sync的tryAcquireShared方法和AQS的setHeadAndPropagate方法。

  • setHeadAndPropagate方法源碼如下。
private void setHeadAndPropagate(Node node, int propagate) {
    // 獲取頭節點
    Node h = head; // Record old head for check below
    // 設置頭節點
    setHead(node);
    /*
        * Try to signal next queued node if:
        *   Propagation was indicated by caller,
        *     or was recorded (as h.waitStatus either before
        *     or after setHead) by a previous operation
        *     (note: this uses sign-check of waitStatus because
        *      PROPAGATE status may transition to SIGNAL.)
        * and
        *   The next node is waiting in shared mode,
        *     or we don't know, because it appears null
        *
        * The conservatism in both of these checks may cause
        * unnecessary wake-ups, but only when there are multiple
        * racing acquires/releases, so most need signals now or soon
        * anyway.
        */
    // 進行判斷
    if (propagate > 0 || h == null || h.waitStatus < 0 ||
        (h = head) == null || h.waitStatus < 0) {
        // 獲取節點的後繼
        Node s = node.next;
        if (s == null || s.isShared()) // 後繼為空或者為共享模式
            // 以共享模式進行釋放
            doReleaseShared();
    }
}

說明: 該方法設置頭節點並且釋放頭節點後面的滿足條件的結點,該方法中可能會調用到AQS的doReleaseShared方法,其源碼如下。

private void doReleaseShared() {
    /*
        * Ensure that a release propagates, even if there are other
        * in-progress acquires/releases.  This proceeds in the usual
        * way of trying to unparkSuccessor of head if it needs
        * signal. But if it does not, status is set to PROPAGATE to
        * ensure that upon release, propagation continues.
        * Additionally, we must loop in case a new node is added
        * while we are doing this. Also, unlike other uses of
        * unparkSuccessor, we need to know if CAS to reset status
        * fails, if so rechecking.
        */
    // 無限循環
    for (;;) {
        // 保存頭節點
        Node h = head;
        if (h != null && h != tail) { // 頭節點不為空並且頭節點不為尾結點
            // 獲取頭節點的等待狀態
            int ws = h.waitStatus; 
            if (ws == Node.SIGNAL) { // 狀態為SIGNAL
                if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0)) // 不成功就繼續
                    continue;            // loop to recheck cases
                // 釋放後繼結點
                unparkSuccessor(h);
            }
            else if (ws == 0 &&
                        !compareAndSetWaitStatus(h, 0, Node.PROPAGATE)) // 狀態為0並且不成功,繼續
                continue;                // loop on failed CAS
        }
        if (h == head) // 若頭節點改變,繼續循環  
            break;
    }
}

說明: 該方法在共享模式下釋放,具體的流程再之後會通過一個示例給出。

所以,對CountDownLatch的await調用大致會有如下的調用鏈。

說明: 上圖給出了可能會調用到的主要方法,並非一定會調用到,之後,會通過一個示例給出詳細的分析。

核心函數 - countDown函數

此函數將遞減鎖存器的計數,如果計數到達零,則釋放所有等待的線程

public void countDown() {
    sync.releaseShared(1);
}
    

說明: 對countDown的調用轉換為對Sync對象的releaseShared(從AQS繼承而來)方法的調用。

  • releaseShared源碼如下
public final boolean releaseShared(int arg) {
    if (tryReleaseShared(arg)) {
        doReleaseShared();
        return true;
    }
    return false;
}

說明: 此函數會以共享模式釋放對象,並且在函數中會調用到CountDownLatch的tryReleaseShared函數,並且可能會調用AQS的doReleaseShared函數。

  • tryReleaseShared源碼如下
protected boolean tryReleaseShared(int releases) {
    // Decrement count; signal when transition to zero
    // 無限循環
    for (;;) {
        // 獲取狀態
        int c = getState();
        if (c == 0) // 沒有被線程占有
            return false;
        // 下一個狀態
        int nextc = c-1;
        if (compareAndSetState(c, nextc)) // 比較並且設置成功
            return nextc == 0;
    }
}

說明: 此函數會試圖設置狀態來反映共享模式下的一個釋放。具體的流程在下面的示例中會進行分析。

  • AQS的doReleaseShared的源碼如下
private void doReleaseShared() {
    /*
        * Ensure that a release propagates, even if there are other
        * in-progress acquires/releases.  This proceeds in the usual
        * way of trying to unparkSuccessor of head if it needs
        * signal. But if it does not, status is set to PROPAGATE to
        * ensure that upon release, propagation continues.
        * Additionally, we must loop in case a new node is added
        * while we are doing this. Also, unlike other uses of
        * unparkSuccessor, we need to know if CAS to reset status
        * fails, if so rechecking.
        */
    // 無限循環
    for (;;) {
        // 保存頭節點
        Node h = head;
        if (h != null && h != tail) { // 頭節點不為空並且頭節點不為尾結點
            // 獲取頭節點的等待狀態
            int ws = h.waitStatus; 
            if (ws == Node.SIGNAL) { // 狀態為SIGNAL
                if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0)) // 不成功就繼續
                    continue;            // loop to recheck cases
                // 釋放後繼結點
                unparkSuccessor(h);
            }
            else if (ws == 0 &&
                        !compareAndSetWaitStatus(h, 0, Node.PROPAGATE)) // 狀態為0並且不成功,繼續
                continue;                // loop on failed CAS
        }
        if (h == head) // 若頭節點改變,繼續循環  
            break;
    }
}

說明: 此函數在共享模式下釋放資源。

所以,對CountDownLatch的countDown調用大致會有如下的調用鏈。

說明: 上圖給出了可能會調用到的主要方法,並非一定會調用到,之後,會通過一個示例給出詳細的分析。

CountDownLatch示例

下面給出了一個使用CountDownLatch的示例。

import java.util.concurrent.CountDownLatch;

class MyThread extends Thread {
    private CountDownLatch countDownLatch;
    
    public MyThread(String name, CountDownLatch countDownLatch) {
        super(name);
        this.countDownLatch = countDownLatch;
    }
    
    public void run() {
        System.out.println(Thread.currentThread().getName() + " doing something");
        try {
            Thread.sleep(1000);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        System.out.println(Thread.currentThread().getName() + " finish");
        countDownLatch.countDown();
    }
}

public class CountDownLatchDemo {
    public static void main(String[] args) {
        CountDownLatch countDownLatch = new CountDownLatch(2);
        MyThread t1 = new MyThread("t1", countDownLatch);
        MyThread t2 = new MyThread("t2", countDownLatch);
        t1.start();
        t2.start();
        System.out.println("Waiting for t1 thread and t2 thread to finish");
        try {
            countDownLatch.await();
        } catch (InterruptedException e) {
            e.printStackTrace();
        }            
        System.out.println(Thread.currentThread().getName() + " continue");        
    }
}

運行結果(某一次):

Waiting for t1 thread and t2 thread to finish
t1 doing something
t2 doing something
t1 finish
t2 finish
main continue

說明: 本程序首先計數器初始化為2。根據結果,可能會存在如下的一種時序圖。

說明: 首先main線程會調用await操作,此時main線程會被阻塞,等待被喚醒,之後t1線程執行了countDown操作,最後,t2線程執行了countDown操作,此時main線程就被喚醒了,可以繼續運行。下面,進行詳細分析。

  • main線程執行countDownLatch.await操作,主要調用的函數如下。

說明: 在最後,main線程就被park了,即禁止運行了。此時Sync queue(同步隊列)中有兩個節點,AQS的state為2,包含main線程的結點的nextWaiter指向SHARED結點。

  • t1線程執行countDownLatch.countDown操作,主要調用的函數如下。

說明: 此時,Sync queue隊列里的結點個數未發生變化,但是此時,AQS的state已經變為1了。

  • t2線程執行countDownLatch.countDown操作,主要調用的函數如下。

說明: 經過調用後,AQS的state為0,並且此時,main線程會被unpark,可以繼續運行。當main線程獲取cpu資源後,繼續運行。

  • main線程獲取cpu資源,繼續運行,由於main線程是在parkAndCheckInterrupt函數中被禁止的,所以此時,繼續在parkAndCheckInterrupt函數運行。

說明: main線程恢復,繼續在parkAndCheckInterrupt函數中運行,之後又會回到最終達到的狀態為AQS的state為0,並且head與tail指向同一個結點,該節點的額nextWaiter域還是指向SHARED結點。

更深入理解

寫道面試題

實現一個容器,提供兩個方法,add,size 寫兩個線程,線程1添加10個元素到容器中,線程2實現監控元素的個數,當個數到5個時,線程2給出提示並結束.

使用wait和notify實現

import java.util.ArrayList;
import java.util.List;

/**
 *  必須先讓t2先進行啟動 使用wait 和 notify 進行相互通訊,wait會釋放鎖,notify不會釋放鎖
 */
public class T2 {

 volatile   List list = new ArrayList();

    public void add (int i){
        list.add(i);
    }

    public int getSize(){
        return list.size();
    }

    public static void main(String[] args) {

        T2 t2 = new T2();

        Object lock = new Object();

        new Thread(() -> {
            synchronized(lock){
                System.out.println("t2 啟動");
                if(t2.getSize() != 5){
                    try {
                        /**會釋放鎖*/
                        lock.wait();
                        System.out.println("t2 結束");
                    } catch (InterruptedException e) {
                        e.printStackTrace();
                    }
                }
                lock.notify();
            }
        },"t2").start();

        new Thread(() -> {
           synchronized (lock){
               System.out.println("t1 啟動");
               for (int i=0;i<9;i++){
                   t2.add(i);
                   System.out.println("add"+i);
                   if(t2.getSize() == 5){
                       /**不會釋放鎖*/
                       lock.notify();
                       try {
                           lock.wait();
                       } catch (InterruptedException e) {
                           e.printStackTrace();
                       }
                   }
               }
           }
        }).start();
    }
}
    

輸出:

t2 啟動
t1 啟動
add0
add1
add2
add3
add4
t2 結束
add5
add6
add7
add8

CountDownLatch實現

說出使用CountDownLatch 代替wait notify 好處?

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CountDownLatch;

/**
 * 使用CountDownLatch 代替wait notify 好處是通訊方式簡單,不涉及鎖定  Count 值為0時當前線程繼續執行,
 */
public class T3 {

   volatile List list = new ArrayList();

    public void add(int i){
        list.add(i);
    }

    public int getSize(){
        return list.size();
    }


    public static void main(String[] args) {
        T3 t = new T3();
        CountDownLatch countDownLatch = new CountDownLatch(1);

        new Thread(() -> {
            System.out.println("t2 start");
           if(t.getSize() != 5){
               try {
                   countDownLatch.await();
                   System.out.println("t2 end");
               } catch (InterruptedException e) {
                   e.printStackTrace();
               }
           }
        },"t2").start();

        new Thread(()->{
            System.out.println("t1 start");
           for (int i = 0;i<9;i++){
               t.add(i);
               System.out.println("add"+ i);
               if(t.getSize() == 5){
                   System.out.println("countdown is open");
                   countDownLatch.countDown();
               }
           }
            System.out.println("t1 end");
        },"t1").start();
    }
}
關鍵字: