diff --git a/pkg/config/db.go b/pkg/config/db.go index 122f5364b..a367beb3a 100644 --- a/pkg/config/db.go +++ b/pkg/config/db.go @@ -75,6 +75,10 @@ type ( AllowedCidrsV6 []string `toml:"allowed_cidrs_v6"` } + sslEnforcement struct { + Enabled bool `toml:"enabled"` + } + db struct { Image string `toml:"-"` Port uint16 `toml:"port"` @@ -88,6 +92,7 @@ type ( Seed seed `toml:"seed"` Settings settings `toml:"settings"` NetworkRestrictions networkRestrictions `toml:"network_restrictions"` + SslEnforcement *sslEnforcement `toml:"ssl_enforcement"` Vault map[string]Secret `toml:"vault"` } @@ -233,3 +238,31 @@ func (n *networkRestrictions) DiffWithRemote(remoteConfig v1API.NetworkRestricti } return diff.Diff("remote[db.network_restrictions]", remoteCompare, "local[db.network_restrictions]", currentValue), nil } + +func (s sslEnforcement) ToUpdateSslEnforcementBody() v1API.V1UpdateSslEnforcementConfigJSONRequestBody { + body := v1API.V1UpdateSslEnforcementConfigJSONRequestBody{} + body.RequestedConfig.Database = s.Enabled + return body +} + +func (s *sslEnforcement) FromRemoteSslEnforcement(remoteConfig v1API.SslEnforcementResponse) { + if s == nil { + return + } + s.Enabled = remoteConfig.CurrentConfig.Database +} + +func (s *sslEnforcement) DiffWithRemote(remoteConfig v1API.SslEnforcementResponse) ([]byte, error) { + copy := *s + // Convert the config values into easily comparable remoteConfig values + currentValue, err := ToTomlBytes(copy) + if err != nil { + return nil, err + } + copy.FromRemoteSslEnforcement(remoteConfig) + remoteCompare, err := ToTomlBytes(copy) + if err != nil { + return nil, err + } + return diff.Diff("remote[db.ssl_enforcement]", remoteCompare, "local[db.ssl_enforcement]", currentValue), nil +} diff --git a/pkg/config/templates/config.toml b/pkg/config/templates/config.toml index 87af2c389..44b58cdb0 100644 --- a/pkg/config/templates/config.toml +++ b/pkg/config/templates/config.toml @@ -74,6 +74,10 @@ allowed_cidrs = ["0.0.0.0/0"] # Defaults to allow all IPv6 connections. Set empty array to block all IPs. allowed_cidrs_v6 = ["::/0"] +# Uncomment to reject non-secure connections to the database. +# [db.ssl_enforcement] +# enabled = true + [realtime] enabled = true # Bind realtime via either IPv4 or IPv6. (default: IPv4) diff --git a/pkg/config/testdata/config.toml b/pkg/config/testdata/config.toml index ad2a0dd3e..b228a9c07 100644 --- a/pkg/config/testdata/config.toml +++ b/pkg/config/testdata/config.toml @@ -74,6 +74,10 @@ allowed_cidrs = ["0.0.0.0/0"] # Defaults to allow all IPv6 connections. Set empty array to block all IPs. allowed_cidrs_v6 = ["::/0"] +# Uncomment to reject non-secure connections to the database. +[db.ssl_enforcement] +enabled = true + [realtime] enabled = true # Bind realtime via either IPv4 or IPv6. (default: IPv6) diff --git a/pkg/config/updater.go b/pkg/config/updater.go index 8915f43fb..eb0e87c08 100644 --- a/pkg/config/updater.go +++ b/pkg/config/updater.go @@ -100,6 +100,9 @@ func (u *ConfigUpdater) UpdateDbConfig(ctx context.Context, projectRef string, c if err := u.UpdateDbNetworkRestrictionsConfig(ctx, projectRef, c.NetworkRestrictions, filter...); err != nil { return err } + if c.SslEnforcement != nil { + return u.UpdateSslEnforcement(ctx, projectRef, *c.SslEnforcement, filter...) + } return nil } @@ -132,6 +135,35 @@ func (u *ConfigUpdater) UpdateDbNetworkRestrictionsConfig(ctx context.Context, p return nil } +func (u *ConfigUpdater) UpdateSslEnforcement(ctx context.Context, projectRef string, s sslEnforcement, filter ...func(string) bool) error { + sslEnforcementConfig, err := u.client.V1GetSslEnforcementConfigWithResponse(ctx, projectRef) + if err != nil { + return errors.Errorf("failed to read SSL enforcement config: %w", err) + } else if sslEnforcementConfig.JSON200 == nil { + return errors.Errorf("unexpected status %d: %s", sslEnforcementConfig.StatusCode(), string(sslEnforcementConfig.Body)) + } + sslEnforcementDiff, err := s.DiffWithRemote(*sslEnforcementConfig.JSON200) + if err != nil { + return err + } else if len(sslEnforcementDiff) == 0 { + fmt.Fprintln(os.Stderr, "Remote DB SSL enforcement config is up to date.") + return nil + } + fmt.Fprintln(os.Stderr, "Updating SSL enforcement with config:", string(sslEnforcementDiff)) + for _, keep := range filter { + if !keep("db") { + return nil + } + } + updateBody := s.ToUpdateSslEnforcementBody() + if resp, err := u.client.V1UpdateSslEnforcementConfigWithResponse(ctx, projectRef, updateBody); err != nil { + return errors.Errorf("failed to update SSL enforcement config: %w", err) + } else if resp.JSON200 == nil { + return errors.Errorf("unexpected status %d: %s", resp.StatusCode(), string(resp.Body)) + } + return nil +} + func (u *ConfigUpdater) UpdateAuthConfig(ctx context.Context, projectRef string, c auth, filter ...func(string) bool) error { if !c.Enabled { return nil