From f69c49896ecdb53b1e8dd53355a18829bb5bff3a Mon Sep 17 00:00:00 2001 From: Antti Kervinen Date: Tue, 22 Apr 2025 15:52:31 +0300 Subject: [PATCH] Add memory policy support Signed-off-by: Antti Kervinen --- cmd/oci-runtime-tool/generate.go | 24 +++++ generate/config.go | 7 ++ generate/generate.go | 37 ++++++++ validate/memorypolicy/validate.go | 141 ++++++++++++++++++++++++++++++ validate/validate_linux.go | 8 ++ 5 files changed, 217 insertions(+) create mode 100644 validate/memorypolicy/validate.go diff --git a/cmd/oci-runtime-tool/generate.go b/cmd/oci-runtime-tool/generate.go index 0936da76..123edf68 100644 --- a/cmd/oci-runtime-tool/generate.go +++ b/cmd/oci-runtime-tool/generate.go @@ -64,6 +64,9 @@ var generateFlags = []cli.Flag{ cli.StringFlag{Name: "linux-mems", Usage: "list of memory nodes in the cpuset (default is to use any available memory node)"}, cli.Uint64Flag{Name: "linux-mem-swap", Usage: "total memory limit (memory + swap) (in bytes)"}, cli.Uint64Flag{Name: "linux-mem-swappiness", Usage: "how aggressive the kernel will swap memory pages (Range from 0 to 100)"}, + cli.StringFlag{Name: "linux-memorypolicy-mode", Usage: "mode of the default NUMA memory policy for page allocation, e.g MPOL_INTERLEAVE"}, + cli.StringFlag{Name: "linux-memorypolicy-nodes", Usage: "nodes of the default NUMA memory policy, e.g 0-3,7"}, + cli.StringSliceFlag{Name: "linux-memorypolicy-flag", Usage: "adds a flag for the default NUMA memory policy, e.g MPOL_F_STATIC_NODES"}, cli.StringFlag{Name: "linux-mount-label", Usage: "selinux mount context label"}, cli.StringSliceFlag{Name: "linux-namespace-add", Usage: "adds a namespace to the set of namespaces to create or join of the form 'ns[:path]'"}, cli.StringSliceFlag{Name: "linux-namespace-remove", Usage: "removes a namespace from the set of namespaces to create or join of the form 'ns'"}, @@ -782,6 +785,27 @@ func setupSpec(g *generate.Generator, context *cli.Context) error { g.SetLinuxResourcesMemorySwappiness(context.Uint64("linux-mem-swappiness")) } + if context.IsSet("linux-memorypolicy-mode") { + mpolMode := context.String("linux-memorypolicy-mode") + if err := g.SetLinuxMemoryPolicyMode(mpolMode); err != nil { + return err + } + } + + if context.IsSet("linux-memorypolicy-nodes") { + mpolNodes := context.String("linux-memorypolicy-nodes") + if err := g.SetLinuxMemoryPolicyNodes(mpolNodes); err != nil { + return err + } + } + + if context.IsSet("linux-memorypolicy-flag") { + mpolFlags := context.StringSlice("linux-memorypolicy-flag") + if err := g.SetLinuxMemoryPolicyFlags(mpolFlags); err != nil { + return err + } + } + if context.IsSet("linux-network-classid") { g.SetLinuxResourcesNetworkClassID(uint32(context.Int("linux-network-classid"))) } diff --git a/generate/config.go b/generate/config.go index 48f281d2..b6dd9c3b 100644 --- a/generate/config.go +++ b/generate/config.go @@ -109,6 +109,13 @@ func (g *Generator) initConfigLinuxResourcesMemory() { } } +func (g *Generator) initConfigLinuxMemoryPolicy() { + g.initConfigLinux() + if g.Config.Linux.MemoryPolicy == nil { + g.Config.Linux.MemoryPolicy = &rspec.LinuxMemoryPolicy{} + } +} + func (g *Generator) initConfigLinuxResourcesNetwork() { g.initConfigLinuxResources() if g.Config.Linux.Resources.Network == nil { diff --git a/generate/generate.go b/generate/generate.go index 44c199e1..95efffb0 100644 --- a/generate/generate.go +++ b/generate/generate.go @@ -13,6 +13,7 @@ import ( rspec "github.com/opencontainers/runtime-spec/specs-go" "github.com/opencontainers/runtime-tools/generate/seccomp" capsCheck "github.com/opencontainers/runtime-tools/validate/capabilities" + mpolCheck "github.com/opencontainers/runtime-tools/validate/memorypolicy" ) var ( @@ -932,6 +933,42 @@ func (g *Generator) SetLinuxResourcesMemorySwappiness(swappiness uint64) { g.Config.Linux.Resources.Memory.Swappiness = &swappiness } +// SetLinuxMemoryPolicyMode sets g.Config.Linux.MemoryPolicy.Mode +func (g *Generator) SetLinuxMemoryPolicyMode(mode string) error { + modecp := strings.ToUpper(mode) + if err := mpolCheck.MpolModeValid(modecp); err != nil { + return err + } + g.initConfigLinuxMemoryPolicy() + g.Config.Linux.MemoryPolicy.Mode = rspec.MemoryPolicyModeType(modecp) + return nil +} + +// SetLinuxMemoryPolicyNodes sets g.Config.Linux.MemoryPolicy.Nodes +func (g *Generator) SetLinuxMemoryPolicyNodes(nodes string) error { + if err := mpolCheck.MpolNodesValid(nodes); err != nil { + return err + } + g.initConfigLinuxMemoryPolicy() + g.Config.Linux.MemoryPolicy.Nodes = nodes + return nil +} + +// SetLinuxMemoryPolicyFlags sets g.Config.Linux.MemoryPolicy.Flags +func (g *Generator) SetLinuxMemoryPolicyFlags(flags []string) error { + var validFlags []rspec.MemoryPolicyFlagType + for _, flag := range flags { + flagcp := strings.ToUpper(flag) + if err := mpolCheck.MpolFlagValid(flagcp); err != nil { + return err + } + validFlags = append(validFlags, rspec.MemoryPolicyFlagType(flagcp)) + } + g.initConfigLinuxMemoryPolicy() + g.Config.Linux.MemoryPolicy.Flags = validFlags + return nil +} + // SetLinuxResourcesMemoryDisableOOMKiller sets g.Config.Linux.Resources.Memory.DisableOOMKiller. func (g *Generator) SetLinuxResourcesMemoryDisableOOMKiller(disable bool) { g.initConfigLinuxResourcesMemory() diff --git a/validate/memorypolicy/validate.go b/validate/memorypolicy/validate.go new file mode 100644 index 00000000..7b193269 --- /dev/null +++ b/validate/memorypolicy/validate.go @@ -0,0 +1,141 @@ +package memorypolicy + +import ( + "fmt" + "strconv" + "strings" + + "github.com/hashicorp/go-multierror" + rspec "github.com/opencontainers/runtime-spec/specs-go" +) + +var ( + knownModes map[rspec.MemoryPolicyModeType]struct{} = map[rspec.MemoryPolicyModeType]struct{}{ + rspec.MpolDefault: {}, + rspec.MpolBind: {}, + rspec.MpolInterleave: {}, + rspec.MpolWeightedInterleave: {}, + rspec.MpolPreferred: {}, + rspec.MpolPreferredMany: {}, + rspec.MpolLocal: {}, + } + + knownModeFlags map[rspec.MemoryPolicyFlagType]struct{} = map[rspec.MemoryPolicyFlagType]struct{}{ + rspec.MpolFNumaBalancing: {}, + rspec.MpolFRelativeNodes: {}, + rspec.MpolFStaticNodes: {}, + } +) + +// MpolModeValid checks if the provided memory policy mode is valid. +func MpolModeValid(mode string) error { + if !strings.HasPrefix(mode, "MPOL_") { + return fmt.Errorf("memory policy mode %q must start with 'MPOL_'", mode) + } + if _, ok := knownModes[rspec.MemoryPolicyModeType(mode)]; !ok { + return fmt.Errorf("invalid memory policy mode %q", mode) + } + return nil +} + +// MpolNodesValid checks if the provided nodes specification is valid. +func MpolNodesValid(nodes string) error { + // nodes is a comma-separated list of node IDs or ranges thereof. + nodeRanges := strings.Split(nodes, ",") + for _, nodeRange := range nodeRanges { + nodeRange = strings.TrimSpace(nodeRange) + if nodeRange == "" { + continue + } + bounds := strings.Split(nodeRange, "-") + switch len(bounds) { + case 1: + // Single node + number := strings.TrimSpace(bounds[0]) + if _, err := parseNodeID(number); err != nil { + return err + } + case 2: + // Range of nodes + startNumber := strings.TrimSpace(bounds[0]) + startID, err := parseNodeID(startNumber) + if err != nil { + return err + } + endNumber := strings.TrimSpace(bounds[1]) + endID, err := parseNodeID(endNumber) + if err != nil { + return err + } + if startID > endID { + return fmt.Errorf("invalid memory policy node range %q: start ID greater than end ID", nodeRange) + } + default: + return fmt.Errorf("invalid memory policy node range %q", nodeRange) + } + } + return nil +} + +func parseNodeID(nodeStr string) (int, error) { + nodeID, err := strconv.Atoi(nodeStr) + if err != nil { + return 0, fmt.Errorf("invalid memory policy node %q", nodeStr) + } + if nodeID < 0 { + return 0, fmt.Errorf("memory policy node %d must be non-negative", nodeID) + } + return nodeID, nil +} + +// MpolFlagValid checks if the provided memory policy flag is valid. +func MpolFlagValid(flag string) error { + if !strings.HasPrefix(flag, "MPOL_F_") { + return fmt.Errorf("memory policy flag %q must start with 'MPOL_F_'", flag) + } + if _, ok := knownModeFlags[rspec.MemoryPolicyFlagType(flag)]; !ok { + return fmt.Errorf("invalid memory policy flag %q", flag) + } + return nil +} + +// MpolModeNodesValid checks if the nodes specification is valid for the given memory policy mode. +func MpolModeNodesValid(mode rspec.MemoryPolicyModeType, nodes string) error { + switch mode { + case rspec.MpolDefault, rspec.MpolLocal: + if nodes != "" { + return fmt.Errorf("memory policy mode %q must not have nodes specified", mode) + } + case rspec.MpolBind, rspec.MpolInterleave, rspec.MpolWeightedInterleave, rspec.MpolPreferred, rspec.MpolPreferredMany: + if nodes == "" { + return fmt.Errorf("memory policy mode %q must have nodes specified", mode) + } + case "": + return fmt.Errorf("memory policy mode must be specified") + default: + return fmt.Errorf("unknown memory policy mode %q ", mode) + } + return nil +} + +// MpolValid checks if the provided memory policy configuration is valid. +func MpolValid(mode rspec.MemoryPolicyModeType, nodes string, flags []rspec.MemoryPolicyFlagType) (errs error) { + if err := MpolModeValid(string(mode)); err != nil { + errs = multierror.Append(errs, err) + } + if err := MpolNodesValid(nodes); err != nil { + errs = multierror.Append(errs, err) + } + for _, flag := range flags { + if err := MpolFlagValid(string(flag)); err != nil { + multierror.Append(errs, err) + } + } + if errs == nil { + err := MpolModeNodesValid(mode, nodes) + if err != nil { + errs = multierror.Append(errs, err) + } + } + return errs +} diff --git a/validate/validate_linux.go b/validate/validate_linux.go index 2c7cdb75..736089d7 100644 --- a/validate/validate_linux.go +++ b/validate/validate_linux.go @@ -14,6 +14,7 @@ import ( rspec "github.com/opencontainers/runtime-spec/specs-go" osFilepath "github.com/opencontainers/runtime-tools/filepath" "github.com/opencontainers/runtime-tools/specerror" + mpolCheck "github.com/opencontainers/runtime-tools/validate/memorypolicy" "github.com/opencontainers/selinux/go-selinux/label" "github.com/sirupsen/logrus" ) @@ -220,5 +221,12 @@ func (v *Validator) CheckLinux() (errs error) { } } + if v.spec.Linux.MemoryPolicy != nil { + mp := v.spec.Linux.MemoryPolicy + if err := mpolCheck.MpolValid(mp.Mode, mp.Nodes, mp.Flags); err != nil { + errs = multierror.Append(errs, err) + } + } + return }