block local addresses to prevent SSRF attacks

This commit is contained in:
Athou
2025-02-14 11:49:55 +01:00
parent dc3e5476a1
commit f519aa039f
8 changed files with 166 additions and 13 deletions

View File

@@ -2,7 +2,9 @@ package com.commafeed.backend;
import java.io.IOException;
import java.io.InputStream;
import java.net.InetAddress;
import java.net.URI;
import java.net.UnknownHostException;
import java.time.Duration;
import java.time.Instant;
import java.time.InstantSource;
@@ -10,8 +12,11 @@ import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.ExecutionException;
import java.util.stream.Stream;
import org.apache.commons.lang3.StringUtils;
import org.apache.hc.client5.http.DnsResolver;
import org.apache.hc.client5.http.SystemDefaultDnsResolver;
import org.apache.hc.client5.http.config.ConnectionConfig;
import org.apache.hc.client5.http.config.RequestConfig;
import org.apache.hc.client5.http.config.TlsConfig;
@@ -66,6 +71,7 @@ public class HttpGetter {
private final InstantSource instantSource;
private final CloseableHttpClient client;
private final Cache<HttpRequest, HttpResponse> cache;
private final DnsResolver dnsResolver = SystemDefaultDnsResolver.INSTANCE;
public HttpGetter(CommaFeedConfiguration config, InstantSource instantSource, CommaFeedVersion version, MetricRegistry metrics) {
this.config = config;
@@ -89,11 +95,20 @@ public class HttpGetter {
() -> cache == null ? 0 : cache.asMap().values().stream().mapToInt(e -> e.content != null ? e.content.length : 0).sum());
}
public HttpResult get(String url) throws IOException, NotModifiedException, TooManyRequestsException {
public HttpResult get(String url)
throws IOException, NotModifiedException, TooManyRequestsException, SchemeNotAllowedException, HostNotAllowedException {
return get(HttpRequest.builder(url).build());
}
public HttpResult get(HttpRequest request) throws IOException, NotModifiedException, TooManyRequestsException {
public HttpResult get(HttpRequest request)
throws IOException, NotModifiedException, TooManyRequestsException, SchemeNotAllowedException, HostNotAllowedException {
URI uri = URI.create(request.getUrl());
ensureHttpScheme(uri.getScheme());
if (config.httpClient().blockLocalAddresses()) {
ensurePublicAddress(uri.getHost());
}
final HttpResponse response;
if (cache == null) {
response = invoke(request);
@@ -141,6 +156,28 @@ public class HttpGetter {
response.getUrlAfterRedirect(), validFor);
}
private void ensureHttpScheme(String scheme) throws SchemeNotAllowedException {
if (!"http".equals(scheme) && !"https".equals(scheme)) {
throw new SchemeNotAllowedException(scheme);
}
}
private void ensurePublicAddress(String host) throws HostNotAllowedException, UnknownHostException {
if (host == null) {
throw new HostNotAllowedException(null);
}
InetAddress[] addresses = dnsResolver.resolve(host);
if (Stream.of(addresses).anyMatch(this::isPrivateAddress)) {
throw new HostNotAllowedException(host);
}
}
private boolean isPrivateAddress(InetAddress address) {
return address.isSiteLocalAddress() || address.isAnyLocalAddress() || address.isLinkLocalAddress() || address.isLoopbackAddress()
|| address.isMulticastAddress();
}
private HttpResponse invoke(HttpRequest request) throws IOException {
log.debug("fetching {}", request.getUrl());
@@ -229,7 +266,7 @@ public class HttpGetter {
}
}
private static PoolingHttpClientConnectionManager newConnectionManager(CommaFeedConfiguration config) {
private PoolingHttpClientConnectionManager newConnectionManager(CommaFeedConfiguration config) {
SSLFactory sslFactory = SSLFactory.builder().withUnsafeTrustMaterial().withUnsafeHostnameVerifier().build();
int poolSize = config.feedRefresh().httpThreads();
@@ -243,6 +280,7 @@ public class HttpGetter {
.setDefaultTlsConfig(TlsConfig.custom().setHandshakeTimeout(Timeout.of(config.httpClient().sslHandshakeTimeout())).build())
.setMaxConnPerRoute(poolSize)
.setMaxConnTotal(poolSize)
.setDnsResolver(dnsResolver)
.build();
}
@@ -279,6 +317,22 @@ public class HttpGetter {
.build();
}
public static class SchemeNotAllowedException extends Exception {
private static final long serialVersionUID = 1L;
public SchemeNotAllowedException(String scheme) {
super("Scheme not allowed: " + scheme);
}
}
public static class HostNotAllowedException extends Exception {
private static final long serialVersionUID = 1L;
public HostNotAllowedException(String host) {
super("Host not allowed: " + host);
}
}
@Getter
public static class NotModifiedException extends Exception {
private static final long serialVersionUID = 1L;