From d37982d6d9961e79b87d8f919affc918d7f9a546 Mon Sep 17 00:00:00 2001 From: seungjong bae Date: Wed, 24 Dec 2025 16:13:47 +0900 Subject: [PATCH] Add ConcurrentLruCache2 Signed-off-by: seungjong bae --- .../util/ConcurrentLruCache2Benchmark.java | 126 ++++ .../util/ConcurrentLruCache2.java | 651 ++++++++++++++++++ .../util/ConcurrentLruCache2Tests.java | 70 ++ 3 files changed, 847 insertions(+) create mode 100644 spring-core/src/jmh/java/org/springframework/util/ConcurrentLruCache2Benchmark.java create mode 100644 spring-core/src/main/java/org/springframework/util/ConcurrentLruCache2.java create mode 100644 spring-core/src/test/java/org/springframework/util/ConcurrentLruCache2Tests.java diff --git a/spring-core/src/jmh/java/org/springframework/util/ConcurrentLruCache2Benchmark.java b/spring-core/src/jmh/java/org/springframework/util/ConcurrentLruCache2Benchmark.java new file mode 100644 index 000000000000..20383cbdf547 --- /dev/null +++ b/spring-core/src/jmh/java/org/springframework/util/ConcurrentLruCache2Benchmark.java @@ -0,0 +1,126 @@ +/* + * Copyright 2002-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.util; + +import java.util.ArrayList; +import java.util.List; +import java.util.Random; +import java.util.concurrent.TimeUnit; +import java.util.function.Function; + +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Threads; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; + +/** + * Benchmarks comparing {@link ConcurrentLruCache} and {@link ConcurrentLruCache2}. + * @author Brian Clozel + */ +@BenchmarkMode(Mode.Throughput) +@Fork(3) +@Threads(8) +@Warmup(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS) +@Measurement(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS) +public class ConcurrentLruCache2Benchmark { + + @Benchmark + public void legacyCache(LegacyBenchmarkData data, Blackhole bh) { + for (String element : data.elements) { + String value = data.lruCache.get(element); + bh.consume(value); + } + } + + @Benchmark + public void lruCache2(LruCache2BenchmarkData data, Blackhole bh) { + for (String element : data.elements) { + String value = data.lruCache.get(element); + if (value == null) { + value = data.generator.apply(element); + data.lruCache.put(element, value); + } + bh.consume(value); + } + } + + @State(Scope.Benchmark) + public static class LegacyBenchmarkData { + + ConcurrentLruCache lruCache; + + @Param({"100"}) + public int capacity; + + @Param({"0.1"}) + public float cacheMissRate; + + public List elements; + + public Function generator; + + @Setup(Level.Iteration) + public void setup() { + this.generator = key -> key + "value"; + this.lruCache = new ConcurrentLruCache<>(this.capacity, this.generator); + Assert.isTrue(this.cacheMissRate < 1, "cache miss rate should be < 1"); + Random random = new Random(); + int elementsCount = Math.round(this.capacity * (1 + this.cacheMissRate)); + this.elements = new ArrayList<>(elementsCount); + random.ints(elementsCount).forEach(value -> this.elements.add(String.valueOf(value))); + this.elements.sort(String::compareTo); + } + } + + @State(Scope.Benchmark) + public static class LruCache2BenchmarkData { + + ConcurrentLruCache2 lruCache; + + @Param({"100"}) + public int capacity; + + @Param({"0.1"}) + public float cacheMissRate; + + public List elements; + + public Function generator; + + @Setup(Level.Iteration) + public void setup() { + this.generator = key -> key + "value"; + this.lruCache = new ConcurrentLruCache2<>(this.capacity); + Assert.isTrue(this.cacheMissRate < 1, "cache miss rate should be < 1"); + Random random = new Random(); + int elementsCount = Math.round(this.capacity * (1 + this.cacheMissRate)); + this.elements = new ArrayList<>(elementsCount); + random.ints(elementsCount).forEach(value -> this.elements.add(String.valueOf(value))); + this.elements.sort(String::compareTo); + } + } +} + diff --git a/spring-core/src/main/java/org/springframework/util/ConcurrentLruCache2.java b/spring-core/src/main/java/org/springframework/util/ConcurrentLruCache2.java new file mode 100644 index 000000000000..e3e71e946b71 --- /dev/null +++ b/spring-core/src/main/java/org/springframework/util/ConcurrentLruCache2.java @@ -0,0 +1,651 @@ +/* + * Copyright 2002-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.util; + +import java.util.Queue; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; +import java.util.concurrent.atomic.AtomicReferenceArray; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; +import java.util.function.Consumer; + +import org.jspecify.annotations.Nullable; + +/** + * Performance-tuned variant of {@link ConcurrentLruCache} with manual population and + * an optional eviction listener. + * + * @author Brian Clozel + * @param the type of the key used for cache retrieval + * @param the type of the cached values, does not allow null values + */ +@SuppressWarnings({"unchecked", "NullAway"}) +public final class ConcurrentLruCache2 { + + private final int capacity; + + private final AtomicInteger currentSize = new AtomicInteger(); + + private final ConcurrentMap> cache; + + private final ReadOperations readOperations; + + private final WriteOperations writeOperations; + + private final Lock evictionLock = new ReentrantLock(); + + /* + * Queue that contains all ACTIVE cache entries, ordered with least recently used entries first. + * Read and write operations are buffered and periodically processed to reorder the queue. + */ + private final EvictionQueue evictionQueue = new EvictionQueue<>(); + + private final AtomicReference drainStatus = new AtomicReference<>(DrainStatus.IDLE); + + private volatile Consumer> evictionListener = entry -> {}; + + /** + * Create a new cache instance with the given capacity. + * @param maxSize the maximum number of entries in the cache + * (0 indicates no caching; lookups always return {@code null}) + */ + public ConcurrentLruCache2(int maxSize) { + this(maxSize, 16); + } + + private ConcurrentLruCache2(int maxSize, int concurrencyLevel) { + this(maxSize, maxSize, concurrencyLevel); + } + + private ConcurrentLruCache2(int initialCapacity, int maxSize, int concurrencyLevel) { + Assert.isTrue(maxSize >= 0, "Capacity must be >= 0"); + this.capacity = maxSize; + this.cache = new ConcurrentHashMap<>(initialCapacity, 0.75f, concurrencyLevel); + this.readOperations = new ReadOperations<>(this.evictionQueue); + this.writeOperations = new WriteOperations(); + } + + /** + * Retrieve an entry from the cache. + * @param key the key to retrieve the entry for + * @return the cached value, or {@code null} if absent + */ + @Nullable + public V get(K key) { + final Node node = this.cache.get(key); + if (node == null) { + return null; + } + processRead(node); + return node.getValue(); + } + + public void put(K key, V value) { + Assert.notNull(key, "key must not be null"); + Assert.notNull(value, "value must not be null"); + final CacheEntry cacheEntry = new CacheEntry<>(value, CacheEntryState.ACTIVE); + final Node node = new Node<>(key, cacheEntry); + final Node prior = this.cache.putIfAbsent(node.key, node); + if (prior == null) { + processWrite(new AddTask(node)); + } + else { + processRead(prior); + } + } + + private void processRead(Node node) { + boolean delayable = this.readOperations.recordRead(node); + final DrainStatus status = this.drainStatus.get(); + if (status.shouldDrainBuffers(delayable)) { + drainOperations(); + } + } + + private void processWrite(Runnable task) { + this.writeOperations.add(task); + boolean delayable = this.writeOperations.isDelayable(); + final DrainStatus status = this.drainStatus.get(); + if (status.shouldDrainBuffers(delayable)) { + drainOperations(); + } + } + + private void drainOperations() { + if (this.evictionLock.tryLock()) { + try { + this.drainStatus.lazySet(DrainStatus.PROCESSING); + this.readOperations.drain(); + this.writeOperations.drain(); + } + finally { + this.drainStatus.compareAndSet(DrainStatus.PROCESSING, DrainStatus.IDLE); + this.evictionLock.unlock(); + } + } + } + + /** + * Return the maximum number of entries in the cache. + * @see #size() + */ + public int capacity() { + return this.capacity; + } + + /** + * Return the maximum number of entries in the cache. + * @deprecated in favor of {@link #capacity()} as of 6.0. + */ + @Deprecated(since = "6.0") + public int sizeLimit() { + return this.capacity; + } + + /** + * Return the current size of the cache. + * @see #capacity() + */ + public int size() { + return this.cache.size(); + } + + /** + * Immediately remove all entries from this cache. + */ + public void clear() { + this.evictionLock.lock(); + try { + Node node; + while ((node = this.evictionQueue.poll()) != null) { + this.cache.remove(node.key, node); + markAsRemoved(node); + } + this.readOperations.clear(); + this.writeOperations.drainAll(); + } + finally { + this.evictionLock.unlock(); + } + } + + /* + * Transition the node to the {@code removed} state and decrement the current size of the cache. + */ + private void markAsRemoved(Node node) { + for (; ; ) { + CacheEntry current = node.get(); + CacheEntry removed = new CacheEntry<>(current.value, CacheEntryState.REMOVED); + if (node.compareAndSet(current, removed)) { + this.currentSize.lazySet(this.currentSize.get() - 1); + this.evictionListener.accept(new Entry<>(node.key, current.value)); + return; + } + } + } + + /** + * Determine whether the given key is present in this cache. + * @param key the key to check for + * @return {@code true} if the key is present, {@code false} if there was no matching key + */ + public boolean contains(K key) { + return this.cache.containsKey(key); + } + + /** + * Set the eviction listener that will be called when an entry is evicted from the cache. + * @param listener the listener to be called on eviction, or null to disable + */ + public void setEvictionListener(Consumer> listener) { + this.evictionListener = (listener == null) ? entry -> {} : listener; + } + + /** + * Immediately remove the given key and any associated value. + * @param key the key to evict the entry for + * @return {@code true} if the key was present before, + * {@code false} if there was no matching key + */ + public boolean remove(K key) { + final Node node = this.cache.remove(key); + if (node == null) { + return false; + } + markForRemoval(node); + processWrite(new RemovalTask(node)); + return true; + } + + /* + * Transition the node from the {@code active} state to the {@code pending removal} state, + * if the transition is valid. + */ + private void markForRemoval(Node node) { + for (; ; ) { + final CacheEntry current = node.get(); + if (!current.isActive()) { + return; + } + final CacheEntry pendingRemoval = new CacheEntry<>(current.value, CacheEntryState.PENDING_REMOVAL); + if (node.compareAndSet(current, pendingRemoval)) { + return; + } + } + } + + /** + * Write operation recorded when a new entry is added to the cache. + */ + private final class AddTask implements Runnable { + final Node node; + + AddTask(Node node) { + this.node = node; + } + + @Override + public void run() { + currentSize.lazySet(currentSize.get() + 1); + if (this.node.get().isActive()) { + evictionQueue.add(this.node); + evictEntries(); + } + } + + private void evictEntries() { + while (currentSize.get() > capacity) { + final Node node = evictionQueue.poll(); + if (node == null) { + return; + } + cache.remove(node.key, node); + markAsRemoved(node); + } + } + + } + + + /** + * Write operation recorded when an entry is removed to the cache. + */ + private final class RemovalTask implements Runnable { + final Node node; + + RemovalTask(Node node) { + this.node = node; + } + + @Override + public void run() { + evictionQueue.remove(this.node); + markAsRemoved(this.node); + } + } + + + /* + * Draining status for the read/write buffers. + */ + private enum DrainStatus { + + /* + * No drain operation currently running. + */ + IDLE { + @Override + boolean shouldDrainBuffers(boolean delayable) { + return !delayable; + } + }, + + /* + * A drain operation is in progress. + */ + PROCESSING { + @Override + boolean shouldDrainBuffers(boolean delayable) { + return false; + } + }; + + /** + * Determine whether the buffers should be drained. + * @param delayable if a drain should be delayed until required + * @return if a drain should be attempted + */ + abstract boolean shouldDrainBuffers(boolean delayable); + } + + private enum CacheEntryState { + ACTIVE, PENDING_REMOVAL, REMOVED + } + + private record CacheEntry(V value, CacheEntryState state) { + + boolean isActive() { + return this.state == CacheEntryState.ACTIVE; + } + } + + /** + * Public representation of a cache entry, exposing the key and value. + * @param the type of the key + * @param the type of the value + * @param key the cache key + * @param value the cache value + */ + public record Entry(K key, V value) { + } + + private static final class ReadOperations { + + private static final int BUFFER_COUNT = detectNumberOfBuffers(); + + private static int detectNumberOfBuffers() { + int availableProcessors = Runtime.getRuntime().availableProcessors(); + int nextPowerOfTwo = 1 << (Integer.SIZE - Integer.numberOfLeadingZeros(availableProcessors - 1)); + return nextPowerOfTwo; + } + + private static final int BUFFERS_MASK = BUFFER_COUNT - 1; + + private static final int MAX_PENDING_OPERATIONS = 32; + + private static final int MAX_DRAIN_COUNT = 2 * MAX_PENDING_OPERATIONS; + + private static final int BUFFER_SIZE = 2 * MAX_DRAIN_COUNT; + + private static final int BUFFER_INDEX_MASK = BUFFER_SIZE - 1; + + private final PaddedAtomicLong[] recordedCount = new PaddedAtomicLong[BUFFER_COUNT]; + + private final PaddedLong[] readCount = new PaddedLong[BUFFER_COUNT]; + + private final PaddedAtomicLong[] processedCount = new PaddedAtomicLong[BUFFER_COUNT]; + + @SuppressWarnings("rawtypes") + private final AtomicReferenceArray>[] buffers = new AtomicReferenceArray[BUFFER_COUNT]; + + private final EvictionQueue evictionQueue; + + ReadOperations(EvictionQueue evictionQueue) { + this.evictionQueue = evictionQueue; + for (int i = 0; i < BUFFER_COUNT; i++) { + this.recordedCount[i] = new PaddedAtomicLong(0); + this.processedCount[i] = new PaddedAtomicLong(0); + this.readCount[i] = new PaddedLong(); + this.buffers[i] = new AtomicReferenceArray<>(BUFFER_SIZE); + } + } + + @SuppressWarnings("deprecation") // for Thread.getId() on JDK 19 + private static int getBufferIndex() { + return ((int) Thread.currentThread().getId()) & BUFFERS_MASK; + } + + boolean recordRead(Node node) { + int bufferIndex = getBufferIndex(); + final long writeCount = this.recordedCount[bufferIndex].get(); + this.recordedCount[bufferIndex].lazySet(writeCount + 1); + + final int index = (int) (writeCount & BUFFER_INDEX_MASK); + this.buffers[bufferIndex].lazySet(index, node); + + final long pending = (writeCount - this.processedCount[bufferIndex].get()); + return (pending < MAX_PENDING_OPERATIONS); + } + + @SuppressWarnings("deprecation") // for Thread.getId() on JDK 19 + void drain() { + final int start = (int) Thread.currentThread().getId(); + final int end = start + BUFFER_COUNT; + for (int i = start; i < end; i++) { + drainReadBuffer(i & BUFFERS_MASK); + } + } + + void clear() { + for (int i = 0; i < BUFFER_COUNT; i++) { + AtomicReferenceArray> buffer = this.buffers[i]; + for (int j = 0; j < BUFFER_SIZE; j++) { + buffer.lazySet(j, null); + } + } + } + + private void drainReadBuffer(int bufferIndex) { + final long writeCount = this.recordedCount[bufferIndex].get(); + for (int i = 0; i < MAX_DRAIN_COUNT; i++) { + final int index = (int) (this.readCount[bufferIndex].value & BUFFER_INDEX_MASK); + final AtomicReferenceArray> buffer = this.buffers[bufferIndex]; + final Node node = buffer.get(index); + if (node == null) { + break; + } + buffer.lazySet(index, null); + this.evictionQueue.moveToBack(node); + this.readCount[bufferIndex].value++; + } + this.processedCount[bufferIndex].lazySet(writeCount); + } + + /** + * Padded AtomicLong to reduce false sharing between buffer counters. + */ + private static final class PaddedAtomicLong extends AtomicLong { + private static final long serialVersionUID = 1L; + @SuppressWarnings("unused") + long p1; + long p2; + long p3; + long p4; + long p5; + long p6; + long p7; + + PaddedAtomicLong(long initialValue) { + super(initialValue); + } + } + + /** + * Padded long container to reduce false sharing. + */ + private static final class PaddedLong { + volatile long value; + @SuppressWarnings("unused") + long p1; + long p2; + long p3; + long p4; + long p5; + long p6; + long p7; + } + } + + private static final class WriteOperations { + + private static final int MAX_PENDING_OPERATIONS = 32; + private static final int DRAIN_THRESHOLD = MAX_PENDING_OPERATIONS * 2; + + private final Queue operations = new ConcurrentLinkedQueue<>(); + private final AtomicInteger pendingCount = new AtomicInteger(); + + public void add(Runnable task) { + this.operations.add(task); + this.pendingCount.incrementAndGet(); + } + + public boolean isDelayable() { + return this.pendingCount.get() < MAX_PENDING_OPERATIONS; + } + + public void drain() { + for (int i = 0; i < DRAIN_THRESHOLD; i++) { + final Runnable task = this.operations.poll(); + if (task == null) { + break; + } + this.pendingCount.decrementAndGet(); + task.run(); + } + } + + public void drainAll() { + Runnable task; + while ((task = this.operations.poll()) != null) { + this.pendingCount.decrementAndGet(); + task.run(); + } + } + + } + + @SuppressWarnings("serial") + private static final class Node extends AtomicReference> { + final K key; + + @Nullable + Node prev; + + @Nullable + Node next; + + Node(K key, CacheEntry cacheEntry) { + super(cacheEntry); + this.key = key; + } + + @Nullable + public Node getPrevious() { + return this.prev; + } + + public void setPrevious(@Nullable Node prev) { + this.prev = prev; + } + + @Nullable + public Node getNext() { + return this.next; + } + + public void setNext(@Nullable Node next) { + this.next = next; + } + + V getValue() { + return get().value; + } + } + + + private static final class EvictionQueue { + + @Nullable + Node first; + + @Nullable + Node last; + + + @Nullable + Node poll() { + if (this.first == null) { + return null; + } + final Node f = this.first; + final Node next = f.getNext(); + f.setNext(null); + + this.first = next; + if (next == null) { + this.last = null; + } + else { + next.setPrevious(null); + } + return f; + } + + void add(Node e) { + if (contains(e)) { + return; + } + linkLast(e); + } + + private boolean contains(Node e) { + return (e.getPrevious() != null) || (e.getNext() != null) || (e == this.first); + } + + private void linkLast(final Node e) { + final Node l = this.last; + this.last = e; + + if (l == null) { + this.first = e; + } + else { + l.setNext(e); + e.setPrevious(l); + } + } + + private void unlink(Node e) { + final Node prev = e.getPrevious(); + final Node next = e.getNext(); + if (prev == null) { + this.first = next; + } + else { + prev.setNext(next); + e.setPrevious(null); + } + if (next == null) { + this.last = prev; + } + else { + next.setPrevious(prev); + e.setNext(null); + } + } + + void moveToBack(Node e) { + if (contains(e) && e != this.last) { + unlink(e); + linkLast(e); + } + } + + void remove(Node e) { + if (contains(e)) { + unlink(e); + } + } + + } + +} diff --git a/spring-core/src/test/java/org/springframework/util/ConcurrentLruCache2Tests.java b/spring-core/src/test/java/org/springframework/util/ConcurrentLruCache2Tests.java new file mode 100644 index 000000000000..3a89a6ffbc28 --- /dev/null +++ b/spring-core/src/test/java/org/springframework/util/ConcurrentLruCache2Tests.java @@ -0,0 +1,70 @@ +/* + * Copyright 2002-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.util; + +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link ConcurrentLruCache2}. + * + */ +class ConcurrentLruCache2Tests { + + @Test + void missReturnsNullAndDoesNotPopulate() { + ConcurrentLruCache2 cache = new ConcurrentLruCache2<>(2); + + assertThat(cache.get("k1")).isNull(); + assertThat(cache.size()).isZero(); + assertThat(cache.contains("k1")).isFalse(); + } + + @Test + void manualPutCachesValueAfterMiss() { + ConcurrentLruCache2 cache = new ConcurrentLruCache2<>(2); + + assertThat(cache.get("k1")).isNull(); + assertThat(cache.size()).isZero(); + + cache.put("k1", "v1"); + cache.put("k2", "v2"); + + assertThat(cache.get("k1")).isEqualTo("v1"); + assertThat(cache.get("k2")).isEqualTo("v2"); + assertThat(cache.size()).isEqualTo(2); + assertThat(cache.contains("k1")).isTrue(); + assertThat(cache.contains("k2")).isTrue(); + } + + @Test + void differsFromLegacyCacheThatGeneratesOnMiss() { + ConcurrentLruCache legacy = new ConcurrentLruCache<>(2, key -> key + "value"); + ConcurrentLruCache2 cache = new ConcurrentLruCache2<>(2); + + assertThat(cache.get("k1")).isNull(); + assertThat(cache.size()).isZero(); + assertThat(cache.contains("k1")).isFalse(); + + assertThat(legacy.get("k1")).isEqualTo("k1value"); + assertThat(legacy.contains("k1")).isTrue(); + assertThat(legacy.size()).isEqualTo(1); + } + +} +