diff --git a/modules/caddyhttp/reverseproxy/httptransport.go b/modules/caddyhttp/reverseproxy/httptransport.go index 7d15644a9..bf11ab283 100644 --- a/modules/caddyhttp/reverseproxy/httptransport.go +++ b/modules/caddyhttp/reverseproxy/httptransport.go @@ -244,10 +244,6 @@ func (h *HTTPTransport) NewTransport(ctx caddy.Context) (*http.Transport, error) // RoundTrip implements http.RoundTripper. func (h *HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) { - repl := req.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer) - if h.Transport.TLSClientConfig != nil { - h.Transport.TLSClientConfig.ServerName = repl.ReplaceAll(h.Transport.TLSClientConfig.ServerName, "") - } h.SetScheme(req) diff --git a/modules/caddyhttp/reverseproxy/reverseproxy.go b/modules/caddyhttp/reverseproxy/reverseproxy.go index 1068c23da..c876128bd 100644 --- a/modules/caddyhttp/reverseproxy/reverseproxy.go +++ b/modules/caddyhttp/reverseproxy/reverseproxy.go @@ -730,10 +730,39 @@ func (h *Handler) reverseProxy(rw http.ResponseWriter, req *http.Request, origRe server := req.Context().Value(caddyhttp.ServerCtxKey).(*caddyhttp.Server) shouldLogCredentials := server.Logs != nil && server.Logs.ShouldLogCredentials + // Default to using the transport configured during provisioning stage + transport := h.Transport + + if tmpTransport, ok := transport.(*HTTPTransport); ok { + // check whether we have TLS and need to replace the servername in the TLSClientConfig + if tmpTransport.TLSEnabled() && strings.Contains(tmpTransport.TLS.ServerName, "{") { + // make a new transport, "copy" the parts we don't need to touch, add a new *tls.Config, replace servername and then call RoundTrip on that to avoid any races + newtransport := &HTTPTransport{ + Resolver: tmpTransport.Resolver, + TLS: tmpTransport.TLS, + KeepAlive: tmpTransport.KeepAlive, + Compression: tmpTransport.Compression, + MaxConnsPerHost: tmpTransport.MaxConnsPerHost, + DialTimeout: tmpTransport.DialTimeout, + FallbackDelay: tmpTransport.FallbackDelay, + ResponseHeaderTimeout: tmpTransport.ResponseHeaderTimeout, + ExpectContinueTimeout: tmpTransport.ExpectContinueTimeout, + MaxResponseHeaderSize: tmpTransport.MaxResponseHeaderSize, + WriteBufferSize: tmpTransport.WriteBufferSize, + ReadBufferSize: tmpTransport.ReadBufferSize, + Versions: tmpTransport.Versions, + Transport: tmpTransport.Transport.Clone(), + h2cTransport: tmpTransport.h2cTransport, + } + newtransport.Transport.TLSClientConfig.ServerName = repl.ReplaceAll(newtransport.Transport.TLSClientConfig.ServerName, "") + transport = newtransport + } + } + // do the round-trip; emit debug log with values we know are // safe, or if there is no error, emit fuller log entry start := time.Now() - res, err := h.Transport.RoundTrip(req) + res, err := transport.RoundTrip(req) duration := time.Since(start) logger := h.logger.With( zap.String("upstream", di.Upstream.String()),