001/*
002 * nimbus-jose-jwt
003 *
004 * Copyright 2012-2022, Connect2id Ltd.
005 *
006 * Licensed under the Apache License, Version 2.0 (the "License"); you may not use
007 * this file except in compliance with the License. You may obtain a copy of the
008 * License at
009 *
010 *    http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing, software distributed
013 * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
014 * CONDITIONS OF ANY KIND, either express or implied. See the License for the
015 * specific language governing permissions and limitations under the License.
016 */
017
018package com.nimbusds.jose.jwk.source;
019
020
021import java.util.Objects;
022import java.util.concurrent.TimeUnit;
023import java.util.concurrent.locks.ReentrantLock;
024
025import net.jcip.annotations.ThreadSafe;
026
027import com.nimbusds.jose.KeySourceException;
028import com.nimbusds.jose.jwk.JWKSet;
029import com.nimbusds.jose.proc.SecurityContext;
030import com.nimbusds.jose.util.cache.CachedObject;
031import com.nimbusds.jose.util.events.EventListener;
032
033
034/**
035 * Caching {@linkplain JWKSetSource}. Blocks during cache updates.
036 *
037 * @author Thomas Rørvik Skjølberg
038 * @author Vladimir Dzhuvinov
039 * @version 2022-11-08
040 */
041@ThreadSafe
042public class CachingJWKSetSource<C extends SecurityContext> extends AbstractCachingJWKSetSource<C> {
043        
044        
045        static class AbstractCachingJWKSetSourceEvent<C extends SecurityContext> extends AbstractJWKSetSourceEvent<CachingJWKSetSource<C>, C> {
046                
047                private final int threadQueueLength;
048                
049                public AbstractCachingJWKSetSourceEvent(final CachingJWKSetSource<C> source,
050                                                        final int threadQueueLength,
051                                                        final C context) {
052                        super(source, context);
053                        this.threadQueueLength = threadQueueLength;
054                }
055                
056                
057                /**
058                 * Returns an estimate of the number of queued threads.
059                 *
060                 * @return An estimate of the number of queued threads.
061                 */
062                public int getThreadQueueLength() {
063                        return threadQueueLength;
064                }
065        }
066        
067        
068        /**
069         * JWK set cache refresh initiated event.
070         */
071        public static class RefreshInitiatedEvent<C extends SecurityContext> extends AbstractCachingJWKSetSourceEvent<C> {
072                
073                private RefreshInitiatedEvent(final CachingJWKSetSource<C> source, final int queueLength, final C context) {
074                        super(source, queueLength, context);
075                }
076        }
077        
078        
079        /**
080         * JWK set cache refresh completed event.
081         */
082        public static class RefreshCompletedEvent<C extends SecurityContext> extends AbstractCachingJWKSetSourceEvent<C> {
083                
084                private final JWKSet jwkSet;
085                
086                private RefreshCompletedEvent(final CachingJWKSetSource<C> source,
087                                              final JWKSet jwkSet,
088                                              final int queueLength,
089                                              final C context) {
090                        super(source, queueLength, context);
091                        Objects.requireNonNull(jwkSet);
092                        this.jwkSet = jwkSet;
093                }
094                
095                
096                /**
097                 * Returns the refreshed JWK set.
098                 *
099                 * @return The refreshed JWK set.
100                 */
101                public JWKSet getJWKSet() {
102                        return jwkSet;
103                }
104        }
105        
106        
107        /**
108         * Waiting for a JWK set cache refresh to complete on another thread
109         * event.
110         */
111        public static class WaitingForRefreshEvent<C extends SecurityContext> extends AbstractCachingJWKSetSourceEvent<C> {
112                
113                private WaitingForRefreshEvent(final CachingJWKSetSource<C> source, final int queueLength, final C context) {
114                        super(source, queueLength, context);
115                }
116        }
117        
118        
119        /**
120         * Unable to refresh the JWK set cache event.
121         */
122        public static class UnableToRefreshEvent<C extends SecurityContext> extends AbstractJWKSetSourceEvent<CachingJWKSetSource<C>, C> {
123                
124                private UnableToRefreshEvent(final CachingJWKSetSource<C> source, final C context) {
125                        super(source, context);
126                }
127        }
128        
129        
130        /**
131         * JWK set cache refresh timed out event.
132         */
133        public static class RefreshTimedOutEvent<C extends SecurityContext> extends AbstractCachingJWKSetSourceEvent<C> {
134                
135                private RefreshTimedOutEvent(final CachingJWKSetSource<C> source, final int queueLength, final C context) {
136                        super(source, queueLength, context);
137                }
138        }
139        
140        
141        private final ReentrantLock lock = new ReentrantLock();
142
143        private final long cacheRefreshTimeout;
144        
145        private final EventListener<CachingJWKSetSource<C>, C> eventListener;
146        
147        
148        /**
149         * Creates a new caching JWK set source.
150         *
151         * @param source              The JWK set source to decorate. Must not
152         *                            be {@code null}.
153         * @param timeToLive          The time to live of the cached JWK set,
154         *                            in milliseconds.
155         * @param cacheRefreshTimeout The cache refresh timeout, in
156         *                            milliseconds.
157         * @param eventListener       The event listener, {@code null} if not
158         *                            specified.
159         */
160        public CachingJWKSetSource(final JWKSetSource<C> source,
161                                   final long timeToLive,
162                                   final long cacheRefreshTimeout,
163                                   final EventListener<CachingJWKSetSource<C>, C> eventListener) {
164                super(source, timeToLive);
165                this.cacheRefreshTimeout = cacheRefreshTimeout;
166                this.eventListener = eventListener;
167        }
168
169        
170        @Override
171        public JWKSet getJWKSet(final JWKSetCacheRefreshEvaluator refreshEvaluator, final long currentTime, final C context) throws KeySourceException {
172                CachedObject<JWKSet> cache = getCachedJWKSet();
173                if (cache == null) {
174                        return loadJWKSetBlocking(JWKSetCacheRefreshEvaluator.noRefresh(), currentTime, context);
175                }
176
177                JWKSet jwkSet = cache.get();
178                if (refreshEvaluator.requiresRefresh(jwkSet)) {
179                        return loadJWKSetBlocking(refreshEvaluator, currentTime, context);
180                }
181                
182                if (cache.isExpired(currentTime)) {
183                        return loadJWKSetBlocking(JWKSetCacheRefreshEvaluator.referenceComparison(jwkSet), currentTime, context);
184                }
185
186                return cache.get();
187        }
188        
189        
190        /**
191         * Returns the cache refresh timeout.
192         *
193         * @return The cache refresh timeout, in milliseconds.
194         */
195        public long getCacheRefreshTimeout() {
196                return cacheRefreshTimeout;
197        }
198        
199        
200        /**
201         * Loads and caches the JWK set, with blocking.
202         *
203         * @param refreshEvaluator The JWK set cache refresh evaluator.
204         * @param currentTime      The current time, in milliseconds since the
205         *                         Unix epoch.
206         * @param context          Optional context, {@code null} if not
207         *                         required.
208         *
209         * @return The loaded and cached JWK set.
210         *
211         * @throws KeySourceException If retrieval failed.
212         */
213        JWKSet loadJWKSetBlocking(final JWKSetCacheRefreshEvaluator refreshEvaluator, final long currentTime, final C context)
214                throws KeySourceException {
215                
216                // Synchronize so that the first thread to acquire the lock
217                // exclusively gets to call the underlying source.
218                // Other (later) threads must wait until the result is ready.
219                //
220                // If the first to get the lock fails within the waiting interval,
221                // subsequent threads will attempt to update the cache themselves.
222                //
223                // This approach potentially blocks a number of threads,
224                // but requesting the same data downstream is not better, so
225                // this is a necessary evil.
226
227                final CachedObject<JWKSet> cache;
228                try {
229                        if (lock.tryLock()) {
230                                try {
231                                        // We hold the lock, so safe to update it now, 
232                                        // Check evaluator, another thread might have already updated the JWKs
233                                        CachedObject<JWKSet> cachedJWKSet = getCachedJWKSet();
234                                        if (cachedJWKSet == null || refreshEvaluator.requiresRefresh(cachedJWKSet.get())) {
235        
236                                                if (eventListener != null) {
237                                                        eventListener.notify(new RefreshInitiatedEvent<>(this, lock.getQueueLength(), context));
238                                                }
239                                                
240                                                CachedObject<JWKSet> result = loadJWKSetNotThreadSafe(refreshEvaluator, currentTime, context);
241                                                
242                                                if (eventListener != null) {
243                                                        eventListener.notify(new RefreshCompletedEvent<>(this, result.get(), lock.getQueueLength(), context));
244                                                }
245                                                
246                                                cache = result;
247                                        } else {
248                                                // load updated value
249                                                cache = cachedJWKSet;
250                                        }
251                                        
252                                } finally {
253                                        lock.unlock();
254                                }
255                        } else {
256                                // Lock held by another thread, wait for refresh timeout
257                                if (eventListener != null) {
258                                        eventListener.notify(new WaitingForRefreshEvent<>(this, lock.getQueueLength(), context));
259                                }
260
261                                if (lock.tryLock(getCacheRefreshTimeout(), TimeUnit.MILLISECONDS)) {
262                                        try {
263                                                // Check evaluator, another thread have most likely already updated the JWKs
264                                                CachedObject<JWKSet> cachedJWKSet = getCachedJWKSet();
265                                                if (cachedJWKSet == null || refreshEvaluator.requiresRefresh(cachedJWKSet.get())) {
266                                                        // Seems cache was not updated.
267                                                        // We hold the lock, so safe to update it now
268                                                        if (eventListener != null) {
269                                                                eventListener.notify(new RefreshInitiatedEvent<>(this, lock.getQueueLength(), context));
270                                                        }
271                                                        
272                                                        cache = loadJWKSetNotThreadSafe(refreshEvaluator, currentTime, context);
273                                                        
274                                                        if (eventListener != null) {
275                                                                eventListener.notify(new RefreshCompletedEvent<>(this, cache.get(), lock.getQueueLength(), context));
276                                                        }
277                                                } else {
278                                                        // load updated value
279                                                        cache = cachedJWKSet;
280                                                }
281                                        } finally {
282                                                lock.unlock();
283                                        }
284                                } else {
285
286                                        if (eventListener != null) {
287                                                eventListener.notify(new RefreshTimedOutEvent<>(this, lock.getQueueLength(), context));
288                                        }
289                                        
290                                        throw new JWKSetUnavailableException("Timeout while waiting for cache refresh (" + cacheRefreshTimeout + "ms exceeded)");
291                                }
292                        }
293
294                        if (cache != null && cache.isValid(currentTime)) {
295                                return cache.get();
296                        }
297                        
298                        if (eventListener != null) {
299                                eventListener.notify(new UnableToRefreshEvent<>(this, context));
300                        }
301                        
302                        throw new JWKSetUnavailableException("Unable to refresh cache");
303                        
304                } catch (InterruptedException e) {
305                        
306                        Thread.currentThread().interrupt(); // Restore interrupted state to make Sonar happy
307
308                        throw new JWKSetUnavailableException("Interrupted while waiting for cache refresh", e);
309                }
310        }
311        
312        
313        /**
314         * Loads the JWK set from the wrapped source and caches it. Should not
315         * be run by more than one thread at a time.
316         *
317         * @param refreshEvaluator The JWK set cache refresh evaluator.
318         * @param currentTime      The current time, in milliseconds since the
319         *                         Unix epoch.
320         * @param context          Optional context, {@code null} if not
321         *                         required.
322         *
323         * @return Reference to the cached JWK set.
324         *
325         * @throws KeySourceException If loading failed.
326         */
327        CachedObject<JWKSet> loadJWKSetNotThreadSafe(final JWKSetCacheRefreshEvaluator refreshEvaluator, final long currentTime, final C context)
328                throws KeySourceException {
329                
330                JWKSet jwkSet = getSource().getJWKSet(refreshEvaluator, currentTime, context);
331
332                return cacheJWKSet(jwkSet, currentTime);
333        }
334        
335        
336        /**
337         * Returns the lock.
338         *
339         * @return The lock.
340         */
341        ReentrantLock getLock() {
342                return lock;
343        }
344}