jdk源码系列-ConcurrentHashMap

jdk源码系列-ConcurrentHashMap

ConcurrentHashMap 是将锁的范围细化来实现高效并发的。 基本策略是将数据结构分为一个一个 Segment(每一个都是一个并发可读的 hash table, 即分段锁)作为一个并发单元。 为了减少开销, 除了一处 Segment 是在构造器初始化的, 其他都延迟初始化(详见 ensureSegment)。 并使用 volatile 关键字来保证 Segment 延迟初始化的可见性问题。

HashMap 不是线程安全的, 故多线程情况下会出现 infinit loop。 HashTable 是线程安全的, 但是是用全局锁来保障, 效率很低。 所以 Doug Lea 并发专家研发了高效并发的 ConcurrentHashMap 来应对并发情况下的情景。 阅读本文前最好先看: Java内存模型AtomicInteger 分析

术语定义

  1. 哈希算法 是一种将任意内容的输入转换成相同长度输出的加密方式,其输出被称为哈希值。

  2. 哈希表 根据设定的哈希函数H(key)和处理冲突方法将一组关键字映象到一个有限的地址区间上,并以关键字在地址区间中的象作为记录在表中的存储位置,这种表称为哈希表或散列,所得存储位置称为哈希地址或散列地址。

数据结构

抽象结构图

从上述类图和抽象结构图可以看出 ConcurrentHashMap 是由 Segment 数组结构和 HashEntry 数组结构组成。Segment 继承(Generalization)了可重入锁ReentrantLock,HashEntry 用于存储键值对数据。一个 ConcurrentHashMap 里 contains-a (Composition) 一个 Segment [],Segment 的结构和 HashMap 类似,是一种数组和链表结构, 一个 Segment 里 has-a (Aggregation)一个 HashEntry 数组,每个 HashEntry 是一个链表结构的元素, 每个 Segment 锁定一个 HashEntry 数组里的元素, 当对 HashEntry 数组的数据进行 put 等修改操作时,必须先获得它对应的 Segment 锁。

源码解析

内部类 Segment 类: Segment 维护着条目列表状态一致性, 所以可以实现无锁读。 在表超出 threshold 进行 resize 的时候, 复制节点, 所以在做 resize 修改操作的时候, 还可以进行读操作(在旧 list 读)。 本类里只有变化操作的方法才需要加锁, 变化的方法利用一系列忙等控制来处理资源争用, 例如 scanAndLock 和 scanAndLockForPut 。 那些遍历去查找节点的 tryLocks() 方法, 主要是用来吸收 cached 不命中(在 hash tables 经常出现), 这样后续获取锁的遍历操作效率将会有不小提升。 我们可能不是真的需要使用找到的数据, 因为重新获得数据还需要加锁来保证更新操作的一致性, 但他们会更快地进行重定位。 此外, scanAndLockForPut 特地创建新数据用于没有数据被找到的 put 方法。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
// 在准备锁住 segment 操作前最大的 tryLock() 次数。 多核情况下, 在定位 nodes 时使用 64 次最大值维持缓存  
  
static final int MAX_SCAN_RETRIES =  
    Runtime.getRuntime().availableProcessors() > 1 ? 64 : 1;  
  
/** 
 * The per-segment table. 数据访问通过 
 * entryAt/setEntryAt 提供的 volatile 语义来保证可见性. 
 */  
transient volatile HashEntry<K,V>[] table;  
  
/** 
 * sengment 内 hash entry 元素个数 , 之所以在每个 Segment 对象中包含一个计数器,而不是在 ConcurrentHashMap 中使用全局的计数器,是为了避免出现“热点域”而影响 ConcurrentHashMap 的并发性。 
 
 */  
transient int count;  
  
/** 
 * 在 segment 中可变操作总数, 即更新次数 
 */  
transient int modCount;  
  
/** 
 * 当超过threshold 时候,  table 再哈希 
 * (The value of this field is always <tt>(int)(capacity * 
 * loadFactor)</tt>.) 
 */  
transient int threshold;  
  
/** 
 *  hash table 负载因子.  Even though this value 
 * is same for all segments, it is replicated to avoid needing 
 * links to outer object. 
 * @serial 
 */  
final float loadFactor;  

put 方法: 插入A, B, C后 Segment 示意图:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
final V put(K key, int hash, V value, boolean onlyIfAbsent) {  
 // tryLock 一般缓存作用  
HashEntry<K,V> node = tryLock() ? null :  
        scanAndLockForPut(key, hash, value);  
    V oldValue;  
    try {  
        HashEntry<K,V>[] tab = table;  
        int index = (tab.length - 1) & hash;  
        // 找到 bucket 位置  
        HashEntry<K,V> first = entryAt(tab, index);  
        for (HashEntry<K,V> e = first;;) {  
            if (e != null) {  
                K k;  
                if ((k = e.key) == key ||  
                    (e.hash == hash && key.equals(k))) {  
                    // 保存旧值, 这样 get 操作就可以无锁访问正在写操作的节点  
                    oldValue = e.value;  
                    if (!onlyIfAbsent) {  
                        // 覆盖原来的值  
                        e.value = value;  
                        // 修改计数  
                        ++modCount;  
                    }  
                    break;  
                }  
                // 每次从头部插入  
                e = e.next;  
            }  
            else {  
                if (node != null)  
                    node.setNext(first);  
                else  
                    node = new HashEntry<K,V>(hash, key, value, first);  
                int c = count + 1;  
                // 超过 threshold 则, rehash()  
                if (c > threshold && tab.length < MAXIMUM_CAPACITY)  
                    rehash(node);  
                else  
                    setEntryAt(tab, index, node);  
                ++modCount;  
                count = c;  
                oldValue = null;  
                break;  
            }  
        }  
    } finally {  
        // 典型的 ReentrantLock 释放锁  
        unlock();  
    }  
    return oldValue;  
}  

ConcurrentHashMap的size操作 如果我们要统计整个ConcurrentHashMap里元素的大小,就必须统计所有Segment里元素的大小后求和。Segment里的全局变量count是一个volatile变量,那么在多线程场景下,我们是不是直接把所有Segment的count相加就可以得到整个ConcurrentHashMap大小了呢?不是的,虽然相加时可以获取每个Segment的count的最新值,但是拿到之后可能累加前使用的count发生了变化,那么统计结果就不准了。所以最安全的做法,是在统计size的时候把所有Segment的put,remove和clean方法全部锁住,但是这种做法显然非常低效。 因为在累加count操作过程中,之前累加过的count发生变化的几率非常小,所以ConcurrentHashMap的做法是先尝试2次通过不锁住Segment的方式来统计各个Segment大小,如果统计的过程中,容器的count发生了变化,则再采用加锁的方式来统计所有Segment的大小。

那么ConcurrentHashMap是如何判断在统计的时候容器是否发生了变化呢?使用modCount变量,在put , remove和clean方法里操作元素前都会将变量modCount进行加1,那么在统计size前后比较modCount是否发生变化,从而得知容器的大小是否发生变化。

rehash() 方法: 在新 table 中重新分类节点。 因为使用了 2 的幂指数扩展方式, bucket/bin 中的数据还是在原位, 即旧数据的索引位置不变或者偏移了 2 的幂指数距离。 可以重用旧节点减少不必要的节点生成。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
/** 
 * table 大小 *2, 重新放置 HashEntry, 加入新节点 
 */  
@SuppressWarnings("unchecked")  
private void rehash(HashEntry<K,V> node) {  
    // 保存旧值以便 get 操作遍历  
    HashEntry<K,V>[] oldTable = table;  
    int oldCapacity = oldTable.length;  
    int newCapacity = oldCapacity << 1;  
    threshold = (int)(newCapacity * loadFactor);  
    HashEntry<K,V>[] newTable =  
        (HashEntry<K,V>[]) new HashEntry[newCapacity];  
    int sizeMask = newCapacity - 1;  
    for (int i = 0; i < oldCapacity ; i++) {  
        HashEntry<K,V> e = oldTable[i];  
        if (e != null) {  
            HashEntry<K,V> next = e.next;  
            int idx = e.hash & sizeMask;  
            if (next == null)   //  单节点链表  
                newTable[idx] = e;  
            else { // 重用在同一个 slot 的连续序列  
                HashEntry<K,V> lastRun = e;  
                int lastIdx = idx;  
                for (HashEntry<K,V> last = next;  
                     last != null;  
                     last = last.next) {  
                    int k = last.hash & sizeMask;  
                    if (k != lastIdx) {  
                        lastIdx = k;  
                        lastRun = last;  
                    }  
                }  
                newTable[lastIdx] = lastRun;  
                // Clone remaining nodes  
                for (HashEntry<K,V> p = e; p != lastRun; p = p.next) {  
                    V v = p.value;  
                    int h = p.hash;  
                    int k = h & sizeMask;  
                    HashEntry<K,V> n = newTable[k];  
                    newTable[k] = new HashEntry<K,V>(h, p.key, v, n);  
                }  
            }  
        }  
    }  
    int nodeIndex = node.hash & sizeMask; // add the new node  
    node.setNext(newTable[nodeIndex]);  
    newTable[nodeIndex] = node;  
    table = newTable;  
}  

内部类 HashEntry:链表结构

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
static final class HashEntry<K,V> {  
    // final 保证不变性, 即为线程安全的字段  
    final int hash;  
    final K key;  
    // 根据 volatile 写永远先于读操作的 happens-before 原则来保证获取到都是最新值  
    volatile V value;  
    volatile HashEntry<K,V> next;  
  
    HashEntry(int hash, K key, V value, HashEntry<K,V> next) {  
        this.hash = hash;  
        this.key = key;  
        this.value = value;  
        this.next = next;  
    }  
  
    /** 
     * 使用 volatile 写语义设置 next 
     */  
    final void setNext(HashEntry<K,V> n) {  
        UNSAFE.putOrderedObject(this, nextOffset, n);  
    }  
  
    // Unsafe mechanics  
    static final sun.misc.Unsafe UNSAFE;  
    static final long nextOffset;  
    static {  
        try {  
            UNSAFE = sun.misc.Unsafe.getUnsafe();  
            Class k = HashEntry.class;  
            nextOffset = UNSAFE.objectFieldOffset  
                (k.getDeclaredField("next"));  
        } catch (Exception e) {  
            throw new Error(e);  
        }  
    }  
}  

ConcurrentHashMap 类:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
/* ---------------- Constants -------------- */  
  
/** 
 * The default initial capacity for this table, 
 * used when not otherwise specified in a constructor. 
 */  
static final int DEFAULT_INITIAL_CAPACITY = 16;  
  
/** 
 * 本值是 HashEntry 个数与 table 数组长度的比值 
 * 
 */  
static final float DEFAULT_LOAD_FACTOR = 0.75f;  
  
/** 
 * 
 * 当前并发线程的使用数 
 */  
static final int DEFAULT_CONCURRENCY_LEVEL = 16;  
  
/** 
 * The maximum capacity, used if a higher value is implicitly 
 * specified by either of the constructors with arguments.  MUST 
 * be a power of two <= 1<<30 to ensure that entries are indexable 
 * using ints. 
 */  
static final int MAXIMUM_CAPACITY = 1 << 30;  
  
/** 
 * The minimum capacity for per-segment tables.  Must be a power 
 * of two, at least two to avoid immediate resizing on next use 
 * after lazy construction. 
 */  
static final int MIN_SEGMENT_TABLE_CAPACITY = 2;  
  
/** 
 * The maximum number of segments to allow; used to bound 
 * constructor arguments. Must be power of two less than 1 << 24. 
 */  
static final int MAX_SEGMENTS = 1 << 16; // slightly conservative  
  
/** 
 * Number of unsynchronized retries in size and containsValue 
 * methods before resorting to locking. This is used to avoid 
 * unbounded retries if tables undergo continuous modification 
 * which would make it impossible to obtain an accurate result. 
 */  
static final int RETRIES_BEFORE_LOCK = 2;  
/** 
 * 索引 segments 时使用: 使用高比特位的 hash 值去选择 segment 
 */  
final int segmentMask;  
  
/** 
 * Shift value for indexing within segments. 
 */  
final int segmentShift;  

初始化 ConcurrentHashMap: ConcurrentHashMap 结构图

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
public ConcurrentHashMap(int initialCapacity,  
                         float loadFactor, int concurrencyLevel) {  
    if (!(loadFactor > 0) || initialCapacity < 0 || concurrencyLevel <= 0)  
        throw new IllegalArgumentException();  
    if (concurrencyLevel > MAX_SEGMENTS)  
        concurrencyLevel = MAX_SEGMENTS;  
    // Find power-of-two sizes best matching arguments  
    int sshift = 0;  
    int ssize = 1;  
    while (ssize < concurrencyLevel) {  
        ++sshift;  
        ssize <<= 1;  
    }  
    this.segmentShift = 32 - sshift;  
    this.segmentMask = ssize - 1;  
    if (initialCapacity > MAXIMUM_CAPACITY)  
        initialCapacity = MAXIMUM_CAPACITY;  
    int c = initialCapacity / ssize;  
    if (c * ssize < initialCapacity)  
        ++c;  
    int cap = MIN_SEGMENT_TABLE_CAPACITY;  
    while (cap < c)  
        cap <<= 1;  
    // create segments and segments[0]  
    Segment<K,V> s0 =  
        new Segment<K,V>(loadFactor, (int)(cap * loadFactor),  
                         (HashEntry<K,V>[])new HashEntry[cap]);  
    Segment<K,V>[] ss = (Segment<K,V>[])new Segment[ssize];  
    UNSAFE.putOrderedObject(ss, SBASE, s0); // ordered write of segments[0]  
    this.segments = ss;  
}  

由上面的代码可知segments数组的长度size通过concurrencyLevel计算得出。为了能通过按位与的哈希算法来定位segments数组的索引,必须保证segments数组的长度是2的N次方(power-of-two size),所以必须计算出一个是大于或等于concurrencyLevel的最小的2的N次方值来作为segments数组的长度。假如concurrencyLevel等于14,15或16,ssize都会等于16,即容器里锁的个数也是16。注意concurrencyLevel的最大大小是65535,意味着segments数组的长度最大为65536,对应的二进制是16位。

初始化segmentShift和segmentMask。这两个全局变量在定位segment时的哈希算法里需要使用,sshift等于ssize从1向左移位的次数,在默认情况下concurrencyLevel等于16,1需要向左移位移动4次,所以sshift等于4。segmentShift用于定位参与hash运算的位数,segmentShift等于32减sshift,所以等于28,这里之所以用32是因为ConcurrentHashMap里的hash()方法输出的最大数是32位的,后面的测试中我们可以看到这点。segmentMask是哈希运算的掩码,等于ssize减1,即15,掩码的二进制各个位的值都是1。因为ssize的最大长度是65536,所以segmentShift最大值是16,segmentMask最大值是65535,对应的二进制是16位,每个位都是1。

变量cap就是segment里HashEntry数组的长度,它等于initialCapacity除以ssize的倍数c,如果c大于1,就会取大于等于c的2的N次方值,所以cap不是1,就是2的N次方。segment的容量threshold=(int)cap*loadFactor,默认情况下initialCapacity等于16,load factor等于0.75,通过运算cap等于1,threshold等于零。

在jdk1.8中主要做了2方面的改进

数据结构 1.8中放弃了Segment臃肿的设计,取而代之的是采用Node + CAS + Synchronized来保证并发安全进行实现,结构如下:

只有在执行第一次put方法时才会调用initTable()初始化Node数组,实现如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
private final Node<K,V>[] initTable() {
    Node<K,V>[] tab; int sc;
    while ((tab = table) == null || tab.length == 0) {
        if ((sc = sizeCtl) < 0)
            Thread.yield(); // lost initialization race; just spin
        else if (U.compareAndSwapInt(this, SIZECTL, sc, -1)) {
            try {
                if ((tab = table) == null || tab.length == 0) {
                    int n = (sc > 0) ? sc : DEFAULT_CAPACITY;
                    @SuppressWarnings("unchecked")
                    Node<K,V>[] nt = (Node<K,V>[])new Node<?,?>[n];
                    table = tab = nt;
                    sc = n - (n >>> 2);
                }
            } finally {
                sizeCtl = sc;
            }
            break;
        }
    }
    return tab;
}

put实现 当执行put方法插入数据时,根据key的hash值,在Node数组中找到相应的位置,实现如下:

1、如果相应位置的Node还未初始化,则通过CAS插入相应的数据;

1
2
3
4
else if ((f = tabAt(tab, i = (n - 1) & hash)) == null) {
    if (casTabAt(tab, i, null, new Node<K,V>(hash, key, value, null)))
        break;                   // no lock when adding to empty bin
}

2、如果相应位置的Node不为空,且当前该节点不处于移动状态,则对该节点加synchronized锁,如果该节点的hash不小于0,则遍历链表更新节点或插入新节点;

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
if (fh >= 0) {
    binCount = 1;
    for (Node<K,V> e = f;; ++binCount) {
        K ek;
        if (e.hash == hash &&
            ((ek = e.key) == key ||
             (ek != null && key.equals(ek)))) {
            oldVal = e.val;
            if (!onlyIfAbsent)
                e.val = value;
            break;
        }
        Node<K,V> pred = e;
        if ((e = e.next) == null) {
            pred.next = new Node<K,V>(hash, key, value, null);
            break;
        }
    }
}

3、如果该节点是TreeBin类型的节点,说明是红黑树结构,则通过putTreeVal方法往红黑树中插入节点;

1
2
3
4
5
6
7
8
9
else if (f instanceof TreeBin) {
    Node<K,V> p;
    binCount = 2;
    if ((p = ((TreeBin<K,V>)f).putTreeVal(hash, key, value)) != null) {
        oldVal = p.val;
        if (!onlyIfAbsent)
            p.val = value;
    }
}

4、如果binCount不为0,说明put操作对数据产生了影响,如果当前链表的个数达到8个,则通过treeifyBin方法转化为红黑树,如果oldVal不为空,说明是一次更新操作,没有对元素个数产生影响,则直接返回旧值;

1
2
3
4
5
6
7
if (binCount != 0) {
    if (binCount >= TREEIFY_THRESHOLD)
        treeifyBin(tab, i);
    if (oldVal != null)
        return oldVal;
    break;
}

5、如果插入的是一个新节点,则执行addCount()方法尝试更新元素个数baseCount;

size实现 1.8中使用一个volatile类型的变量baseCount记录元素的个数,当插入新数据或则删除数据时,会通过addCount()方法更新baseCount,实现如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
if ((as = counterCells) != null ||
    !U.compareAndSwapLong(this, BASECOUNT, b = baseCount, s = b + x)) {
    CounterCell a; long v; int m;
    boolean uncontended = true;
    if (as == null || (m = as.length - 1) < 0 ||
        (a = as[ThreadLocalRandom.getProbe() & m]) == null ||
        !(uncontended =
          U.compareAndSwapLong(a, CELLVALUE, v = a.value, v + x))) {
        fullAddCount(x, uncontended);
        return;
    }
    if (check <= 1)
        return;
    s = sumCount();
}

1、初始化时counterCells为空,在并发量很高时,如果存在两个线程同时执行CAS修改baseCount值,则失败的线程会继续执行方法体中的逻辑,使用CounterCell记录元素个数的变化;

2、如果CounterCell数组counterCells为空,调用fullAddCount()方法进行初始化,并插入对应的记录数,通过CAS设置cellsBusy字段,只有设置成功的线程才能初始化CounterCell数组,实现如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
else if (cellsBusy == 0 && counterCells == as &&
         U.compareAndSwapInt(this, CELLSBUSY, 0, 1)) {
    boolean init = false;
    try {                           // Initialize table
        if (counterCells == as) {
            CounterCell[] rs = new CounterCell[2];
            rs[h & 1] = new CounterCell(x);
            counterCells = rs;
            init = true;
        }
    } finally {
        cellsBusy = 0;
    }
    if (init)
        break;
}

3、如果通过CAS设置cellsBusy字段失败的话,则继续尝试通过CAS修改baseCount字段,如果修改baseCount字段成功的话,就退出循环,否则继续循环插入CounterCell对象;

1
2
else if (U.compareAndSwapLong(this, BASECOUNT, v = baseCount, v + x))
    break;

所以在1.8中的size实现比1.7简单多,因为元素个数保存baseCount中,部分元素的变化个数保存在CounterCell数组中,实现如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
public int size() {
    long n = sumCount();
    return ((n < 0L) ? 0 :
            (n > (long)Integer.MAX_VALUE) ? Integer.MAX_VALUE :
            (int)n);
}
 
final long sumCount() {
    CounterCell[] as = counterCells; CounterCell a;
    long sum = baseCount;
    if (as != null) {
        for (int i = 0; i < as.length; ++i) {
            if ((a = as[i]) != null)
                sum += a.value;
        }
    }
    return sum;
}

通过累加baseCount和CounterCell数组中的数量,即可得到元素的总个数;

鸣谢

wenniuwuren

署名 - 非商业性使用 - 禁止演绎 4.0