diff --git a/runner/options.go b/runner/options.go index 2fa8d601..44bda3b0 100644 --- a/runner/options.go +++ b/runner/options.go @@ -274,6 +274,8 @@ type Options struct { RateLimitMinute int Probe bool Resume bool + RetryRounds int + RetryDelay int resumeCfg *ResumeCfg Exclude goflags.StringSlice HostMaxErrors int @@ -530,6 +532,8 @@ func ParseOptions() *Options { flagSet.DurationVar(&options.Delay, "delay", -1, "duration between each http request (eg: 200ms, 1s)"), flagSet.IntVarP(&options.MaxResponseBodySizeToSave, "response-size-to-save", "rsts", math.MaxInt32, "max response size to save in bytes"), flagSet.IntVarP(&options.MaxResponseBodySizeToRead, "response-size-to-read", "rstr", math.MaxInt32, "max response size to read in bytes"), + flagSet.IntVar(&options.RetryRounds, "retry-rounds", 0, "number of retry rounds for HTTP 429 responses (Too Many Requests)"), + flagSet.IntVar(&options.RetryDelay, "retry-delay", 500, "delay between retry rounds for HTTP 429 responses (e.g. 5ms, 30ms)"), ) flagSet.CreateGroup("cloud", "Cloud", @@ -757,6 +761,10 @@ func (options *Options) ValidateOptions() error { options.Threads = defaultThreads } + if options.RetryRounds > 0 && options.RetryDelay <= 0 { + return errors.New(fmt.Sprintf("invalid retry-delay: must be >0 when retry-rounds=%d (got %d)", options.RetryRounds, options.RetryDelay)) + } + return nil } diff --git a/runner/runner.go b/runner/runner.go index f6ec504d..9f003544 100644 --- a/runner/runner.go +++ b/runner/runner.go @@ -23,6 +23,7 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "time" "golang.org/x/exp/maps" @@ -1257,6 +1258,9 @@ func (r *Runner) RunEnumeration() { }(nextStep) wg, _ := syncutil.New(syncutil.WithSize(r.options.Threads)) + retryCh := make(chan retryJob) + + _, drainedCh := r.retryLoop(context.Background(), retryCh, output, r.analyze) processItem := func(k string) error { if r.options.resumeCfg != nil { @@ -1279,10 +1283,10 @@ func (r *Runner) RunEnumeration() { for _, p := range r.options.requestURIs { scanopts := r.scanopts.Clone() scanopts.RequestURI = p - r.process(k, wg, r.hp, protocol, scanopts, output) + r.process(k, wg, r.hp, protocol, scanopts, output, retryCh) } } else { - r.process(k, wg, r.hp, protocol, &r.scanopts, output) + r.process(k, wg, r.hp, protocol, &r.scanopts, output, retryCh) } return nil @@ -1299,9 +1303,10 @@ func (r *Runner) RunEnumeration() { } wg.Wait() - + if r.options.RetryRounds > 0 { + <-drainedCh + } close(output) - wgoutput.Wait() if r.scanopts.StoreVisionReconClusters { @@ -1323,6 +1328,70 @@ func (r *Runner) RunEnumeration() { } } +type analyzeFunc func(*httpx.HTTPX, string, httpx.Target, string, string, *ScanOptions) Result + +func (r *Runner) retryLoop( + parent context.Context, + retryCh chan retryJob, + output chan<- Result, + analyze analyzeFunc, +) (stop func(), drained <-chan struct{}) { + var remaining atomic.Int64 + ctx, cancel := context.WithCancel(parent) + drainedCh := make(chan struct{}) + + go func() { + defer close(retryCh) + + for { + select { + case <-ctx.Done(): + return + case job, ok := <-retryCh: + if !ok { + return + } + if job.attempt == 1 { + remaining.Add(1) + } + + go func(j retryJob) { + if wait := time.Until(j.when); wait > 0 { + timer := time.NewTimer(wait) + select { + case <-ctx.Done(): + timer.Stop() + return + case <-timer.C: + } + } + + res := analyze(j.hp, j.protocol, j.target, j.method, j.origInput, j.scanopts) + output <- res + + if res.StatusCode == http.StatusTooManyRequests && j.attempt < r.options.RetryRounds { + j.attempt++ + j.when = time.Now().Add(time.Duration(r.options.RetryDelay) * time.Millisecond) + + select { + case <-ctx.Done(): + return + case retryCh <- j: + return + } + } + + if remaining.Add(-1) == 0 { + close(drainedCh) + } + }(job) + } + } + }() + + return func() { cancel() }, drainedCh +} + func logFilteredErrorPage(fileName, url string) { dir := filepath.Dir(fileName) if !fileutil.FolderExists(dir) { @@ -1380,11 +1449,11 @@ func (r *Runner) GetScanOpts() ScanOptions { return r.scanopts } -func (r *Runner) Process(t string, wg *syncutil.AdaptiveWaitGroup, protocol string, scanopts *ScanOptions, output chan Result) { - r.process(t, wg, r.hp, protocol, scanopts, output) +func (r *Runner) Process(t string, wg *syncutil.AdaptiveWaitGroup, protocol string, scanopts *ScanOptions, output chan Result, retryCh chan retryJob) { + r.process(t, wg, r.hp, protocol, scanopts, output, retryCh) } -func (r *Runner) process(t string, wg *syncutil.AdaptiveWaitGroup, hp *httpx.HTTPX, protocol string, scanopts *ScanOptions, output chan Result) { +func (r *Runner) process(t string, wg *syncutil.AdaptiveWaitGroup, hp *httpx.HTTPX, protocol string, scanopts *ScanOptions, output chan Result, retryCh chan retryJob) { // attempts to set the workpool size to the number of threads if r.options.Threads > 0 && wg.Size != r.options.Threads { if err := wg.Resize(context.Background(), r.options.Threads); err != nil { @@ -1409,15 +1478,28 @@ func (r *Runner) process(t string, wg *syncutil.AdaptiveWaitGroup, hp *httpx.HTT defer wg.Done() result := r.analyze(hp, protocol, target, method, t, scanopts) output <- result + if result.StatusCode == http.StatusTooManyRequests && + r.options.RetryRounds > 0 { + retryCh <- retryJob{ + hp: hp, + protocol: protocol, + target: target, + method: method, + origInput: t, + scanopts: scanopts.Clone(), + attempt: 1, + when: time.Now().Add(time.Duration(r.options.RetryDelay) * time.Millisecond), + } + } if scanopts.TLSProbe && result.TLSData != nil { for _, tt := range result.TLSData.SubjectAN { if !r.testAndSet(tt) { continue } - r.process(tt, wg, hp, protocol, scanopts, output) + r.process(tt, wg, hp, protocol, scanopts, output, retryCh) } if r.testAndSet(result.TLSData.SubjectCN) { - r.process(result.TLSData.SubjectCN, wg, hp, protocol, scanopts, output) + r.process(result.TLSData.SubjectCN, wg, hp, protocol, scanopts, output, retryCh) } } if scanopts.CSPProbe && result.CSPData != nil { @@ -1428,7 +1510,7 @@ func (r *Runner) process(t string, wg *syncutil.AdaptiveWaitGroup, hp *httpx.HTT if !r.testAndSet(tt) { continue } - r.process(tt, wg, hp, protocol, scanopts, output) + r.process(tt, wg, hp, protocol, scanopts, output, retryCh) } } }(target, method, prot) @@ -1463,15 +1545,28 @@ func (r *Runner) process(t string, wg *syncutil.AdaptiveWaitGroup, hp *httpx.HTT } result := r.analyze(hp, protocol, target, method, t, scanopts) output <- result + if result.StatusCode == http.StatusTooManyRequests && + r.options.RetryRounds > 0 { + retryCh <- retryJob{ + hp: hp, + protocol: protocol, + target: target, + method: method, + origInput: t, + scanopts: scanopts.Clone(), + attempt: 1, + when: time.Now().Add(time.Duration(r.options.RetryDelay) * time.Millisecond), + } + } if scanopts.TLSProbe && result.TLSData != nil { for _, tt := range result.TLSData.SubjectAN { if !r.testAndSet(tt) { continue } - r.process(tt, wg, hp, protocol, scanopts, output) + r.process(tt, wg, hp, protocol, scanopts, output, retryCh) } if r.testAndSet(result.TLSData.SubjectCN) { - r.process(result.TLSData.SubjectCN, wg, hp, protocol, scanopts, output) + r.process(result.TLSData.SubjectCN, wg, hp, protocol, scanopts, output, retryCh) } } }(port, target, method, wantedProtocol) diff --git a/runner/runner_test.go b/runner/runner_test.go index 850566b8..a566356d 100644 --- a/runner/runner_test.go +++ b/runner/runner_test.go @@ -1,9 +1,14 @@ package runner import ( + "context" "fmt" + "net/http" + "net/http/httptest" "os" "strings" + "sync" + "sync/atomic" "testing" "time" @@ -11,6 +16,7 @@ import ( "github.com/projectdiscovery/httpx/common/httpx" "github.com/projectdiscovery/mapcidr/asn" stringsutil "github.com/projectdiscovery/utils/strings" + syncutil "github.com/projectdiscovery/utils/sync" "github.com/stretchr/testify/require" ) @@ -227,10 +233,10 @@ func TestCreateNetworkpolicyInstance_AllowDenyFlags(t *testing.T) { runner := &Runner{} tests := []struct { - name string - allow []string - deny []string - testCases []struct { + name string + allow []string + deny []string + testCases []struct { ip string expected bool reason string @@ -312,3 +318,92 @@ func TestCreateNetworkpolicyInstance_AllowDenyFlags(t *testing.T) { }) } } + +func TestRunner_Process_And_RetryLoop(t *testing.T) { + var hits1, hits2 int32 + + // srv1: returns 429 for the first 3 requests, and 200 on the 4th request + srv1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if atomic.AddInt32(&hits1, 1) != 4 { + w.WriteHeader(http.StatusTooManyRequests) + return + } + w.WriteHeader(http.StatusOK) + })) + defer srv1.Close() + + // srv2: returns 429 for the first 2 requests, and 200 on the 3rd request + srv2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if atomic.AddInt32(&hits2, 1) != 3 { + w.WriteHeader(http.StatusTooManyRequests) + return + } + w.WriteHeader(http.StatusOK) + })) + defer srv2.Close() + + r, err := New(&Options{ + Threads: 1, + RetryRounds: 2, + RetryDelay: 5, + Timeout: 3, + }) + require.NoError(t, err) + + output := make(chan Result) + retryCh := make(chan retryJob) + + _, drainedCh := r.retryLoop(context.Background(), retryCh, output, r.analyze) + + wg, _ := syncutil.New(syncutil.WithSize(r.options.Threads)) + so := r.scanopts.Clone() + so.Methods = []string{"GET"} + so.TLSProbe = false + so.CSPProbe = false + + seed := map[string]string{ + "srv1": srv1.URL, + "srv2": srv2.URL, + } + + var drainWG sync.WaitGroup + drainWG.Add(1) + var s1n429, s1n200, s2n429, s2n200 int + go func() { + defer drainWG.Done() + for res := range output { + switch res.StatusCode { + case http.StatusTooManyRequests: + if res.URL == srv1.URL { + s1n429++ + } else { + s2n429++ + } + case http.StatusOK: + if res.URL == srv1.URL { + s1n200++ + } else { + s2n200++ + } + } + } + }() + + for _, url := range seed { + r.process(url, wg, r.hp, httpx.HTTP, so, output, retryCh) + } + + wg.Wait() + <-drainedCh + close(output) + drainWG.Wait() + + // Verify expected results + // srv1: should have 3x 429 responses and no 200 (never succeeds within retries) + require.Equal(t, 3, s1n429) + require.Equal(t, 0, s1n200) + + // srv2: should have 2x 429 responses and 1x 200 (succeeds on 3rd attempt) + require.Equal(t, 2, s2n429) + require.Equal(t, 1, s2n200) +} diff --git a/runner/types.go b/runner/types.go index 724e8697..5f4367c3 100644 --- a/runner/types.go +++ b/runner/types.go @@ -120,6 +120,17 @@ type Trace struct { WroteRequest time.Time `json:"wrote_request,omitempty"` } +type retryJob struct { + hp *httpx.HTTPX + protocol string + target httpx.Target + method string + origInput string + scanopts *ScanOptions + attempt int + when time.Time +} + // function to get dsl variables from result struct func dslVariables() ([]string, error) { fakeResult := Result{}