diff --git a/cmd/attested-get/main.go b/cmd/attested-get/main.go index b1ff349..304350f 100644 --- a/cmd/attested-get/main.go +++ b/cmd/attested-get/main.go @@ -4,7 +4,7 @@ package main // Make a HTTP GET request over a TEE-attested connection (to a server with aTLS support), // and print the verified measurements and the response payload. // -// Currently only works for Azure TDX but is straight-forward to expand. +// Currently supports Azure TDX and DCAP TDX attestation. // // Usage: // @@ -37,6 +37,7 @@ import ( "fmt" "io" "log" + "log/slog" "net/http" "os" "strings" @@ -50,6 +51,7 @@ import ( "github.com/flashbots/cvm-reverse-proxy/internal/config" "github.com/flashbots/cvm-reverse-proxy/multimeasurements" "github.com/flashbots/cvm-reverse-proxy/proxy" + dcap_tdx "github.com/flashbots/cvm-reverse-proxy/tdx" "github.com/urfave/cli/v2" // imports as package "cli" ) @@ -70,9 +72,9 @@ var flags []cli.Flag = []cli.Flag{ Usage: "Output file for the response payload", }, &cli.StringFlag{ - Name: "attestation-type", // TODO: Add support for other attestation types - Value: string(proxy.AttestationAzureTDX), - Usage: "type of attestation to present (currently only azure-tdx)", + Name: "attestation-type", + Value: string(proxy.AttestationAuto), + Usage: "type of attestation to present (auto, azure-tdx, or dcap-tdx)", }, &cli.StringFlag{ Name: "expected-measurements", @@ -105,6 +107,23 @@ func main() { } } +// createAzureTDXValidator creates an Azure TDX validator without required measurements +func createAzureTDXValidator(log *slog.Logger, overrideAzurev6Tcbinfo bool) atls.Validator { + attConfig := config.DefaultForAzureTDX() + attConfig.SetMeasurements(measurements.M{}) + validator := azure_tdx.NewValidator(attConfig, proxy.AttestationLogger{Log: log}) + if overrideAzurev6Tcbinfo { + azure_tcbinfo_override.OverrideAzureValidatorsForV6SEAMLoader(log, []atls.Validator{validator}) + } + return validator +} + +// createDCAPTDXValidator creates a DCAP TDX validator without required measurements +func createDCAPTDXValidator(log *slog.Logger) atls.Validator { + attConfig := &config.QEMUTDX{Measurements: measurements.M{}} + return dcap_tdx.NewValidator(attConfig, proxy.AttestationLogger{Log: log}) +} + func runClient(cCtx *cli.Context) (err error) { logDebug := cCtx.Bool("log-debug") addr := cCtx.String("addr") @@ -137,17 +156,17 @@ func runClient(cCtx *cli.Context) (err error) { var validators []atls.Validator switch attestationType { case proxy.AttestationAzureTDX: - // Prepare an azure-tdx validator without any required measurements - attConfig := config.DefaultForAzureTDX() - attConfig.SetMeasurements(measurements.M{}) - validator := azure_tdx.NewValidator(attConfig, proxy.AttestationLogger{Log: log}) - if overrideAzurev6Tcbinfo { - azure_tcbinfo_override.OverrideAzureValidatorsForV6SEAMLoader(log, []atls.Validator{validator}) - } - validators = append(validators, validator) + validators = append(validators, createAzureTDXValidator(log, overrideAzurev6Tcbinfo)) + case proxy.AttestationDCAPTDX: + validators = append(validators, createDCAPTDXValidator(log)) + case proxy.AttestationAuto: + // In auto mode, add all validators to support any attestation type + log.Info("Auto mode: creating validators for all supported attestation types") + validators = append(validators, createAzureTDXValidator(log, overrideAzurev6Tcbinfo)) + validators = append(validators, createDCAPTDXValidator(log)) default: - log.Error("currently only azure-tdx attestation is supported") - return errors.New("currently only azure-tdx attestation is supported") + log.Error("unsupported attestation type, see --help for available options") + return errors.New("unsupported attestation type") } // Load expected measurements from file or URL (if provided) @@ -188,7 +207,7 @@ func runClient(cCtx *cli.Context) (err error) { } // Extract the aTLS variant and measurements from the TLS connection - atlsVariant, extractedMeasurements, err := proxy.GetMeasurementsFromTLS(resp.TLS.PeerCertificates, []asn1.ObjectIdentifier{variant.AzureTDX{}.OID()}) + atlsVariant, extractedMeasurements, err := proxy.GetMeasurementsFromTLS(resp.TLS.PeerCertificates, []asn1.ObjectIdentifier{variant.AzureTDX{}.OID(), variant.QEMUTDX{}.OID()}) if err != nil { log.Error("Error in getMeasurementsFromTLS", "err", err) return err