001/*
002 * nimbus-jose-jwt
003 *
004 * Copyright 2012-2022, Connect2id Ltd.
005 *
006 * Licensed under the Apache License, Version 2.0 (the "License"); you may not use
007 * this file except in compliance with the License. You may obtain a copy of the
008 * License at
009 *
010 *    http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing, software distributed
013 * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
014 * CONDITIONS OF ANY KIND, either express or implied. See the License for the
015 * specific language governing permissions and limitations under the License.
016 */
017
018package com.nimbusds.jose.jwk.source;
019
020
021import java.io.IOException;
022import java.util.Objects;
023import java.util.concurrent.*;
024import java.util.concurrent.locks.ReentrantLock;
025
026import net.jcip.annotations.ThreadSafe;
027
028import com.nimbusds.jose.KeySourceException;
029import com.nimbusds.jose.jwk.JWKSet;
030import com.nimbusds.jose.proc.SecurityContext;
031import com.nimbusds.jose.util.cache.CachedObject;
032import com.nimbusds.jose.util.events.EventListener;
033
034
035/**
036 * Caching {@linkplain JWKSetSource} that refreshes the JWK set prior to its
037 * expiration. The updates run on a separate, dedicated thread. Updates can be
038 * repeatedly scheduled, or (lazily) triggered by incoming requests for the JWK
039 * set.
040 *
041 * <p>This class is intended for uninterrupted operation under high-load, to
042 * avoid a potentially large number of threads blocking when the cache expires
043 * (and must be refreshed).
044 *
045 * @author Thomas Rørvik Skjølberg
046 * @author Vladimir Dzhuvinov
047 * @version 2022-11-22
048 */
049@ThreadSafe
050public class RefreshAheadCachingJWKSetSource<C extends SecurityContext> extends CachingJWKSetSource<C> {
051        
052        
053        /**
054         * New JWK set refresh scheduled event.
055         */
056        public static class RefreshScheduledEvent<C extends SecurityContext> extends AbstractJWKSetSourceEvent<CachingJWKSetSource<C>, C> {
057                
058                public RefreshScheduledEvent(final RefreshAheadCachingJWKSetSource<C> source, final C context) {
059                        super(source, context);
060                }
061        }
062        
063        
064        /**
065         * JWK set refresh not scheduled event.
066         */
067        public static class RefreshNotScheduledEvent<C extends SecurityContext> extends AbstractJWKSetSourceEvent<CachingJWKSetSource<C>, C> {
068                
069                public RefreshNotScheduledEvent(final RefreshAheadCachingJWKSetSource<C> source, final C context) {
070                        super(source, context);
071                }
072        }
073        
074        
075        /**
076         * Scheduled JWK refresh failed event.
077         */
078        public static class ScheduledRefreshFailed<C extends SecurityContext> extends AbstractJWKSetSourceEvent<CachingJWKSetSource<C>, C> {
079                
080                private final Exception exception;
081                
082                public ScheduledRefreshFailed(final CachingJWKSetSource<C> source,
083                                              final Exception exception,
084                                              final C context) {
085                        super(source, context);
086                        Objects.requireNonNull(exception);
087                        this.exception = exception;
088                }
089                
090                
091                public Exception getException() {
092                        return exception;
093                }
094        }
095        
096        
097        /**
098         * Scheduled JWK set cache refresh initiated event.
099         */
100        public static class ScheduledRefreshInitiatedEvent<C extends SecurityContext> extends AbstractJWKSetSourceEvent<CachingJWKSetSource<C>, C> {
101                
102                private ScheduledRefreshInitiatedEvent(final CachingJWKSetSource<C> source, final C context) {
103                        super(source, context);
104                }
105        }
106        
107        
108        /**
109         * Scheduled JWK set cache refresh completed event.
110         */
111        public static class ScheduledRefreshCompletedEvent<C extends SecurityContext> extends AbstractJWKSetSourceEvent<CachingJWKSetSource<C>, C> {
112                
113                private final JWKSet jwkSet;
114                
115                private ScheduledRefreshCompletedEvent(final CachingJWKSetSource<C> source,
116                                                       final JWKSet jwkSet,
117                                                       final C context) {
118                        super(source, context);
119                        Objects.requireNonNull(jwkSet);
120                        this.jwkSet = jwkSet;
121                }
122                
123                
124                /**
125                 * Returns the refreshed JWK set.
126                 *
127                 * @return The refreshed JWK set.
128                 */
129                public JWKSet getJWKSet() {
130                        return jwkSet;
131                }
132        }
133        
134        
135        /**
136         * Unable to refresh the JWK set cache ahead of expiration event.
137         */
138        public static class UnableToRefreshAheadOfExpirationEvent<C extends SecurityContext> extends AbstractJWKSetSourceEvent<CachingJWKSetSource<C>, C> {
139                
140                
141                public UnableToRefreshAheadOfExpirationEvent(final CachingJWKSetSource<C> source, final C context) {
142                        super(source, context);
143                }
144        }
145        
146        
147        
148        // refresh ahead of expiration should execute when
149        // expirationTime - refreshAheadTime < currentTime < expirationTime
150        private final long refreshAheadTime; // milliseconds
151
152        private final ReentrantLock lazyLock = new ReentrantLock();
153
154        private final ExecutorService executorService;
155        private final boolean shutdownExecutorOnClose;
156        private final ScheduledExecutorService scheduledExecutorService;
157        
158        // cache expiration time (in milliseconds) used as fingerprint
159        private volatile long cacheExpiration;
160        
161        private ScheduledFuture<?> scheduledRefreshFuture;
162        
163        private final EventListener<CachingJWKSetSource<C>, C> eventListener;
164
165        
166        /**
167         * Creates a new refresh-ahead caching JWK set source.
168         *
169         * @param source              The JWK set source to decorate. Must not
170         *                            be {@code null}.
171         * @param timeToLive          The time to live of the cached JWK set,
172         *                            in milliseconds.
173         * @param cacheRefreshTimeout The cache refresh timeout, in
174         *                            milliseconds.
175         * @param refreshAheadTime    The refresh ahead time, in milliseconds.
176         * @param scheduled           {@code true} to refresh in a scheduled
177         *                            manner, regardless of requests.
178         * @param eventListener       The event listener, {@code null} if not
179         *                            specified.
180         */
181        public RefreshAheadCachingJWKSetSource(final JWKSetSource<C> source,
182                                               final long timeToLive,
183                                               final long cacheRefreshTimeout,
184                                               final long refreshAheadTime,
185                                               final boolean scheduled,
186                                               final EventListener<CachingJWKSetSource<C>, C> eventListener) {
187                
188                this(source, timeToLive, cacheRefreshTimeout, refreshAheadTime,
189                        scheduled, Executors.newSingleThreadExecutor(), true,
190                        eventListener);
191        }
192        
193
194        /**
195         * Creates a new refresh-ahead caching JWK set source with the
196         * specified executor service to run the updates in the background.
197         *
198         * @param source                  The JWK set source to decorate. Must
199         *                                not be {@code null}.
200         * @param timeToLive              The time to live of the cached JWK
201         *                                set, in milliseconds.
202         * @param cacheRefreshTimeout     The cache refresh timeout, in
203         *                                milliseconds.
204         * @param refreshAheadTime        The refresh ahead time, in
205         *                                milliseconds.
206         * @param scheduled               {@code true} to refresh in a
207         *                                scheduled manner, regardless of
208         *                                requests.
209         * @param executorService         The executor service to run the
210         *                                updates in the background.
211         * @param shutdownExecutorOnClose If {@code true} the executor service
212         *                                will be shut down upon closing the
213         *                                source.
214         * @param eventListener           The event listener, {@code null} if
215         *                                not specified.
216         */
217        public RefreshAheadCachingJWKSetSource(final JWKSetSource<C> source,
218                                               final long timeToLive,
219                                               final long cacheRefreshTimeout,
220                                               final long refreshAheadTime,
221                                               final boolean scheduled,
222                                               final ExecutorService executorService,
223                                               final boolean shutdownExecutorOnClose,
224                                               final EventListener<CachingJWKSetSource<C>, C> eventListener) {
225                
226                super(source, timeToLive, cacheRefreshTimeout, eventListener);
227
228                if (refreshAheadTime + cacheRefreshTimeout > timeToLive) {
229                        throw new IllegalArgumentException("The sum of the refresh-ahead time (" + refreshAheadTime +"ms) " +
230                                "and the cache refresh timeout (" + cacheRefreshTimeout +"ms) " +
231                                "must not exceed the time-to-lived time (" + timeToLive + "ms)");
232                }
233
234                this.refreshAheadTime = refreshAheadTime;
235                
236                Objects.requireNonNull(executorService, "The executor service must not be null");
237                this.executorService = executorService;
238                
239                this.shutdownExecutorOnClose = shutdownExecutorOnClose;
240
241                if (scheduled) {
242                        scheduledExecutorService = Executors.newSingleThreadScheduledExecutor();
243                } else {
244                        scheduledExecutorService = null;
245                }
246                
247                this.eventListener = eventListener;
248        }
249
250        
251        @Override
252        public JWKSet getJWKSet(final JWKSetCacheRefreshEvaluator refreshEvaluator, final long currentTime, final C context) throws KeySourceException {
253                CachedObject<JWKSet> cache = getCachedJWKSet();
254                if (cache == null) {
255                        return loadJWKSetBlocking(JWKSetCacheRefreshEvaluator.noRefresh(), currentTime, context);
256                }
257
258                JWKSet jwkSet = cache.get();
259                if (refreshEvaluator.requiresRefresh(jwkSet)) {
260                        return loadJWKSetBlocking(refreshEvaluator, currentTime, context);
261                }               
262                
263                if (cache.isExpired(currentTime)) {
264                        return loadJWKSetBlocking(JWKSetCacheRefreshEvaluator.referenceComparison(jwkSet), currentTime, context);
265                }
266                
267                refreshAheadOfExpiration(cache, false, currentTime, context);
268
269                return cache.get();
270        }
271        
272
273        @Override
274        CachedObject<JWKSet> loadJWKSetNotThreadSafe(final JWKSetCacheRefreshEvaluator refreshEvaluator, final long currentTime, final C context) throws KeySourceException {
275                // Never run by two threads at the same time!
276                CachedObject<JWKSet> cache = super.loadJWKSetNotThreadSafe(refreshEvaluator, currentTime, context);
277
278                if (scheduledExecutorService != null) {
279                        scheduleRefreshAheadOfExpiration(cache, currentTime, context);
280                }
281
282                return cache;
283        }
284        
285        
286        /**
287         * Schedules repeated refresh ahead of cached JWK set expiration.
288         */
289        void scheduleRefreshAheadOfExpiration(final CachedObject<JWKSet> cache, final long currentTime, final C context) {
290                
291                if (scheduledRefreshFuture != null) {
292                        scheduledRefreshFuture.cancel(false);
293                }
294
295                // so we want to keep other threads from triggering preemptive refresh
296                // subtracting the refresh timeout should be enough
297                long delay = cache.getExpirationTime() - currentTime - refreshAheadTime - getCacheRefreshTimeout();
298                if (delay > 0) {
299                        final RefreshAheadCachingJWKSetSource<C> that = this;
300                        Runnable command = new Runnable() {
301
302                                @Override
303                                public void run() {
304                                        try {
305                                                // so will only refresh if this specific cache entry still is the current one
306                                                refreshAheadOfExpiration(cache, true, System.currentTimeMillis(), context);
307                                        } catch (Exception e) {
308                                                if (eventListener != null) {
309                                                        eventListener.notify(new ScheduledRefreshFailed<C>(that, e, context));
310                                                }
311                                                // ignore
312                                        }
313                                }
314                        };
315                        this.scheduledRefreshFuture = scheduledExecutorService.schedule(command, delay, TimeUnit.MILLISECONDS);
316                        
317                        if (eventListener != null) {
318                                eventListener.notify(new RefreshScheduledEvent<C>(this, context));
319                        }
320                } else {
321                        // cache refresh not scheduled
322                        if (eventListener != null) {
323                                eventListener.notify(new RefreshNotScheduledEvent<C>(this, context));
324                        }
325                }
326        }
327
328        
329        /**
330         * Refreshes the cached JWK set if past the time threshold or refresh
331         * is forced.
332         *
333         * @param cache        The current cache. Must not be {@code null}.
334         * @param forceRefresh {@code true} to force refresh.
335         * @param currentTime  The current time.
336         */
337        void refreshAheadOfExpiration(final CachedObject<JWKSet> cache, final boolean forceRefresh, final long currentTime, final C context) {
338                
339                if (cache.isExpired(currentTime + refreshAheadTime) || forceRefresh) {
340                        
341                        // cache will expire soon, preemptively update it
342
343                        // check if an update is already in progress
344                        if (cacheExpiration < cache.getExpirationTime()) {
345                                // seems no update is in progress, see if we can get the lock
346                                if (lazyLock.tryLock()) {
347                                        try {
348                                                lockedRefresh(cache, currentTime, context);
349                                        } finally {
350                                                lazyLock.unlock();
351                                        }
352                                }
353                        }
354                }
355        }
356
357        
358        /**
359         * Checks if a refresh is in progress and if not triggers one. To be
360         * called by a single thread at a time.
361         *
362         * @param cache       The current cache. Must not be {@code null}.
363         * @param currentTime The current time.
364         */
365        void lockedRefresh(final CachedObject<JWKSet> cache, final long currentTime, final C context) {
366                // check if an update is already in progress (again now that this thread holds the lock)
367                if (cacheExpiration < cache.getExpirationTime()) {
368
369                        // still no update is in progress
370                        cacheExpiration = cache.getExpirationTime();
371                        
372                        final RefreshAheadCachingJWKSetSource<C> that = this;
373
374                        Runnable runnable = new Runnable() {
375
376                                @Override
377                                public void run() {
378                                        try {
379                                                if (eventListener != null) {
380                                                        eventListener.notify(new ScheduledRefreshInitiatedEvent<>(that, context));
381                                                }
382                                                
383                                                JWKSet jwkSet = RefreshAheadCachingJWKSetSource.this.loadJWKSetBlocking(JWKSetCacheRefreshEvaluator.forceRefresh(), currentTime, context);
384                                                
385                                                if (eventListener != null) {
386                                                        eventListener.notify(new ScheduledRefreshCompletedEvent<>(that, jwkSet, context));
387                                                }
388
389                                                // so next time this method is invoked, it'll be with the updated cache item expiry time
390                                        } catch (Throwable e) {
391                                                // update failed, but another thread can retry
392                                                cacheExpiration = -1L;
393                                                // ignore, unable to update
394                                                // another thread will attempt the same
395                                                if (eventListener != null) {
396                                                        eventListener.notify(new UnableToRefreshAheadOfExpirationEvent<C>(that, context));
397                                                }
398                                        }
399                                }
400                        };
401                        // run update in the background
402                        executorService.execute(runnable);
403                }
404        }
405
406
407        /**
408         * Returns the executor service running the updates in the background.
409         *
410         * @return The executor service.
411         */
412        public ExecutorService getExecutorService() {
413                return executorService;
414        }
415
416        
417        ReentrantLock getLazyLock() {
418                return lazyLock;
419        }
420        
421        
422        /**
423         * Returns the current scheduled refresh future.
424         *
425         * @return The current future, {@code null} if none.
426         */
427        ScheduledFuture<?> getScheduledRefreshFuture() {
428                return scheduledRefreshFuture;
429        }
430
431        
432        @Override
433        public void close() throws IOException {
434                
435                ScheduledFuture<?> currentScheduledRefreshFuture = this.scheduledRefreshFuture; // defensive copy
436                if (currentScheduledRefreshFuture != null) {
437                        currentScheduledRefreshFuture.cancel(true);
438                }
439                
440                super.close();
441                
442                if (shutdownExecutorOnClose) {
443                        executorService.shutdownNow();
444                        try {
445                                executorService.awaitTermination(getCacheRefreshTimeout(), TimeUnit.MILLISECONDS);
446                        } catch (InterruptedException e) {
447                                // ignore
448                                Thread.currentThread().interrupt();
449                        }
450                }
451                if (scheduledExecutorService != null) {
452                        scheduledExecutorService.shutdownNow();
453                        try {
454                                scheduledExecutorService.awaitTermination(getCacheRefreshTimeout(), TimeUnit.MILLISECONDS);
455                        } catch (InterruptedException e) {
456                                // ignore
457                                Thread.currentThread().interrupt();
458                        }
459                }               
460        }
461}