Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 121 additions & 0 deletions packages/react-native-executorch/src/controllers/BaseOCRController.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import { symbols } from '../constants/ocr/symbols';
import { RnExecutorchErrorCode } from '../errors/ErrorCodes';
import { RnExecutorchError, parseUnknownError } from '../errors/errorUtils';
import { ResourceSource } from '../types/common';
import { OCRLanguage, OCRDetection } from '../types/ocr';
import { ResourceFetcher } from '../utils/ResourceFetcher';

export abstract class BaseOCRController {
protected nativeModule: any;
public isReady: boolean = false;
public isGenerating: boolean = false;
public error: RnExecutorchError | null = null;
protected isReadyCallback: (isReady: boolean) => void;
protected isGeneratingCallback: (isGenerating: boolean) => void;
protected errorCallback: (error: RnExecutorchError) => void;

constructor({
isReadyCallback = (_isReady: boolean) => {},
isGeneratingCallback = (_isGenerating: boolean) => {},
errorCallback = (_error: RnExecutorchError) => {},
} = {}) {
this.isReadyCallback = isReadyCallback;
this.isGeneratingCallback = isGeneratingCallback;
this.errorCallback = errorCallback;
}

protected abstract loadNativeModule(
detectorPath: string,
recognizerPath: string,
language: OCRLanguage,
extraParams?: any
): any;

protected internalLoad = async (
detectorSource: ResourceSource,
recognizerSource: ResourceSource,
language: OCRLanguage,
onDownloadProgressCallback?: (downloadProgress: number) => void,
extraParams?: any
) => {
try {
if (!detectorSource || !recognizerSource) return;

if (!symbols[language]) {
throw new RnExecutorchError(
RnExecutorchErrorCode.LanguageNotSupported,
'The provided language for OCR is not supported. Please try using other language.'
);
}

this.isReady = false;
this.isReadyCallback(false);

const paths = await ResourceFetcher.fetch(
onDownloadProgressCallback,
detectorSource,
recognizerSource
);
if (paths === null || paths.length < 2) {
throw new RnExecutorchError(
RnExecutorchErrorCode.DownloadInterrupted,
'The download has been interrupted. As a result, not every file was downloaded. Please retry the download.'
);
}
this.nativeModule = this.loadNativeModule(
paths[0]!,
paths[1]!,
language,
extraParams
);
this.isReady = true;
this.isReadyCallback(this.isReady);
} catch (e) {
if (this.errorCallback) {
this.errorCallback(parseUnknownError(e));
} else {
throw parseUnknownError(e);
}
}
};

public forward = async (imageSource: string): Promise<OCRDetection[]> => {
if (!this.isReady) {
throw new RnExecutorchError(
RnExecutorchErrorCode.ModuleNotLoaded,
'The model is currently not loaded. Please load the model before calling forward().'
);
}
if (this.isGenerating) {
throw new RnExecutorchError(
RnExecutorchErrorCode.ModelGenerating,
'The model is currently generating. Please wait until previous model run is complete.'
);
}

try {
this.isGenerating = true;
this.isGeneratingCallback(this.isGenerating);
return await this.nativeModule.generate(imageSource);
} catch (e) {
throw parseUnknownError(e);
} finally {
this.isGenerating = false;
this.isGeneratingCallback(this.isGenerating);
}
};

public delete() {
if (this.isGenerating) {
throw new RnExecutorchError(
RnExecutorchErrorCode.ModelGenerating,
'The model is currently generating. Please wait until previous model run is complete.'
);
}
if (this.nativeModule) {
this.nativeModule.unload();
}
this.isReadyCallback(false);
this.isGeneratingCallback(false);
}
}
114 changes: 15 additions & 99 deletions packages/react-native-executorch/src/controllers/OCRController.ts
Original file line number Diff line number Diff line change
@@ -1,27 +1,15 @@
import { symbols } from '../constants/ocr/symbols';
import { RnExecutorchErrorCode } from '../errors/ErrorCodes';
import { RnExecutorchError, parseUnknownError } from '../errors/errorUtils';
import { ResourceSource } from '../types/common';
import { OCRLanguage } from '../types/ocr';
import { ResourceFetcher } from '../utils/ResourceFetcher';

export class OCRController {
private nativeModule: any;
public isReady: boolean = false;
public isGenerating: boolean = false;
public error: RnExecutorchError | null = null;
private isReadyCallback: (isReady: boolean) => void;
private isGeneratingCallback: (isGenerating: boolean) => void;
private errorCallback: (error: RnExecutorchError) => void;

constructor({
isReadyCallback = (_isReady: boolean) => {},
isGeneratingCallback = (_isGenerating: boolean) => {},
errorCallback = (_error: RnExecutorchError) => {},
} = {}) {
this.isReadyCallback = isReadyCallback;
this.isGeneratingCallback = isGeneratingCallback;
this.errorCallback = errorCallback;
import { BaseOCRController } from './BaseOCRController';

export class OCRController extends BaseOCRController {
protected loadNativeModule(
detectorPath: string,
recognizerPath: string,
language: OCRLanguage
): any {
return global.loadOCR(detectorPath, recognizerPath, symbols[language]);
}

public load = async (
Expand All @@ -30,83 +18,11 @@ export class OCRController {
language: OCRLanguage,
onDownloadProgressCallback?: (downloadProgress: number) => void
) => {
try {
if (!detectorSource || !recognizerSource) return;

if (!symbols[language]) {
throw new RnExecutorchError(
RnExecutorchErrorCode.LanguageNotSupported,
'The provided language for OCR is not supported. Please try using other language.'
);
}

this.isReady = false;
this.isReadyCallback(false);

const paths = await ResourceFetcher.fetch(
onDownloadProgressCallback,
detectorSource,
recognizerSource
);
if (paths === null || paths.length < 2) {
throw new RnExecutorchError(
RnExecutorchErrorCode.DownloadInterrupted,
'The download has been interrupted. As a result, not every file was downloaded. Please retry the download.'
);
}
this.nativeModule = global.loadOCR(
paths[0]!,
paths[1]!,
symbols[language]
);
this.isReady = true;
this.isReadyCallback(this.isReady);
} catch (e) {
if (this.errorCallback) {
this.errorCallback(parseUnknownError(e));
} else {
throw parseUnknownError(e);
}
}
await this.internalLoad(
detectorSource,
recognizerSource,
language,
onDownloadProgressCallback
);
};

public forward = async (imageSource: string) => {
if (!this.isReady) {
throw new RnExecutorchError(
RnExecutorchErrorCode.ModuleNotLoaded,
'The model is currently not loaded. Please load the model before calling forward().'
);
}
if (this.isGenerating) {
throw new RnExecutorchError(
RnExecutorchErrorCode.ModelGenerating,
'The model is currently generating. Please wait until previous model run is complete.'
);
}

try {
this.isGenerating = true;
this.isGeneratingCallback(this.isGenerating);
return await this.nativeModule.generate(imageSource);
} catch (e) {
throw parseUnknownError(e);
} finally {
this.isGenerating = false;
this.isGeneratingCallback(this.isGenerating);
}
};

public delete() {
if (this.isGenerating) {
throw new RnExecutorchError(
RnExecutorchErrorCode.ModelGenerating,
'The model is currently generating. Please wait until previous model run is complete.'
);
}
if (this.nativeModule) {
this.nativeModule.unload();
}
this.isReadyCallback(false);
this.isGeneratingCallback(false);
}
}
Original file line number Diff line number Diff line change
@@ -1,27 +1,21 @@
import { symbols } from '../constants/ocr/symbols';
import { RnExecutorchErrorCode } from '../errors/ErrorCodes';
import { RnExecutorchError, parseUnknownError } from '../errors/errorUtils';
import { ResourceSource } from '../types/common';
import { OCRLanguage } from '../types/ocr';
import { ResourceFetcher } from '../utils/ResourceFetcher';
import { BaseOCRController } from './BaseOCRController';

export class VerticalOCRController {
private ocrNativeModule: any;
public isReady: boolean = false;
public isGenerating: boolean = false;
public error: string | null = null;
private isReadyCallback: (isReady: boolean) => void;
private isGeneratingCallback: (isGenerating: boolean) => void;
private errorCallback: (error: RnExecutorchError) => void;

constructor({
isReadyCallback = (_isReady: boolean) => {},
isGeneratingCallback = (_isGenerating: boolean) => {},
errorCallback = (_error: RnExecutorchError) => {},
} = {}) {
this.isReadyCallback = isReadyCallback;
this.isGeneratingCallback = isGeneratingCallback;
this.errorCallback = errorCallback;
export class VerticalOCRController extends BaseOCRController {
protected loadNativeModule(
detectorPath: string,
recognizerPath: string,
language: OCRLanguage,
independentCharacters?: boolean
): any {
return global.loadVerticalOCR(
detectorPath,
recognizerPath,
symbols[language],
independentCharacters
);
}

public load = async (
Expand All @@ -31,85 +25,12 @@ export class VerticalOCRController {
independentCharacters: boolean,
onDownloadProgressCallback: (downloadProgress: number) => void
) => {
try {
if (!detectorSource || !recognizerSource) return;

if (!symbols[language]) {
throw new RnExecutorchError(
RnExecutorchErrorCode.LanguageNotSupported,
'The provided language for OCR is not supported. Please try using other language.'
);
}

this.isReady = false;
this.isReadyCallback(this.isReady);

const paths = await ResourceFetcher.fetch(
onDownloadProgressCallback,
detectorSource,
recognizerSource
);
if (paths === null || paths.length < 3) {
throw new RnExecutorchError(
RnExecutorchErrorCode.DownloadInterrupted,
'The download has been interrupted. As a result, not every file was downloaded. Please retry the download.'
);
}
this.ocrNativeModule = global.loadVerticalOCR(
paths[0]!,
paths[1]!,
symbols[language],
independentCharacters
);

this.isReady = true;
this.isReadyCallback(this.isReady);
} catch (e) {
if (this.errorCallback) {
this.errorCallback(parseUnknownError(e));
} else {
throw parseUnknownError(e);
}
}
await this.internalLoad(
detectorSource,
recognizerSource,
language,
onDownloadProgressCallback,
independentCharacters
);
};

public forward = async (imageSource: string) => {
if (!this.isReady) {
throw new RnExecutorchError(
RnExecutorchErrorCode.ModuleNotLoaded,
'The model is currently not loaded. Please load the model before calling forward().'
);
}
if (this.isGenerating) {
throw new RnExecutorchError(
RnExecutorchErrorCode.ModelGenerating,
'The model is currently generating. Please wait until previous model run is complete.'
);
}

try {
this.isGenerating = true;
this.isGeneratingCallback(this.isGenerating);
return await this.ocrNativeModule.generate(imageSource);
} catch (e) {
throw parseUnknownError(e);
} finally {
this.isGenerating = false;
this.isGeneratingCallback(this.isGenerating);
}
};

public delete() {
if (this.isGenerating) {
throw new RnExecutorchError(
RnExecutorchErrorCode.ModelGenerating,
'The model is currently generating. Please wait until previous model run is complete.'
);
}
if (this.ocrNativeModule) {
this.ocrNativeModule.unload();
}
this.isReadyCallback(false);
this.isGeneratingCallback(false);
}
}
Loading