Skip to content

Commit 100d6d5

Browse files
Added type alias validation
1 parent 1290219 commit 100d6d5

File tree

3 files changed

+81
-45
lines changed

3 files changed

+81
-45
lines changed

pkg/analysis/numericbounds/analyzer.go

Lines changed: 67 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -69,63 +69,98 @@ func run(pass *analysis.Pass) (any, error) {
6969
return nil, kalerrors.ErrCouldNotGetInspector
7070
}
7171

72-
inspect.InspectFields(func(field *ast.Field, _ extractjsontags.FieldTagInfo, markersAccess markershelper.Markers, qualifiedFieldName string) {
73-
// Create TypeChecker with closure capturing markersAccess and qualifiedFieldName
74-
// Ignore TypeChecker's prefix since we use qualifiedFieldName from inspector
75-
typeChecker := utils.NewTypeChecker(func(pass *analysis.Pass, ident *ast.Ident, node ast.Node, _ string) {
76-
checkNumericType(pass, ident, node, markersAccess, qualifiedFieldName)
72+
// Check fields in structs
73+
inspect.InspectFields(func(field *ast.Field, _ extractjsontags.FieldTagInfo, markersAccess markershelper.Markers, qualifiedName string) {
74+
// Create TypeChecker with closure capturing markersAccess and qualifiedName
75+
typeChecker := utils.NewTypeChecker(isNumericType, func(pass *analysis.Pass, expr ast.Expr, node ast.Node, _ string) {
76+
checkNumericTypeExpr(pass, expr, node, markersAccess, qualifiedName)
7777
})
7878

7979
typeChecker.CheckNode(pass, field)
8080
})
8181

82+
// Check type declarations (type aliases)
83+
inspect.InspectTypeSpec(func(typeSpec *ast.TypeSpec, markersAccess markershelper.Markers) {
84+
// Create TypeChecker with closure capturing markersAccess and type name
85+
typeChecker := utils.NewTypeChecker(isNumericType, func(pass *analysis.Pass, expr ast.Expr, node ast.Node, _ string) {
86+
checkNumericTypeExpr(pass, expr, node, markersAccess, typeSpec.Name.Name)
87+
})
88+
89+
typeChecker.CheckNode(pass, typeSpec)
90+
})
91+
8292
return nil, nil //nolint:nilnil
8393
}
8494

85-
//nolint:cyclop
86-
func checkNumericType(pass *analysis.Pass, ident *ast.Ident, node ast.Node, markersAccess markershelper.Markers, qualifiedFieldName string) {
95+
// isNumericType checks if the expression is a numeric type we want to validate.
96+
func isNumericType(pass *analysis.Pass, expr ast.Expr) bool {
97+
ident, ok := expr.(*ast.Ident)
98+
if !ok {
99+
return false
100+
}
101+
87102
// Only check int32, int64, float32, and float64 types
88-
if ident.Name != "int32" && ident.Name != "int64" && ident.Name != "float32" && ident.Name != "float64" {
89-
return
103+
switch ident.Name {
104+
case "int32", "int64", "float32", "float64":
105+
return true
106+
default:
107+
return false
90108
}
109+
}
91110

92-
field, ok := node.(*ast.Field)
111+
//nolint:cyclop
112+
func checkNumericTypeExpr(pass *analysis.Pass, expr ast.Expr, node ast.Node, markersAccess markershelper.Markers, qualifiedName string) {
113+
// Extract the identifier from the expression
114+
ident, ok := expr.(*ast.Ident)
93115
if !ok {
94116
return
95117
}
96118

97-
fieldMarkers := utils.TypeAwareMarkerCollectionForField(pass, markersAccess, field)
98-
99-
// Check if this is an array/slice field
100-
isSlice := utils.IsArrayTypeOrAlias(pass, field)
119+
// Handle both fields and type aliases
120+
var markerSet markershelper.MarkerSet
121+
var isSlice bool
122+
var pos ast.Node
123+
124+
switch n := node.(type) {
125+
case *ast.Field:
126+
markerSet = utils.TypeAwareMarkerCollectionForField(pass, markersAccess, n)
127+
isSlice = utils.IsArrayTypeOrAlias(pass, n)
128+
pos = n
129+
case *ast.TypeSpec:
130+
markerSet = markersAccess.TypeMarkers(n)
131+
isSlice = false // Type aliases themselves are never slices
132+
pos = n
133+
default:
134+
return
135+
}
101136

102137
// Determine which markers to look for based on whether the field is a slice
103138
minMarkers, maxMarkers := getMarkerNames(isSlice)
104139

105140
// Get minimum and maximum marker values
106-
minimum, minErr := getMarkerNumericValue(fieldMarkers, minMarkers)
107-
maximum, maxErr := getMarkerNumericValue(fieldMarkers, maxMarkers)
141+
minimum, minErr := getMarkerNumericValue(markerSet, minMarkers)
142+
maximum, maxErr := getMarkerNumericValue(markerSet, maxMarkers)
108143

109144
// Check if markers are missing
110145
minMissing := errors.Is(minErr, errMarkerMissingValue)
111146
maxMissing := errors.Is(maxErr, errMarkerMissingValue)
112147

113148
// Report any invalid marker values (e.g., non-numeric values)
114149
if minErr != nil && !minMissing {
115-
pass.Reportf(field.Pos(), "%s has an invalid minimum marker: %v", qualifiedFieldName, minErr)
150+
pass.Reportf(pos.Pos(), "%s has an invalid minimum marker: %v", qualifiedName, minErr)
116151
}
117152

118153
if maxErr != nil && !maxMissing {
119-
pass.Reportf(field.Pos(), "%s has an invalid maximum marker: %v", qualifiedFieldName, maxErr)
154+
pass.Reportf(pos.Pos(), "%s has an invalid maximum marker: %v", qualifiedName, maxErr)
120155
}
121156

122157
// Report if markers are missing
123158
if minMissing {
124-
pass.Reportf(field.Pos(), "%s is missing minimum bound validation marker", qualifiedFieldName)
159+
pass.Reportf(pos.Pos(), "%s is missing minimum bound validation marker", qualifiedName)
125160
}
126161

127162
if maxMissing {
128-
pass.Reportf(field.Pos(), "%s is missing maximum bound validation marker", qualifiedFieldName)
163+
pass.Reportf(pos.Pos(), "%s is missing maximum bound validation marker", qualifiedName)
129164
}
130165

131166
// If any markers are missing or invalid, don't continue with bounds checks
@@ -134,7 +169,7 @@ func checkNumericType(pass *analysis.Pass, ident *ast.Ident, node ast.Node, mark
134169
}
135170

136171
// Validate bounds are within the type's valid range
137-
checkBoundsWithinTypeRange(pass, field, qualifiedFieldName, ident.Name, minimum, maximum)
172+
checkBoundsWithinTypeRange(pass, pos, qualifiedName, ident.Name, minimum, maximum)
138173
}
139174

140175
// getMarkerNames returns the appropriate minimum and maximum marker names
@@ -178,35 +213,35 @@ func getMarkerNumericValue(markerSet markershelper.MarkerSet, markerNames []stri
178213
// checkBoundsWithinTypeRange validates that the bounds are within the valid range for the type.
179214
// For int64, enforces JavaScript-safe bounds as per Kubernetes API conventions to ensure
180215
// compatibility with JavaScript clients.
181-
func checkBoundsWithinTypeRange(pass *analysis.Pass, field *ast.Field, prefix, typeName string, minimum, maximum float64) {
216+
func checkBoundsWithinTypeRange(pass *analysis.Pass, pos ast.Node, prefix, typeName string, minimum, maximum float64) {
182217
switch typeName {
183218
case "int32":
184-
checkBoundInRange(pass, field, prefix, minimum, minInt32, maxInt32, "minimum", "int32")
185-
checkBoundInRange(pass, field, prefix, maximum, minInt32, maxInt32, "maximum", "int32")
219+
checkBoundInRange(pass, pos, prefix, minimum, minInt32, maxInt32, "minimum", "int32")
220+
checkBoundInRange(pass, pos, prefix, maximum, minInt32, maxInt32, "maximum", "int32")
186221
case "int64":
187222
// K8s API conventions enforce JavaScript-safe bounds for int64 (±2^53-1)
188-
checkBoundInRange(pass, field, prefix, minimum, int64(minSafeInt64), int64(maxSafeInt64), "minimum", "JavaScript-safe int64",
223+
checkBoundInRange(pass, pos, prefix, minimum, int64(minSafeInt64), int64(maxSafeInt64), "minimum", "JavaScript-safe int64",
189224
"Consider using a string type to avoid precision loss in JavaScript clients")
190-
checkBoundInRange(pass, field, prefix, maximum, int64(minSafeInt64), int64(maxSafeInt64), "maximum", "JavaScript-safe int64",
225+
checkBoundInRange(pass, pos, prefix, maximum, int64(minSafeInt64), int64(maxSafeInt64), "maximum", "JavaScript-safe int64",
191226
"Consider using a string type to avoid precision loss in JavaScript clients")
192227
case "float32":
193-
checkBoundInRange(pass, field, prefix, minimum, minFloat32, maxFloat32, "minimum", "float32")
194-
checkBoundInRange(pass, field, prefix, maximum, minFloat32, maxFloat32, "maximum", "float32")
228+
checkBoundInRange(pass, pos, prefix, minimum, minFloat32, maxFloat32, "minimum", "float32")
229+
checkBoundInRange(pass, pos, prefix, maximum, minFloat32, maxFloat32, "maximum", "float32")
195230
case "float64":
196-
checkBoundInRange(pass, field, prefix, minimum, minFloat64, maxFloat64, "minimum", "float64")
197-
checkBoundInRange(pass, field, prefix, maximum, minFloat64, maxFloat64, "maximum", "float64")
231+
checkBoundInRange(pass, pos, prefix, minimum, minFloat64, maxFloat64, "minimum", "float64")
232+
checkBoundInRange(pass, pos, prefix, maximum, minFloat64, maxFloat64, "maximum", "float64")
198233
}
199234
}
200235

201236
// checkBoundInRange checks if a bound value is within the valid range.
202237
// Uses generics to work with both integer and float types.
203-
func checkBoundInRange[T constraints.Integer | constraints.Float](pass *analysis.Pass, field *ast.Field, prefix string, value float64, minBound, maxBound T, boundType, typeName string, extraMsg ...string) {
238+
func checkBoundInRange[T constraints.Integer | constraints.Float](pass *analysis.Pass, pos ast.Node, prefix string, value float64, minBound, maxBound T, boundType, typeName string, extraMsg ...string) {
204239
if value < float64(minBound) || value > float64(maxBound) {
205240
msg := fmt.Sprintf("%s has %s bound %%v that is outside the %s range [%%v, %%v]", prefix, boundType, typeName)
206241
if len(extraMsg) > 0 {
207242
msg += ". " + extraMsg[0]
208243
}
209244

210-
pass.Reportf(field.Pos(), msg, value, minBound, maxBound)
245+
pass.Reportf(pos.Pos(), msg, value, minBound, maxBound)
211246
}
212247
}

pkg/analysis/numericbounds/testdata/src/a/a.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -148,10 +148,10 @@ type SliceFields struct {
148148
}
149149

150150
// TypeAliasFields with type aliases should be checked
151-
type Int32Alias int32
152-
type Int64Alias int64
153-
type Float32Alias float32
154-
type Float64Alias float64
151+
type Int32Alias int32 // want "Int32Alias is missing minimum bound validation marker" "Int32Alias is missing maximum bound validation marker"
152+
type Int64Alias int64 // want "Int64Alias is missing minimum bound validation marker" "Int64Alias is missing maximum bound validation marker"
153+
type Float32Alias float32 // want "Float32Alias is missing minimum bound validation marker" "Float32Alias is missing maximum bound validation marker"
154+
type Float64Alias float64 // want "Float64Alias is missing minimum bound validation marker" "Float64Alias is missing maximum bound validation marker"
155155

156156
// Type aliases with bounds on the type itself
157157
// +kubebuilder:validation:Minimum=0

pkg/analysis/utils/type_check.go

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,16 @@ type TypeChecker interface {
3131
}
3232

3333
// NewTypeChecker returns a new TypeChecker with the provided checkFunc.
34-
func NewTypeChecker(checkFunc func(pass *analysis.Pass, ident *ast.Ident, node ast.Node, qualifiedFieldName string)) TypeChecker {
34+
func NewTypeChecker(isTypeFunc func(pass *analysis.Pass, ident ast.Expr) bool, checkFunc func(pass *analysis.Pass, expr ast.Expr, node ast.Node, prefix string)) TypeChecker {
3535
return &typeChecker{
36-
checkFunc: checkFunc,
36+
isTypeFunc: isTypeFunc,
37+
checkFunc: checkFunc,
3738
}
3839
}
3940

4041
type typeChecker struct {
41-
checkFunc func(pass *analysis.Pass, ident *ast.Ident, node ast.Node, qualifiedFieldName string)
42+
isTypeFunc func(pass *analysis.Pass, expr ast.Expr) bool
43+
checkFunc func(pass *analysis.Pass, expr ast.Expr, node ast.Node, prefix string)
4244
}
4345

4446
// CheckNode checks the provided node for built-in types.
@@ -84,6 +86,11 @@ func (t *typeChecker) checkTypeSpec(pass *analysis.Pass, tSpec *ast.TypeSpec, no
8486
}
8587

8688
func (t *typeChecker) checkTypeExpr(pass *analysis.Pass, typeExpr ast.Expr, node ast.Node, prefix string) {
89+
if t.isTypeFunc(pass, typeExpr) {
90+
t.checkFunc(pass, typeExpr, node, prefix)
91+
return
92+
}
93+
8794
switch typ := typeExpr.(type) {
8895
case *ast.Ident:
8996
t.checkIdent(pass, typ, node, prefix)
@@ -102,12 +109,6 @@ func (t *typeChecker) checkTypeExpr(pass *analysis.Pass, typeExpr ast.Expr, node
102109
// checkIdent calls the checkFunc with the ident, when we have hit a built-in type.
103110
// If the ident is not a built in, we look at the underlying type until we hit a built-in type.
104111
func (t *typeChecker) checkIdent(pass *analysis.Pass, ident *ast.Ident, node ast.Node, prefix string) {
105-
if IsBasicType(pass, ident) {
106-
// We've hit a built-in type, no need to check further.
107-
t.checkFunc(pass, ident, node, prefix)
108-
return
109-
}
110-
111112
tSpec, ok := LookupTypeSpec(pass, ident)
112113
if !ok {
113114
return

0 commit comments

Comments
 (0)