Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions pkg/config/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ type (
AllowedCidrsV6 []string `toml:"allowed_cidrs_v6"`
}

sslEnforcement struct {
Enabled bool `toml:"enabled"`
}

db struct {
Image string `toml:"-"`
Port uint16 `toml:"port"`
Expand All @@ -88,6 +92,7 @@ type (
Seed seed `toml:"seed"`
Settings settings `toml:"settings"`
NetworkRestrictions networkRestrictions `toml:"network_restrictions"`
SslEnforcement *sslEnforcement `toml:"ssl_enforcement"`
Vault map[string]Secret `toml:"vault"`
}

Expand Down Expand Up @@ -233,3 +238,31 @@ func (n *networkRestrictions) DiffWithRemote(remoteConfig v1API.NetworkRestricti
}
return diff.Diff("remote[db.network_restrictions]", remoteCompare, "local[db.network_restrictions]", currentValue), nil
}

func (s sslEnforcement) ToUpdateSslEnforcementBody() v1API.V1UpdateSslEnforcementConfigJSONRequestBody {
body := v1API.V1UpdateSslEnforcementConfigJSONRequestBody{}
body.RequestedConfig.Database = s.Enabled
return body
}

func (s *sslEnforcement) FromRemoteSslEnforcement(remoteConfig v1API.SslEnforcementResponse) {
if s == nil {
return
}
s.Enabled = remoteConfig.CurrentConfig.Database
}

func (s *sslEnforcement) DiffWithRemote(remoteConfig v1API.SslEnforcementResponse) ([]byte, error) {
copy := *s
// Convert the config values into easily comparable remoteConfig values
currentValue, err := ToTomlBytes(copy)
if err != nil {
return nil, err
}
copy.FromRemoteSslEnforcement(remoteConfig)
remoteCompare, err := ToTomlBytes(copy)
if err != nil {
return nil, err
}
return diff.Diff("remote[db.ssl_enforcement]", remoteCompare, "local[db.ssl_enforcement]", currentValue), nil
}
4 changes: 4 additions & 0 deletions pkg/config/templates/config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ allowed_cidrs = ["0.0.0.0/0"]
# Defaults to allow all IPv6 connections. Set empty array to block all IPs.
allowed_cidrs_v6 = ["::/0"]

# Uncomment to reject non-secure connections to the database.
# [db.ssl_enforcement]
# enabled = true

[realtime]
enabled = true
# Bind realtime via either IPv4 or IPv6. (default: IPv4)
Expand Down
4 changes: 4 additions & 0 deletions pkg/config/testdata/config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ allowed_cidrs = ["0.0.0.0/0"]
# Defaults to allow all IPv6 connections. Set empty array to block all IPs.
allowed_cidrs_v6 = ["::/0"]

# Uncomment to reject non-secure connections to the database.
[db.ssl_enforcement]
enabled = true

[realtime]
enabled = true
# Bind realtime via either IPv4 or IPv6. (default: IPv6)
Expand Down
32 changes: 32 additions & 0 deletions pkg/config/updater.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ func (u *ConfigUpdater) UpdateDbConfig(ctx context.Context, projectRef string, c
if err := u.UpdateDbNetworkRestrictionsConfig(ctx, projectRef, c.NetworkRestrictions, filter...); err != nil {
return err
}
if c.SslEnforcement != nil {
return u.UpdateSslEnforcement(ctx, projectRef, *c.SslEnforcement, filter...)
}
return nil
}

Expand Down Expand Up @@ -132,6 +135,35 @@ func (u *ConfigUpdater) UpdateDbNetworkRestrictionsConfig(ctx context.Context, p
return nil
}

func (u *ConfigUpdater) UpdateSslEnforcement(ctx context.Context, projectRef string, s sslEnforcement, filter ...func(string) bool) error {
sslEnforcementConfig, err := u.client.V1GetSslEnforcementConfigWithResponse(ctx, projectRef)
if err != nil {
return errors.Errorf("failed to read SSL enforcement config: %w", err)
} else if sslEnforcementConfig.JSON200 == nil {
return errors.Errorf("unexpected status %d: %s", sslEnforcementConfig.StatusCode(), string(sslEnforcementConfig.Body))
}
sslEnforcementDiff, err := s.DiffWithRemote(*sslEnforcementConfig.JSON200)
if err != nil {
return err
} else if len(sslEnforcementDiff) == 0 {
fmt.Fprintln(os.Stderr, "Remote DB SSL enforcement config is up to date.")
return nil
}
fmt.Fprintln(os.Stderr, "Updating SSL enforcement with config:", string(sslEnforcementDiff))
for _, keep := range filter {
if !keep("db") {
return nil
}
}
updateBody := s.ToUpdateSslEnforcementBody()
if resp, err := u.client.V1UpdateSslEnforcementConfigWithResponse(ctx, projectRef, updateBody); err != nil {
return errors.Errorf("failed to update SSL enforcement config: %w", err)
} else if resp.JSON200 == nil {
return errors.Errorf("unexpected status %d: %s", resp.StatusCode(), string(resp.Body))
}
return nil
}

func (u *ConfigUpdater) UpdateAuthConfig(ctx context.Context, projectRef string, c auth, filter ...func(string) bool) error {
if !c.Enabled {
return nil
Expand Down