清理ThreadLocal

jopen 9年前發布 | 29K 次閱讀 Java開發 ThreadLocal

在我很多的課程里(masterconcurrencyxj-conc-j8),我經常提起ThreadLocal。它經常受到我嚴厲的指責要盡可能的避免使用。ThreadLocal是為了那些使用完就銷毀的線程設計的。線程生成之前,線程內的局部變量都會被清除掉。實際上,如果你讀過 Why 0x61c88647?,這篇文章中解釋了實際的值是存在一個內部的map中,這個map是伴隨著線程的產生而產生的。存在于線程池的線程不會只存活于單個用戶請求,這很容易導致內存溢出。通常我在討論這個的時候,至少有一位學生有過因為在一個線程局部變量持有某個類而導致整個生產系統奔潰。因此,預防實體類加載后不被卸載,是一個非常普遍的問題。

在這篇文章中,我將演示一個ThreadLocalCleaner類。通過這個類,可以在線程回到線程池之前,恢復所有本地線程變量到最開始的狀態。最基礎的,我們可以保存當前線程的ThreadLocal的狀態,之后再重置。我們可以使用Java 7提供的try-with-resource結構來完成這件事情。例如:

try (ThreadLocalCleaner tlc = new ThreadLocalCleaner()) {
  // some code that potentially creates and adds thread locals
}
// at this point, the new thread locals have been cleared

為了簡化調試,我們增加一個觀察者的機制,這樣我們能夠監測到線程局部map發生的任何變化。這能幫助我們發現可能出現泄漏的線程局部變量。這是我們的監聽器:

package threadcleaner;
@FunctionalInterface
public interface ThreadLocalChangeListener {
  void changed(Mode mode, Thread thread,
               ThreadLocal<?> threadLocal, Object value);

ThreadLocalChangeListener EMPTY = (m, t, tl, v) -> {};

ThreadLocalChangeListener PRINTER = (m, t, tl, v) -> System.out.printf( "Thread %s %s ThreadLocal %s with value %s%n", t, m, tl.getClass(), v);

enum Mode { ADDED, REMOVED } }</pre>

這個地方可能需要做一下必要的說明。首先,我添加了注解@FunctionalInterface,這個注解是Java 8提供的,它的意思是該類只有一個抽象方法,可以作為lambda表達式使用。其次,我在該類的內部定義了一個EMPTY的lambda表達式。這樣,你可以見識到,這段代碼會非常短小。第三,我還定義了一個默認的PRINTER,它可以簡單的通過System.out輸出改變的信息。最后,我們還有兩個不不同的事件,但是因為想設計成為一個函數式編程接口(@FunctionalInterface),我不得不把這個標示定義為單獨的屬性,這里定義成了枚舉。

當我們構造ThreadLocalCleaner時,我們可以傳遞一個ThreadLocalChangeListener。這樣,從Treadlocal創建開始發生的任何變化,我們都能監測到。請注意,這種機制只適合于當前線程。這有一個例子演示我們怎樣通過try-with-resource代碼塊來使用ThreadLocalCleaner:任何定義在在 try(…) 中的局部變量,都會在代碼塊的自后進行自動關閉。因此,我們需要在ThreadLocalCleaner內部有一個 close() 方法,用于恢復線程局部變量到初始值。

import java.text.*;
public class ThreadLocalCleanerExample {
  private static final ThreadLocal df =
      new ThreadLocal() {
        protected DateFormat initialValue() {
          return new SimpleDateFormat("yyyy-MM-dd");
        }
      };

public static void main(String... args) { System.out.println("First ThreadLocalCleaner context"); try (ThreadLocalCleaner tlc = new ThreadLocalCleaner( ThreadLocalChangeListener.PRINTER)) { System.out.println(System.identityHashCode(df.get())); System.out.println(System.identityHashCode(df.get())); System.out.println(System.identityHashCode(df.get())); }

System.out.println("Another ThreadLocalCleaner context");
try (ThreadLocalCleaner tlc = new ThreadLocalCleaner(
    ThreadLocalChangeListener.PRINTER)) {
  System.out.println(System.identityHashCode(df.get()));
  System.out.println(System.identityHashCode(df.get()));
  System.out.println(System.identityHashCode(df.get()));
}

} }</pre>

在ThreadLocalCleaner類中還有兩個公共的靜態方法:forEach() 和 cleanup(Thread)。forEach() 方法有兩個參數:Thread和BiConsumer。該方法通過ThreadLocal調用來遍歷其中的每一個值。我們跳過了key為null的對象,但是沒有跳過值為null的對象。理由是如果僅僅是值為null,ThreadLocal仍然有可能出現內存泄露。一旦我們使用完了ThreadLocal,就應該在將線程返回給線程池之前調用 remove() 方法。cleanup(Thread) 方法設置ThreadLocal的map在該線程內為null,因此,允許垃圾回收器回收所有的對象。如果一個ThreadLocal在我們清理后再次使用,就簡單的調用 initialValue() 方法來創建對象。這是方法的定義:

public static void forEach(Thread thread,
      BiConsumer<ThreadLocal<?>, Object> consumer) { ... }

public static void cleanup(Thread thread) { ... }</pre>

ThreadLocalCleaner類完整的代碼如下。該類使用了許多反射來操作私有域。它可能只能在OpenJDK或其直接衍生產品上運行。你也能注意到我使用了Java 8的語法。我糾結過很長一段時間是否使用Java 8 或 7。我的某些客戶端還在使用1.4。最后,我的大部分大銀行客戶已經在產品中開始使用Java 8。銀行通常來說不是最先采用的新技術人,除非存在非常大的經濟意義。因此,如果你還沒有在產品中使用Java 8,你應該盡可能快的移植過去,甚至可以跳過Java 8,直接到Java 9。你應該可以很容易的反向移植到Java 7上,只需要自定義一個BiConsumer接口。Java 6不支持try-with-resource結構,所以反向移植會比較困難一點。

package threadcleaner;

import java.lang.ref.; import java.lang.reflect.; import java.util.; import java.util.function.;

import static threadcleaner.ThreadLocalChangeListener.Mode.*;

public class ThreadLocalCleaner implements AutoCloseable { private final ThreadLocalChangeListener listener;

public ThreadLocalCleaner() { this(ThreadLocalChangeListener.EMPTY); }

public ThreadLocalCleaner(ThreadLocalChangeListener listener) { this.listener = listener; saveOldThreadLocals(); }

public void close() { cleanup(); }

public void cleanup() { diff(threadLocalsField, copyOfThreadLocals.get()); diff(inheritableThreadLocalsField, copyOfInheritableThreadLocals.get()); restoreOldThreadLocals(); }

public static void forEach( Thread thread, BiConsumer<ThreadLocal<?>, Object> consumer) { forEach(thread, threadLocalsField, consumer); forEach(thread, inheritableThreadLocalsField, consumer); }

public static void cleanup(Thread thread) { try { threadLocalsField.set(thread, null); inheritableThreadLocalsField.set(thread, null); } catch (IllegalAccessException e) { throw new IllegalStateException( "Could not clear thread locals: " + e); } }

private void diff(Field field, Reference<?>[] backup) { try { Thread thread = Thread.currentThread(); Object threadLocals = field.get(thread); if (threadLocals == null) { if (backup != null) { for (Reference<?> reference : backup) { changed(thread, reference, REMOVED); } } return; }

  Reference<?>[] current =
      (Reference<?>[]) tableField.get(threadLocals);
  if (backup == null) {
    for (Reference<?> reference : current) {
      changed(thread, reference, ADDED);
    }
  } else {
    // nested loop - both arrays *should* be relatively small
    next:
    for (Reference<?> curRef : current) {
      if (curRef != null) {
        if (curRef.get() == copyOfThreadLocals ||
            curRef.get() == copyOfInheritableThreadLocals) {
          continue next;
        }
        for (Reference<?> backupRef : backup) {
          if (curRef == backupRef) continue next;
        }
        // could not find it in backup - added
        changed(thread, curRef, ADDED);
      }
    }
    next:
    for (Reference<?> backupRef : backup) {
      for (Reference<?> curRef : current) {
        if (curRef == backupRef) continue next;
      }
      // could not find it in current - removed
      changed(thread, backupRef,
          REMOVED);
    }
  }
} catch (IllegalAccessException e) {
  throw new IllegalStateException("Access denied", e);
}

}

private void changed(Thread thread, Reference<?> reference, ThreadLocalChangeListener.Mode mode) throws IllegalAccessException { listener.changed(mode, thread, (ThreadLocal<?>) reference.get(), threadLocalEntryValueField.get(reference)); }

private static Field field(Class<?> c, String name) throws NoSuchFieldException { Field field = c.getDeclaredField(name); field.setAccessible(true); return field; }

private static Class<?> inner(Class<?> clazz, String name) { for (Class<?> c : clazz.getDeclaredClasses()) { if (c.getSimpleName().equals(name)) { return c; } } throw new IllegalStateException( "Could not find inner class " + name + " in " + clazz); }

private static void forEach( Thread thread, Field field, BiConsumer<ThreadLocal<?>, Object> consumer) { try { Object threadLocals = field.get(thread); if (threadLocals != null) { Reference<?>[] table = (Reference<?>[]) tableField.get(threadLocals); for (Reference<?> ref : table) { if (ref != null) { ThreadLocal<?> key = (ThreadLocal<?>) ref.get(); if (key != null) { Object value = threadLocalEntryValueField.get(ref); consumer.accept(key, value); } } } } } catch (IllegalAccessException e) { throw new IllegalStateException(e); } }

private static final ThreadLocal<Reference<?>[]> copyOfThreadLocals = new ThreadLocal<>();

private static final ThreadLocal<Reference<?>[]> copyOfInheritableThreadLocals = new ThreadLocal<>();

private static void saveOldThreadLocals() { copyOfThreadLocals.set(copy(threadLocalsField)); copyOfInheritableThreadLocals.set( copy(inheritableThreadLocalsField)); }

private static Reference<?>[] copy(Field field) { try { Thread thread = Thread.currentThread(); Object threadLocals = field.get(thread); if (threadLocals == null) return null; Reference<?>[] table = (Reference<?>[]) tableField.get(threadLocals); return Arrays.copyOf(table, table.length); } catch (IllegalAccessException e) { throw new IllegalStateException("Access denied", e); } }

private static void restoreOldThreadLocals() { try { restore(threadLocalsField, copyOfThreadLocals.get()); restore(inheritableThreadLocalsField, copyOfInheritableThreadLocals.get()); } finally { copyOfThreadLocals.remove(); copyOfInheritableThreadLocals.remove(); } }

private static void restore(Field field, Object value) { try { Thread thread = Thread.currentThread(); if (value == null) { field.set(thread, null); } else { tableField.set(field.get(thread), value); } } catch (IllegalAccessException e) { throw new IllegalStateException("Access denied", e); } }

/ Reflection fields /

private static final Field threadLocalsField;

private static final Field inheritableThreadLocalsField; private static final Class<?> threadLocalMapClass; private static final Field tableField; private static final Class<?> threadLocalMapEntryClass;

private static final Field threadLocalEntryValueField;

static { try { threadLocalsField = field(Thread.class, "threadLocals"); inheritableThreadLocalsField = field(Thread.class, "inheritableThreadLocals");

  threadLocalMapClass =
      inner(ThreadLocal.class, "ThreadLocalMap");

  tableField = field(threadLocalMapClass, "table");
  threadLocalMapEntryClass =
      inner(threadLocalMapClass, "Entry");

  threadLocalEntryValueField =
      field(threadLocalMapEntryClass, "value");
} catch (NoSuchFieldException e) {
  throw new IllegalStateException(
      "Could not locate threadLocals field in Thread.  " +
          "Will not be able to clear thread locals: " + e);
}

} }</pre>

這是一個ThreadLocalCleaner在實踐中應用的例子:

import java.text.*;

public class ThreadLocalCleanerExample { private static final ThreadLocal<DateFormat> df = new ThreadLocal<DateFormat>() { protected DateFormat initialValue() { return new SimpleDateFormat("yyyy-MM-dd"); } };

public static void main(String... args) { System.out.println("First ThreadLocalCleaner context"); try (ThreadLocalCleaner tlc = new ThreadLocalCleaner( ThreadLocalChangeListener.PRINTER)) { System.out.println(System.identityHashCode(df.get())); System.out.println(System.identityHashCode(df.get())); System.out.println(System.identityHashCode(df.get())); }

System.out.println("Another ThreadLocalCleaner context");
try (ThreadLocalCleaner tlc = new ThreadLocalCleaner(
    ThreadLocalChangeListener.PRINTER)) {
  System.out.println(System.identityHashCode(df.get()));
  System.out.println(System.identityHashCode(df.get()));
  System.out.println(System.identityHashCode(df.get()));
}

} }</pre>

你的輸出結果可能會包含不同的hash code值。但是請記住我在Identity Crisis Newsletter中所說的:hash code的生成算法是一個隨機數字生成器。這是我的輸出。注意,在try-with-resource內部,線程局部變量的值是相同的。

First ThreadLocalCleaner context
186370029
186370029
186370029
Thread Thread[main,5,main] ADDED ThreadLocal class 
    ThreadLocalCleanerExample$1 with value 
    java.text.SimpleDateFormat@f67a0200
Another ThreadLocalCleaner context
2094548358
2094548358
2094548358
Thread Thread[main,5,main] ADDED ThreadLocal class 
    ThreadLocalCleanerExample$1 with value 
    java.text.SimpleDateFormat@f67a0200

為了讓這個代碼使用起來更簡單,我寫了一個Facede。門面設計模式不是阻止用戶使用直接使用子系統,而是提供一種更簡單的接口來完成復雜的系統。最典型的方式是將最常用子系統作為方法提供,我們的門面包括兩個方法:findAll(Thread) 和 printThreadLocals() 方法。findAll() 方法返回一個線程內部的Entry集合。

package threadcleaner;

import java.io.; import java.lang.ref.; import java.util.AbstractMap.; import java.util.; import java.util.Map.; import java.util.function.;

import static threadcleaner.ThreadLocalCleaner.*;

public class ThreadLocalCleaners { public static Collection<Entry<ThreadLocal<?>, Object>> findAll( Thread thread) { Collection<Entry<ThreadLocal<?>, Object>> result = new ArrayList<>(); BiConsumer<ThreadLocal<?>, Object> adder = (key, value) -> result.add(new SimpleImmutableEntry<>(key, value)); forEach(thread, adder); return result; }

public static void printThreadLocals() { printThreadLocals(System.out); }

public static void printThreadLocals(Thread thread) { printThreadLocals(thread, System.out); }

public static void printThreadLocals(PrintStream out) { printThreadLocals(Thread.currentThread(), out); }

public static void printThreadLocals(Thread thread, PrintStream out) { out.println("Thread " + thread.getName()); out.println(" ThreadLocals"); printTable(thread, out); }

private static void printTable( Thread thread, PrintStream out) { forEach(thread, (key, value) -> { out.printf(" {%s,%s", key, value); if (value instanceof Reference) { out.print("->" + ((Reference<?>) value).get()); } out.println("}"); }); } }</pre>

線程可以包含兩個不同類型的ThreadLocal:一個是普通的,另一個是可繼承的。大部分情況下,我們使用普通的那個。可繼承意味著如果你從當前線程構造出一個新的線程,則所有可繼承的ThreadLocals將被新線程繼承過去。非常同意。我們很少這么使用。所以現在我們可以忘了這種情況,或者永遠忘記。

一個使用ThreadLocalCleaner的典型場景是和ThreadPoolExecutor一起使用。我們寫一個子類,覆蓋 beforeExecute() 和 afterExecute() 方法。這個類比較長,因為我們不得不編寫所有的構造函數。有意思的地方在最后面。

package threadcleaner;

import java.util.concurrent.*;

public class ThreadPoolExecutorExt extends ThreadPoolExecutor { private final ThreadLocalChangeListener listener;

// Bunch of constructors following - you can ignore those

public ThreadPoolExecutorExt( int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, BlockingQueue<Runnable> workQueue) { this(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue, ThreadLocalChangeListener.EMPTY); }

public ThreadPoolExecutorExt( int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, BlockingQueue<Runnable> workQueue, ThreadFactory threadFactory) { this(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue, threadFactory, ThreadLocalChangeListener.EMPTY); }

public ThreadPoolExecutorExt( int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, BlockingQueue<Runnable> workQueue, RejectedExecutionHandler handler) { this(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue, handler, ThreadLocalChangeListener.EMPTY); }

public ThreadPoolExecutorExt( int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, BlockingQueue<Runnable> workQueue, ThreadFactory threadFactory, RejectedExecutionHandler handler) { this(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue, threadFactory, handler, ThreadLocalChangeListener.EMPTY); }

public ThreadPoolExecutorExt( int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, BlockingQueue<Runnable> workQueue, ThreadLocalChangeListener listener) { super(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue); this.listener = listener; }

public ThreadPoolExecutorExt( int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, BlockingQueue<Runnable> workQueue, ThreadFactory threadFactory, ThreadLocalChangeListener listener) { super(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue, threadFactory); this.listener = listener; }

public ThreadPoolExecutorExt( int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, BlockingQueue<Runnable> workQueue, RejectedExecutionHandler handler, ThreadLocalChangeListener listener) { super(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue, handler); this.listener = listener; }

public ThreadPoolExecutorExt( int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, BlockingQueue<Runnable> workQueue, ThreadFactory threadFactory, RejectedExecutionHandler handler, ThreadLocalChangeListener listener) { super(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue, threadFactory, handler); this.listener = listener; }

/ The interest bit of this class is below ... /

private static final ThreadLocal<ThreadLocalCleaner> local = new ThreadLocal<>();

protected void beforeExecute(Thread t, Runnable r) { assert t == Thread.currentThread(); local.set(new ThreadLocalCleaner(listener)); }

protected void afterExecute(Runnable r, Throwable t) { ThreadLocalCleaner cleaner = local.get(); local.remove(); cleaner.cleanup(); } }</pre>

你可以像使用一個普通的ThreadPoolExecutor一樣使用這個類,該類不同的地方在于,當每個Runnable執行完之后需要重置線程局部變量的狀態。如果需要調試系統,你也可以獲取到綁定的監聽器。在我們的這個例子里,你可以看到,我們將監聽器綁定到我們的增加線程局部變量的LOG上。注意,在Java 8中,java.util.logging.Logger的方法使用Supplier作為參數,這意味著我們不再需要任何代碼來保證日志的性能。

import java.text.;
import java.util.concurrent.;
import java.util.logging.*;

public class ThreadPoolExecutorExtTest { private final static Logger LOG = Logger.getLogger( ThreadPoolExecutorExtTest.class.getName() );

private static final ThreadLocal<DateFormat> df = new ThreadLocal<DateFormat>() { protected DateFormat initialValue() { return new SimpleDateFormat("yyyy-MM-dd"); } };

public static void main(String... args) throws InterruptedException { ThreadPoolExecutor tpe = new ThreadPoolExecutorExt( 1, 1, 0, TimeUnit.SECONDS, new LinkedBlockingQueue<>(), (m, t, tl, v) -> { LOG.warning( () -> String.format( "Thread %s %s ThreadLocal %s with value %s%n", t, m, tl.getClass(), v) ); } );

for (int i = 0; i < 10; i++) {
  tpe.submit(() ->
      System.out.println(System.identityHashCode(df.get())));
  Thread.sleep(1000);
}
tpe.shutdown();

} }</pre>

我機器的輸出結果如下:

914524658
May 23, 2015 9:28:50 PM ThreadPoolExecutorExtTest lambda$main$1
WARNING: Thread Thread[pool-1-thread-1,5,main] 
    ADDED ThreadLocal class ThreadPoolExecutorExtTest$1 
    with value java.text.SimpleDateFormat@f67a0200

957671209 May 23, 2015 9:28:51 PM ThreadPoolExecutorExtTest lambda$main$1 WARNING: Thread Thread[pool-1-thread-1,5,main] ADDED ThreadLocal class ThreadPoolExecutorExtTest$1 with value java.text.SimpleDateFormat@f67a0200

466968587 May 23, 2015 9:28:52 PM ThreadPoolExecutorExtTest lambda$main$1 WARNING: Thread Thread[pool-1-thread-1,5,main] ADDED ThreadLocal class ThreadPoolExecutorExtTest$1 with value java.text.SimpleDateFormat@f67a0200</pre>

現在,這段代碼還沒有在生產服務器上經過嚴格的考驗,所以請謹慎使用。非常感謝你的閱讀和支持。我真的非常感激。

Kind regards

Heinz

原文鏈接: javaspecialists 翻譯: ImportNew.com - paddx
譯文鏈接: http://www.importnew.com/16112.html
 

 本文由用戶 jopen 自行上傳分享,僅供網友學習交流。所有權歸原作者,若您的權利被侵害,請聯系管理員。
 轉載本站原創文章,請注明出處,并保留原始鏈接、圖片水印。
 本站是一個以用戶分享為主的開源技術平臺,歡迎各類分享!