Move TLS servername replacement into it's own function

pull/4836/head
Kiss Károly 2022-06-13 11:20:17 +02:00
parent 7b99772e85
commit a661daff98
1 changed files with 34 additions and 23 deletions

View File

@ -716,6 +716,38 @@ func (h Handler) addForwardedHeaders(req *http.Request) error {
return nil
}
// replaceTLSServername checks TLS servername to see if it needs replacing
// if it does need replacing, it creates a new cloned HTTPTransport object to avoid any races
// and does the replacing of the TLS servername on that and returns the new object
// if no replacement is necessary it returns the original
func (h *Handler) replaceTLSServername(transport *HTTPTransport, repl *caddy.Replacer) *HTTPTransport {
// check whether we have TLS and need to replace the servername in the TLSClientConfig
if transport.TLSEnabled() && strings.Contains(transport.TLS.ServerName, "{") {
// make a new transport, "copy" the parts we don't need to touch, add a new *tls.Config and replace servername
newtransport := &HTTPTransport{
Resolver: transport.Resolver,
TLS: transport.TLS,
KeepAlive: transport.KeepAlive,
Compression: transport.Compression,
MaxConnsPerHost: transport.MaxConnsPerHost,
DialTimeout: transport.DialTimeout,
FallbackDelay: transport.FallbackDelay,
ResponseHeaderTimeout: transport.ResponseHeaderTimeout,
ExpectContinueTimeout: transport.ExpectContinueTimeout,
MaxResponseHeaderSize: transport.MaxResponseHeaderSize,
WriteBufferSize: transport.WriteBufferSize,
ReadBufferSize: transport.ReadBufferSize,
Versions: transport.Versions,
Transport: transport.Transport.Clone(),
h2cTransport: transport.h2cTransport,
}
newtransport.Transport.TLSClientConfig.ServerName = repl.ReplaceAll(newtransport.Transport.TLSClientConfig.ServerName, "")
return newtransport
}
return transport
}
// reverseProxy performs a round-trip to the given backend and processes the response with the client.
// (This method is mostly the beginning of what was borrowed from the net/http/httputil package in the
// Go standard library which was used as the foundation.)
@ -733,30 +765,9 @@ func (h *Handler) reverseProxy(rw http.ResponseWriter, req *http.Request, origRe
// Default to using the transport configured during provisioning stage
transport := h.Transport
// If we have a HTTP transport, try to replace the TLS servername
if tmpTransport, ok := transport.(*HTTPTransport); ok {
// check whether we have TLS and need to replace the servername in the TLSClientConfig
if tmpTransport.TLSEnabled() && strings.Contains(tmpTransport.TLS.ServerName, "{") {
// make a new transport, "copy" the parts we don't need to touch, add a new *tls.Config, replace servername and then call RoundTrip on that to avoid any races
newtransport := &HTTPTransport{
Resolver: tmpTransport.Resolver,
TLS: tmpTransport.TLS,
KeepAlive: tmpTransport.KeepAlive,
Compression: tmpTransport.Compression,
MaxConnsPerHost: tmpTransport.MaxConnsPerHost,
DialTimeout: tmpTransport.DialTimeout,
FallbackDelay: tmpTransport.FallbackDelay,
ResponseHeaderTimeout: tmpTransport.ResponseHeaderTimeout,
ExpectContinueTimeout: tmpTransport.ExpectContinueTimeout,
MaxResponseHeaderSize: tmpTransport.MaxResponseHeaderSize,
WriteBufferSize: tmpTransport.WriteBufferSize,
ReadBufferSize: tmpTransport.ReadBufferSize,
Versions: tmpTransport.Versions,
Transport: tmpTransport.Transport.Clone(),
h2cTransport: tmpTransport.h2cTransport,
}
newtransport.Transport.TLSClientConfig.ServerName = repl.ReplaceAll(newtransport.Transport.TLSClientConfig.ServerName, "")
transport = newtransport
}
transport = h.replaceTLSServername(tmpTransport, repl)
}
// do the round-trip; emit debug log with values we know are