/** * Return a composed filter function that first applies this filter, and * then applies the given {@code "after"} filter. * @param afterFilter the filter to apply after this filter * @return the composed filter */ default ExchangeFilterFunction andThen(ExchangeFilterFunction afterFilter) { Assert.notNull(afterFilter, "ExchangeFilterFunction must not be null"); return (request, next) -> filter(request, afterRequest -> afterFilter.filter(afterRequest, next)); }
/** * Filters this exchange function with the given {@code ExchangeFilterFunction}, resulting in a * filtered {@code ExchangeFunction}. * @param filter the filter to apply to this exchange * @return the filtered exchange * @see ExchangeFilterFunction#apply(ExchangeFunction) */ default ExchangeFunction filter(ExchangeFilterFunction filter) { return filter.apply(this); }
/** * Return a filter that generates an error signal when the given * {@link HttpStatus} predicate matches. * @param statusPredicate the predicate to check the HTTP status with * @param exceptionFunction the function that to create the exception * @return the filter to generate an error signal */ public static ExchangeFilterFunction statusError(Predicate<HttpStatus> statusPredicate, Function<ClientResponse, ? extends Throwable> exceptionFunction) { Assert.notNull(statusPredicate, "Predicate must not be null"); Assert.notNull(exceptionFunction, "Function must not be null"); return ExchangeFilterFunction.ofResponseProcessor( response -> (statusPredicate.test(response.statusCode()) ? Mono.error(exceptionFunction.apply(response)) : Mono.just(response))); }
@Test public void andThen() { ClientRequest request = ClientRequest.create(HttpMethod.GET, DEFAULT_URL).build(); ClientResponse response = mock(ClientResponse.class); ExchangeFunction exchange = r -> Mono.just(response); boolean[] filtersInvoked = new boolean[2]; ExchangeFilterFunction filter1 = (r, n) -> { assertFalse(filtersInvoked[0]); assertFalse(filtersInvoked[1]); filtersInvoked[0] = true; assertFalse(filtersInvoked[1]); return n.exchange(r); }; ExchangeFilterFunction filter2 = (r, n) -> { assertTrue(filtersInvoked[0]); assertFalse(filtersInvoked[1]); filtersInvoked[1] = true; return n.exchange(r); }; ExchangeFilterFunction filter = filter1.andThen(filter2); ClientResponse result = filter.filter(request, exchange).block(); assertEquals(response, result); assertTrue(filtersInvoked[0]); assertTrue(filtersInvoked[1]); }
@Override public WebClient.Builder filter(ExchangeFilterFunction filter) { Assert.notNull(filter, "'filter' must not be null"); this.filter = filter.andThen(this.filter); return this; }
/** * Return a filter that adds an Authorization header for HTTP Basic Authentication. * @param username the username to use * @param password the password to use * @return the {@link ExchangeFilterFunction} that adds the Authorization header */ public static ExchangeFilterFunction basicAuthentication(String username, String password) { Assert.notNull(username, "'username' must not be null"); Assert.notNull(password, "'password' must not be null"); return ExchangeFilterFunction.ofRequestProcessor( clientRequest -> { String authorization = authorization(username, password); ClientRequest<?> authorizedRequest = ClientRequest.from(clientRequest) .header(HttpHeaders.AUTHORIZATION, authorization) .body(clientRequest.inserter()); return Mono.just(authorizedRequest); }); }
@Override public WebClient filter(ExchangeFilterFunction filter) { Assert.notNull(filter, "'filter' must not be null"); ExchangeFilterFunction composedFilter = filter.andThen(this.filter); return new DefaultWebClient(this.clientHttpConnector, this.strategies, composedFilter); } }
/** * Create a session-bound {@link WebClient} to be used by {@link VaultTemplate} for * Vault communication given {@link VaultEndpointProvider} and * {@link ClientHttpConnector} for calls that require an authenticated context. * {@link VaultEndpointProvider} is used to contribute host and port details for * relative URLs typically used by the Template API. Subclasses may override this * method to customize the {@link WebClient}. * * @param endpointProvider must not be {@literal null}. * @param connector must not be {@literal null}. * @return the {@link WebClient} used for Vault communication. * @since 2.1 */ protected WebClient doCreateSessionWebClient(VaultEndpointProvider endpointProvider, ClientHttpConnector connector) { Assert.notNull(endpointProvider, "VaultEndpointProvider must not be null"); Assert.notNull(connector, "ClientHttpConnector must not be null"); ExchangeFilterFunction filter = ofRequestProcessor(request -> vaultTokenSupplier .getVaultToken().map(token -> { return ClientRequest.from(request).headers(headers -> { headers.set(VaultHttpHeaders.VAULT_TOKEN, token.getToken()); }).build(); })); return doCreateWebClient(endpointProvider, connector).mutate().filter(filter) .build(); }
/** * Apply this filter to the given {@linkplain ExchangeFunction}, resulting * in a filtered exchange function. * @param exchange the exchange function to filter * @return the filtered exchange function */ default ExchangeFunction apply(ExchangeFunction exchange) { Assert.notNull(exchange, "ExchangeFunction must not be null"); return request -> this.filter(request, exchange); }
@Override public WebClient build() { ExchangeFunction exchange = initExchangeFunction(); ExchangeFunction filteredExchange = (this.filters != null ? this.filters.stream() .reduce(ExchangeFilterFunction::andThen) .map(filter -> filter.apply(exchange)) .orElse(exchange) : exchange); return new DefaultWebClient(filteredExchange, initUriBuilderFactory(), this.defaultHeaders != null ? unmodifiableCopy(this.defaultHeaders) : null, this.defaultCookies != null ? unmodifiableCopy(this.defaultCookies) : null, this.defaultRequest, new DefaultWebClientBuilder(this)); }
@Test public void shouldApplyErrorHandlingFilter() { ExchangeFilterFunction filter = ExchangeFilterFunction.ofResponseProcessor( clientResponse -> { List<String> headerValues = clientResponse.headers().header("Foo");
@Test(expected = IllegalArgumentException.class) public void basicAuthenticationInvalidCharacters() { ClientRequest request = ClientRequest.create(HttpMethod.GET, DEFAULT_URL).build(); ExchangeFunction exchange = r -> Mono.just(mock(ClientResponse.class)); ExchangeFilterFunctions.basicAuthentication("foo", "\ud83d\udca9").filter(request, exchange); }
@Test public void apply() { ClientRequest request = ClientRequest.create(HttpMethod.GET, DEFAULT_URL).build(); ClientResponse response = mock(ClientResponse.class); ExchangeFunction exchange = r -> Mono.just(response); boolean[] filterInvoked = new boolean[1]; ExchangeFilterFunction filter = (r, n) -> { assertFalse(filterInvoked[0]); filterInvoked[0] = true; return n.exchange(r); }; ExchangeFunction filteredExchange = filter.apply(exchange); ClientResponse result = filteredExchange.exchange(request).block(); assertEquals(response, result); assertTrue(filterInvoked[0]); }
@Test public void statusHandlerMatch() { ClientRequest request = ClientRequest.create(HttpMethod.GET, DEFAULT_URL).build(); ClientResponse response = mock(ClientResponse.class); when(response.statusCode()).thenReturn(HttpStatus.NOT_FOUND); ExchangeFunction exchange = r -> Mono.just(response); ExchangeFilterFunction errorHandler = ExchangeFilterFunctions.statusError( HttpStatus::is4xxClientError, r -> new MyException()); Mono<ClientResponse> result = errorHandler.filter(request, exchange); StepVerifier.create(result) .expectError(MyException.class) .verify(); }
@Test public void statusHandlerNoMatch() { ClientRequest request = ClientRequest.create(HttpMethod.GET, DEFAULT_URL).build(); ClientResponse response = mock(ClientResponse.class); when(response.statusCode()).thenReturn(HttpStatus.NOT_FOUND); Mono<ClientResponse> result = ExchangeFilterFunctions .statusError(HttpStatus::is5xxServerError, req -> new MyException()) .filter(request, req -> Mono.just(response)); StepVerifier.create(result) .expectNext(response) .expectComplete() .verify(); }
@Test @SuppressWarnings("deprecation") public void basicAuthenticationAbsentAttributes() { ClientRequest request = ClientRequest.create(HttpMethod.GET, DEFAULT_URL).build(); ClientResponse response = mock(ClientResponse.class); ExchangeFunction exchange = r -> { assertFalse(r.headers().containsKey(HttpHeaders.AUTHORIZATION)); return Mono.just(response); }; ExchangeFilterFunction auth = ExchangeFilterFunctions.basicAuthentication(); assertFalse(request.headers().containsKey(HttpHeaders.AUTHORIZATION)); ClientResponse result = auth.filter(request, exchange).block(); assertEquals(response, result); }
@Test public void basicAuthenticationUsernamePassword() { ClientRequest request = ClientRequest.create(HttpMethod.GET, DEFAULT_URL).build(); ClientResponse response = mock(ClientResponse.class); ExchangeFunction exchange = r -> { assertTrue(r.headers().containsKey(HttpHeaders.AUTHORIZATION)); assertTrue(r.headers().getFirst(HttpHeaders.AUTHORIZATION).startsWith("Basic ")); return Mono.just(response); }; ExchangeFilterFunction auth = ExchangeFilterFunctions.basicAuthentication("foo", "bar"); assertFalse(request.headers().containsKey(HttpHeaders.AUTHORIZATION)); ClientResponse result = auth.filter(request, exchange).block(); assertEquals(response, result); }
@Test @SuppressWarnings("deprecation") public void basicAuthenticationAttributes() { ClientRequest request = ClientRequest.create(HttpMethod.GET, DEFAULT_URL) .attributes(org.springframework.web.reactive.function.client.ExchangeFilterFunctions .Credentials.basicAuthenticationCredentials("foo", "bar")) .build(); ClientResponse response = mock(ClientResponse.class); ExchangeFunction exchange = r -> { assertTrue(r.headers().containsKey(HttpHeaders.AUTHORIZATION)); assertTrue(r.headers().getFirst(HttpHeaders.AUTHORIZATION).startsWith("Basic ")); return Mono.just(response); }; ExchangeFilterFunction auth = ExchangeFilterFunctions.basicAuthentication(); assertFalse(request.headers().containsKey(HttpHeaders.AUTHORIZATION)); ClientResponse result = auth.filter(request, exchange).block(); assertEquals(response, result); }
@Test public void limitResponseSize() { DefaultDataBufferFactory bufferFactory = new DefaultDataBufferFactory(); DataBuffer b1 = dataBuffer("foo", bufferFactory); DataBuffer b2 = dataBuffer("bar", bufferFactory); DataBuffer b3 = dataBuffer("baz", bufferFactory); ClientRequest request = ClientRequest.create(HttpMethod.GET, DEFAULT_URL).build(); ClientResponse response = ClientResponse.create(HttpStatus.OK).body(Flux.just(b1, b2, b3)).build(); Mono<ClientResponse> result = ExchangeFilterFunctions.limitResponseSize(5) .filter(request, req -> Mono.just(response)); StepVerifier.create(result.flatMapMany(res -> res.body(BodyExtractors.toDataBuffers()))) .consumeNextWith(buffer -> assertEquals("foo", string(buffer))) .consumeNextWith(buffer -> assertEquals("ba", string(buffer))) .expectComplete() .verify(); }
/** * Return a composed filter function that first applies this filter, and then applies the * {@code after} filter. * @param after the filter to apply after this filter is applied * @return a composed filter that first applies this function and then applies the * {@code after} function */ default ExchangeFilterFunction andThen(ExchangeFilterFunction after) { Assert.notNull(after, "'after' must not be null"); return (request, next) -> { ExchangeFunction nextExchange = exchangeRequest -> after.filter(exchangeRequest, next); return filter(request, nextExchange); }; }