From cb1368d86b0e41b66c1659a5abd2b27d60a9b5fc Mon Sep 17 00:00:00 2001 From: andermatt64 Date: Mon, 11 Sep 2023 00:21:06 -0400 Subject: [PATCH] Add initial_prompt support for Whisper params --- src/whispercpp/api.pyi | 2 ++ src/whispercpp/context.h | 8 ++++++++ src/whispercpp/params.cc | 9 +++++++++ tests/params_export_test.py | 8 ++++++++ 4 files changed, 27 insertions(+) diff --git a/src/whispercpp/api.pyi b/src/whispercpp/api.pyi index 43eb11a..53cc26b 100644 --- a/src/whispercpp/api.pyi +++ b/src/whispercpp/api.pyi @@ -101,6 +101,8 @@ class Params: prompt_num_tokens: int language: str def with_language(self, language: str) -> Params: ... + initial_prompt: str + def with_initial_prompt(self, initial_prompt: str) -> Params: ... suppress_blank: bool def with_suppress_blank(self, suppress_blank: bool) -> Params: ... suppress_none_speech_tokens: bool diff --git a/src/whispercpp/context.h b/src/whispercpp/context.h index abaf48e..7eb6edf 100644 --- a/src/whispercpp/context.h +++ b/src/whispercpp/context.h @@ -145,6 +145,7 @@ struct Params { private: std::shared_ptr fp; std::string language; + std::string initial_prompt; CallbackAndContext new_segment_callback; CallbackAndContext progress_callback; @@ -348,6 +349,13 @@ struct Params { return this; } + // Set initial prompt + Params *with_initial_prompt(std::string initial_prompt) { + this->initial_prompt = initial_prompt; + fp->initial_prompt = this->initial_prompt.c_str(); + return this; + } + // Set suppress_blank. See // https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89 // for more information. diff --git a/src/whispercpp/params.cc b/src/whispercpp/params.cc index 8672840..0df61c4 100644 --- a/src/whispercpp/params.cc +++ b/src/whispercpp/params.cc @@ -604,6 +604,15 @@ void ExportParamsApi(py::module &m) { WITH_DEPRECATION("language"); self.with_language(language); }) + // NOTE set initial_prompt + .def("with_initial_prompt", &Params::with_initial_prompt, "initial_prompt"_a, + py::return_value_policy::reference) + .def_property( + "initial_prompt", [](Params &self) { return self.get()->initial_prompt; }, + [](Params &self, const char *initial_prompt) { + WITH_DEPRECATION("initial_prompt"); + self.with_initial_prompt(initial_prompt); + }) // NOTE setting suppress_blank .def("with_suppress_blank", &Params::with_suppress_blank, "suppress_blank"_a, py::return_value_policy::reference) diff --git a/tests/params_export_test.py b/tests/params_export_test.py index 5f1277b..46c0154 100644 --- a/tests/params_export_test.py +++ b/tests/params_export_test.py @@ -95,3 +95,11 @@ def test_set_language(): params_with_lang = params.with_language(lang) print(lang, params_with_lang.language) assert params_with_lang.language == lang + +def test_set_initial_prompt(): + params = w.api.Params.from_enum(w.api.StrategyType.SAMPLING_GREEDY) + for prompt in ["This is a test initial prompt", ""]: + assert params.initial_prompt != "" + params_with_initial_prompt = params.with_initial_prompt(prompt) + print(prompt, params_with_initial_prompt.initial_prompt) + assert params_with_initial_prompt.initial_prompt == prompt \ No newline at end of file