diff --git a/modules/caddytls/storageloader.go b/modules/caddytls/storageloader.go index c9487e892..46705d426 100644 --- a/modules/caddytls/storageloader.go +++ b/modules/caddytls/storageloader.go @@ -72,17 +72,16 @@ func (sl *StorageLoader) Provision(ctx caddy.Context) error { return nil } -// LoadCertificates returns the certificates to be loaded by sl. -func (sl StorageLoader) LoadCertificates() ([]Certificate, error) { +func (sl StorageLoader) Initialize(updateCertificates func(add []Certificate, remove []string) error) error { certs := make([]Certificate, 0, len(sl.Pairs)) for _, pair := range sl.Pairs { certData, err := sl.storage.Load(sl.ctx, pair.Certificate) if err != nil { - return nil, err + return err } keyData, err := sl.storage.Load(sl.ctx, pair.Key) if err != nil { - return nil, err + return err } var cert tls.Certificate @@ -94,21 +93,21 @@ func (sl StorageLoader) LoadCertificates() ([]Certificate, error) { // if the start of the key file looks like an encrypted private key, // reject it with a helpful error message if strings.Contains(string(keyData[:40]), "ENCRYPTED") { - return nil, fmt.Errorf("encrypted private keys are not supported; please decrypt the key first") + return fmt.Errorf("encrypted private keys are not supported; please decrypt the key first") } cert, err = tls.X509KeyPair(certData, keyData) default: - return nil, fmt.Errorf("unrecognized certificate/key encoding format: %s", pair.Format) + return fmt.Errorf("unrecognized certificate/key encoding format: %s", pair.Format) } if err != nil { - return nil, err + return err } certs = append(certs, Certificate{Certificate: cert, Tags: pair.Tags}) } - return certs, nil + return updateCertificates(certs, []string{}) } // Interface guard diff --git a/modules/caddytls/tls.go b/modules/caddytls/tls.go index 7b49c0208..9ba2b344d 100644 --- a/modules/caddytls/tls.go +++ b/modules/caddytls/tls.go @@ -252,18 +252,14 @@ func (t *TLS) Provision(ctx caddy.Context) error { DisableStorageCheck: t.DisableStorageCheck, }) certCacheMu.RUnlock() + for _, loader := range t.certificateLoaders { - certs, err := loader.LoadCertificates() + err := loader.Initialize(func(add []Certificate, remove []string) error { + return t.updateCertificates(ctx, magic, add, remove) + }) if err != nil { return fmt.Errorf("loading certificates: %v", err) } - for _, cert := range certs { - hash, err := magic.CacheUnmanagedTLSCertificate(ctx, cert.Certificate, cert.Tags) - if err != nil { - return fmt.Errorf("caching unmanaged certificate: %v", err) - } - t.loaded[hash] = "" - } } // on-demand permission module @@ -782,6 +778,20 @@ func (t *TLS) HasCertificateForSubject(subject string) bool { return false } +func (t *TLS) updateCertificates(ctx caddy.Context, magic *certmagic.Config, add []Certificate, remove []string) error { + for _, cert := range add { + hash, err := magic.CacheUnmanagedTLSCertificate(ctx, cert.Certificate, cert.Tags) + if err != nil { + return fmt.Errorf("caching unmanaged certificate: %v", err) + } + t.loaded[hash] = "" + } + certCacheMu.Lock() + certCache.Remove(remove) + certCacheMu.Unlock() + return nil +} + // keepStorageClean starts a goroutine that immediately cleans up all // known storage units if it was not recently done, and then runs the // operation at every tick from t.storageCleanTicker. @@ -887,7 +897,7 @@ func (t *TLS) onEvent(ctx context.Context, eventName string, data map[string]any // CertificateLoader is a type that can load certificates. // Certificates can optionally be associated with tags. type CertificateLoader interface { - LoadCertificates() ([]Certificate, error) + Initialize(updateCertificates func(add []Certificate, remove []string) error) error } // Certificate is a TLS certificate, optionally