Java並發編程的藝術——ThreadLocal原理和使用

一個即將被退役的碼農 發佈 2022-10-02T07:13:56.661508+00:00

官網的解釋是這樣的:This class provides thread-local variables. These variables differ from their normal counterparts in that each thread that accesses one has its own, independently initialized copy of the variable. {@code ThreadLocal} instances are typically private static fields in classes that wish to associate state with a thread 該類提供了線程局部 變量。

帶著BAT大廠的面試問題去理解

請帶著這些問題繼續後文,會很大程度上幫助你更好地理解相關知識點。

  • 什麼是ThreadLocal? 用來解決什麼問題的?
  • 說說你對ThreadLocal的理解
  • ThreadLocal是如何實現線程隔離的?
  • 為什麼ThreadLocal會造成內存泄露? 如何解決
  • 還有哪些使用ThreadLocal的應用場景?

ThreadLocal簡介

我們在Java 並發 - 並發理論基礎總結過線程安全(是指廣義上的共享資源訪問安全性,因為線程隔離是通過副本保證本線程訪問資源安全性,它不保證線程之間還存在共享關係的狹義上的安全性)的解決思路:

  • 互斥同步: synchronized 和 ReentrantLock
  • 非阻塞同步: CAS, AtomicXXXX
  • 無同步方案: 棧封閉,本地存儲(Thread Local),可重入代碼

這個章節將詳細地講講 本地存儲(Thread Local)。官網的解釋是這樣的:

This class provides thread-local variables. These variables differ from their normal counterparts in that each thread that accesses one (via its {@code get} or {@code set} method) has its own, independently initialized copy of the variable. {@code ThreadLocal} instances are typically private static fields in classes that wish to associate state with a thread (e.g., a user ID or Transaction ID) 該類提供了線程局部 (thread-local) 變量。這些變量不同於它們的普通對應物,因為訪問某個變量(通過其 get 或 set 方法)的每個線程都有自己的局部變量,它獨立於變量的初始化副本。ThreadLocal 實例通常是類中的 private static 欄位,它們希望將狀態與某一個線程(例如,用戶 ID 或事務 ID)相關聯。

總結而言:ThreadLocal是一個將在多線程中為每一個線程創建單獨的變量副本的類; 當使用ThreadLocal來維護變量時, ThreadLocal會為每個線程創建單獨的變量副本, 避免因多線程操作共享變量而導致的數據不一致的情況。

ThreadLocal理解

提到ThreadLocal被提到應用最多的是session管理和資料庫連結管理,這裡以數據訪問為例幫助你理解threadLocal:

  • 如下資料庫管理類在單線程使用是沒有任何問題的
class ConnectionManager {
    private static Connection connect = null;

    public static Connection openConnection() {
        if (connect == null) {
            connect = DriverManager.getConnection();
        }
        return connect;
    }

    public static void closeConnection() {
        if (connect != null)
            connect.close();
    }
}

很顯然,在多線程中使用會存在線程安全問題:第一,這裡面的2個方法都沒有進行同步,很可能在openConnection方法中會多次創建connect;第二,由於connect是共享變量,那麼必然在調用connect的地方需要使用到同步來保障線程安全,因為很可能一個線程在使用connect進行資料庫操作,而另外一個線程調用closeConnection關閉連結。

  • 為了解決上述線程安全的問題,第一考慮:互斥同步

你可能會說,將這段代碼的兩個方法進行同步處理,並且在調用connect的地方需要進行同步處理,比如用Synchronized或者ReentrantLock互斥鎖。

  • 這裡再拋出一個問題:這地方到底需不需要將connect變量進行共享?

事實上,是不需要的。假如每個線程中都有一個connect變量,各個線程之間對connect變量的訪問實際上是沒有依賴關係的,即一個線程不需要關心其他線程是否對這個connect進行了修改的。即改後的代碼可以這樣:

class ConnectionManager {
    private Connection connect = null;

    public Connection openConnection() {
        if (connect == null) {
            connect = DriverManager.getConnection();
        }
        return connect;
    }

    public void closeConnection() {
        if (connect != null)
            connect.close();
    }
}

class Dao {
    public void insert() {
        ConnectionManager connectionManager = new ConnectionManager();
        Connection connection = connectionManager.openConnection();

        // 使用connection進行操作

        connectionManager.closeConnection();
    }
}
    

確實也沒有任何問題,由於每次都是在方法內部創建的連接,那麼線程之間自然不存在線程安全問題。但是這樣會有一個致命的影響:導致伺服器壓力非常大,並且嚴重影響程序執行性能。由於在方法中需要頻繁地開啟和關閉資料庫連接,這樣不僅嚴重影響程序執行效率,還可能導致伺服器壓力巨大。

  • 這時候ThreadLocal登場了

那麼這種情況下使用ThreadLocal是再適合不過的了,因為ThreadLocal在每個線程中對該變量會創建一個副本,即每個線程內部都會有一個該變量,且在線程內部任何地方都可以使用,線程之間互不影響,這樣一來就不存在線程安全問題,也不會嚴重影響程序執行性能。下面就是網上出現最多的例子:

import java.SQL.Connection;
import java.sql.DriverManager;
import java.sql.SQLException;

public class ConnectionManager {

    private static final ThreadLocal<Connection> dbConnectionLocal = new ThreadLocal<Connection>() {
        @Override
        protected Connection initialValue() {
            try {
                return DriverManager.getConnection("", "", "");
            } catch (SQLException e) {
                e.printStackTrace();
            }
            return null;
        }
    };

    public Connection getConnection() {
        return dbConnectionLocal.get();
    }
}
  • 再注意下ThreadLocal的修飾符

ThreaLocal的JDK文檔中說明:ThreadLocal instances are typically private static fields in classes that wish to associate state with a thread。如果我們希望通過某個類將狀態(例如用戶ID、事務ID)與線程關聯起來,那麼通常在這個類中定義private static類型的ThreadLocal 實例。

但是要注意,雖然ThreadLocal能夠解決上面說的問題,但是由於在每個線程中都創建了副本,所以要考慮它對資源的消耗,比如內存的占用會比不使用ThreadLocal要大。

ThreadLocal原理

如何實現線程隔離

主要是用到了Thread對象中的一個ThreadLocalMap類型的變量threadLocals, 負責存儲當前線程的關於Connection的對象, dbConnectionLocal(以上述例子中為例) 這個變量為Key, 以新建的Connection對象為Value; 這樣的話, 線程第一次讀取的時候如果不存在就會調用ThreadLocal的initialValue方法創建一個Connection對象並且返回;

具體關於為線程分配變量副本的代碼如下:

public T get() {
    Thread t = Thread.currentThread();
    ThreadLocalMap threadLocals = getMap(t);
    if (threadLocals != null) {
        ThreadLocalMap.Entry e = threadLocals.getEntry(this);
        if (e != null) {
            @SuppressWarnings("unchecked")
            T result = (T)e.value;
            return result;
        }
    }
    return setInitialValue();
}
  • 首先獲取當前線程對象t, 然後從線程t中獲取到ThreadLocalMap的成員屬性threadLocals
  • 如果當前線程的threadLocals已經初始化(即不為null) 並且存在以當前ThreadLocal對象為Key的值, 則直接返回當前線程要獲取的對象(本例中為Connection);
  • 如果當前線程的threadLocals已經初始化(即不為null)但是不存在以當前ThreadLocal對象為Key的的對象, 那麼重新創建一個Connection對象, 並且添加到當前線程的threadLocals Map中,並返回
  • 如果當前線程的threadLocals屬性還沒有被初始化, 則重新創建一個ThreadLocalMap對象, 並且創建一個Connection對象並添加到ThreadLocalMap對象中並返回。

如果存在則直接返回很好理解, 那麼對於如何初始化的代碼又是怎樣的呢?

private T setInitialValue() {
    T value = initialValue();
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null)
        map.set(this, value);
    else
        createMap(t, value);
    return value;
}
  • 首先調用我們上面寫的重載過後的initialValue方法, 產生一個Connection對象
  • 繼續查看當前線程的threadLocals是不是空的, 如果ThreadLocalMap已被初始化, 那麼直接將產生的對象添加到ThreadLocalMap中, 如果沒有初始化, 則創建並添加對象到其中;

同時, ThreadLocal還提供了直接操作Thread對象中的threadLocals的方法

public void set(T value) {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null)
        map.set(this, value);
    else
        createMap(t, value);
}

這樣我們也可以不實現initialValue, 將初始化工作放到DBConnectionFactory的getConnection方法中:

public Connection getConnection() {
    Connection connection = dbConnectionLocal.get();
    if (connection == null) {
        try {
            connection = DriverManager.getConnection("", "", "");
            dbConnectionLocal.set(connection);
        } catch (SQLException e) {
            e.printStackTrace();
        }
    }
    return connection;
}

那麼我們看過代碼之後就很清晰的知道了為什麼ThreadLocal能夠實現變量的多線程隔離了; 其實就是用了Map的數據結構給當前線程緩存了, 要使用的時候就從本線程的threadLocals對象中獲取就可以了, key就是當前線程;

當然了在當前線程下獲取當前線程裡面的Map裡面的對象並操作肯定沒有線程並發問題了, 當然能做到變量的線程間隔離了;

現在我們知道了ThreadLocal到底是什麼了, 又知道了如何使用ThreadLocal以及其基本實現原理了是不是就可以結束了呢? 其實還有一個問題就是ThreadLocalMap是個什麼對象, 為什麼要用這個對象呢?

ThreadLocalMap對象是什麼

本質上來講, 它就是一個Map, 但是這個ThreadLocalMap與我們平時見到的Map有點不一樣

  • 它沒有實現Map接口;
  • 它沒有public的方法, 最多有一個default的構造方法, 因為這個ThreadLocalMap的方法僅僅在ThreadLocal類中調用, 屬於靜態內部類
  • ThreadLocalMap的Entry實現繼承了WeakReference<ThreadLocal<?>>
  • 該方法僅僅用了一個Entry數組來存儲Key, Value; Entry並不是鍊表形式, 而是每個bucket裡面僅僅放一個Entry;

要了解ThreadLocalMap的實現, 我們先從入口開始, 就是往該Map中添加一個值:

private void set(ThreadLocal<?> key, Object value) {

    // We don't use a fast path as with get() because it is at
    // least as common to use set() to create new entries as
    // it is to replace existing ones, in which case, a fast
    // path would fail more often than not.

    Entry[] tab = table;
    int len = tab.length;
    int i = key.threadLocalHashCode & (len-1);

    for (Entry e = tab[i];
         e != null;
         e = tab[i = nextIndex(i, len)]) {
        ThreadLocal<?> k = e.get();

        if (k == key) {
            e.value = value;
            return;
        }

        if (k == null) {
            replaceStaleEntry(key, value, i);
            return;
        }
    }

    tab[i] = new Entry(key, value);
    int sz = ++size;
    if (!cleanSomeSlots(i, sz) && sz >= threshold)
        rehash();
}

先進行簡單的分析, 對該代碼表層意思進行解讀:

  • 看下當前threadLocal的在數組中的索引位置 比如: i = 2, 看 i = 2 位置上面的元素(Entry)的Key是否等於threadLocal 這個 Key, 如果等於就很好說了, 直接將該位置上面的Entry的Value替換成最新的就可以了;
  • 如果當前位置上面的 Entry 的 Key為空, 說明ThreadLocal對象已經被回收了, 那麼就調用replaceStaleEntry
  • 如果清理完無用條目(ThreadLocal被回收的條目)、並且數組中的數據大小 > 閾值的時候對當前的Table進行重新哈希 所以, 該HashMap是處理衝突檢測的機制是向後移位, 清除過期條目 最終找到合適的位置;

了解完Set方法, 後面就是Get方法了:

private Entry getEntry(ThreadLocal<?> key) {
    int i = key.threadLocalHashCode & (table.length - 1);
    Entry e = table[i];
    if (e != null && e.get() == key)
        return e;
    else
        return getEntryAfterMiss(key, i, e);
}

先找到ThreadLocal的索引位置, 如果索引位置處的entry不為空並且鍵與threadLocal是同一個對象, 則直接返回; 否則去後面的索引位置繼續查找。

ThreadLocal造成內存泄露的問題

網上有這樣一個例子:

import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;

public class ThreadLocalDemo {
    static class LocalVariable {
        private Long[] a = new Long[1024 * 1024];
    }

    // (1)
    final static ThreadPoolExecutor poolExecutor = new ThreadPoolExecutor(5, 5, 1, TimeUnit.MINUTES,
            new LinkedBlockingQueue<>());
    // (2)
    final static ThreadLocal<LocalVariable> localVariable = new ThreadLocal<LocalVariable>();

    public static void main(String[] args) throws InterruptedException {
        // (3)
        Thread.sleep(5000 * 4);
        for (int i = 0; i < 50; ++i) {
            poolExecutor.execute(new Runnable() {
                public void run() {
                    // (4)
                    localVariable.set(new LocalVariable());
                    // (5)
                    System.out.println("use local varaible" + localVariable.get());
                    localVariable.remove();
                }
            });
        }
        // (6)
        System.out.println("pool execute over");
    }
}

如果用線程池來操作ThreadLocal 對象確實會造成內存泄露, 因為對於線程池裡面不會銷毀的線程, 裡面總會存在著<ThreadLocal, LocalVariable>的強引用, 因為final static 修飾的 ThreadLocal 並不會釋放, 而ThreadLocalMap 對於 Key 雖然是弱引用, 但是強引用不會釋放, 弱引用當然也會一直有值, 同時創建的LocalVariable對象也不會釋放, 就造成了內存泄露; 如果LocalVariable對象不是一個大對象的話, 其實泄露的並不嚴重, 泄露的內存 = 核心線程數 * LocalVariable對象的大小;

所以, 為了避免出現內存泄露的情況, ThreadLocal提供了一個清除線程中對象的方法, 即 remove, 其實內部實現就是調用 ThreadLocalMap 的remove方法:

private void remove(ThreadLocal<?> key) {
    Entry[] tab = table;
    int len = tab.length;
    int i = key.threadLocalHashCode & (len-1);
    for (Entry e = tab[i];
         e != null;
         e = tab[i = nextIndex(i, len)]) {
        if (e.get() == key) {
            e.clear();
            expungeStaleEntry(i);
            return;
        }
    }
}

找到Key對應的Entry, 並且清除Entry的Key(ThreadLocal)置空, 隨後清除過期的Entry即可避免內存泄露。

再看ThreadLocal應用場景

除了上述的資料庫管理類的例子,我們再看看其它一些應用:

每個線程維護了一個「序列號」

再回想上文說的,如果我們希望通過某個類將狀態(例如用戶ID、事務ID)與線程關聯起來,那麼通常在這個類中定義private static類型的ThreadLocal 實例。

每個線程維護了一個「序列號」

public class SerialNum {
    // The next serial number to be assigned
    private static int nextSerialNum = 0;

    private static ThreadLocal serialNum = new ThreadLocal() {
        protected synchronized Object initialValue() {
            return new Integer(nextSerialNum++);
        }
    };

    public static int get() {
        return ((Integer) (serialNum.get())).intValue();
    }
}

Session的管理

經典的另外一個例子:

private static final ThreadLocal threadSession = new ThreadLocal();  
  
public static Session getSession() throws InfrastructureException {  
    Session s = (Session) threadSession.get();  
    try {  
        if (s == null) {  
            s = getSessionFactory().openSession();  
            threadSession.set(s);  
        }  
    } catch (HibernateException ex) {  
        throw new InfrastructureException(ex);  
    }  
    return s;  
}  

在線程內部創建ThreadLocal

還有一種用法是在線程類內部創建ThreadLocal,基本步驟如下:

  • 在多線程的類(如ThreadDemo類)中,創建一個ThreadLocal對象threadXxx,用來保存線程間需要隔離處理的對象xxx。
  • 在ThreadDemo類中,創建一個獲取要隔離訪問的數據的方法getXxx(),在方法中判斷,若ThreadLocal對象為null時候,應該new()一個隔離訪問類型的對象,並強制轉換為要應用的類型。
  • 在ThreadDemo類的run()方法中,通過調用getXxx()方法獲取要操作的數據,這樣可以保證每個線程對應一個數據對象,在任何時刻都操作的是這個對象。
public class ThreadLocalTest implements Runnable{
    
    ThreadLocal<Student> StudentThreadLocal = new ThreadLocal<Student>();

    @Override
    public void run() {
        String currentThreadName = Thread.currentThread().getName();
        System.out.println(currentThreadName + " is running...");
        Random random = new Random();
        int age = random.nextInt(100);
        System.out.println(currentThreadName + " is set age: "  + age);
        Student Student = getStudentt(); //通過這個方法,為每個線程都獨立的new一個Studentt對象,每個線程的的Studentt對象都可以設置不同的值
        Student.setAge(age);
        System.out.println(currentThreadName + " is first get age: " + Student.getAge());
        try {
            Thread.sleep(500);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        System.out.println( currentThreadName + " is second get age: " + Student.getAge());
        
    }
    
    private Student getStudentt() {
        Student Student = StudentThreadLocal.get();
        if (null == Student) {
            Student = new Student();
            StudentThreadLocal.set(Student);
        }
        return Student;
    }

    public static void main(String[] args) {
        ThreadLocalTest t = new ThreadLocalTest();
        Thread t1 = new Thread(t,"Thread A");
        Thread t2 = new Thread(t,"Thread B");
        t1.start();
        t2.start();
    }
    
}

class Student{
    int age;
    public int getAge() {
        return age;
    }
    public void setAge(int age) {
        this.age = age;
    }
    
}

java 開發手冊中推薦的 ThreadLocal

看看阿里巴巴 java 開發手冊中推薦的 ThreadLocal 的用法:

import java.text.DateFormat;
import java.text.SimpleDateFormat;
 
public class DateUtils {
    public static final ThreadLocal<DateFormat> df = new ThreadLocal<DateFormat>(){
        @Override
        protected DateFormat initialValue() {
            return new SimpleDateFormat("yyyy-MM-dd");
        }
    };
}

然後我們再要用到 DateFormat 對象的地方,這樣調用:

DateUtils.df.get().format(new Date());
關鍵字: