Reverted previous TLS server name replacement, and implemented thread safe version.

pull/4836/head
Kiss Károly 2022-06-10 18:17:26 +02:00
parent 7683eb294b
commit 7b99772e85
2 changed files with 30 additions and 5 deletions

View File

@ -244,10 +244,6 @@ func (h *HTTPTransport) NewTransport(ctx caddy.Context) (*http.Transport, error)
// RoundTrip implements http.RoundTripper.
func (h *HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) {
repl := req.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
if h.Transport.TLSClientConfig != nil {
h.Transport.TLSClientConfig.ServerName = repl.ReplaceAll(h.Transport.TLSClientConfig.ServerName, "")
}
h.SetScheme(req)

View File

@ -730,10 +730,39 @@ func (h *Handler) reverseProxy(rw http.ResponseWriter, req *http.Request, origRe
server := req.Context().Value(caddyhttp.ServerCtxKey).(*caddyhttp.Server)
shouldLogCredentials := server.Logs != nil && server.Logs.ShouldLogCredentials
// Default to using the transport configured during provisioning stage
transport := h.Transport
if tmpTransport, ok := transport.(*HTTPTransport); ok {
// check whether we have TLS and need to replace the servername in the TLSClientConfig
if tmpTransport.TLSEnabled() && strings.Contains(tmpTransport.TLS.ServerName, "{") {
// make a new transport, "copy" the parts we don't need to touch, add a new *tls.Config, replace servername and then call RoundTrip on that to avoid any races
newtransport := &HTTPTransport{
Resolver: tmpTransport.Resolver,
TLS: tmpTransport.TLS,
KeepAlive: tmpTransport.KeepAlive,
Compression: tmpTransport.Compression,
MaxConnsPerHost: tmpTransport.MaxConnsPerHost,
DialTimeout: tmpTransport.DialTimeout,
FallbackDelay: tmpTransport.FallbackDelay,
ResponseHeaderTimeout: tmpTransport.ResponseHeaderTimeout,
ExpectContinueTimeout: tmpTransport.ExpectContinueTimeout,
MaxResponseHeaderSize: tmpTransport.MaxResponseHeaderSize,
WriteBufferSize: tmpTransport.WriteBufferSize,
ReadBufferSize: tmpTransport.ReadBufferSize,
Versions: tmpTransport.Versions,
Transport: tmpTransport.Transport.Clone(),
h2cTransport: tmpTransport.h2cTransport,
}
newtransport.Transport.TLSClientConfig.ServerName = repl.ReplaceAll(newtransport.Transport.TLSClientConfig.ServerName, "")
transport = newtransport
}
}
// do the round-trip; emit debug log with values we know are
// safe, or if there is no error, emit fuller log entry
start := time.Now()
res, err := h.Transport.RoundTrip(req)
res, err := transport.RoundTrip(req)
duration := time.Since(start)
logger := h.logger.With(
zap.String("upstream", di.Upstream.String()),