Java源码系列(14) -- CopyOnWriteArrayList

August 9, 2018

一、类签名

ArrayList 在多线程操作下是不安全的,为此应使用 CopyOnWriteArrayList。通过CopyOnWrite(简称COW,写时复制)策略,所有读取共享同一个数组对象,修改时另拷贝出新数组,操作在新数组完成后再替换掉旧数组。

1
2
public class CopyOnWriteArrayList<E>
    implements List<E>, RandomAccess, Cloneable, java.io.Serializable

由于修改时方法会自行拷贝得到新数组,所以这段时间内存同时存在原数组和新数组对象。如果修改操作过于频繁,产生大量废弃对象将增加垃圾回收负担。

由此,可推理出此类适合在读多写少的场景下使用。通过读写分离,即使修改操作费时也不会阻塞读取,而读取的数组数据未必是最新的。还有修改操作是线程安全的,每次最多只有一个线程在进行修改,以此保证数据最终一致性。

此次源码来自JDK10,和之前版本有一定差别。

二、数据成员

通过网上阅读JDK8版本的 CopyOnWriteArrayList 源码,可了解以前约束同步使用的是 ReentrantLock。而在JDK10中用 synchronized (lock) 方式,暂时不知道对性能有多大提升。

1
final transient Object lock = new Object();

变量array通过getArray()、setArray()获取,见第四节。注意array变量的类型是Object[],而不是E[]。

1
private transient volatile Object[] array;

三、构造方法

构造空列表

1
2
3
public CopyOnWriteArrayList() {
    setArray(new Object[0]);
}

从集合c构建实例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
public CopyOnWriteArrayList(Collection<? extends E> c) {
    Object[] elements;
    if (c.getClass() == CopyOnWriteArrayList.class)
        elements = ((CopyOnWriteArrayList<?>)c).getArray();
    else {
        elements = c.toArray();
        // defend against c.toArray (incorrectly) not returning Object[]
        // (see e.g. https://bugs.openjdk.java.net/browse/JDK-6260652)
        if (elements.getClass() != Object[].class)
            // 类型不是Object[].class,重新拷贝
            elements = Arrays.copyOf(elements, elements.length, Object[].class);
    }
    setArray(elements);
}

从数组构建实例

1
2
3
public CopyOnWriteArrayList(E[] toCopyIn) {
    setArray(Arrays.copyOf(toCopyIn, toCopyIn.length, Object[].class));
}

四、列表存取

获取数组,返回类型为Object[]。

1
2
3
final Object[] getArray() {
    return array;
}

设置数组,形参类型为Object[]。

1
2
3
final void setArray(Object[] a) {
    array = a;
}

五、基本操作

5.1、增加

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
// 把新元素插入到列表尾
public boolean add(E e) {
    // 修改时获取锁以保证线程安全
    synchronized (lock) {
        // 获取原数组
        Object[] elements = getArray();
        // 获取原数组长度
        int len = elements.length;
        // 用原数组拷贝出新数组,长度增加1个单位
        Object[] newElements = Arrays.copyOf(elements, len + 1);
        // 在新数组最后位置存入新元素e
        newElements[len] = e;
        // 此时同时存在新、旧两个数组,用新数组引用替换旧数组引用
        setArray(newElements);
        return true;
    }
}

// 把新元素插入到指定索引位置,指定索引原位置元素及后续元素都相对后移一个位置
public void add(int index, E element) {
    // 修改时获取锁以保证线程安全
    synchronized (lock) {
        // 获取原数组
        Object[] elements = getArray();
        // 获取原数组长度
        int len = elements.length;
        // 检查插入位置是否越界
        if (index > len || index < 0)
            throw new IndexOutOfBoundsException(outOfBounds(index, len));
        Object[] newElements;
        // 有多少个元素需要后移
        int numMoved = len - index;
        // numMoved为0表示新元素插入到最后一个索引
        if (numMoved == 0)
            // 创建新数组
            newElements = Arrays.copyOf(elements, len + 1);
        else {
            // 创建新数组
            newElements = new Object[len + 1];
            // 把旧元素赋值到新数组[0, index)
            System.arraycopy(elements, 0, newElements, 0, index);
            // 把旧元素赋值到新数组[index+1, numMoved)
            System.arraycopy(elements, index, newElements, index + 1,
                             numMoved);
        }
        // 在新数组指定位置index插入新元素
        newElements[index] = element;
        // 此时同时存在新、旧两个数组,用新数组引用替换旧数组引用
        setArray(newElements);
    }
}

// 把集合c的所有元素添加到本数组中
public boolean addAll(Collection<? extends E> c) {
    Object[] cs = (c.getClass() == CopyOnWriteArrayList.class) ?
        ((CopyOnWriteArrayList<?>)c).getArray() : c.toArray();
    
    // 没有添加任何元素
    if (cs.length == 0)
        return false;
        
    // 修改时获取锁以保证线程安全
    synchronized (lock) {
        // 获取原数组
        Object[] elements = getArray();
        // 获取原数组长度
        int len = elements.length;
        if (len == 0 && cs.getClass() == Object[].class)
            setArray(cs);
        else {
            // 用原数组构建新数组,长度为原数组长度与集合c长度之和
            Object[] newElements = Arrays.copyOf(elements, len + cs.length);
            // 把集合c的元素拷贝到新数组中
            System.arraycopy(cs, 0, newElements, len, cs.length);
            // 新数组替换旧数组
            setArray(newElements);
        }
        return true; // 完成添加所有新元素
    }
}

// 在原数组index位置开始插入所有集合c的元素
public boolean addAll(int index, Collection<? extends E> c) {
    Object[] cs = c.toArray();
    // 修改时获取锁以保证线程安全
    synchronized (lock) {
        // 原数组
        Object[] elements = getArray();
        // 原数组长度
        int len = elements.length;
        // 检查元素插入索引
        if (index > len || index < 0)
            throw new IndexOutOfBoundsException(outOfBounds(index, len));
        if (cs.length == 0)
            return false;
        // 有多少个元素需要后移
        int numMoved = len - index;
        // 新数组
        Object[] newElements;
        if (numMoved == 0)
            // 用原数组构建新数组,新长度为原数组长度与集合c长度之和
            newElements = Arrays.copyOf(elements, len + cs.length);
        else {
            // 创建新数组
            newElements = new Object[len + cs.length];
            // 拷贝原数组数据到新数组,[0, index)
            System.arraycopy(elements, 0, newElements, 0, index);
            // 拷贝原数组数据到新数组,[index + cs.length, index + cs.length + numMoved)
            System.arraycopy(elements, index,
                             newElements, index + cs.length,
                             numMoved);
        }
        // 拷贝新元素到新数组,[index, index + cs.length)
        System.arraycopy(cs, 0, newElements, index, cs.length);
        // 新数组索引替换旧数组
        setArray(newElements);
        return true;
    }
}

// 若原列表不存在此元素,则存入
public boolean addIfAbsent(E e) {
    // 先拷贝出一份快照
    Object[] snapshot = getArray();
    // 检查快照是否包含此元素,不包含则存入此元素
    return indexOf(e, snapshot, 0, snapshot.length) >= 0 ? false :
        addIfAbsent(e, snapshot);
}

private boolean addIfAbsent(E e, Object[] snapshot) {
    // 修改时获取锁以保证线程安全
    synchronized (lock) {
        // 原数组
        Object[] current = getArray();
        // 获取数组长度
        int len = current.length;
        // 对比快照与原数组是否为同一对象
        if (snapshot != current) {
            // Optimize for lost race to another addXXX operation
            int common = Math.min(snapshot.length, len);
            
            // 依次遍历两者最小长度下的元素
            for (int i = 0; i < common; i++)
                // 检查快照的顺序元素是否和原数组的顺序元素一致
                // 同时检查原数组是否已包含元素e
                if (current[i] != snapshot[i]
                    && Objects.equals(e, current[i]))
                    return false; // 快照与原数组不一致,或原数组已包含元素e
            // 元素已经存在原数组中
            if (indexOf(e, current, common, len) >= 0)
                    return false;
        }
        // 创建新数组
        Object[] newElements = Arrays.copyOf(current, len + 1);
        // 存入新元素
        newElements[len] = e;
        // 新数组引用替换旧数组引用
        setArray(newElements);
        // 添加成功
        return true;
    }
}

// 集合c中不包含在原列表的元素添加到原列表尾部
public int addAllAbsent(Collection<? extends E> c) {
    // c为空抛出空指针异常,cs是c的拷贝
    Object[] cs = c.toArray();
    if (cs.length == 0)
        return 0;
    // 修改时获取锁以保证线程安全
    synchronized (lock) {
        // 获取原数组
        Object[] elements = getArray();
        // 获取原数组长度
        int len = elements.length;
        int added = 0;
        // 依次遍历集合c中的元素
        for (int i = 0; i < cs.length; ++i) {
            Object e = cs[i];
            // 如果集合c的元素e不在elements中
            if (indexOf(e, elements, 0, len) < 0 &&
                indexOf(e, cs, 0, added) < 0)
                cs[added++] = e; // 把元素e放在cs前面
        }
        if (added > 0) {
            // 根据原列表拷贝新数组
            Object[] newElements = Arrays.copyOf(elements, len + added);
            // 把cs中前added个元素尾插入到新数组
            System.arraycopy(cs, 0, newElements, len, added);
            // 新数组替换旧数组
            setArray(newElements);
        }
        return added;
    }
}

5.2、删除

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
// 移除数组中指定索引的元素
public E remove(int index) {
    // 修改时获取锁以保证线程安全
    synchronized (lock) {
        // 获取原数组
        Object[] elements = getArray();
        // 获取原数组长度
        int len = elements.length;
        // 获取指定索引元素
        E oldValue = elementAt(elements, index);
        int numMoved = len - index - 1;
        if (numMoved == 0)
            setArray(Arrays.copyOf(elements, len - 1));
        else {
            // 创建新数组
            Object[] newElements = new Object[len - 1];
            // 把旧元素赋值到新数组[0, index)
            System.arraycopy(elements, 0, newElements, 0, index);
            // 把旧元素赋值到新数组[index, numMoved)
            System.arraycopy(elements, index + 1, newElements, index,
                             numMoved);
            // 此时同时存在新、旧两个数组,用新数组引用替换旧数组引用
            setArray(newElements);
        }
        return oldValue; // 返回被移除的元素
    }
}

// 移除该元素
public boolean remove(Object o) {
    // 获取原数组
    Object[] snapshot = getArray();
    // 获取指定元素在数组中的索引
    int index = indexOf(o, snapshot, 0, snapshot.length);
    // 成功移除返回true,否则返回false
    return (index < 0) ? false : remove(o, snapshot, index);
}

// 根据快照移除元素
private boolean remove(Object o, Object[] snapshot, int index) {
    // 修改时获取锁以保证线程安全
    synchronized (lock) {
        // 获取原数组
        Object[] current = getArray();
        // 获取原数组长度
        int len = current.length;
        if (snapshot != current) findIndex: {
            int prefix = Math.min(index, len);
            for (int i = 0; i < prefix; i++) {
                if (current[i] != snapshot[i]
                    && Objects.equals(o, current[i])) {
                    index = i;
                    break findIndex;
                }
            }
            if (index >= len)
                return false;
            if (current[index] == o)
                break findIndex;
            index = indexOf(o, current, index, len);
            if (index < 0)
                return false;
        }
        // 除了被移除的元素,其余元素全部复制到新数组中
        Object[] newElements = new Object[len - 1];
        // 拷贝原数组[0, index)元素到新数组[0, index)
        System.arraycopy(current, 0, newElements, 0, index);
        // 拷贝原数组[index+1, len)元素到新数组[index, len-1)
        System.arraycopy(current, index + 1,
                         newElements, index,
                         len - index - 1);
        // 此时同时存在新、旧两个数组,用新数组引用替换旧数组引用
        setArray(newElements);
        return true;
    }
}

// 移除指定范围的元素[fromIndex, toIndex)
void removeRange(int fromIndex, int toIndex) {
    // 修改时获取锁以保证线程安全
    synchronized (lock) {
        // 原数组
        Object[] elements = getArray();
        // 获取数组长度
        int len = elements.length;
        
        // 越界检查
        if (fromIndex < 0 || toIndex > len || toIndex < fromIndex)
            throw new IndexOutOfBoundsException();
        // 计算移除元素后数组的长度
        int newlen = len - (toIndex - fromIndex);
        int numMoved = len - toIndex;
        if (numMoved == 0)
            setArray(Arrays.copyOf(elements, newlen));
        else {
            // 用新长度创建新数组
            Object[] newElements = new Object[newlen];
            // 旧数组[0, fromIndex)拷贝元素到新数组
            System.arraycopy(elements, 0, newElements, 0, fromIndex);
            // 旧数组[toIndex, toIndex+numMoved)拷贝元素到新数组
            System.arraycopy(elements, toIndex, newElements,
                             fromIndex, numMoved);
            // 用新数组引用替换旧数组
            setArray(newElements);
        }
    }
}

// 在原数组中移除所有与集合c有关的元素
public boolean removeAll(Collection<?> c) {
    // c不能为空
    Objects.requireNonNull(c);
    return bulkRemove(e -> c.contains(e));
}

// 被removeAll调用
boolean bulkRemove(Predicate<? super E> filter, int i, int end) {
    // assert Thread.holdsLock(lock);
    final Object[] es = getArray();
    // Optimize for initial run of survivors
    for (; i < end && !filter.test(elementAt(es, i)); i++)
        ;
    if (i < end) {
        final int beg = i;
        final long[] deathRow = nBits(end - beg);
        int deleted = 1;
        deathRow[0] = 1L;   // set bit 0
        for (i = beg + 1; i < end; i++)
            if (filter.test(elementAt(es, i))) {
                setBit(deathRow, i - beg);
                deleted++;
            }
        // Did filter reentrantly modify the list?
        if (es != getArray())
            throw new ConcurrentModificationException();
        final Object[] newElts = Arrays.copyOf(es, es.length - deleted);
        int w = beg;
        for (i = beg; i < end; i++)
            if (isClear(deathRow, i - beg))
                newElts[w++] = es[i];
        System.arraycopy(es, i, newElts, w, es.length - i);
        setArray(newElts);
        return true;
    } else {
        if (es != getArray())
            throw new ConcurrentModificationException();
        return false;
    }
}

5.3、查询

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
// 获取列表元素的数量
public int size() {
    return getArray().length;
}

// 检查列表是否为空
public boolean isEmpty() {
    return size() == 0;
}

// 获取指定索引的元素,元素来自本数组getArray()
public E get(int index) {
    return elementAt(getArray(), index);
}

// 检查列表是否包含指定对象
public boolean contains(Object o) {
    Object[] elements = getArray();
    return indexOf(o, elements, 0, elements.length) >= 0;
}

// 在[0, fence)的范围内查找元素
private static int indexOf(Object o, Object[] elements,
                           int index, int fence) {
    if (o == null) {
        // 返回数组中第一个内容为null的索引
        for (int i = index; i < fence; i++)
            if (elements[i] == null)
                return i;
    } else {
        // 返回数组中第一个匹配内容的索引
        for (int i = index; i < fence; i++)
            if (o.equals(elements[i]))
                return i;
    }
    return -1;
}

// 在[0, index]的范围内倒序查找元素
private static int lastIndexOf(Object o, Object[] elements, int index) {
    if (o == null) {
        // 返回数组中最后一个内容为null的索引
        for (int i = index; i >= 0; i--)
            if (elements[i] == null)
                return i;
    } else {
        // 返回数组中最后一个匹配内容的索引
        for (int i = index; i >= 0; i--)
            if (o.equals(elements[i]))
                return i;
    }
    return -1;
}

// 获取列表中指定元素的索引值
public int indexOf(Object o) {
    Object[] elements = getArray();
    return indexOf(o, elements, 0, elements.length);
}

// 获取列表中指定元素在index之后的索引值
public int indexOf(E e, int index) {
    Object[] elements = getArray();
    return indexOf(e, elements, index, elements.length);
}

// 获取列表中指定元素最后出现的索引值
public int lastIndexOf(Object o) {
    Object[] elements = getArray();
    return lastIndexOf(o, elements, elements.length - 1);
}

// 获取列表中指定元素在索引之后,且最后出现的索引值
public int lastIndexOf(E e, int index) {
    Object[] elements = getArray();
    return lastIndexOf(e, elements, index);
}

// 检查原列表是否全部包含集合c元素
public boolean containsAll(Collection<?> c) {
    // 原数组
    Object[] elements = getArray();
    // 原数组长度
    int len = elements.length;
    // 逐个确认元素是否包含
    for (Object e : c) {
        if (indexOf(e, elements, 0, len) < 0)
            // 有一个元素不包含就返回false
            return false;
    }
    // 原数组包含所有集合c的元素
    return true;
}

5.4、修改

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
// 把指定元素设置到指定索引位置
public E set(int index, E element) {
    // 修改时获取锁以保证线程安全
    synchronized (lock) {
        // 原数组
        Object[] elements = getArray();
        // 获取原数组指定索引位置下的元素
        E oldValue = elementAt(elements, index);
        
        // 旧元素和新元素是否相同
        if (oldValue != element) {
            int len = elements.length;
            // 拷贝出新数组
            Object[] newElements = Arrays.copyOf(elements, len);
            // 把新元素放入新数组
            newElements[index] = element;
            // 新数组引用替换旧数组引用
            setArray(newElements);
        } else {
            // Not quite a no-op; ensures volatile write semantics
            setArray(elements);
        }
        return oldValue; // 返回旧元素
    }
}

// 原数组仅保留所有集合c的元素,相当于把两者交集结果作为原数组的新结果
public boolean retainAll(Collection<?> c) {
    Objects.requireNonNull(c);
    return bulkRemove(e -> !c.contains(e));
}

六、其他

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
// 根据原数组浅拷贝一份独立新数组,长度一致并保持元素原有顺序,元素类型为Object
public Object[] toArray() {
    Object[] elements = getArray();
    return Arrays.copyOf(elements, elements.length);
}

// 把原数组元素放入数组a中
@SuppressWarnings("unchecked")
public <T> T[] toArray(T[] a) {
    // 原数组
    Object[] elements = getArray();
    // 原数组长度
    int len = elements.length;
    // 数组a长度不足以存放所有数组elements的元素
    if (a.length < len)
        // 创建新的数组,类型和数组a一致
        return (T[]) Arrays.copyOf(elements, len, a.getClass()); 
    else {
        // 空间足够,把elements数组的元素放入数组a中
        System.arraycopy(elements, 0, a, 0, len);
        // 如果a还有空闲位置,则把空闲位置置空
        if (a.length > len)
            a[len] = null;
        return a; // 返回结果
    }
}

// 获取指定索引的元素,元素来自传入数组a
@SuppressWarnings("unchecked")
static <E> E elementAt(Object[] a, int index) {
    // 有可能出现数组越界
    return (E) a[index];
}

static String outOfBounds(int index, int size) {
    return "Index: " + index + ", Size: " + size;
}

// 清空列表
public void clear() {
    synchronized (lock) {
        // 申请一个空数组,并作为新结果存入
        setArray(new Object[0]);
    }
}

// A tiny bit set implementation
private static long[] nBits(int n) {
    return new long[((n - 1) >> 6) + 1];
}

private static void setBit(long[] bits, int i) {
    bits[i >> 6] |= 1L << i;
}

private static boolean isClear(long[] bits, int i) {
    return (bits[i >> 6] & (1L << i)) == 0;
}

private boolean bulkRemove(Predicate<? super E> filter) {
    synchronized (lock) {
        return bulkRemove(filter, 0, getArray().length);
    }
}

public boolean equals(Object o) {
    if (o == this)
        // 同一个元素
        return true;
    if (!(o instanceof List))
        // 类型不一样,返回false
        return false;

    List<?> list = (List<?>)o;
    Iterator<?> it = list.iterator(); // 获取o的迭代器
    Object[] elements = getArray();
    
    // 逐个对比数组o与原数组的元素
    for (int i = 0, len = elements.length; i < len; i++)
        if (!it.hasNext() || !Objects.equals(elements[i], it.next()))
            return false;

    // 对象o还存在更多元素
    if (it.hasNext())
        return false;

    // 元素完全一致
    return true;
}

// 返回列表所有元素哈希值的总值
public int hashCode() {
    int hashCode = 1;
    for (Object x : getArray())
        // 逐个叠加哈希值
        hashCode = 31 * hashCode + (x == null ? 0 : x.hashCode());
    return hashCode;
}

七、重置锁

反序列化或克隆后需要重置同步对象

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
private void resetLock() {
    Field lockField = java.security.AccessController.doPrivileged(
        (java.security.PrivilegedAction<Field>) () -> {
            try {
                // 反射获取lock变量
                Field f = CopyOnWriteArrayList.class
                    .getDeclaredField("lock");
                f.setAccessible(true);
                return f;
            } catch (ReflectiveOperationException e) {
                throw new Error(e);
            }});
    try {
        // 给lock赋对象
        lockField.set(this, new Object());
    } catch (IllegalAccessException e) {
        throw new Error(e);
    }
}