Skip to content

Commit d5096f1

Browse files
committed
Enhance marker scope validation by introducing allowDangerousTypes flag and updating type constraints. Refactor schema type definitions and improve error handling for invalid markers. Add new test cases for various marker scenarios.
Signed-off-by: nayuta-ai <nayuta723@gmail.com>
1 parent 2bb79c2 commit d5096f1

File tree

22 files changed

+1315
-664
lines changed

22 files changed

+1315
-664
lines changed

pkg/analysis/markerscope/analyzer.go

Lines changed: 47 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,9 @@ func init() {
5050
}
5151

5252
type analyzer struct {
53-
markerRules map[string]MarkerScopeRule
54-
policy MarkerScopePolicy
53+
markerRules map[string]MarkerScopeRule
54+
policy MarkerScopePolicy
55+
allowDangerousTypes bool
5556
}
5657

5758
// newAnalyzer creates a new analyzer.
@@ -61,8 +62,9 @@ func newAnalyzer(cfg *MarkerScopeConfig) *analysis.Analyzer {
6162
}
6263

6364
a := &analyzer{
64-
markerRules: mergeMarkerRules(DefaultMarkerRules(), cfg.MarkerRules),
65-
policy: cfg.Policy,
65+
markerRules: mergeMarkerRules(DefaultMarkerRules(), cfg.MarkerRules),
66+
policy: cfg.Policy,
67+
allowDangerousTypes: cfg.AllowDangerousTypes,
6668
}
6769

6870
// Register all markers (both default and custom) with the markers helper
@@ -168,22 +170,20 @@ func (a *analyzer) checkFieldMarkers(pass *analysis.Pass, field *ast.Field, mark
168170
}
169171

170172
// Check type constraints if present
171-
if rule.TypeConstraint != nil {
172-
if err := a.validateFieldTypeConstraint(pass, field, rule.TypeConstraint); err != nil {
173-
if a.policy == MarkerScopePolicySuggestFix {
174-
pass.Report(analysis.Diagnostic{
175-
Pos: marker.Pos,
176-
End: marker.End,
177-
Message: fmt.Sprintf("marker %q: %s", marker.Identifier, err),
178-
SuggestedFixes: a.suggestMoveToFieldsIfCompatible(pass, field, marker, rule),
179-
})
180-
} else {
181-
pass.Report(analysis.Diagnostic{
182-
Pos: marker.Pos,
183-
End: marker.End,
184-
Message: fmt.Sprintf("marker %q: %s", marker.Identifier, err),
185-
})
186-
}
173+
if err := a.validateFieldTypeConstraint(pass, field, rule, a.allowDangerousTypes); err != nil {
174+
if a.policy == MarkerScopePolicySuggestFix {
175+
pass.Report(analysis.Diagnostic{
176+
Pos: marker.Pos,
177+
End: marker.End,
178+
Message: fmt.Sprintf("marker %q: %s", marker.Identifier, err),
179+
SuggestedFixes: a.suggestMoveToFieldsIfCompatible(pass, field, marker, rule),
180+
})
181+
} else {
182+
pass.Report(analysis.Diagnostic{
183+
Pos: marker.Pos,
184+
End: marker.End,
185+
Message: fmt.Sprintf("marker %q: %s", marker.Identifier, err),
186+
})
187187
}
188188
}
189189
}
@@ -223,9 +223,7 @@ func (a *analyzer) checkSingleTypeMarkers(pass *analysis.Pass, typeSpec *ast.Typ
223223
}
224224

225225
// Check type constraints if present
226-
if rule.TypeConstraint != nil {
227-
a.checkTypeConstraintViolation(pass, typeSpec, marker, rule)
228-
}
226+
a.checkTypeConstraintViolation(pass, typeSpec, marker, rule, a.allowDangerousTypes)
229227
}
230228
}
231229

@@ -254,8 +252,8 @@ func (a *analyzer) reportTypeScopeViolation(pass *analysis.Pass, typeSpec *ast.T
254252
}
255253

256254
// checkTypeConstraintViolation checks and reports type constraint violations.
257-
func (a *analyzer) checkTypeConstraintViolation(pass *analysis.Pass, typeSpec *ast.TypeSpec, marker markershelper.Marker, rule MarkerScopeRule) {
258-
if err := a.validateTypeSpecTypeConstraint(pass, typeSpec, rule.TypeConstraint); err != nil {
255+
func (a *analyzer) checkTypeConstraintViolation(pass *analysis.Pass, typeSpec *ast.TypeSpec, marker markershelper.Marker, rule MarkerScopeRule, allowDangerousTypes bool) {
256+
if err := a.validateTypeSpecTypeConstraint(pass, typeSpec, rule.TypeConstraint, allowDangerousTypes); err != nil {
259257
var fixes []analysis.SuggestedFix
260258

261259
if a.policy == MarkerScopePolicySuggestFix {
@@ -273,22 +271,29 @@ func (a *analyzer) checkTypeConstraintViolation(pass *analysis.Pass, typeSpec *a
273271
}
274272

275273
// validateFieldTypeConstraint validates that a field's type matches the type constraint.
276-
func (a *analyzer) validateFieldTypeConstraint(pass *analysis.Pass, field *ast.Field, tc *TypeConstraint) error {
274+
func (a *analyzer) validateFieldTypeConstraint(pass *analysis.Pass, field *ast.Field, rule MarkerScopeRule, allowDangerousTypes bool) error {
277275
// Get the type of the field
278276
tv, ok := pass.TypesInfo.Types[field.Type]
279277
if !ok {
280278
return nil // Skip if we can't determine the type
281279
}
282280

283-
if err := validateTypeAgainstConstraint(tv.Type, tc); err != nil {
281+
if err := validateTypeAgainstConstraint(tv.Type, rule.TypeConstraint, allowDangerousTypes); err != nil {
284282
return err
285283
}
286284

285+
if rule.StrictTypeConstraint && rule.Scope == AnyScope {
286+
namedType, ok := tv.Type.(*types.Named)
287+
if ok {
288+
return fmt.Errorf("%w of %s instead of the field", errMarkerShouldBeOnTypeDefinition, namedType.Obj().Name())
289+
}
290+
}
291+
287292
return nil
288293
}
289294

290295
// validateTypeSpecTypeConstraint validates that a type spec's type matches the type constraint.
291-
func (a *analyzer) validateTypeSpecTypeConstraint(pass *analysis.Pass, typeSpec *ast.TypeSpec, tc *TypeConstraint) error {
296+
func (a *analyzer) validateTypeSpecTypeConstraint(pass *analysis.Pass, typeSpec *ast.TypeSpec, tc *TypeConstraint, allowDangerousTypes bool) error {
292297
// Get the type of the type spec
293298
obj := pass.TypesInfo.Defs[typeSpec.Name]
294299
if obj == nil {
@@ -300,30 +305,37 @@ func (a *analyzer) validateTypeSpecTypeConstraint(pass *analysis.Pass, typeSpec
300305
return nil
301306
}
302307

303-
return validateTypeAgainstConstraint(typeName.Type(), tc)
308+
return validateTypeAgainstConstraint(typeName.Type(), tc, allowDangerousTypes)
304309
}
305310

306311
// validateTypeAgainstConstraint validates that a Go type satisfies the type constraint.
307-
func validateTypeAgainstConstraint(t types.Type, tc *TypeConstraint) error {
312+
func validateTypeAgainstConstraint(t types.Type, tc *TypeConstraint, allowDangerousTypes bool) error {
313+
// Get the schema type from the Go type
314+
schemaType := getSchemaType(t)
315+
316+
// Check if dangerous types are disallowed
317+
if !allowDangerousTypes && schemaType == SchemaTypeNumber {
318+
// Get the underlying type for better error messages
319+
underlyingType := getUnderlyingType(t)
320+
return fmt.Errorf("type %s is dangerous and not allowed (set allowDangerousTypes to true to permit)", underlyingType.String())
321+
}
322+
308323
if tc == nil {
309324
return nil
310325
}
311326

312-
// Get the schema type from the Go type
313-
schemaType := getSchemaType(t)
314-
315327
// Check if the schema type is allowed
316328
if len(tc.AllowedSchemaTypes) > 0 {
317329
if !slices.Contains(tc.AllowedSchemaTypes, schemaType) {
318-
return fmt.Errorf("%w: type %s (expected one of: %v)", errTypeNotAllowed, schemaType, tc.AllowedSchemaTypes)
330+
return fmt.Errorf("type %s is not allowed (expected one of: %v)", schemaType, tc.AllowedSchemaTypes)
319331
}
320332
}
321333

322334
// Validate element constraint for arrays/slices
323335
if tc.ElementConstraint != nil && schemaType == SchemaTypeArray {
324336
elemType := getElementType(t)
325337
if elemType != nil {
326-
if err := validateTypeAgainstConstraint(elemType, tc.ElementConstraint); err != nil {
338+
if err := validateTypeAgainstConstraint(elemType, tc.ElementConstraint, allowDangerousTypes); err != nil {
327339
return fmt.Errorf("array element: %w", err)
328340
}
329341
}
@@ -332,69 +344,6 @@ func validateTypeAgainstConstraint(t types.Type, tc *TypeConstraint) error {
332344
return nil
333345
}
334346

335-
// getSchemaType converts a Go type to an OpenAPI schema type.
336-
//
337-
//nolint:cyclop // This function has many cases for different Go types
338-
func getSchemaType(t types.Type) SchemaType {
339-
// Unwrap pointer types
340-
if ptr, ok := t.(*types.Pointer); ok {
341-
t = ptr.Elem()
342-
}
343-
344-
// Unwrap named types to get underlying type
345-
if named, ok := t.(*types.Named); ok {
346-
t = named.Underlying()
347-
}
348-
349-
switch ut := t.Underlying().(type) {
350-
case *types.Basic:
351-
switch ut.Kind() {
352-
case types.Bool:
353-
return SchemaTypeBoolean
354-
case types.Int, types.Int8, types.Int16, types.Int32, types.Int64,
355-
types.Uint, types.Uint8, types.Uint16, types.Uint32, types.Uint64:
356-
return SchemaTypeInteger
357-
case types.Float32, types.Float64:
358-
return SchemaTypeNumber
359-
case types.String:
360-
return SchemaTypeString
361-
case types.Invalid, types.Uintptr, types.Complex64, types.Complex128,
362-
types.UnsafePointer, types.UntypedBool, types.UntypedInt, types.UntypedRune,
363-
types.UntypedFloat, types.UntypedComplex, types.UntypedString, types.UntypedNil:
364-
// These types are not supported in OpenAPI schemas
365-
return ""
366-
}
367-
case *types.Slice, *types.Array:
368-
return SchemaTypeArray
369-
case *types.Map, *types.Struct:
370-
return SchemaTypeObject
371-
}
372-
373-
return ""
374-
}
375-
376-
// getElementType returns the element type of an array or slice.
377-
func getElementType(t types.Type) types.Type {
378-
// Unwrap pointer types
379-
if ptr, ok := t.(*types.Pointer); ok {
380-
t = ptr.Elem()
381-
}
382-
383-
// Unwrap named types to get underlying type
384-
if named, ok := t.(*types.Named); ok {
385-
t = named.Underlying()
386-
}
387-
388-
switch ut := t.(type) {
389-
case *types.Slice:
390-
return ut.Elem()
391-
case *types.Array:
392-
return ut.Elem()
393-
}
394-
395-
return nil
396-
}
397-
398347
// extractIdent extracts an *ast.Ident from an ast.Expr, unwrapping pointers and arrays.
399348
func extractIdent(expr ast.Expr) *ast.Ident {
400349
switch e := expr.(type) {

0 commit comments

Comments
 (0)