|
| 1 | +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' |
| 2 | + |
| 3 | +const { postMock, getMock, putMock, deleteMock, streamMock, MockRequestError, MockHTTPError } = |
| 4 | + vi.hoisted(() => { |
| 5 | + class RequestError extends Error { |
| 6 | + response?: unknown |
| 7 | + |
| 8 | + constructor(message = 'RequestError', response?: unknown) { |
| 9 | + super(message) |
| 10 | + this.name = 'RequestError' |
| 11 | + this.response = response |
| 12 | + } |
| 13 | + } |
| 14 | + |
| 15 | + class HTTPError extends RequestError { |
| 16 | + constructor(response: unknown) { |
| 17 | + super('HTTPError', response) |
| 18 | + this.name = 'HTTPError' |
| 19 | + } |
| 20 | + } |
| 21 | + |
| 22 | + return { |
| 23 | + postMock: vi.fn(), |
| 24 | + getMock: vi.fn(), |
| 25 | + putMock: vi.fn(), |
| 26 | + deleteMock: vi.fn(), |
| 27 | + streamMock: vi.fn(), |
| 28 | + MockRequestError: RequestError, |
| 29 | + MockHTTPError: HTTPError, |
| 30 | + } |
| 31 | + }) |
| 32 | + |
| 33 | +vi.mock('got', () => { |
| 34 | + const mockClient = { |
| 35 | + post: postMock, |
| 36 | + get: getMock, |
| 37 | + put: putMock, |
| 38 | + delete: deleteMock, |
| 39 | + stream: streamMock, |
| 40 | + } |
| 41 | + |
| 42 | + return { |
| 43 | + default: mockClient, |
| 44 | + ...mockClient, |
| 45 | + HTTPError: MockHTTPError, |
| 46 | + RequestError: MockRequestError, |
| 47 | + } |
| 48 | +}) |
| 49 | + |
| 50 | +import { ApiError } from '../../src/ApiError.ts' |
| 51 | +import type { AssemblyStatus } from '../../src/alphalib/types/assemblyStatus.ts' |
| 52 | +import PaginationStream from '../../src/PaginationStream.ts' |
| 53 | +import PollingTimeoutError from '../../src/PollingTimeoutError.ts' |
| 54 | +import { Transloadit } from '../../src/Transloadit.ts' |
| 55 | + |
| 56 | +const getInternalRemoteJson = (client: Transloadit) => |
| 57 | + (client as unknown as { _remoteJson: Transloadit['_remoteJson'] })._remoteJson.bind(client) |
| 58 | + |
| 59 | +describe('Transloadit advanced behaviors', () => { |
| 60 | + let client: Transloadit |
| 61 | + |
| 62 | + beforeEach(() => { |
| 63 | + client = new Transloadit({ authKey: 'key', authSecret: 'secret', maxRetries: 2 }) |
| 64 | + postMock.mockReset() |
| 65 | + getMock.mockReset() |
| 66 | + putMock.mockReset() |
| 67 | + deleteMock.mockReset() |
| 68 | + }) |
| 69 | + |
| 70 | + afterEach(() => { |
| 71 | + vi.useRealTimers() |
| 72 | + vi.restoreAllMocks() |
| 73 | + }) |
| 74 | + |
| 75 | + it('retries rate limited requests before succeeding', async () => { |
| 76 | + vi.useFakeTimers() |
| 77 | + const remoteJson = getInternalRemoteJson(client) |
| 78 | + |
| 79 | + const body = { |
| 80 | + error: 'RATE_LIMIT_REACHED', |
| 81 | + info: { |
| 82 | + retryIn: 1, |
| 83 | + }, |
| 84 | + } |
| 85 | + |
| 86 | + const retryError = new MockHTTPError({ statusCode: 429, body }) |
| 87 | + postMock.mockRejectedValueOnce(retryError).mockResolvedValueOnce({ body: { ok: true } }) |
| 88 | + |
| 89 | + const randomSpy = vi.spyOn(Math, 'random').mockReturnValue(0) |
| 90 | + |
| 91 | + const resultPromise = remoteJson({ urlSuffix: '/foo', method: 'post' }) |
| 92 | + |
| 93 | + await vi.advanceTimersByTimeAsync(1000) |
| 94 | + |
| 95 | + const result = await resultPromise |
| 96 | + |
| 97 | + expect(result).toEqual({ ok: true }) |
| 98 | + expect(postMock).toHaveBeenCalledTimes(2) |
| 99 | + |
| 100 | + randomSpy.mockRestore() |
| 101 | + }) |
| 102 | + |
| 103 | + it('wraps non-retryable HTTP errors in ApiError', async () => { |
| 104 | + const remoteJson = getInternalRemoteJson(client) |
| 105 | + |
| 106 | + const errorBody = { error: 'SOME_ERROR', info: {} } |
| 107 | + const httpError = new MockHTTPError({ statusCode: 500, body: errorBody }) |
| 108 | + postMock.mockRejectedValueOnce(httpError) |
| 109 | + |
| 110 | + await expect(remoteJson({ urlSuffix: '/foo', method: 'post' })).rejects.toBeInstanceOf(ApiError) |
| 111 | + }) |
| 112 | + |
| 113 | + it('polls assemblies until a terminal status is reached', async () => { |
| 114 | + const statuses: AssemblyStatus[] = [ |
| 115 | + { ok: 'ASSEMBLY_UPLOADING' } as AssemblyStatus, |
| 116 | + { ok: 'ASSEMBLY_EXECUTING' } as AssemblyStatus, |
| 117 | + { ok: 'ASSEMBLY_COMPLETED' } as AssemblyStatus, |
| 118 | + ] |
| 119 | + |
| 120 | + const getAssembly = vi |
| 121 | + .spyOn(client, 'getAssembly') |
| 122 | + .mockImplementation(async () => statuses.shift() as AssemblyStatus) |
| 123 | + |
| 124 | + const onAssemblyProgress = vi.fn() |
| 125 | + |
| 126 | + const result = await client.awaitAssemblyCompletion('assembly-id', { |
| 127 | + onAssemblyProgress, |
| 128 | + interval: 1, |
| 129 | + }) |
| 130 | + |
| 131 | + expect(result).toEqual({ ok: 'ASSEMBLY_COMPLETED' }) |
| 132 | + expect(onAssemblyProgress).toHaveBeenCalledTimes(2) |
| 133 | + expect(getAssembly).toHaveBeenCalledTimes(3) |
| 134 | + }) |
| 135 | + |
| 136 | + it('throws a timeout error when polling exceeds the allowed duration', async () => { |
| 137 | + vi.useFakeTimers() |
| 138 | + vi.spyOn(client, 'getAssembly').mockResolvedValue({ |
| 139 | + ok: 'ASSEMBLY_UPLOADING', |
| 140 | + } as AssemblyStatus) |
| 141 | + |
| 142 | + const promise = client.awaitAssemblyCompletion('assembly-id', { |
| 143 | + timeout: 0, |
| 144 | + startTimeMs: 0, |
| 145 | + interval: 1, |
| 146 | + }) |
| 147 | + |
| 148 | + await expect(promise).rejects.toBeInstanceOf(PollingTimeoutError) |
| 149 | + }) |
| 150 | + |
| 151 | + it('streams assemblies page by page until all items are read', async () => { |
| 152 | + type ListAssembliesReturn = Awaited<ReturnType<Transloadit['listAssemblies']>> |
| 153 | + |
| 154 | + const listAssemblies = vi.spyOn(client, 'listAssemblies').mockImplementation(async (params) => { |
| 155 | + const page = params?.page ?? 1 |
| 156 | + |
| 157 | + if (page === 1) { |
| 158 | + return { |
| 159 | + items: [{ id: 1 }, { id: 2 }], |
| 160 | + count: 3, |
| 161 | + } as unknown as ListAssembliesReturn |
| 162 | + } |
| 163 | + if (page === 2) { |
| 164 | + return { |
| 165 | + items: [{ id: 3 }], |
| 166 | + count: 3, |
| 167 | + } as unknown as ListAssembliesReturn |
| 168 | + } |
| 169 | + return { items: [], count: 3 } as unknown as ListAssembliesReturn |
| 170 | + }) |
| 171 | + |
| 172 | + const stream = client.streamAssemblies({ page: 1 } as never) |
| 173 | + |
| 174 | + const collected: Array<{ id: number }> = [] |
| 175 | + |
| 176 | + await new Promise<void>((resolve, reject) => { |
| 177 | + stream.on('data', (item) => { |
| 178 | + collected.push(item as { id: number }) |
| 179 | + }) |
| 180 | + stream.on('end', resolve) |
| 181 | + stream.on('error', reject) |
| 182 | + }) |
| 183 | + |
| 184 | + expect(collected).toEqual([{ id: 1 }, { id: 2 }, { id: 3 }]) |
| 185 | + expect(listAssemblies).toHaveBeenCalledTimes(2) |
| 186 | + expect(listAssemblies).toHaveBeenNthCalledWith(1, expect.objectContaining({ page: 1 })) |
| 187 | + expect(listAssemblies).toHaveBeenNthCalledWith(2, expect.objectContaining({ page: 2 })) |
| 188 | + }) |
| 189 | +}) |
| 190 | + |
| 191 | +describe('PaginationStream edge cases', () => { |
| 192 | + it('stops requesting pages once the reported count is reached', async () => { |
| 193 | + const fetchPage = vi |
| 194 | + .fn() |
| 195 | + .mockResolvedValueOnce({ items: [1, 2], count: 2 }) |
| 196 | + .mockResolvedValueOnce({ items: [3, 4] }) |
| 197 | + |
| 198 | + const stream = new PaginationStream<number>(fetchPage) |
| 199 | + const items: number[] = [] |
| 200 | + |
| 201 | + await new Promise<void>((resolve, reject) => { |
| 202 | + stream.on('data', (item) => { |
| 203 | + items.push(item) |
| 204 | + }) |
| 205 | + stream.on('end', resolve) |
| 206 | + stream.on('error', reject) |
| 207 | + }) |
| 208 | + |
| 209 | + expect(items).toEqual([1, 2]) |
| 210 | + expect(fetchPage).toHaveBeenCalledTimes(1) |
| 211 | + }) |
| 212 | +}) |
0 commit comments