diff --git a/packages/react-native-executorch/src/controllers/BaseOCRController.ts b/packages/react-native-executorch/src/controllers/BaseOCRController.ts new file mode 100644 index 000000000..9f0d5d611 --- /dev/null +++ b/packages/react-native-executorch/src/controllers/BaseOCRController.ts @@ -0,0 +1,121 @@ +import { symbols } from '../constants/ocr/symbols'; +import { RnExecutorchErrorCode } from '../errors/ErrorCodes'; +import { RnExecutorchError, parseUnknownError } from '../errors/errorUtils'; +import { ResourceSource } from '../types/common'; +import { OCRLanguage, OCRDetection } from '../types/ocr'; +import { ResourceFetcher } from '../utils/ResourceFetcher'; + +export abstract class BaseOCRController { + protected nativeModule: any; + public isReady: boolean = false; + public isGenerating: boolean = false; + public error: RnExecutorchError | null = null; + protected isReadyCallback: (isReady: boolean) => void; + protected isGeneratingCallback: (isGenerating: boolean) => void; + protected errorCallback: (error: RnExecutorchError) => void; + + constructor({ + isReadyCallback = (_isReady: boolean) => {}, + isGeneratingCallback = (_isGenerating: boolean) => {}, + errorCallback = (_error: RnExecutorchError) => {}, + } = {}) { + this.isReadyCallback = isReadyCallback; + this.isGeneratingCallback = isGeneratingCallback; + this.errorCallback = errorCallback; + } + + protected abstract loadNativeModule( + detectorPath: string, + recognizerPath: string, + language: OCRLanguage, + extraParams?: any + ): any; + + protected internalLoad = async ( + detectorSource: ResourceSource, + recognizerSource: ResourceSource, + language: OCRLanguage, + onDownloadProgressCallback?: (downloadProgress: number) => void, + extraParams?: any + ) => { + try { + if (!detectorSource || !recognizerSource) return; + + if (!symbols[language]) { + throw new RnExecutorchError( + RnExecutorchErrorCode.LanguageNotSupported, + 'The provided language for OCR is not supported. Please try using other language.' + ); + } + + this.isReady = false; + this.isReadyCallback(false); + + const paths = await ResourceFetcher.fetch( + onDownloadProgressCallback, + detectorSource, + recognizerSource + ); + if (paths === null || paths.length < 2) { + throw new RnExecutorchError( + RnExecutorchErrorCode.DownloadInterrupted, + 'The download has been interrupted. As a result, not every file was downloaded. Please retry the download.' + ); + } + this.nativeModule = this.loadNativeModule( + paths[0]!, + paths[1]!, + language, + extraParams + ); + this.isReady = true; + this.isReadyCallback(this.isReady); + } catch (e) { + if (this.errorCallback) { + this.errorCallback(parseUnknownError(e)); + } else { + throw parseUnknownError(e); + } + } + }; + + public forward = async (imageSource: string): Promise => { + if (!this.isReady) { + throw new RnExecutorchError( + RnExecutorchErrorCode.ModuleNotLoaded, + 'The model is currently not loaded. Please load the model before calling forward().' + ); + } + if (this.isGenerating) { + throw new RnExecutorchError( + RnExecutorchErrorCode.ModelGenerating, + 'The model is currently generating. Please wait until previous model run is complete.' + ); + } + + try { + this.isGenerating = true; + this.isGeneratingCallback(this.isGenerating); + return await this.nativeModule.generate(imageSource); + } catch (e) { + throw parseUnknownError(e); + } finally { + this.isGenerating = false; + this.isGeneratingCallback(this.isGenerating); + } + }; + + public delete() { + if (this.isGenerating) { + throw new RnExecutorchError( + RnExecutorchErrorCode.ModelGenerating, + 'The model is currently generating. Please wait until previous model run is complete.' + ); + } + if (this.nativeModule) { + this.nativeModule.unload(); + } + this.isReadyCallback(false); + this.isGeneratingCallback(false); + } +} diff --git a/packages/react-native-executorch/src/controllers/OCRController.ts b/packages/react-native-executorch/src/controllers/OCRController.ts index 57f1e3489..31523563d 100644 --- a/packages/react-native-executorch/src/controllers/OCRController.ts +++ b/packages/react-native-executorch/src/controllers/OCRController.ts @@ -1,27 +1,15 @@ import { symbols } from '../constants/ocr/symbols'; -import { RnExecutorchErrorCode } from '../errors/ErrorCodes'; -import { RnExecutorchError, parseUnknownError } from '../errors/errorUtils'; import { ResourceSource } from '../types/common'; import { OCRLanguage } from '../types/ocr'; -import { ResourceFetcher } from '../utils/ResourceFetcher'; - -export class OCRController { - private nativeModule: any; - public isReady: boolean = false; - public isGenerating: boolean = false; - public error: RnExecutorchError | null = null; - private isReadyCallback: (isReady: boolean) => void; - private isGeneratingCallback: (isGenerating: boolean) => void; - private errorCallback: (error: RnExecutorchError) => void; - - constructor({ - isReadyCallback = (_isReady: boolean) => {}, - isGeneratingCallback = (_isGenerating: boolean) => {}, - errorCallback = (_error: RnExecutorchError) => {}, - } = {}) { - this.isReadyCallback = isReadyCallback; - this.isGeneratingCallback = isGeneratingCallback; - this.errorCallback = errorCallback; +import { BaseOCRController } from './BaseOCRController'; + +export class OCRController extends BaseOCRController { + protected loadNativeModule( + detectorPath: string, + recognizerPath: string, + language: OCRLanguage + ): any { + return global.loadOCR(detectorPath, recognizerPath, symbols[language]); } public load = async ( @@ -30,83 +18,11 @@ export class OCRController { language: OCRLanguage, onDownloadProgressCallback?: (downloadProgress: number) => void ) => { - try { - if (!detectorSource || !recognizerSource) return; - - if (!symbols[language]) { - throw new RnExecutorchError( - RnExecutorchErrorCode.LanguageNotSupported, - 'The provided language for OCR is not supported. Please try using other language.' - ); - } - - this.isReady = false; - this.isReadyCallback(false); - - const paths = await ResourceFetcher.fetch( - onDownloadProgressCallback, - detectorSource, - recognizerSource - ); - if (paths === null || paths.length < 2) { - throw new RnExecutorchError( - RnExecutorchErrorCode.DownloadInterrupted, - 'The download has been interrupted. As a result, not every file was downloaded. Please retry the download.' - ); - } - this.nativeModule = global.loadOCR( - paths[0]!, - paths[1]!, - symbols[language] - ); - this.isReady = true; - this.isReadyCallback(this.isReady); - } catch (e) { - if (this.errorCallback) { - this.errorCallback(parseUnknownError(e)); - } else { - throw parseUnknownError(e); - } - } + await this.internalLoad( + detectorSource, + recognizerSource, + language, + onDownloadProgressCallback + ); }; - - public forward = async (imageSource: string) => { - if (!this.isReady) { - throw new RnExecutorchError( - RnExecutorchErrorCode.ModuleNotLoaded, - 'The model is currently not loaded. Please load the model before calling forward().' - ); - } - if (this.isGenerating) { - throw new RnExecutorchError( - RnExecutorchErrorCode.ModelGenerating, - 'The model is currently generating. Please wait until previous model run is complete.' - ); - } - - try { - this.isGenerating = true; - this.isGeneratingCallback(this.isGenerating); - return await this.nativeModule.generate(imageSource); - } catch (e) { - throw parseUnknownError(e); - } finally { - this.isGenerating = false; - this.isGeneratingCallback(this.isGenerating); - } - }; - - public delete() { - if (this.isGenerating) { - throw new RnExecutorchError( - RnExecutorchErrorCode.ModelGenerating, - 'The model is currently generating. Please wait until previous model run is complete.' - ); - } - if (this.nativeModule) { - this.nativeModule.unload(); - } - this.isReadyCallback(false); - this.isGeneratingCallback(false); - } } diff --git a/packages/react-native-executorch/src/controllers/VerticalOCRController.ts b/packages/react-native-executorch/src/controllers/VerticalOCRController.ts index eaf4b0849..2509d8331 100644 --- a/packages/react-native-executorch/src/controllers/VerticalOCRController.ts +++ b/packages/react-native-executorch/src/controllers/VerticalOCRController.ts @@ -1,27 +1,21 @@ import { symbols } from '../constants/ocr/symbols'; -import { RnExecutorchErrorCode } from '../errors/ErrorCodes'; -import { RnExecutorchError, parseUnknownError } from '../errors/errorUtils'; import { ResourceSource } from '../types/common'; import { OCRLanguage } from '../types/ocr'; -import { ResourceFetcher } from '../utils/ResourceFetcher'; +import { BaseOCRController } from './BaseOCRController'; -export class VerticalOCRController { - private ocrNativeModule: any; - public isReady: boolean = false; - public isGenerating: boolean = false; - public error: string | null = null; - private isReadyCallback: (isReady: boolean) => void; - private isGeneratingCallback: (isGenerating: boolean) => void; - private errorCallback: (error: RnExecutorchError) => void; - - constructor({ - isReadyCallback = (_isReady: boolean) => {}, - isGeneratingCallback = (_isGenerating: boolean) => {}, - errorCallback = (_error: RnExecutorchError) => {}, - } = {}) { - this.isReadyCallback = isReadyCallback; - this.isGeneratingCallback = isGeneratingCallback; - this.errorCallback = errorCallback; +export class VerticalOCRController extends BaseOCRController { + protected loadNativeModule( + detectorPath: string, + recognizerPath: string, + language: OCRLanguage, + independentCharacters?: boolean + ): any { + return global.loadVerticalOCR( + detectorPath, + recognizerPath, + symbols[language], + independentCharacters + ); } public load = async ( @@ -31,85 +25,12 @@ export class VerticalOCRController { independentCharacters: boolean, onDownloadProgressCallback: (downloadProgress: number) => void ) => { - try { - if (!detectorSource || !recognizerSource) return; - - if (!symbols[language]) { - throw new RnExecutorchError( - RnExecutorchErrorCode.LanguageNotSupported, - 'The provided language for OCR is not supported. Please try using other language.' - ); - } - - this.isReady = false; - this.isReadyCallback(this.isReady); - - const paths = await ResourceFetcher.fetch( - onDownloadProgressCallback, - detectorSource, - recognizerSource - ); - if (paths === null || paths.length < 3) { - throw new RnExecutorchError( - RnExecutorchErrorCode.DownloadInterrupted, - 'The download has been interrupted. As a result, not every file was downloaded. Please retry the download.' - ); - } - this.ocrNativeModule = global.loadVerticalOCR( - paths[0]!, - paths[1]!, - symbols[language], - independentCharacters - ); - - this.isReady = true; - this.isReadyCallback(this.isReady); - } catch (e) { - if (this.errorCallback) { - this.errorCallback(parseUnknownError(e)); - } else { - throw parseUnknownError(e); - } - } + await this.internalLoad( + detectorSource, + recognizerSource, + language, + onDownloadProgressCallback, + independentCharacters + ); }; - - public forward = async (imageSource: string) => { - if (!this.isReady) { - throw new RnExecutorchError( - RnExecutorchErrorCode.ModuleNotLoaded, - 'The model is currently not loaded. Please load the model before calling forward().' - ); - } - if (this.isGenerating) { - throw new RnExecutorchError( - RnExecutorchErrorCode.ModelGenerating, - 'The model is currently generating. Please wait until previous model run is complete.' - ); - } - - try { - this.isGenerating = true; - this.isGeneratingCallback(this.isGenerating); - return await this.ocrNativeModule.generate(imageSource); - } catch (e) { - throw parseUnknownError(e); - } finally { - this.isGenerating = false; - this.isGeneratingCallback(this.isGenerating); - } - }; - - public delete() { - if (this.isGenerating) { - throw new RnExecutorchError( - RnExecutorchErrorCode.ModelGenerating, - 'The model is currently generating. Please wait until previous model run is complete.' - ); - } - if (this.ocrNativeModule) { - this.ocrNativeModule.unload(); - } - this.isReadyCallback(false); - this.isGeneratingCallback(false); - } }