diff --git a/runtime/binding/python/cpp/binding.cc b/runtime/binding/python/cpp/binding.cc index cff4f545e1..28f5e989fb 100644 --- a/runtime/binding/python/cpp/binding.cc +++ b/runtime/binding/python/cpp/binding.cc @@ -1,4 +1,5 @@ // Copyright (c) 2022 Binbin Zhang(binbzha@qq.com) +// 2022 SoundDataConverge Co.LTD (Weiliang Chong) // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -13,8 +14,10 @@ // limitations under the License. #include +#include #include "api/wenet_api.h" +#include "api/batch_recognizer.h" namespace py = pybind11; @@ -37,4 +40,12 @@ PYBIND11_MODULE(_wenet, m) { m.def("wenet_set_language", &wenet_set_language, "set language"); m.def("wenet_set_continuous_decoding", &wenet_set_continuous_decoding, "enable continuous decoding or not"); + py::class_(m, "BatchRecognizer") + .def(py::init()) + .def("set_enable_timestamp", &BatchRecognizer::set_enable_timestamp) + .def("AddContext", &BatchRecognizer::AddContext) + .def("set_context_score", &BatchRecognizer::set_context_score) + .def("set_language", &BatchRecognizer::set_language) + .def("DecodeData", &BatchRecognizer::DecodeData) + .def("Decode", &BatchRecognizer::Decode); } diff --git a/runtime/binding/python/py/__init__.py b/runtime/binding/python/py/__init__.py index 58cd2aef82..dcb72e0162 100644 --- a/runtime/binding/python/py/__init__.py +++ b/runtime/binding/python/py/__init__.py @@ -1,2 +1,3 @@ from .decoder import Decoder # noqa +from .batch_decoder import BatchDecoder # noqa from _wenet import wenet_set_log_level as set_log_level # noqa diff --git a/runtime/binding/python/py/batch_decoder.py b/runtime/binding/python/py/batch_decoder.py new file mode 100644 index 0000000000..d8e3b71ad6 --- /dev/null +++ b/runtime/binding/python/py/batch_decoder.py @@ -0,0 +1,79 @@ +# Copyright (c) 2022 Binbin Zhang(binbzha@qq.com) +# 2022 SoundDataConverge Co.LTD (Weiliang Chong) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional + +import _wenet + +from .hub import Hub + + +class BatchDecoder: + + def __init__(self, + model_dir: Optional[str] = None, + lang: str = 'chs', + nbest: int = 1, + enable_timestamp: bool = False, + context: Optional[List[str]] = None, + context_score: float = 3.0): + """ Init WeNet decoder + Args: + lang: language type of the model + nbest: nbest number for the final result + enable_timestamp: whether to enable word level timestamp + for the final result + context: context words + context_score: bonus score when the context is matched + """ + if model_dir is None: + model_dir = Hub.get_model_by_lang(lang) + + self.d = _wenet.BatchRecognizer(model_dir) + + self.set_language(lang) + self.enable_timestamp(enable_timestamp) + if context is not None: + self.add_context(context) + self.set_context_score(context_score) + + def __del__(self): + del self.d + + def enable_timestamp(self, flag: bool): + tag = 1 if flag else 0 + self.d.set_enable_timestamp(tag) + + def add_context(self, contexts: List[str]): + for c in contexts: + assert isinstance(c, str) + self.d.AddContext(c) + + def set_context_score(self, score: float): + self.d.set_context_score(score) + + def set_language(self, lang: str): + assert lang in ['chs', 'en'] + self.d.set_language(lang) + + def decode(self, pcms: List[bytes]) -> str: + """ Decode the input data + + Args: + pcms: a list of wav pcm + """ + assert isinstance(pcms[0], bytes) + result = self.d.Decode(pcms) + return result diff --git a/runtime/core/api/batch_recognizer.h b/runtime/core/api/batch_recognizer.h new file mode 100644 index 0000000000..02418a19de --- /dev/null +++ b/runtime/core/api/batch_recognizer.h @@ -0,0 +1,148 @@ +// Copyright (c) 2022 Binbin Zhang (binbzha@qq.com) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef API_BATCH_RECOGNIZER_H_ +#define API_BATCH_RECOGNIZER_H_ + +#include +#include +#include +#include + +#include "decoder/asr_decoder.h" +#include "decoder/batch_asr_decoder.h" +#include "decoder/batch_torch_asr_model.h" +#include "post_processor/post_processor.h" +#include "utils/file.h" +#include "utils/json.h" +#include "utils/string.h" + +class BatchRecognizer { + public: + explicit BatchRecognizer(const std::string& model_dir, int num_threads = 1) { + // FeaturePipeline init + feature_config_ = std::make_shared(80, 16000); + // Resource init + resource_ = std::make_shared(); + wenet::BatchTorchAsrModel::InitEngineThreads(num_threads); + std::string model_path = wenet::JoinPath(model_dir, "final.zip"); + CHECK(wenet::FileExists(model_path)); + + auto model = std::make_shared(); + model->Read(model_path); + resource_->batch_model = model; + + // units.txt: E2E model unit + std::string unit_path = wenet::JoinPath(model_dir, "units.txt"); + CHECK(wenet::FileExists(unit_path)); + resource_->unit_table = std::shared_ptr( + fst::SymbolTable::ReadText(unit_path)); + + std::string fst_path = wenet::JoinPath(model_dir, "TLG.fst"); + if (wenet::FileExists(fst_path)) { // With LM + resource_->fst = std::shared_ptr>( + fst::Fst::Read(fst_path)); + + std::string symbol_path = wenet::JoinPath(model_dir, "words.txt"); + CHECK(wenet::FileExists(symbol_path)); + resource_->symbol_table = std::shared_ptr( + fst::SymbolTable::ReadText(symbol_path)); + } else { // Without LM, symbol_table is the same as unit_table + resource_->symbol_table = resource_->unit_table; + } + + // Context config init + context_config_ = std::make_shared(); + decode_options_ = std::make_shared(); + post_process_opts_ = std::make_shared(); + } + + void InitDecoder() { + CHECK(decoder_ == nullptr); + // Optional init context graph + if (context_.size() > 0) { + context_config_->context_score = context_score_; + auto context_graph = + std::make_shared(*context_config_); + context_graph->BuildContextGraph(context_, resource_->symbol_table); + resource_->context_graph = context_graph; + } + // PostProcessor + if (language_ == "chs") { // TODO(Binbin Zhang): CJK(chs, jp, kr) + post_process_opts_->language_type = wenet::kMandarinEnglish; + } else { + post_process_opts_->language_type = wenet::kIndoEuropean; + } + resource_->post_processor = + std::make_shared(*post_process_opts_); + // Init decoder + decoder_ = std::make_shared( + feature_config_, resource_, + *decode_options_); + } + + std::string Decode(const std::vector& wavs) { + // Init decoder when it is called first time + if (decoder_ == nullptr) { + InitDecoder(); + } + std::vector> wavs_float; + for (auto& wav : wavs) { + const int16_t* pcm = reinterpret_cast(wav.data()); + int pcm_len = wav.size() / sizeof(int16_t); + std::vector wav_float(pcm_len); + for (size_t i = 0; i < pcm_len; i++) { + wav_float[i] = static_cast(*(pcm + i)); + } + wavs_float.push_back(std::move(wav_float)); + } + decoder_->Reset(); + decoder_->Decode(wavs_float); + return decoder_->get_batch_result(nbest_, enable_timestamp_); + } + + std::string DecodeData(const std::vector>& wavs) { + // Init decoder when it is called first time + if (decoder_ == nullptr) { + InitDecoder(); + } + decoder_->Reset(); + decoder_->Decode(wavs); + return decoder_->get_batch_result(nbest_, enable_timestamp_); + } + + + + void set_nbest(int n) { nbest_ = n; } + void set_enable_timestamp(bool flag) { enable_timestamp_ = flag; } + void AddContext(const char* word) { context_.emplace_back(word); } + void set_context_score(float score) { context_score_ = score; } + void set_language(const char* lang) { language_ = lang; } + + private: + std::shared_ptr feature_config_ = nullptr; + std::shared_ptr resource_ = nullptr; + std::shared_ptr decode_options_ = nullptr; + std::shared_ptr decoder_ = nullptr; + std::shared_ptr context_config_ = nullptr; + std::shared_ptr post_process_opts_ = nullptr; + + int nbest_ = 1; + bool enable_timestamp_ = false; + std::vector context_; + float context_score_; + std::string language_ = "chs"; +}; + +#endif // API_BATCH_RECOGNIZER_H_ diff --git a/runtime/core/bin/CMakeLists.txt b/runtime/core/bin/CMakeLists.txt index 03b19247a2..46252f254b 100644 --- a/runtime/core/bin/CMakeLists.txt +++ b/runtime/core/bin/CMakeLists.txt @@ -1,6 +1,9 @@ add_executable(decoder_main decoder_main.cc) target_link_libraries(decoder_main PUBLIC decoder) +add_executable(decoder_main_batch decoder_main_batch.cc) +target_link_libraries(decoder_main_batch PUBLIC decoder kaldifeat_core) + add_executable(label_checker_main label_checker_main.cc) target_link_libraries(label_checker_main PUBLIC decoder) diff --git a/runtime/core/bin/api_batch_main.cc b/runtime/core/bin/api_batch_main.cc new file mode 100644 index 0000000000..e80321c21a --- /dev/null +++ b/runtime/core/bin/api_batch_main.cc @@ -0,0 +1,51 @@ +// Copyright (c) 2022 Binbin Zhang (binbzha@qq.com) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "api/batch_recognizer.h" +#include "api/wenet_api.h" +#include "frontend/wav.h" +#include "utils/flags.h" +#include "utils/timer.h" + +DEFINE_string(model_dir, "", "model dir path"); +DEFINE_string(wav_path, "", "single wave path"); +DEFINE_int32(batch_size, 1, "batch size of input"); +DEFINE_int32(num_threads, 1, "number threads of intraop"); +DEFINE_bool(enable_timestamp, false, "enable timestamps"); + +int main(int argc, char* argv[]) { + gflags::ParseCommandLineFlags(&argc, &argv, false); + google::InitGoogleLogging(argv[0]); + + wenet_set_log_level(2); + + BatchRecognizer br(FLAGS_model_dir, FLAGS_num_threads); + if (FLAGS_enable_timestamp) br.set_enable_timestamp(true); + wenet::WavReader wav_reader(FLAGS_wav_path); + std::vector data; + data.insert( + data.end(), wav_reader.data(), + wav_reader.data() + wav_reader.num_samples()); + std::vector> wavs; + for (size_t i = 0; i < FLAGS_batch_size - 1; i++) { + wavs.push_back(data); + } + wavs.push_back(std::move(data)); + wenet::Timer timer; + std::string result = br.DecodeData(wavs); + int forward_time = timer.Elapsed(); + VLOG(1) << "Decode() takes " << forward_time << " ms"; + LOG(INFO) << result; + return 0; +} diff --git a/runtime/core/bin/decoder_main_batch.cc b/runtime/core/bin/decoder_main_batch.cc new file mode 100644 index 0000000000..3420baa260 --- /dev/null +++ b/runtime/core/bin/decoder_main_batch.cc @@ -0,0 +1,84 @@ +// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang, Di Wu) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "decoder/params.h" +#include "frontend/wav.h" +#include "utils/flags.h" +#include "utils/string.h" +#include "utils/timer.h" +#include "utils/utils.h" + +DEFINE_string(wav_path, "", "single wave path"); +DEFINE_int32(thread_num, 1, "num of decode thread"); +DEFINE_int32(batch_size, 1, "batch size of input"); + +std::shared_ptr g_decode_config; +std::shared_ptr g_feature_config; +std::shared_ptr g_decode_resource; + +int g_total_waves_dur = 0; +int g_total_decode_time = 0; + +// using namespace wenet; + +void decode(const std::string& wav) { + wenet::WavReader wav_reader(wav); + std::vector wav_data; + int num_samples = wav_reader.num_samples(); + wav_data.insert( + wav_data.end(), wav_reader.data(), wav_reader.data() + num_samples); + std::vector> batch_wav_data; + int wav_dur = static_cast( + static_cast(num_samples) / wav_reader.sample_rate() * 1000); + for (int i = 0; i < FLAGS_batch_size; ++i) { + batch_wav_data.push_back(wav_data); + g_total_waves_dur += wav_dur; + } + + auto decoder = std::make_unique( + g_feature_config, g_decode_resource, *g_decode_config); + wenet::Timer timer; + decoder->Decode(batch_wav_data); + int decode_time = timer.Elapsed(); + std::string result = decoder->get_batch_result(1, false); + std::cout << result << std::endl; + + LOG(INFO) << "batch_size : " << FLAGS_batch_size << std::endl; + LOG(INFO) << "Total: decoded " << g_total_waves_dur << "ms audio taken " + << decode_time << "ms."; + LOG(INFO) << "RTF: " << std::setprecision(4) + << static_cast(decode_time) / g_total_waves_dur; +} + + +int main(int argc, char* argv[]) { + gflags::ParseCommandLineFlags(&argc, &argv, false); + google::InitGoogleLogging(argv[0]); + + g_decode_config = wenet::InitDecodeOptionsFromFlags(); + g_feature_config = wenet::InitFeaturePipelineConfigFromFlags(); + g_decode_resource = wenet::InitDecodeResourceFromFlags(); + + if (FLAGS_wav_path.empty()) { + LOG(FATAL) << "Please provide the wave path."; + } + LOG(INFO) << "decoding " << FLAGS_wav_path; + decode(FLAGS_wav_path); + + return 0; +} diff --git a/runtime/core/bin/websocket_server_main.cc b/runtime/core/bin/websocket_server_main.cc index 796d9d2e6d..3bf7308b6b 100644 --- a/runtime/core/bin/websocket_server_main.cc +++ b/runtime/core/bin/websocket_server_main.cc @@ -29,6 +29,7 @@ int main(int argc, char* argv[]) { wenet::WebSocketServer server(FLAGS_port, feature_config, decode_config, decode_resource); LOG(INFO) << "Listening at port " << FLAGS_port; - server.Start(); + LOG(INFO) << "run for batch decoding: " << FLAGS_run_batch; + server.Start(FLAGS_run_batch); return 0; } diff --git a/runtime/core/cmake/kaldifeat.cmake b/runtime/core/cmake/kaldifeat.cmake new file mode 100644 index 0000000000..544826a844 --- /dev/null +++ b/runtime/core/cmake/kaldifeat.cmake @@ -0,0 +1,32 @@ +# Copyright 2022 veelion (veelion@gmail.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +if(GPU) + set(kaldifeat_URL "https://github.com/csukuangfj/kaldifeat/archive/refs/tags/v1.21.zip") + set(kaldifeat_HASH "SHA256=10652d930dee12d71d04da3f5b3b1bd618fa2f1af6723eb0e70d7267bfa57fe1") + set(kaldifeat_BUILD_TESTS OFF CACHE BOOL "" FORCE) + set(kaldifeat_BUILD_PYMODULE OFF CACHE BOOL "" FORCE) + set(PYTHON_EXECUTABLE "python") + list(REMOVE_AT CMAKE_MODULE_PATH 0) # hide wenet's cmake/xx.cmake from kaldifeat's + + FetchContent_Declare(kaldifeat + URL ${kaldifeat_URL} + URL_HASH ${kaldifeat_HASH} + ) + FetchContent_MakeAvailable(kaldifeat) + include_directories( + ${kaldifeat_SOURCE_DIR} + ) + list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake) # use wenet's cmake/xx.cmake +endif() diff --git a/runtime/core/cmake/libtorch.cmake b/runtime/core/cmake/libtorch.cmake index 40a64ff84f..bd4a9248fc 100644 --- a/runtime/core/cmake/libtorch.cmake +++ b/runtime/core/cmake/libtorch.cmake @@ -1,6 +1,6 @@ if(TORCH) if(NOT ANDROID) - set(PYTORCH_VERSION "1.10.0") + set(PYTORCH_VERSION "1.12.0") if(GPU) add_definitions(-DUSE_GPU) set(CUDA_NAME "cu113") @@ -20,10 +20,18 @@ if(TORCH) if(CXX11_ABI) if(NOT GPU) set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-${PYTORCH_VERSION}%2Bcpu.zip") - set(URL_HASH "SHA256=6d7be1073d1bd76f6563572b2aa5548ad51d5bc241d6895e3181b7dc25554426") + if(${PYTORCH_VERSION} STREQUAL "1.12.0") + set(URL_HASH "SHA256=0f0f36219862a4ed0ad0522c4de97e9e189194b44eb09036d2b94bea456260c6") + else() + set(URL_HASH "SHA256=6d7be1073d1bd76f6563572b2aa5548ad51d5bc241d6895e3181b7dc25554426") + endif() else() set(LIBTORCH_URL "https://download.pytorch.org/libtorch/${CUDA_NAME}/libtorch-cxx11-abi-shared-with-deps-${PYTORCH_VERSION}%2B${CUDA_NAME}.zip") - set(URL_HASH "SHA256=190e963e739d5f7c2dcf94b3994de8fcd335706a4ebb333812ea7d8c841beb06") + if(${PYTORCH_VERSION} STREQUAL "1.12.0" AND ${CUDA_NAME} STREQUAL "cu113") + set(URL_HASH "SHA256=80f089939de20e68e3fcad4dfa72a26c8bf91b5e77b11042f671f39ebac35865") + else() + set(URL_HASH "SHA256=190e963e739d5f7c2dcf94b3994de8fcd335706a4ebb333812ea7d8c841beb06") + endif() endif() else() if(NOT GPU) @@ -31,7 +39,11 @@ if(TORCH) set(URL_HASH "SHA256=16961222938b205a6a767b0b0b9f5e3b1f8740aa1f3475580e33cfd5952b1a44") else() set(LIBTORCH_URL "https://download.pytorch.org/libtorch/${CUDA_NAME}/libtorch-shared-with-deps-${PYTORCH_VERSION}%2B${CUDA_NAME}.zip") - set(URL_HASH "SHA256=0996a6a4ea8bbc1137b4fb0476eeca25b5efd8ed38955218dec1b73929090053") + if(${PYTORCH_VERSION} STREQUAL "1.12.0" AND ${CUDA_NAME} STREQUAL "cu113") + set(URL_HASH "SHA256=8e35371403f7052d9e9b43bcff383980dbde4df028986dc1dab539953481d55f") + else() + set(URL_HASH "SHA256=0996a6a4ea8bbc1137b4fb0476eeca25b5efd8ed38955218dec1b73929090053") + endif() endif() endif() elseif(${CMAKE_SYSTEM_NAME} STREQUAL "Darwin") @@ -57,7 +69,7 @@ if(TORCH) file(COPY ${TORCH_DLLS} DESTINATION ${CMAKE_BINARY_DIR}) endif() else() - # Change version in runtime/android/app/build.gradle. + # Change version in runtime/device/android/wenet/app/build.gradle. file(GLOB PYTORCH_INCLUDE_DIRS "${build_DIR}/pytorch_android*.aar/headers") file(GLOB PYTORCH_LINK_DIRS "${build_DIR}/pytorch_android*.aar/jni/${ANDROID_ABI}") find_library(PYTORCH_LIBRARY pytorch_jni diff --git a/runtime/core/cmake/onnx.cmake b/runtime/core/cmake/onnx.cmake index bd55402cb2..e159e979d4 100644 --- a/runtime/core/cmake/onnx.cmake +++ b/runtime/core/cmake/onnx.cmake @@ -1,5 +1,9 @@ if(ONNX) set(ONNX_VERSION "1.12.0") + if(GPU) + add_definitions(-DUSE_GPU) + set(ONNX_VERSION "1.13.1") + endif() if(${CMAKE_SYSTEM_NAME} STREQUAL "Windows") set(ONNX_URL "https://github.com/microsoft/onnxruntime/releases/download/v${ONNX_VERSION}/onnxruntime-win-x64-${ONNX_VERSION}.zip") set(URL_HASH "SHA256=8b5d61204989350b7904ac277f5fbccd3e6736ddbb6ec001e412723d71c9c176") @@ -10,6 +14,19 @@ if(ONNX) else() set(ONNX_URL "https://github.com/microsoft/onnxruntime/releases/download/v${ONNX_VERSION}/onnxruntime-linux-x64-${ONNX_VERSION}.tgz") set(URL_HASH "SHA256=5d503ce8540358b59be26c675e42081be14a3e833a5301926f555451046929c5") + if(GPU) + set(ONNX_URL "https://github.com/microsoft/onnxruntime/releases/download/v${ONNX_VERSION}/onnxruntime-linux-x64-gpu-${ONNX_VERSION}.tgz") + if(${ONNX_VERSION} STREQUAL "1.12.1") + set(URL_HASH "SHA256=41fcb4b0bb162c2788240d5f21d18714238817a78fb68e5733c5caef326a7306") + elseif(${ONNX_VERSION} STREQUAL "1.12.0") + set(URL_HASH "SHA256=bc2e615314df0a871c560b7af6d4ce5896f351d23cad476562d2715208c9c7f7") + elseif(${ONNX_VERSION} STREQUAL "1.11.1") + set(URL_HASH "SHA256=b96e3e266f66f6e1293841e0a5b5ec3b0a602512d68e5cc73c014546092c87c8") + elseif(${ONNX_VERSION} STREQUAL "1.13.1") + set(URL_HASH "SHA256=7725c232c78b9b49037fa7409f3ae255ba81d9a7e1af910c2443b1174171d8b1") + endif() + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mf16c") + endif() endif() elseif(${CMAKE_SYSTEM_NAME} STREQUAL "Darwin") set(ONNX_URL "https://github.com/microsoft/onnxruntime/releases/download/v${ONNX_VERSION}/onnxruntime-osx-x86_64-${ONNX_VERSION}.tgz") diff --git a/runtime/core/decoder/CMakeLists.txt b/runtime/core/decoder/CMakeLists.txt index 098fdcdb5e..066620720f 100644 --- a/runtime/core/decoder/CMakeLists.txt +++ b/runtime/core/decoder/CMakeLists.txt @@ -5,16 +5,17 @@ set(decoder_srcs ctc_prefix_beam_search.cc ctc_wfst_beam_search.cc ctc_endpoint.cc + batch_asr_decoder.cc ) if(NOT TORCH AND NOT ONNX AND NOT XPU) message(FATAL_ERROR "Please build with TORCH or ONNX or XPU!!!") endif() if(TORCH) - list(APPEND decoder_srcs torch_asr_model.cc) + list(APPEND decoder_srcs torch_asr_model.cc batch_torch_asr_model.cc) endif() if(ONNX) - list(APPEND decoder_srcs onnx_asr_model.cc) + list(APPEND decoder_srcs onnx_asr_model.cc batch_onnx_asr_model.cc) endif() add_library(decoder STATIC ${decoder_srcs}) diff --git a/runtime/core/decoder/asr_decoder.h b/runtime/core/decoder/asr_decoder.h index df71f5b7ba..31c3a99a7d 100644 --- a/runtime/core/decoder/asr_decoder.h +++ b/runtime/core/decoder/asr_decoder.h @@ -26,6 +26,7 @@ #include "fst/symbol-table.h" #include "decoder/asr_model.h" +#include "decoder/batch_asr_model.h" #include "decoder/context_graph.h" #include "decoder/ctc_endpoint.h" #include "decoder/ctc_prefix_beam_search.h" @@ -90,6 +91,7 @@ enum DecodeState { // decoding threads struct DecodeResource { std::shared_ptr model = nullptr; + std::shared_ptr batch_model = nullptr; std::shared_ptr symbol_table = nullptr; std::shared_ptr> fst = nullptr; std::shared_ptr unit_table = nullptr; diff --git a/runtime/core/decoder/batch_asr_decoder.cc b/runtime/core/decoder/batch_asr_decoder.cc new file mode 100644 index 0000000000..892328d399 --- /dev/null +++ b/runtime/core/decoder/batch_asr_decoder.cc @@ -0,0 +1,364 @@ +// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang, Di Wu) +// 2022 Binbin Zhang (binbzha@qq.com) +// 2022 SoundDataConverge Co.LTD (Weiliang Chong) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#include "decoder/batch_asr_decoder.h" + +#include + +#include +#include +#include +#include + +#include "utils/timer.h" + +namespace wenet { + +BatchAsrDecoder::BatchAsrDecoder(std::shared_ptr config, + std::shared_ptr resource, + const DecodeOptions& opts) + : feature_config_(config), + beam_size_(opts.ctc_prefix_search_opts.first_beam_size), + fbank_(config->num_bins, config->sample_rate, + config->frame_length, config->frame_shift), + fbank_cuda_(config->num_bins, config->sample_rate), + model_(resource->batch_model->Copy()), + post_processor_(resource->post_processor), + symbol_table_(resource->symbol_table), + fst_(resource->fst), + unit_table_(resource->unit_table), + resource_(resource), + opts_(opts) { + if (opts_.reverse_weight > 0) { + // Check if model has a right to left decoder + CHECK(model_->is_bidirectional_decoder()); + } + if (nullptr == fst_) { + searcher_.reset(new CtcPrefixBeamSearch(opts.ctc_prefix_search_opts, + resource->context_graph)); + } else { + searcher_.reset(new CtcWfstBeamSearch(*fst_, opts.ctc_wfst_search_opts, + resource->context_graph)); + } +} + +void BatchAsrDecoder::Reset() { + batch_result_.clear(); + searcher_->Reset(); +} + +void BatchAsrDecoder::SearchWorker( + const std::vector>& topk_scores, + const std::vector>& topk_indexs, + int index) { + Timer ctc_timer; + std::unique_ptr searcher; + if (nullptr == fst_) { + searcher.reset(new CtcPrefixBeamSearch(opts_.ctc_prefix_search_opts, + resource_->context_graph)); + } else { + searcher.reset(new CtcWfstBeamSearch(*fst_, opts_.ctc_wfst_search_opts, + resource_->context_graph)); + } + // 3.1. ctc search + ctc_timer.Reset(); + searcher->Search(topk_scores, topk_indexs); + searcher->FinalizeSearch(); + std::vector result; + UpdateResult(searcher.get(), &result); + VLOG(1) << "\tctc search i==" << index + << " takes " << ctc_timer.Elapsed() << " ms"; + std::lock_guard lock(mutex_); + batch_pair_result_.emplace_back(std::make_pair(index, std::move(result))); + const auto& hypotheses = searcher->Inputs(); + if (hypotheses.size() < beam_size_) { + VLOG(2) << "=== searcher->Inputs() size < beam_size_, padding..."; + std::vector> hyps = hypotheses; + int to_pad = beam_size_ - hypotheses.size(); + for (size_t i = 0; i < to_pad; i++) { + std::vector pad = {0}; + hyps.push_back(std::move(pad)); + } + batch_hyps_.emplace_back(std::make_pair(index, std::move(hyps))); + } else { + batch_hyps_.emplace_back(std::make_pair(index, std::move(hypotheses))); + } +} + +void BatchAsrDecoder::FbankWorker(const std::vector& wav, int index) { + Timer timer; + feature_t feats; + int num_frames = fbank_.Compute(wav, &feats); + std::lock_guard lock(mutex_); + batch_feats_.push_back(std::make_pair(index, std::move(feats))); + batch_feats_lens_.push_back(std::make_pair(index, num_frames)); + VLOG(1) << "\tfeature comput i==" << index + << ", takes " << timer.Elapsed() << " ms."; +} + +void BatchAsrDecoder::ComputeFeatureCpu( + const std::vector>& wavs, + batch_feature_t* feats, + std::vector* feats_lens) { + Timer timer; + batch_feature_t& batch_feats = *feats; + std::vector& batch_feats_lens = *feats_lens; + if (wavs.size() > 1) { + std::vector fbank_threads; + for (size_t i = 0; i < wavs.size(); i++) { + const std::vector& wav = wavs[i]; + std::thread thd(&BatchAsrDecoder::FbankWorker, this, wav, i); + fbank_threads.push_back(std::move(thd)); + } + for (auto& thd : fbank_threads) { + thd.join(); + } + std::sort(batch_feats_.begin(), batch_feats_.end()); + std::sort(batch_feats_lens_.begin(), batch_feats_lens_.end()); + for (auto& pair : batch_feats_) { + batch_feats.push_back(std::move(pair.second)); + } + for (auto& pair : batch_feats_lens_) { + batch_feats_lens.push_back(pair.second); + } + } else { + // only one wave + feature_t feats; + int num_frames = fbank_.Compute(wavs[0], &feats); + batch_feats.push_back(feats); + batch_feats_lens.push_back(num_frames); + } + VLOG(1) << "feature Compute takes " << timer.Elapsed() << " ms."; + + // 1.1 feature padding + if (wavs.size() > 1) { + timer.Reset(); + int max_len = *std::max_element( + batch_feats_lens.begin(), batch_feats_lens.end()); + for (auto& feat : batch_feats) { + if (feat.size() == max_len) continue; + int pad_len = max_len - feat.size(); + for (size_t i = 0; i< pad_len; i++) { + std::vector one(feature_config_->num_bins, 0.0); + feat.push_back(std::move(one)); + } + } + VLOG(1) << "padding feautre takes " << timer.Elapsed() << " ms."; + }} + +void BatchAsrDecoder::Decode(const std::vector>& wavs) { + // 1. calc fbank feature of the batch of wavs + std::vector>> batch_topk_scores; + std::vector>> batch_topk_indexs; + Timer timer; + bool gpu_feature = true; + if (gpu_feature) { + std::vector batch_feats_lens; + timer.Reset(); + auto batch_feats = fbank_cuda_.Compute(wavs, &batch_feats_lens); + VLOG(1) << "fbank_cuda_.Comput() takes " << timer.Elapsed() << " ms."; + timer.Reset(); + // 2. encoder forward + model_->ForwardEncoder( + batch_feats, batch_feats_lens, &batch_topk_scores, &batch_topk_indexs); + VLOG(1) << "encoder forward takes " << timer.Elapsed() << " ms."; + + } else { + batch_feature_t batch_feats; + std::vector batch_feats_lens; + ComputeFeatureCpu(wavs, &batch_feats, &batch_feats_lens); + timer.Reset(); + // 2. encoder forward + model_->ForwardEncoder( + batch_feats, batch_feats_lens, &batch_topk_scores, &batch_topk_indexs); + VLOG(1) << "encoder forward takes " << timer.Elapsed() << " ms."; + } + + // 3. ctc search one by one of the batch + // create batch of tct search result for attention decoding + timer.Reset(); + int batch_size = wavs.size(); + std::vector>> batch_hyps; + if (batch_size > 1) { + batch_pair_result_.clear(); + batch_hyps_.clear(); + std::vector search_threads; + for (size_t i = 0; i < batch_size; i++) { + const auto& topk_scores = batch_topk_scores[i]; + const auto& topk_indexs = batch_topk_indexs[i]; + std::thread thd( + &BatchAsrDecoder::SearchWorker, this, topk_scores, topk_indexs, i); + search_threads.push_back(std::move(thd)); + } + for (auto& thd : search_threads) { + thd.join(); + } + std::sort(batch_hyps_.begin(), batch_hyps_.end()); + std::sort(batch_pair_result_.begin(), batch_pair_result_.end(), + [](auto& a, auto& b) { + return a.first < b.first; }); + for (auto& pair : batch_hyps_) { + batch_hyps.push_back(std::move(pair.second)); + } + batch_result_.clear(); + for (auto& pair : batch_pair_result_) { + batch_result_.push_back(std::move(pair.second)); + } + } else { + // one wav + searcher_->Search(batch_topk_scores[0], batch_topk_indexs[0]); + searcher_->FinalizeSearch(); + std::vector result; + UpdateResult(searcher_.get(), &result); + batch_result_.push_back(std::move(result)); + const auto& hypotheses = searcher_->Inputs(); + if (hypotheses.size() < beam_size_) { + VLOG(2) << "=== searcher->Inputs() size < beam_size_, padding..."; + std::vector> hyps = hypotheses; + int to_pad = beam_size_ - hypotheses.size(); + for (size_t i = 0; i < to_pad; i++) { + std::vector pad = {0}; + hyps.push_back(std::move(pad)); + } + batch_hyps.push_back(std::move(hyps)); + } else { + batch_hyps.push_back(std::move(hypotheses)); + } + } + VLOG(1) << "ctc search batch(" << batch_size << ") takes " + << timer.Elapsed() << " ms."; + std::vector> ctc_scores(batch_size); + for (int i = 0; i < batch_result_.size(); ++i) { + ctc_scores[i].resize(beam_size_); + for (int j = 0; j < beam_size_; ++j) { + ctc_scores[i][j] = batch_result_[i][j].score; + } + } + // 4. attention rescoring + timer.Reset(); + std::vector> attention_scores; + model_->AttentionRescoring(batch_hyps, ctc_scores, &attention_scores); + VLOG(1) << "attention rescoring takes " << timer.Elapsed() << " ms."; + for (size_t i = 0; i < batch_size; i++) { + std::vector& result = batch_result_[i]; + for (size_t j = 0; j < beam_size_; j++) { + result[j].score = attention_scores[i][j]; + } + std::sort(result.begin(), result.end(), DecodeResult::CompareFunc); + } +} + +void BatchAsrDecoder::UpdateResult(SearchInterface* searcher, + std::vector* result) { + bool finish = true; + const auto& hypotheses = searcher->Outputs(); + const auto& inputs = searcher->Inputs(); + const auto& likelihood = searcher->Likelihood(); + const auto& times = searcher->Times(); + result->clear(); + + CHECK_EQ(hypotheses.size(), likelihood.size()); + for (size_t i = 0; i < hypotheses.size(); i++) { + const std::vector& hypothesis = hypotheses[i]; + + DecodeResult path; + path.score = likelihood[i]; + for (size_t j = 0; j < hypothesis.size(); j++) { + std::string word = symbol_table_->Find(hypothesis[j]); + // A detailed explanation of this if-else branch can be found in + // https://github.com/wenet-e2e/wenet/issues/583#issuecomment-907994058 + if (searcher->Type() == kWfstBeamSearch) { + path.sentence += (' ' + word); + } else { + path.sentence += (word); + } + } + + // TimeStamp is only supported in final result + // TimeStamp of the output of CtcWfstBeamSearch may be inaccurate due to + // various FST operations when building the decoding graph. So here we use + // time stamp of the input(e2e model unit), which is more accurate, and it + // requires the symbol table of the e2e model used in training. + if (unit_table_ != nullptr && finish) { + const std::vector& input = inputs[i]; + const std::vector& time_stamp = times[i]; + CHECK_EQ(input.size(), time_stamp.size()); + for (size_t j = 0; j < input.size(); j++) { + std::string word = unit_table_->Find(input[j]); + int start = time_stamp[j] * frame_shift_in_ms() - time_stamp_gap_ > 0 + ? time_stamp[j] * frame_shift_in_ms() - time_stamp_gap_ + : 0; + if (j > 0) { + start = (time_stamp[j] - time_stamp[j - 1]) * frame_shift_in_ms() < + time_stamp_gap_ + ? (time_stamp[j - 1] + time_stamp[j]) / 2 * + frame_shift_in_ms() + : start; + } + int end = time_stamp[j] * frame_shift_in_ms(); + if (j < input.size() - 1) { + end = (time_stamp[j + 1] - time_stamp[j]) * frame_shift_in_ms() < + time_stamp_gap_ + ? (time_stamp[j + 1] + time_stamp[j]) / 2 * + frame_shift_in_ms() + : end; + } + WordPiece word_piece(word, start, end); + path.word_pieces.emplace_back(word_piece); + } + } + + if (post_processor_ != nullptr) { + path.sentence = post_processor_->Process(path.sentence, finish); + } + result->emplace_back(path); + } +} + +const std::string BatchAsrDecoder::get_batch_result(int nbest, + bool enable_timestamp) { + json::JSON obj; + obj["status"] = "ok"; + obj["type"] = "final_result"; + obj["batch_size"] = batch_result_.size(); + obj["batch_result"] = json::Array(); + for (const auto& result : batch_result_) { + json::JSON batch_one; + batch_one["nbest"] = json::Array(); + for (int i = 0; i < nbest && i < result.size(); i++) { + json::JSON one; + one["sentence"] = result[i].sentence; + // one["score"] = result[i].score; + if (enable_timestamp) { + one["word_pieces"] = json::Array(); + for (const auto& word_piece : result[i].word_pieces) { + json::JSON piece; + piece["word"] = word_piece.word; + piece["start"] = word_piece.start; + piece["end"] = word_piece.end; + one["word_pieces"].append(piece); + } + } + one["sentence"] = result[i].sentence; + batch_one["nbest"].append(one); + } + obj["batch_result"].append(batch_one); + } + return obj.dump(); + } + +} // namespace wenet diff --git a/runtime/core/decoder/batch_asr_decoder.h b/runtime/core/decoder/batch_asr_decoder.h new file mode 100644 index 0000000000..a3ec6c3e98 --- /dev/null +++ b/runtime/core/decoder/batch_asr_decoder.h @@ -0,0 +1,112 @@ +// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang, Di Wu) +// 2022 Binbin Zhang (binbzha@qq.com) +// 2022 SoundDataConverge Co.LTD (Weiliang Chong) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#ifndef DECODER_BATCH_ASR_DECODER_H_ +#define DECODER_BATCH_ASR_DECODER_H_ + +#include +#include +#include +#include + +#include "fst/fstlib.h" +#include "fst/symbol-table.h" + +#include "decoder/batch_asr_model.h" +#include "decoder/asr_decoder.h" +#include "decoder/context_graph.h" +#include "decoder/ctc_prefix_beam_search.h" +#include "decoder/ctc_wfst_beam_search.h" +#include "decoder/search_interface.h" +#include "frontend/feature_pipeline.h" +#include "post_processor/post_processor.h" +#include "utils/utils.h" +#include "frontend/fbank.h" +#include "utils/json.h" +#include "frontend/fbank_cuda.h" + +namespace wenet { + +// Torch ASR batch decoder +class BatchAsrDecoder { + public: + BatchAsrDecoder(std::shared_ptr feature_config, + std::shared_ptr resource, + const DecodeOptions& opts); + void Decode(const std::vector>& wavs); + void Reset(); + + int frame_shift_in_ms() const { + return model_->subsampling_rate() * + feature_config_->frame_shift * 1000 / + feature_config_->sample_rate; + } + int feature_frame_shift_in_ms() const { + return feature_config_->frame_shift * 1000 / + feature_config_->sample_rate; + } + const std::vector>& batch_result() const { + return batch_result_; } + const std::string get_batch_result(int nbest, bool enable_timestamp); + + private: + Fbank fbank_; + FbankCuda fbank_cuda_; + + void ComputeFeatureCpu( + const std::vector>& wavs, + batch_feature_t* batch_feats, + std::vector* batch_feats_lens); + void FbankWorker(const std::vector& wav, int index); + std::vector> batch_feats_; // for FbankWorker + std::vector> batch_feats_lens_; // for FbankWorker + + void SearchWorker( + const std::vector>& topk_scores, + const std::vector>& topk_indexs, + int index); + std::mutex mutex_; + // for SearchWorker + std::vector>>> batch_hyps_; + std::vector>> batch_pair_result_; + std::vector> batch_result_; + + void UpdateResult(SearchInterface* searcher, + std::vector* result); + + std::shared_ptr feature_config_; + std::shared_ptr model_; + std::shared_ptr post_processor_; + + std::shared_ptr> fst_ = nullptr; + // output symbol table + std::shared_ptr symbol_table_; + // e2e unit symbol table + std::shared_ptr unit_table_ = nullptr; + std::shared_ptr resource_ = nullptr; + const DecodeOptions& opts_; + int beam_size_; + const int time_stamp_gap_ = 100; // timestamp gap between words in a sentence + std::unique_ptr searcher_; + + public: + WENET_DISALLOW_COPY_AND_ASSIGN(BatchAsrDecoder); +}; + +} // namespace wenet + +#endif // DECODER_BATCH_ASR_DECODER_H_ diff --git a/runtime/core/decoder/batch_asr_model.h b/runtime/core/decoder/batch_asr_model.h new file mode 100644 index 0000000000..073d821a4c --- /dev/null +++ b/runtime/core/decoder/batch_asr_model.h @@ -0,0 +1,63 @@ +// Copyright 2022 Horizon Robotics. All Rights Reserved. +// Author: binbin.zhang@horizon.ai (Binbin Zhang) +// SoundDataConverge Co.LTD (Weiliang Chong) + +#ifndef DECODER_BATCH_ASR_MODEL_H_ +#define DECODER_BATCH_ASR_MODEL_H_ + +#include +#include +#include +#include + +#include "torch/torch.h" +#include "utils/timer.h" +#include "utils/utils.h" + +namespace wenet { + +using feature_t = std::vector>; +using batch_feature_t = std::vector; +using ctc_log_prob_t = std::vector>; +using batch_ctc_log_prob_t = std::vector; + +class BatchAsrModel { + public: + virtual int right_context() const { return right_context_; } + virtual int subsampling_rate() const { return subsampling_rate_; } + virtual int sos() const { return sos_; } + virtual int eos() const { return eos_; } + virtual bool is_bidirectional_decoder() const { + return is_bidirectional_decoder_; + } + + virtual void ForwardEncoder( + const batch_feature_t& batch_feats, + const std::vector& batch_feats_lens, + std::vector>>* batch_topk_scores, + std::vector>>* batch_topk_indexs) = 0; + virtual void ForwardEncoder( + const std::vector& batch_feats, + const std::vector& batch_feats_lens, + std::vector>>* batch_topk_scores, + std::vector>>* batch_topk_indexs) {}; + + virtual void AttentionRescoring( + const std::vector>>& batch_hyps, + const std::vector>& ctc_scores, + std::vector>* attention_scores) = 0; + + virtual std::shared_ptr Copy() const = 0; + + protected: + int right_context_ = 1; + int subsampling_rate_ = 1; + int sos_ = 0; + int eos_ = 0; + bool is_bidirectional_decoder_ = false; + bool is_fp16_ = false; +}; + +} // namespace wenet + +#endif // DECODER_BATCH_ASR_MODEL_H_ diff --git a/runtime/core/decoder/batch_onnx_asr_model.cc b/runtime/core/decoder/batch_onnx_asr_model.cc new file mode 100644 index 0000000000..bad6c54019 --- /dev/null +++ b/runtime/core/decoder/batch_onnx_asr_model.cc @@ -0,0 +1,505 @@ +// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang, Di Wu) +// 2022 ZeXuan Li (lizexuan@huya.com) +// Xingchen Song(sxc19@mails.tsinghua.edu.cn) +// hamddct@gmail.com (Mddct) +// 2022 SoundDataConverge Co.LTD (Weiliang Chong) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#include "decoder/batch_onnx_asr_model.h" + +#include +#include +#include +#include + +#include "glog/logging.h" +#include "utils/string.h" +#include "utils/Yaml.hpp" +#include "utils/timer.h" + +namespace wenet { + +Ort::Env BatchOnnxAsrModel::env_ = Ort::Env(ORT_LOGGING_LEVEL_VERBOSE, ""); +Ort::SessionOptions BatchOnnxAsrModel::session_options_ = Ort::SessionOptions(); +Ort::RunOptions BatchOnnxAsrModel::run_option_ = Ort::RunOptions(); +std::vector BatchOnnxAsrModel::node_names_; + +void BatchOnnxAsrModel::InitEngineThreads(int num_threads) { + session_options_.SetIntraOpNumThreads(num_threads); + session_options_.SetInterOpNumThreads(num_threads); +} + +void BatchOnnxAsrModel::GetInputOutputInfo( + const std::shared_ptr& session, + std::vector* in_names, std::vector* out_names) { + Ort::AllocatorWithDefaultOptions allocator; + // Input info + int num_nodes = session->GetInputCount(); + in_names->resize(num_nodes); + for (int i = 0; i < num_nodes; ++i) { + auto name = session->GetInputNameAllocated(i, allocator); + Ort::TypeInfo type_info = session->GetInputTypeInfo(i); + auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); + ONNXTensorElementDataType type = tensor_info.GetElementType(); + std::vector node_dims = tensor_info.GetShape(); + std::stringstream shape; + for (auto j : node_dims) { + shape << j; + shape << " "; + } + LOG(INFO) << "\tInput " << i << " : name=" << name.get() + << " type=" << type << " dims=" << shape.str(); + node_names_.push_back(std::move(name)); + (*in_names)[i] = node_names_.back().get(); + } + // Output info + num_nodes = session->GetOutputCount(); + out_names->resize(num_nodes); + for (int i = 0; i < num_nodes; ++i) { + auto name = session->GetOutputNameAllocated(i, allocator); + Ort::TypeInfo type_info = session->GetOutputTypeInfo(i); + auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); + ONNXTensorElementDataType type = tensor_info.GetElementType(); + std::vector node_dims = tensor_info.GetShape(); + std::stringstream shape; + for (auto j : node_dims) { + shape << j; + shape << " "; + } + LOG(INFO) << "\tOutput " << i << " : name=" << name.get() + << " type=" << type << " dims=" << shape.str(); + node_names_.push_back(std::move(name)); + (*out_names)[i] = node_names_.back().get(); + } +} + +void BatchOnnxAsrModel::Read(const std::string& model_dir, + bool is_fp16, int gpu_id) { + is_fp16_ = is_fp16; + VLOG(1) << "is_fp16_ " << is_fp16_; + std::vector providers = Ort::GetAvailableProviders(); + VLOG(1) << "providers.size(): " << providers.size(); + bool cuda_is_available = false; + for (auto& prd : providers) { + VLOG(1) << "available provider: " << prd; + if (prd.find("CUDA") != std::string::npos) { + cuda_is_available = true; + } + } + if (!cuda_is_available) { + VLOG(1) << "CUDA is not available! Please check your GPU settings!"; + throw std::runtime_error("CUDA is not available!"); + } + std::string encoder_onnx_path = model_dir + "/encoder.onnx"; + std::string rescore_onnx_path = model_dir + "/decoder.onnx"; + if (is_fp16) { + encoder_onnx_path = model_dir + "/encoder_fp16.onnx"; + rescore_onnx_path = model_dir + "/decoder_fp16.onnx"; + } + + // release GPU memory: + // https://github.com/microsoft/onnxruntime/issues/9509#issuecomment-951546580 + // 1. Not allocate weights memory through the arena + session_options_.AddConfigEntry( + kOrtSessionOptionsUseDeviceAllocatorForInitializers, "1"); + // 2. Configure the arena to have high enough initial chunk + // to support most Run() calls. See "initial_chunk_size_bytes" + const char* keys[] = { + "max_mem", "arena_extend_strategy", "initial_chunk_size_bytes", + "max_dead_bytes_per_chunk", "initial_growth_chunk_size_bytes"}; + const size_t values[] = {0, 0, 1024, 0, 256}; + + OrtArenaCfg* arena_cfg = nullptr; + const auto& api = Ort::GetApi(); + auto zz = api.CreateArenaCfgV2(keys, values, 5, &arena_cfg); + std::unique_ptr rel_arena_cfg( + arena_cfg, api.ReleaseArenaCfg); + + OrtCUDAProviderOptions cuda_options{}; + + cuda_options.device_id = 0; + cuda_options.cudnn_conv_algo_search = + OrtCudnnConvAlgoSearch::OrtCudnnConvAlgoSearchExhaustive; + // cuda_options.gpu_mem_limit = 16 * 1024 * 1024 * 1024ul; + cuda_options.arena_extend_strategy = 1; + cuda_options.do_copy_in_default_stream = true; + cuda_options.has_user_compute_stream = 0; + cuda_options.user_compute_stream = nullptr; + // TODO(veelion): arena_cfg didn't work, it blocked when session.Run() + // Just comment this out until find a work way. + cuda_options.default_memory_arena_cfg = arena_cfg; + session_options_.AppendExecutionProvider_CUDA(cuda_options); + + /* TODO(veelion): use OrtCUDAProviderOptionsV2 until it support ArenaCfg + // 1. Load sessions + // config for CUDA + std::string device_id = std::to_string(gpu_id); + std::vector keys2{ + "device_id", + "gpu_mem_limit", + "arena_extend_strategy", + "cudnn_conv_algo_search", + "do_copy_in_default_stream", + "cudnn_conv_use_max_workspace", + "cudnn_conv1d_pad_to_nc1d" // supported from 1.12.0 + }; + std::vector values2{ + device_id.data(), + //"2147483648", + "8589934592", + "kSameAsRequested", + "DEFAULT", + "1", + "1", + "1" + }; + + const auto& api = Ort::GetApi(); + OrtCUDAProviderOptionsV2* cuda_options = nullptr; + Ort::ThrowOnError(api.CreateCUDAProviderOptions(&cuda_options)); + Ort::ThrowOnError(api.UpdateCUDAProviderOptions( + cuda_options, keys2.data(), values2.data(), keys2.size())); + Ort::ThrowOnError(api.SessionOptionsAppendExecutionProvider_CUDA_V2( + session_options_, cuda_options)); + api.ReleaseCUDAProviderOptions(cuda_options); + */ + + try { + encoder_session_ = std::make_shared( + env_, encoder_onnx_path.c_str(), session_options_); + rescore_session_ = std::make_shared( + env_, rescore_onnx_path.c_str(), session_options_); + } catch (std::exception const& e) { + LOG(ERROR) << "error when load onnx model: " << e.what(); + exit(0); + } + std::cout << "read onnx model done \n"; + + // 2. Read config + std::string config_path = JoinPath(model_dir, "config.yaml"); + VLOG(1) << "Read " << config_path; + Yaml::Node root; + Yaml::Parse(root, config_path.c_str()); + sos_ = root["sos"].As(); + eos_ = root["eos"].As(); + is_bidirectional_decoder_ = root["is_bidirectional_decoder"].As(); + + LOG(INFO) << "Onnx Model Info:"; + LOG(INFO) << "\tsos " << sos_; + LOG(INFO) << "\teos " << eos_; + LOG(INFO) << "\tis bidirectional decoder " << is_bidirectional_decoder_; + + // 3. Read model nodes + LOG(INFO) << "Onnx Encoder:"; + GetInputOutputInfo(encoder_session_, &encoder_in_names_, &encoder_out_names_); + LOG(INFO) << "Onnx Rescore:"; + GetInputOutputInfo(rescore_session_, &rescore_in_names_, &rescore_out_names_); +} + +BatchOnnxAsrModel::BatchOnnxAsrModel(const BatchOnnxAsrModel& other) { + // metadatas + sos_ = other.sos_; + eos_ = other.eos_; + is_bidirectional_decoder_ = other.is_bidirectional_decoder_; + is_fp16_ = other.is_fp16_; + + // sessions + encoder_session_ = other.encoder_session_; + rescore_session_ = other.rescore_session_; + + // node names + encoder_in_names_ = other.encoder_in_names_; + encoder_out_names_ = other.encoder_out_names_; + rescore_in_names_ = other.rescore_in_names_; + rescore_out_names_ = other.rescore_out_names_; +} + +std::shared_ptr BatchOnnxAsrModel::Copy() const { + auto asr_model = std::make_shared(*this); + // Reset the inner states for new decoding + return asr_model; +} + +void BatchOnnxAsrModel::ForwardEncoder( + const batch_feature_t& batch_feats, + const std::vector& batch_feats_lens, + std::vector>>* batch_topk_scores, + std::vector>>* batch_topk_indexs) { + Ort::MemoryInfo memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); + // 1. Prepare onnx required data + int batch_size = batch_feats.size(); + int num_frames = batch_feats[0].size(); + int feature_dim = batch_feats[0][0].size(); + + // generate data for CreateTensor + Ort::Value feats_ort{nullptr}; + // https://github.com/microsoft/onnxruntime/issues/9629#issuecomment-963828881 + // Ort::Value::CreateTensor does NOT copy the data + std::vector feats_fp16; // for holding feats of fp16 + std::vector feats_fp32; // for holding feats of float + + // speech + const int64_t feats_shape[3] = {batch_size, num_frames, feature_dim}; + Timer timer; + if (is_fp16_) { + feats_fp16.resize(batch_size * num_frames * feature_dim); + for (size_t i = 0; i < batch_size; ++i) { + for (size_t j = 0; j < num_frames; ++j) { + for (size_t k = 0; k < feature_dim; ++k) { + int p = i * num_frames * feature_dim + j * feature_dim + k; + feats_fp16[p] = Ort::Float16_t(_cvtss_sh(batch_feats[i][j][k], 0)); + } + } + } + auto tensor = Ort::Value::CreateTensor( + memory_info, + feats_fp16.data(), + feats_fp16.size(), + feats_shape, 3); + feats_ort = std::move(tensor); + VLOG(1) << "feats to fp16 takes " << timer.Elapsed() << " ms."; + } else { + for (size_t i = 0; i < batch_size; ++i) { + for (size_t j = 0; j < num_frames; ++j) { + feats_fp32.insert(feats_fp32.end(), batch_feats[i][j].begin(), + batch_feats[i][j].end()); + } + } + feats_ort = std::move(Ort::Value::CreateTensor( + memory_info, feats_fp32.data(), feats_fp32.size(), feats_shape, 3)); + } + + // speech_lens + const int64_t feats_lens_shape[1] = {batch_size}; + Ort::Value feats_lens_ort = Ort::Value::CreateTensor( + memory_info, const_cast(batch_feats_lens.data()), + batch_feats_lens.size(), feats_lens_shape, 1); + + // 2. Encoder forward + std::vector inputs; + for (auto name : encoder_in_names_) { + if (!strcmp(name, "speech")) { + inputs.push_back(std::move(feats_ort)); + } else if (!strcmp(name, "speech_lengths")) { + inputs.push_back(std::move(feats_lens_ort)); + } + } + + timer.Reset(); + Ort::RunOptions ro; + // ro.AddConfigEntry(kOrtRunOptionsConfigEnableMemoryArenaShrinkage, "gpu:0"); + ro.AddConfigEntry("memory.enable_memory_arena_shrinkage", "cpu:0;gpu:0"); + std::vector ort_outputs = encoder_session_->Run( + ro, encoder_in_names_.data(), inputs.data(), + inputs.size(), encoder_out_names_.data(), encoder_out_names_.size()); + VLOG(1) << "\tencoder ->Run() takes " << timer.Elapsed() << " ms."; + + // get topk_scores + auto out_shape = ort_outputs[3].GetTensorTypeAndShapeInfo().GetShape(); + int num_outputs = out_shape[1]; + int output_dim = out_shape[2]; + float* topk_scores_ptr = nullptr; + std::vector topk_scores_data; // for holding topk_scores in fp16 + if (is_fp16_) { + timer.Reset(); + auto probs = ort_outputs[3].GetTensorMutableData(); + int length = out_shape[0] * out_shape[1] * out_shape[2]; + topk_scores_data.resize(length); + for (size_t i = 0; i < length; ++i) { + topk_scores_data[i] = _cvtsh_ss(probs[i]); + } + topk_scores_ptr = topk_scores_data.data(); + VLOG(1) << "topk_scores from GPU-fp16 to float takes " << timer.Elapsed() + << " ms. data lenght " << length; + } else { + topk_scores_ptr = ort_outputs[3].GetTensorMutableData(); + } + + batch_topk_scores->resize(batch_size); + for (size_t i = 0; i < batch_size; ++i) { + (*batch_topk_scores)[i].resize(num_outputs); + for (size_t j = 0; j < num_outputs; j++) { + (*batch_topk_scores)[i][j].resize(output_dim); + float* p = topk_scores_ptr + (i * num_outputs + j) * output_dim; + memcpy((*batch_topk_scores)[i][j].data(), p, sizeof(float) * output_dim); + } + } + // get batch_topk_indexs + std::vector topk_indexs_data; // for holding topk_indexs from fp16 + timer.Reset(); + auto probs = ort_outputs[4].GetTensorMutableData(); + int length = out_shape[0] * out_shape[1] * out_shape[2]; + topk_indexs_data.resize(length); + for (size_t i = 0; i < length; ++i) { + topk_indexs_data[i] = probs[i]; + } + int32_t* topk_indexs_ptr = topk_indexs_data.data(); + VLOG(1) << "topk_indexs from GPU-fp16 to float takes " + << timer.Elapsed() << " ms. data lenght " << length; + + batch_topk_indexs->resize(batch_size); + for (size_t i = 0; i < batch_size; ++i) { + (*batch_topk_indexs)[i].resize(num_outputs); + for (size_t j = 0; j < num_outputs; j++) { + (*batch_topk_indexs)[i][j].resize(output_dim); + int32_t* p = topk_indexs_ptr + (i * num_outputs + j) * output_dim; + memcpy((*batch_topk_indexs)[i][j].data(), p, + sizeof(int32_t) * output_dim); + } + } + // 3. cache encoder outs + encoder_outs_ = std::move(ort_outputs[0]); + encoder_outs_lens_ = std::move(ort_outputs[1]); +} + +void BatchOnnxAsrModel::AttentionRescoring( + const std::vector>>& batch_hyps, + const std::vector>& ctc_scores, + std::vector>* attention_scores) { + Ort::MemoryInfo memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); + // 1. prepare input for onnx + int batch_size = batch_hyps.size(); + int beam_size = batch_hyps[0].size(); + + // 1.1 generate hyps_lens_sos data for ort (batch_size, beam_size) + std::vector hyps_lens_sos(batch_size * beam_size, 0); + int max_hyps_len = 0; + for (size_t i = 0; i < batch_size; ++i) { + for (size_t j = 0; j < beam_size; ++j) { + int length = batch_hyps[i][j].size() + 1; + max_hyps_len = std::max(length, max_hyps_len); + hyps_lens_sos[i * beam_size + j] = length; + } + } + + // 1.2 generate hyps_pad_sos_eos, r_hyps_pad_sos_eos + std::vector hyps_pad_sos_eos( + batch_size * beam_size * (max_hyps_len + 1), 0); + std::vector r_hyps_pad_sos_eos( + batch_size * beam_size * (max_hyps_len + 1), 0); + for (size_t i = 0; i < batch_size; ++i) { + for (size_t j = 0; j < beam_size; ++j) { + const std::vector& hyps = batch_hyps[i][j]; + hyps_pad_sos_eos[i * beam_size * max_hyps_len] = sos_; + size_t hyps_len = hyps.size(); + for (size_t k = 0; k < hyps_len; ++k) { + size_t p = i * beam_size * max_hyps_len + j * max_hyps_len + k + 1; + hyps_pad_sos_eos[p] = hyps[k]; + r_hyps_pad_sos_eos[p] = hyps[hyps_len - 1 - k]; + } + size_t p = i * beam_size * max_hyps_len + + j * max_hyps_len + hyps.size() + 1; + hyps_pad_sos_eos[p] = eos_; + r_hyps_pad_sos_eos[p] = eos_; + } + } + + // 1.3 ctc_scores_data + Ort::Value ctc_scores_tensor{nullptr}; + std::vector ctc_fp16; + std::vector ctc_fp32; + const int64_t ctc_shape[] = {batch_size, beam_size}; + if (is_fp16_) { + ctc_fp16.resize(batch_size * beam_size); + for (size_t i = 0; i < batch_size; ++i) { + for (size_t j = 0; j < beam_size; ++j) { + int p = i * beam_size + j; + ctc_fp16[p] = Ort::Float16_t(_cvtss_sh(ctc_scores[i][j], 0)); + } + } + ctc_scores_tensor = std::move(Ort::Value::CreateTensor( + memory_info, ctc_fp16.data(), ctc_fp16.size(), ctc_shape, 2)); + } else { + ctc_fp32.resize(batch_size * beam_size); + for (size_t i = 0; i < batch_size; ++i) { + memcpy(ctc_fp32.data() + i * beam_size, + ctc_scores[i].data(), sizeof(float) * beam_size); + } + ctc_scores_tensor = std::move(Ort::Value::CreateTensor( + memory_info, ctc_fp32.data(), ctc_fp32.size(), ctc_shape, 2)); + } + + // 2. forward attetion decoder + const int64_t hyps_lens_shape[] = {batch_size, beam_size}; + const int64_t hyps_pad_shape[] = {batch_size, beam_size, max_hyps_len}; + + Ort::Value hyps_lens_tensor = Ort::Value::CreateTensor( + memory_info, hyps_lens_sos.data(), + hyps_lens_sos.size(), hyps_lens_shape, 2); + Ort::Value hyps_pad_tensor = Ort::Value::CreateTensor( + memory_info, hyps_pad_sos_eos.data(), + hyps_pad_sos_eos.size(), hyps_pad_shape, 3); + Ort::Value r_hyps_pad_tensor = Ort::Value::CreateTensor( + memory_info, r_hyps_pad_sos_eos.data(), + r_hyps_pad_sos_eos.size(), hyps_pad_shape, 3); + + std::vector rescore_inputs; + for (auto name : rescore_in_names_) { + if (!strcmp(name, "encoder_out")) { + rescore_inputs.push_back(std::move(encoder_outs_)); + } else if (!strcmp(name, "encoder_out_lens")) { + rescore_inputs.push_back(std::move(encoder_outs_lens_)); + } else if (!strcmp(name, "hyps_pad_sos_eos")) { + rescore_inputs.push_back(std::move(hyps_pad_tensor)); + } else if (!strcmp(name, "hyps_lens_sos")) { + rescore_inputs.push_back(std::move(hyps_lens_tensor)); + } else if (!strcmp(name, "r_hyps_pad_sos_eos")) { + rescore_inputs.push_back(std::move(r_hyps_pad_tensor)); + } else if (!strcmp(name, "ctc_score")) { + rescore_inputs.push_back(std::move(ctc_scores_tensor)); + } else { + VLOG(1) << "invalid input name " << name; + } + } + + Timer timer; + Ort::RunOptions ro; + // ro.AddConfigEntry(kOrtRunOptionsConfigEnableMemoryArenaShrinkage, "gpu:0"); + ro.AddConfigEntry("memory.enable_memory_arena_shrinkage", "cpu:0;gpu:0"); + std::vector rescore_outputs = rescore_session_->Run( + ro, rescore_in_names_.data(), rescore_inputs.data(), + rescore_inputs.size(), rescore_out_names_.data(), + rescore_out_names_.size()); + VLOG(1) << "decoder->Run() takes " << timer.Elapsed() << " ms."; + + // (B, beam, T2) + auto scores_shape = rescore_outputs[1].GetTensorTypeAndShapeInfo().GetShape(); + attention_scores->resize(scores_shape[0]); + if (is_fp16_) { + Timer timer; + int length = scores_shape[0] * scores_shape[1]; + auto outs = rescore_outputs[1].GetTensorMutableData(); + for (size_t i = 0; i < scores_shape[0]; ++i) { + (*attention_scores)[i].resize(scores_shape[1]); + for (size_t j = 0; j < scores_shape[1]; ++j) { + (*attention_scores)[i][j] = _cvtsh_ss( + outs[i * scores_shape[1] + j].value); + } + } + VLOG(1) << "decoder_out from fp16 to float takes " + << timer.Elapsed() << " ms. data length " << length; + } else { + auto outs = rescore_outputs[0].GetTensorMutableData(); + for (size_t i = 0; i < scores_shape[0]; ++i) { + (*attention_scores)[i].resize(scores_shape[1]); + memcpy((*attention_scores)[i].data(), outs + i * scores_shape[1], + sizeof(float) * scores_shape[1]); + } + } +} + +} // namespace wenet diff --git a/runtime/core/decoder/batch_onnx_asr_model.h b/runtime/core/decoder/batch_onnx_asr_model.h new file mode 100644 index 0000000000..67e61d0f3c --- /dev/null +++ b/runtime/core/decoder/batch_onnx_asr_model.h @@ -0,0 +1,86 @@ +// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang, Di Wu) +// 2022 ZeXuan Li (lizexuan@huya.com) +// Xingchen Song(sxc19@mails.tsinghua.edu.cn) +// hamddct@gmail.com (Mddct) +// 2022 SoundDataConverge Co.LTD (Weiliang Chong) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#ifndef DECODER_BATCH_ONNX_ASR_MODEL_H_ +#define DECODER_BATCH_ONNX_ASR_MODEL_H_ + +#include +#include +#include + +#include "onnxruntime_cxx_api.h" // NOLINT + +#include "decoder/batch_asr_model.h" +#include "utils/log.h" +#include "utils/utils.h" +#include "onnxruntime_run_options_config_keys.h" // NOLINT +#include "onnxruntime_session_options_config_keys.h" // NOLINT + +namespace wenet { + +class BatchOnnxAsrModel : public BatchAsrModel { + public: + // Note: Do not call the InitEngineThreads function more than once. + static void InitEngineThreads(int num_threads = 1); + + public: + BatchOnnxAsrModel() = default; + BatchOnnxAsrModel(const BatchOnnxAsrModel& other); + void Read(const std::string& model_dir, bool is_fp16 = false, int gpu_id = 0); + void AttentionRescoring( + const std::vector>>& batch_hyps, + const std::vector>& ctc_scores, + std::vector>* attention_scores) override; + std::shared_ptr Copy() const override; + + void GetInputOutputInfo(const std::shared_ptr& session, + std::vector* in_names, + std::vector* out_names); + void ForwardEncoder( + const batch_feature_t& batch_feats, + const std::vector& batch_feats_lens, + std::vector>>* batch_topk_scores, + std::vector>>* batch_topk_indexs) override; // NOLINT + + private: + int encoder_output_size_ = 0; + bool is_fp16_ = false; + + // sessions + // NOTE(Mddct): The Env holds the logging state used by all other objects. + // One Env must be created before using any other Onnxruntime functionality. + static Ort::Env env_; // shared environment across threads. + static Ort::SessionOptions session_options_; + static Ort::RunOptions run_option_; + std::shared_ptr encoder_session_ = nullptr; + std::shared_ptr rescore_session_ = nullptr; + + // node names + static std::vector node_names_; + std::vector encoder_in_names_, encoder_out_names_; + std::vector rescore_in_names_, rescore_out_names_; + + // cache encoder outs: [encoder_outs, encoder_outs_lens] + Ort::Value encoder_outs_{nullptr}; + Ort::Value encoder_outs_lens_{nullptr}; +}; + +} // namespace wenet + +#endif // DECODER_BATCH_ONNX_ASR_MODEL_H_ diff --git a/runtime/core/decoder/batch_torch_asr_model.cc b/runtime/core/decoder/batch_torch_asr_model.cc new file mode 100644 index 0000000000..18a00dbf8e --- /dev/null +++ b/runtime/core/decoder/batch_torch_asr_model.cc @@ -0,0 +1,286 @@ +// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang, Di Wu) +// 2022 Binbin Zhang (binbzha@qq.com) +// 2022 SoundDataConverge Co.LTD (Weiliang Chong) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#include "decoder/batch_torch_asr_model.h" + +#ifdef USE_GPU +#include +#endif +#include +#include +#include +#include + +#include "torch/script.h" +#include "torch/torch.h" + +namespace wenet { + +void BatchTorchAsrModel::InitEngineThreads(int num_threads) { + VLOG(1) << "Num intra-op default threads: " << at::get_num_threads(); + // For multi-thread performance + at::set_num_threads(num_threads); + // Note: Do not call the set_num_interop_threads function more than once. + // Please see https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/ + // ParallelThreadPoolNative.cpp#L54-L56 + at::set_num_interop_threads(1); + VLOG(1) << "Num intra-op threads: " << at::get_num_threads(); + VLOG(1) << "Num inter-op threads: " << at::get_num_interop_threads(); +} + +void BatchTorchAsrModel::Read(const std::string& model_path) { +#ifdef USE_GPU + if (!torch::cuda::is_available()) { + VLOG(1) << "CUDA is not available! Please check your GPU settings"; + throw std::runtime_error("CUDA is not available!"); + } else { + VLOG(1) << "CUDA is available! Running on GPU"; + device_ = at::kCUDA; + } +#endif + torch::jit::script::Module model = torch::jit::load(model_path, device_); + model_ = std::make_shared(std::move(model)); + torch::NoGradGuard no_grad; + model_->eval(); + torch::jit::IValue o1 = model_->run_method("subsampling_rate"); + CHECK_EQ(o1.isInt(), true); + subsampling_rate_ = o1.toInt(); + torch::jit::IValue o2 = model_->run_method("right_context"); + CHECK_EQ(o2.isInt(), true); + torch::jit::IValue o3 = model_->run_method("sos_symbol"); + CHECK_EQ(o3.isInt(), true); + sos_ = o3.toInt(); + torch::jit::IValue o4 = model_->run_method("eos_symbol"); + CHECK_EQ(o4.isInt(), true); + eos_ = o4.toInt(); + torch::jit::IValue o5 = model_->run_method("is_bidirectional_decoder"); + CHECK_EQ(o5.isBool(), true); + is_bidirectional_decoder_ = o5.toBool(); + + VLOG(1) << "Torch Model Info:"; + VLOG(1) << "\tsubsampling_rate " << subsampling_rate_; + VLOG(1) << "\tsos " << sos_; + VLOG(1) << "\teos " << eos_; + VLOG(1) << "\tis bidirectional decoder " << is_bidirectional_decoder_; +} + +BatchTorchAsrModel::BatchTorchAsrModel(const BatchTorchAsrModel& other) { + // 1. Init the model info + subsampling_rate_ = other.subsampling_rate_; + sos_ = other.sos_; + eos_ = other.eos_; + is_bidirectional_decoder_ = other.is_bidirectional_decoder_; + // 2. Model copy, just copy the model ptr since: + // PyTorch allows using multiple CPU threads during TorchScript model + // inference, please see https://pytorch.org/docs/stable/notes/cpu_ + // threading_torchscript_inference.html + model_ = other.model_; + device_ = other.device_; +} + +std::shared_ptr BatchTorchAsrModel::Copy() const { + auto asr_model = std::make_shared(*this); + return asr_model; +} + +void BatchTorchAsrModel::ForwardEncoder( + const std::vector& batch_feats, + const std::vector& batch_feats_lens, + std::vector>>* batch_topk_scores, + std::vector>>* batch_topk_indexs) { + // 1. Prepare libtorch required data + int batch_size = batch_feats_lens.size(); + torch::Tensor feats_lens = + torch::from_blob(const_cast(batch_feats_lens.data()), + {batch_size}, torch::kInt).clone(); + // Note: math.log(1e-10) is -23.025850929940457 + auto feats = torch::nn::utils::rnn::pad_sequence(batch_feats, true, + -23.025850929940457f); + + // 2. Encoder batch forward + feats = feats.to(device_); + feats_lens = feats_lens.to(device_); + torch::NoGradGuard no_grad; + std::vector inputs = {feats, feats_lens}; + + auto outputs = + model_->get_method("batch_forward_encoder")(inputs).toTuple()->elements(); + VLOG(1) << "batch_forward_encoder done"; + CHECK_EQ(outputs.size(), 5); + encoder_out_ = outputs[0].toTensor(); // (B, Tmax, dim) + encoder_lens_ = outputs[1].toTensor(); // (B,) + + // Copy topk_scores + auto topk_scores = outputs[3].toTensor().to(at::kCPU); + int num_outputs = topk_scores.size(1); + int output_dim = topk_scores.size(2); + batch_topk_scores->resize(batch_size); + for (size_t i = 0; i < batch_size; i++) { + (*batch_topk_scores)[i].resize(num_outputs); + for (size_t j = 0; j < num_outputs; j++) { + (*batch_topk_scores)[i][j].resize(output_dim); + memcpy((*batch_topk_scores)[i][j].data(), topk_scores[i][j].data_ptr(), + sizeof(float) * output_dim); + } + } + // copy topk_indexes + auto topk_indexes = outputs[4].toTensor().to(at::kCPU); + batch_topk_indexs->resize(batch_size); + for (size_t i = 0; i < batch_size; ++i) { + (*batch_topk_indexs)[i].resize(num_outputs); + for (size_t j = 0; j < num_outputs; ++j) { + (*batch_topk_indexs)[i][j].resize(output_dim); + memcpy((*batch_topk_indexs)[i][j].data(), topk_indexes[i][j].data_ptr(), + sizeof(int) * output_dim); + } + } +} + +void BatchTorchAsrModel::ForwardEncoder( + const batch_feature_t& batch_feats, + const std::vector& batch_feats_lens, + std::vector>>* batch_topk_scores, + std::vector>>* batch_topk_indexs) { + // 1. Prepare libtorch required data + int batch_size = batch_feats.size(); + int num_frames = batch_feats[0].size(); + const int feature_dim = batch_feats[0][0].size(); + Timer timer; + torch::Tensor feats = + torch::zeros({batch_size, num_frames, feature_dim}, torch::kFloat); + for (size_t i = 0; i < batch_size; ++i) { + for (size_t j = 0; j < num_frames; ++j) { + torch::Tensor row = + torch::from_blob(const_cast(batch_feats[i][j].data()), + {feature_dim}, torch::kFloat).clone(); + feats[i][j] = std::move(row); + } + } + VLOG(1) << "feature to Tensor takes " << timer.Elapsed() << " ms."; + torch::Tensor feats_lens = + torch::from_blob(const_cast(batch_feats_lens.data()), + {batch_size}, torch::kInt).clone(); + + // 2. Encoder batch forward + feats = feats.to(device_); + feats_lens = feats_lens.to(device_); + torch::NoGradGuard no_grad; + std::vector inputs = {feats, feats_lens}; + + auto outputs = + model_->get_method("batch_forward_encoder")(inputs).toTuple()->elements(); + CHECK_EQ(outputs.size(), 5); + encoder_out_ = outputs[0].toTensor(); // (B, Tmax, dim) + encoder_lens_ = outputs[1].toTensor(); // (B,) + + // Copy topk_scores + auto topk_scores = outputs[3].toTensor().to(at::kCPU); + int num_outputs = topk_scores.size(1); + int output_dim = topk_scores.size(2); + batch_topk_scores->resize(batch_size); + for (size_t i = 0; i < batch_size; i++) { + (*batch_topk_scores)[i].resize(num_outputs); + for (size_t j = 0; j < num_outputs; j++) { + (*batch_topk_scores)[i][j].resize(output_dim); + memcpy((*batch_topk_scores)[i][j].data(), topk_scores[i][j].data_ptr(), + sizeof(float) * output_dim); + } + } + // copy topk_indexes + auto topk_indexes = outputs[4].toTensor().to(at::kCPU); + batch_topk_indexs->resize(batch_size); + for (size_t i = 0; i < batch_size; ++i) { + (*batch_topk_indexs)[i].resize(num_outputs); + for (size_t j = 0; j < num_outputs; ++j) { + (*batch_topk_indexs)[i][j].resize(output_dim); + memcpy((*batch_topk_indexs)[i][j].data(), topk_indexes[i][j].data_ptr(), + sizeof(int) * output_dim); + } + } +} + +void BatchTorchAsrModel::AttentionRescoring( + const std::vector>>& batch_hyps, + const std::vector>& ctc_scores, + std::vector>* attention_scores) { + // Step 1: Prepare input for libtorch + int batch_size = batch_hyps.size(); + int beam_size = batch_hyps[0].size(); + torch::Tensor hyps_lens_sos = torch::zeros( + {batch_size, beam_size}, torch::kLong); + int max_hyps_len = 0; + for (size_t i = 0; i < batch_size; i++) { + for (size_t j = 0; j < beam_size; j++) { + int length = batch_hyps[i][j].size() + 1; + max_hyps_len = std::max(length, max_hyps_len); + hyps_lens_sos[i][j] = static_cast(length); + } + } + + // 1.2 add sos, eos to hyps, r_hyps + torch::Tensor hyps_pad_sos_eos = torch::zeros( + {batch_size, beam_size, max_hyps_len + 1}, torch::kLong); + torch::Tensor r_hyps_pad_sos_eos = torch::zeros( + {batch_size, beam_size, max_hyps_len + 1}, torch::kLong); + for (size_t i = 0; i < batch_size; i++) { + for (size_t j = 0; j < beam_size; j++) { + const std::vector& hyp = batch_hyps[i][j]; + hyps_pad_sos_eos[i][j][0] = sos_; + r_hyps_pad_sos_eos[i][j][0] = sos_; + size_t hyps_len = hyp.size(); + for (size_t k = 0; k < hyps_len; k++) { + hyps_pad_sos_eos[i][j][k + 1] = hyp[k]; + r_hyps_pad_sos_eos[i][j][k + 1] = hyp[hyps_len - 1 - k]; + } + } + } + + // 1.3 ctc_scores_data + torch::Tensor ctc_scores_tensor = torch::zeros( + {batch_size, beam_size}, torch::kFloat); + for (size_t i = 0; i < batch_size; ++i) { + auto row = torch::from_blob(const_cast(ctc_scores[i].data()), + {beam_size}, torch::kFloat).clone(); + ctc_scores_tensor[i] = std::move(row); + } + + // Step 2: Forward attention decoder + hyps_pad_sos_eos = hyps_pad_sos_eos.to(device_); + hyps_lens_sos = hyps_lens_sos.to(device_); + r_hyps_pad_sos_eos = r_hyps_pad_sos_eos.to(device_); + ctc_scores_tensor = ctc_scores_tensor.to(device_); + // encoder_lens_ = encoder_lens_.to(device_); + // encoder_out_ = encoder_out_.to(device_); + torch::NoGradGuard no_grad; + auto outputs = model_->run_method( + "batch_forward_attention_decoder", + encoder_out_, encoder_lens_, + hyps_pad_sos_eos, hyps_lens_sos, + r_hyps_pad_sos_eos, ctc_scores_tensor).toTuple()->elements(); + auto rescores = outputs[1].toTensor().to(at::kCPU); +#ifdef USE_GPU + c10::cuda::CUDACachingAllocator::emptyCache(); +#endif + attention_scores->resize(batch_size); + for (size_t i = 0; i < batch_size; i++) { + (*attention_scores)[i].resize(beam_size); + memcpy((*attention_scores)[i].data(), rescores[i].data_ptr(), + sizeof(float) * beam_size); + } +} + +} // namespace wenet diff --git a/runtime/core/decoder/batch_torch_asr_model.h b/runtime/core/decoder/batch_torch_asr_model.h new file mode 100644 index 0000000000..c1b4a6d746 --- /dev/null +++ b/runtime/core/decoder/batch_torch_asr_model.h @@ -0,0 +1,69 @@ +// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang, Di Wu) +// 2022 Binbin Zhang (binbzha@qq.com) +// 2022 SoundDataConverge Co.LTD (Weiliang Chong) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#ifndef DECODER_BATCH_TORCH_ASR_MODEL_H_ +#define DECODER_BATCH_TORCH_ASR_MODEL_H_ + +#include +#include +#include + +#include "torch/script.h" +#include "torch/torch.h" + +#include "decoder/batch_asr_model.h" +#include "utils/utils.h" + +namespace wenet { + +class BatchTorchAsrModel : public BatchAsrModel { + public: + // Note: Do not call the InitEngineThreads function more than once. + static void InitEngineThreads(int num_threads = 1); + + public: + using TorchModule = torch::jit::script::Module; + BatchTorchAsrModel() = default; + BatchTorchAsrModel(const BatchTorchAsrModel& other); + void Read(const std::string& model_path); + void AttentionRescoring( + const std::vector>>& batch_hyps, + const std::vector>& ctc_scores, + std::vector>* attention_scores) override; + std::shared_ptr Copy() const override; + + void ForwardEncoder( + const batch_feature_t& batch_feats, + const std::vector& batch_feats_lens, + std::vector>>* batch_topk_scores, + std::vector>>* batch_topk_indexs) override; // NOLINT + void ForwardEncoder( + const std::vector& batch_feats, + const std::vector& batch_feats_lens, + std::vector>>* batch_topk_scores, + std::vector>>* batch_topk_indexs) override; // NOLINT + + private: + std::shared_ptr model_ = nullptr; + torch::Tensor encoder_out_; + torch::Tensor encoder_lens_; + torch::DeviceType device_; +}; + +} // namespace wenet + +#endif // DECODER_BATCH_TORCH_ASR_MODEL_H_ diff --git a/runtime/core/decoder/ctc_prefix_beam_search.cc b/runtime/core/decoder/ctc_prefix_beam_search.cc index 154c8864ba..8b9c4cacb4 100644 --- a/runtime/core/decoder/ctc_prefix_beam_search.cc +++ b/runtime/core/decoder/ctc_prefix_beam_search.cc @@ -209,6 +209,112 @@ void CtcPrefixBeamSearch::Search(const std::vector>& logp) { } } +void CtcPrefixBeamSearch::Search( + const std::vector>& topk_scores, + const std::vector>& topk_indexs) { + if (topk_scores.size() == 0) return; + int first_beam_size = + std::min(static_cast(topk_scores[0].size()), opts_.first_beam_size); + for (int t = 0; t < topk_scores.size(); ++t, ++abs_time_step_) { + std::unordered_map, PrefixScore, PrefixHash> next_hyps; + // 1. First beam prune, only select topk candidates + auto& topk_score = topk_scores[t]; + auto& topk_index = topk_indexs[t]; + + // 2. Token passing + for (int i = 0; i < topk_index.size(); ++i) { + int id = topk_index[i]; + auto prob = topk_score[i]; + for (const auto& it : cur_hyps_) { + const std::vector& prefix = it.first; + const PrefixScore& prefix_score = it.second; + // If prefix doesn't exist in next_hyps, next_hyps[prefix] will insert + // PrefixScore(-inf, -inf) by default, since the default constructor + // of PrefixScore will set fields s(blank ending score) and + // ns(none blank ending score) to -inf, respectively. + if (id == opts_.blank) { + // Case 0: *a + ε => *a + PrefixScore& next_score = next_hyps[prefix]; + next_score.s = LogAdd(next_score.s, prefix_score.score() + prob); + next_score.v_s = prefix_score.viterbi_score() + prob; + next_score.times_s = prefix_score.times(); + // Prefix not changed, copy the context from prefix. + if (context_graph_ && !next_score.has_context) { + next_score.CopyContext(prefix_score); + next_score.has_context = true; + } + } else if (!prefix.empty() && id == prefix.back()) { + // Case 1: *a + a => *a + PrefixScore& next_score1 = next_hyps[prefix]; + next_score1.ns = LogAdd(next_score1.ns, prefix_score.ns + prob); + if (next_score1.v_ns < prefix_score.v_ns + prob) { + next_score1.v_ns = prefix_score.v_ns + prob; + if (next_score1.cur_token_prob < prob) { + next_score1.cur_token_prob = prob; + next_score1.times_ns = prefix_score.times_ns; + CHECK_GT(next_score1.times_ns.size(), 0); + next_score1.times_ns.back() = abs_time_step_; + } + } + if (context_graph_ && !next_score1.has_context) { + next_score1.CopyContext(prefix_score); + next_score1.has_context = true; + } + + // Case 2: *aε + a => *aa + std::vector new_prefix(prefix); + new_prefix.emplace_back(id); + PrefixScore& next_score2 = next_hyps[new_prefix]; + next_score2.ns = LogAdd(next_score2.ns, prefix_score.s + prob); + if (next_score2.v_ns < prefix_score.v_s + prob) { + next_score2.v_ns = prefix_score.v_s + prob; + next_score2.cur_token_prob = prob; + next_score2.times_ns = prefix_score.times_s; + next_score2.times_ns.emplace_back(abs_time_step_); + } + if (context_graph_ && !next_score2.has_context) { + // Prefix changed, calculate the context score. + next_score2.UpdateContext(context_graph_, prefix_score, id, + prefix.size()); + next_score2.has_context = true; + } + } else { + // Case 3: *a + b => *ab, *aε + b => *ab + std::vector new_prefix(prefix); + new_prefix.emplace_back(id); + PrefixScore& next_score = next_hyps[new_prefix]; + next_score.ns = LogAdd(next_score.ns, prefix_score.score() + prob); + if (next_score.v_ns < prefix_score.viterbi_score() + prob) { + next_score.v_ns = prefix_score.viterbi_score() + prob; + next_score.cur_token_prob = prob; + next_score.times_ns = prefix_score.times(); + next_score.times_ns.emplace_back(abs_time_step_); + } + if (context_graph_ && !next_score.has_context) { + // Calculate the context score. + next_score.UpdateContext(context_graph_, prefix_score, id, + prefix.size()); + next_score.has_context = true; + } + } + } + } + + // 3. Second beam prune, only keep top n best paths + std::vector, PrefixScore>> arr(next_hyps.begin(), + next_hyps.end()); + int second_beam_size = + std::min(static_cast(arr.size()), opts_.second_beam_size); + std::nth_element(arr.begin(), arr.begin() + second_beam_size, arr.end(), + PrefixScoreCompare); + arr.resize(second_beam_size); + std::sort(arr.begin(), arr.end(), PrefixScoreCompare); + + // 4. Update cur_hyps_ and get new result + UpdateHypotheses(arr); + } +} + void CtcPrefixBeamSearch::FinalizeSearch() { UpdateFinalContext(); } void CtcPrefixBeamSearch::UpdateFinalContext() { diff --git a/runtime/core/decoder/ctc_prefix_beam_search.h b/runtime/core/decoder/ctc_prefix_beam_search.h index f44ec23c37..743752f97b 100644 --- a/runtime/core/decoder/ctc_prefix_beam_search.h +++ b/runtime/core/decoder/ctc_prefix_beam_search.h @@ -99,6 +99,8 @@ class CtcPrefixBeamSearch : public SearchInterface { const std::shared_ptr& context_graph = nullptr); void Search(const std::vector>& logp) override; + void Search(const std::vector>& topk_scores, + const std::vector>& topk_indexs) override; void Reset() override; void FinalizeSearch() override; SearchType Type() const override { return SearchType::kPrefixBeamSearch; } diff --git a/runtime/core/decoder/ctc_wfst_beam_search.h b/runtime/core/decoder/ctc_wfst_beam_search.h index 56967743d4..0215ed1eba 100644 --- a/runtime/core/decoder/ctc_wfst_beam_search.h +++ b/runtime/core/decoder/ctc_wfst_beam_search.h @@ -63,6 +63,8 @@ class CtcWfstBeamSearch : public SearchInterface { const fst::Fst& fst, const CtcWfstBeamSearchOptions& opts, const std::shared_ptr& context_graph); void Search(const std::vector>& logp) override; + void Search(const std::vector>& topk_scores, + const std::vector>& topk_indexs) override {}; void Reset() override; void FinalizeSearch() override; SearchType Type() const override { return SearchType::kWfstBeamSearch; } diff --git a/runtime/core/decoder/onnx_asr_model.cc b/runtime/core/decoder/onnx_asr_model.cc index 3097e2020c..7ce07e3309 100644 --- a/runtime/core/decoder/onnx_asr_model.cc +++ b/runtime/core/decoder/onnx_asr_model.cc @@ -42,7 +42,7 @@ void OnnxAsrModel::GetInputOutputInfo( int num_nodes = session->GetInputCount(); in_names->resize(num_nodes); for (int i = 0; i < num_nodes; ++i) { - char* name = session->GetInputName(i, allocator); + auto name = session->GetInputNameAllocated(i, allocator); Ort::TypeInfo type_info = session->GetInputTypeInfo(i); auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); ONNXTensorElementDataType type = tensor_info.GetElementType(); @@ -52,15 +52,16 @@ void OnnxAsrModel::GetInputOutputInfo( shape << j; shape << " "; } - LOG(INFO) << "\tInput " << i << " : name=" << name << " type=" << type + LOG(INFO) << "\tInput " << i << " : name=" << name.get() << " type=" << type << " dims=" << shape.str(); - (*in_names)[i] = name; + node_names_.push_back(std::move(name)); + (*in_names)[i] = node_names_.back().get(); } // Output info num_nodes = session->GetOutputCount(); out_names->resize(num_nodes); for (int i = 0; i < num_nodes; ++i) { - char* name = session->GetOutputName(i, allocator); + auto name = session->GetOutputNameAllocated(i, allocator); Ort::TypeInfo type_info = session->GetOutputTypeInfo(i); auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); ONNXTensorElementDataType type = tensor_info.GetElementType(); @@ -70,9 +71,10 @@ void OnnxAsrModel::GetInputOutputInfo( shape << j; shape << " "; } - LOG(INFO) << "\tOutput " << i << " : name=" << name << " type=" << type - << " dims=" << shape.str(); - (*out_names)[i] = name; + LOG(INFO) << "\tOutput " << i << " : name=" << name.get() + << " type=" << type << " dims=" << shape.str(); + node_names_.push_back(std::move(name)); + (*out_names)[i] = node_names_.back().get(); } } @@ -107,25 +109,39 @@ void OnnxAsrModel::Read(const std::string& model_dir) { auto model_metadata = encoder_session_->GetModelMetadata(); Ort::AllocatorWithDefaultOptions allocator; - encoder_output_size_ = - atoi(model_metadata.LookupCustomMetadataMap("output_size", allocator)); - num_blocks_ = - atoi(model_metadata.LookupCustomMetadataMap("num_blocks", allocator)); - head_ = atoi(model_metadata.LookupCustomMetadataMap("head", allocator)); + encoder_output_size_ = atoi( + model_metadata.LookupCustomMetadataMapAllocated( + "output_size", allocator).get()); + num_blocks_ = atoi( + model_metadata.LookupCustomMetadataMapAllocated( + "num_blocks", allocator).get()); + head_ = atoi( + model_metadata.LookupCustomMetadataMapAllocated( + "head", allocator).get()); cnn_module_kernel_ = atoi( - model_metadata.LookupCustomMetadataMap("cnn_module_kernel", allocator)); + model_metadata.LookupCustomMetadataMapAllocated( + "cnn_module_kernel", allocator).get()); subsampling_rate_ = atoi( - model_metadata.LookupCustomMetadataMap("subsampling_rate", allocator)); - right_context_ = - atoi(model_metadata.LookupCustomMetadataMap("right_context", allocator)); - sos_ = atoi(model_metadata.LookupCustomMetadataMap("sos_symbol", allocator)); - eos_ = atoi(model_metadata.LookupCustomMetadataMap("eos_symbol", allocator)); - is_bidirectional_decoder_ = atoi(model_metadata.LookupCustomMetadataMap( - "is_bidirectional_decoder", allocator)); - chunk_size_ = - atoi(model_metadata.LookupCustomMetadataMap("chunk_size", allocator)); - num_left_chunks_ = - atoi(model_metadata.LookupCustomMetadataMap("left_chunks", allocator)); + model_metadata.LookupCustomMetadataMapAllocated( + "subsampling_rate", allocator).get()); + right_context_ = atoi( + model_metadata.LookupCustomMetadataMapAllocated( + "right_context", allocator).get()); + sos_ = atoi( + model_metadata.LookupCustomMetadataMapAllocated( + "sos_symbol", allocator).get()); + eos_ = atoi( + model_metadata.LookupCustomMetadataMapAllocated( + "eos_symbol", allocator).get()); + is_bidirectional_decoder_ = atoi( + model_metadata.LookupCustomMetadataMapAllocated( + "is_bidirectional_decoder", allocator).get()); + chunk_size_ = atoi( + model_metadata.LookupCustomMetadataMapAllocated( + "chunk_size", allocator).get()); + num_left_chunks_ = atoi( + model_metadata.LookupCustomMetadataMapAllocated( + "left_chunks", allocator).get()); LOG(INFO) << "Onnx Model Info:"; LOG(INFO) << "\tencoder_output_size " << encoder_output_size_; diff --git a/runtime/core/decoder/onnx_asr_model.h b/runtime/core/decoder/onnx_asr_model.h index 906bd0d68c..f904df5baf 100644 --- a/runtime/core/decoder/onnx_asr_model.h +++ b/runtime/core/decoder/onnx_asr_model.h @@ -72,6 +72,7 @@ class OnnxAsrModel : public AsrModel { std::shared_ptr ctc_session_ = nullptr; // node names + std::vector node_names_; std::vector encoder_in_names_, encoder_out_names_; std::vector ctc_in_names_, ctc_out_names_; std::vector rescore_in_names_, rescore_out_names_; diff --git a/runtime/core/decoder/params.h b/runtime/core/decoder/params.h index 5831d28f82..99c3b53bfb 100644 --- a/runtime/core/decoder/params.h +++ b/runtime/core/decoder/params.h @@ -22,11 +22,14 @@ #include #include "decoder/asr_decoder.h" +#include "decoder/batch_asr_decoder.h" #ifdef USE_ONNX #include "decoder/onnx_asr_model.h" +#include "decoder/batch_onnx_asr_model.h" #endif #ifdef USE_TORCH #include "decoder/torch_asr_model.h" +#include "decoder/batch_torch_asr_model.h" #endif #ifdef USE_XPU #include "xpu/xpu_asr_model.h" @@ -94,6 +97,9 @@ DEFINE_int32(language_type, 0, "0x00 = kMandarinEnglish, " "0x01 = kIndoEuropean"); DEFINE_bool(lowercase, true, "lowercase final result if needed"); +DEFINE_bool(run_batch, false, "run websocket server for batch decoding"); +DEFINE_bool(is_fp16, false, "the model is of fp16"); +DEFINE_int32(gpu_id, 0, "which GPU to use"); namespace wenet { std::shared_ptr InitFeaturePipelineConfigFromFlags() { @@ -128,21 +134,39 @@ std::shared_ptr InitDecodeResourceFromFlags() { const int kNumGemmThreads = 1; if (!FLAGS_onnx_dir.empty()) { #ifdef USE_ONNX - LOG(INFO) << "Reading onnx model "; - OnnxAsrModel::InitEngineThreads(kNumGemmThreads); - auto model = std::make_shared(); - model->Read(FLAGS_onnx_dir); - resource->model = model; + if (FLAGS_run_batch) { + LOG(INFO) << "BatchOnnxAsrModel Reading ONNX model dir: " + << FLAGS_onnx_dir; + BatchOnnxAsrModel::InitEngineThreads(kNumGemmThreads); + auto model = std::make_shared(); + model->Read(FLAGS_onnx_dir, FLAGS_is_fp16, FLAGS_gpu_id); + resource->batch_model = model; + } else { + LOG(INFO) << "Reading onnx model "; + OnnxAsrModel::InitEngineThreads(kNumGemmThreads); + auto model = std::make_shared(); + model->Read(FLAGS_onnx_dir); + resource->model = model; + } #else LOG(FATAL) << "Please rebuild with cmake options '-DONNX=ON'."; #endif } else if (!FLAGS_model_path.empty()) { #ifdef USE_TORCH - LOG(INFO) << "Reading torch model " << FLAGS_model_path; - TorchAsrModel::InitEngineThreads(kNumGemmThreads); - auto model = std::make_shared(); - model->Read(FLAGS_model_path); - resource->model = model; + if (FLAGS_run_batch) { + LOG(INFO) << "BatchTorchAsrModel Reading torch model " + << FLAGS_model_path; + BatchTorchAsrModel::InitEngineThreads(kNumGemmThreads); + auto model = std::make_shared(); + model->Read(FLAGS_model_path); + resource->batch_model = model; + } else { + LOG(INFO) << "Reading torch model " << FLAGS_model_path; + TorchAsrModel::InitEngineThreads(kNumGemmThreads); + auto model = std::make_shared(); + model->Read(FLAGS_model_path); + resource->model = model; + } #else LOG(FATAL) << "Please rebuild with cmake options '-DTORCH=ON'."; #endif diff --git a/runtime/core/decoder/search_interface.h b/runtime/core/decoder/search_interface.h index 25bad26705..72722fb637 100644 --- a/runtime/core/decoder/search_interface.h +++ b/runtime/core/decoder/search_interface.h @@ -29,6 +29,8 @@ class SearchInterface { public: virtual ~SearchInterface() {} virtual void Search(const std::vector>& logp) = 0; + virtual void Search(const std::vector>& topk_scores, + const std::vector>& topk_indexs) = 0; virtual void Reset() = 0; virtual void FinalizeSearch() = 0; diff --git a/runtime/core/frontend/fbank_cuda.h b/runtime/core/frontend/fbank_cuda.h new file mode 100644 index 0000000000..3b7df06f71 --- /dev/null +++ b/runtime/core/frontend/fbank_cuda.h @@ -0,0 +1,62 @@ + +#ifndef FRONTEND_FBANK_CUDA_H_ +#define FRONTEND_FBANK_CUDA_H_ + +#include "kaldifeat/csrc/feature-fbank.h" + +namespace wenet { + +class FbankCuda { + public: + FbankCuda(int num_bins, int sample_rate) { + fbank_opts_.mel_opts.num_bins = num_bins; + fbank_opts_.frame_opts.samp_freq = sample_rate; + fbank_opts_.frame_opts.dither = 0; + fbank_opts_.frame_opts.frame_shift_ms = 10.0; + fbank_opts_.frame_opts.frame_length_ms = 25.0; + fbank_opts_.device = torch::Device(torch::kCUDA, 0); + fbank_ = std::make_shared(fbank_opts_); + device_ = torch::kCUDA; + } + + torch::Tensor Compute(torch::Tensor wave_data) { + return fbank_->ComputeFeatures(wave_data, 1.0f); + } + + std::vector Compute( + const std::vector> &wave_data, + std::vector *num_frames) { + const auto &frame_opts = fbank_->GetOptions().frame_opts; + std::vector num_frames_vec; + num_frames_vec.reserve(wave_data.size()); + + std::vector strided_vec; + strided_vec.reserve(wave_data.size()); + + for (const auto &w : wave_data) { + torch::Tensor t = torch::from_blob( + const_cast(w.data()), + {static_cast(w.size())}, torch::kFloat).to(device_); + // t = t / 32768.0; + torch::Tensor strided = kaldifeat::GetStrided(t, frame_opts); + num_frames_vec.push_back(strided.size(0)); + num_frames->push_back(strided.size(0)); + strided_vec.emplace_back(std::move(strided)); + } + + torch::Tensor strided = torch::cat(strided_vec, 0); + torch::Tensor features = fbank_->ComputeFeatures(strided, /*vtln_warp*/ 1.0f); + auto ans = features.split_with_sizes(num_frames_vec, /*dim*/ 0); + return ans; + } + + private: + kaldifeat::FbankOptions fbank_opts_; + std::shared_ptr fbank_; + torch::DeviceType device_; + +}; + +} // namespace wenet + +#endif // FRONTEND_FBANK_CUDA_H_ diff --git a/runtime/core/utils/CMakeLists.txt b/runtime/core/utils/CMakeLists.txt index 686362688c..1394fd50ca 100644 --- a/runtime/core/utils/CMakeLists.txt +++ b/runtime/core/utils/CMakeLists.txt @@ -1,6 +1,7 @@ add_library(utils STATIC string.cc utils.cc + Yaml.cpp ) if(NOT ANDROID) @@ -9,4 +10,4 @@ if(NOT ANDROID) else() target_link_libraries(utils PUBLIC fst dl) endif() -endif() \ No newline at end of file +endif() diff --git a/runtime/core/utils/Yaml.cpp b/runtime/core/utils/Yaml.cpp new file mode 100644 index 0000000000..77ed14ca54 --- /dev/null +++ b/runtime/core/utils/Yaml.cpp @@ -0,0 +1,2215 @@ +// Copyright (c) From https://github.com/jimmiebergmann/mini-yaml +// 2022 SoundDataConverge Co.LTD (Weiliang Chong) +/* + * MIT License + * + * Copyright(c) 2018 Jimmie Bergmann + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files(the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions : + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + */ + +#include "Yaml.hpp" // NOLINT + +#include + +#include +#include +#include +#include +#include +#include + +// Implementation access definitions. +#define NODE_IMP static_cast(m_pImp) +#define NODE_IMP_EXT(node) static_cast(node.m_pImp) +#define TYPE_IMP static_cast(m_pImp)->m_pImp + +#define IT_IMP static_cast(m_pImp) + +namespace Yaml { +class ReaderLine; + +// Exception message definitions. +static const std::string g_ErrorInvalidCharacter = "Invalid character found."; // NOLINT +static const std::string g_ErrorKeyMissing = "Missing key."; // NOLINT +static const std::string g_ErrorKeyIncorrect = "Incorrect key."; // NOLINT +static const std::string g_ErrorValueIncorrect = "Incorrect value."; // NOLINT +static const std::string g_ErrorTabInOffset = "Tab found in offset."; // NOLINT +static const std::string g_ErrorBlockSequenceNotAllowed = // NOLINT + "Sequence entries are not allowed in this context."; // NOLINT +static const std::string g_ErrorUnexpectedDocumentEnd = // NOLINT + "Unexpected document end."; // NOLINT +static const std::string g_ErrorDiffEntryNotAllowed = // NOLINT + "Different entry is not allowed in this context."; // NOLINT +static const std::string g_ErrorIncorrectOffset = "Incorrect offset."; // NOLINT +static const std::string g_ErrorSequenceError = "Error in sequence node."; // NOLINT +static const std::string g_ErrorCannotOpenFile = "Cannot open file."; // NOLINT +static const std::string g_ErrorIndentation = // NOLINT + "Space indentation is less than 2."; // NOLINT +static const std::string g_ErrorInvalidBlockScalar = "Invalid block scalar."; // NOLINT +static const std::string g_ErrorInvalidQuote = "Invalid quote."; // NOLINT +static const std::string g_EmptyString = ""; // NOLINT +static Yaml::Node g_NoneNode; + +// Global function definitions. Implemented at end of this source file. +static std::string ExceptionMessage(const std::string &message, + ReaderLine &line); // NOLINT +static std::string ExceptionMessage( + const std::string &message, + ReaderLine &line, const size_t errorPos); // NOLINT +static std::string ExceptionMessage(const std::string &message, + const size_t errorLine, + const size_t errorPos); +static std::string ExceptionMessage(const std::string &message, + const size_t errorLine, + const std::string &data); + +static bool FindQuote( + const std::string &input, size_t &start, size_t &end, // NOLINT + size_t searchPos = 0); +static size_t FindNotCited(const std::string &input, char token, + size_t &preQuoteCount); // NOLINT +static size_t FindNotCited(const std::string &input, char token); +static bool ValidateQuote(const std::string &input); +static void CopyNode(const Node &from, Node &to); // NOLINT +static bool ShouldBeCited(const std::string &key); +static void AddEscapeTokens( + std::string &input, const std::string &tokens); // NOLINT +static void RemoveAllEscapeTokens(std::string &input); // NOLINT + +// Exception implementations +Exception::Exception(const std::string &message, const eType type) + : std::runtime_error(message), m_Type(type) {} + +Exception::eType Exception::Type() const { return m_Type; } + +const char *Exception::Message() const { return what(); } + +InternalException::InternalException(const std::string &message) + : Exception(message, InternalError) {} + +ParsingException::ParsingException(const std::string &message) + : Exception(message, ParsingError) {} + +OperationException::OperationException(const std::string &message) + : Exception(message, OperationError) {} + +class TypeImp { + public: + virtual ~TypeImp() {} + + virtual const std::string &GetData() const = 0; + virtual bool SetData(const std::string &data) = 0; + virtual size_t GetSize() const = 0; + virtual Node *GetNode(const size_t index) = 0; + virtual Node *GetNode(const std::string &key) = 0; + virtual Node *Insert(const size_t index) = 0; + virtual Node *PushFront() = 0; + virtual Node *PushBack() = 0; + virtual void Erase(const size_t index) = 0; + virtual void Erase(const std::string &key) = 0; +}; + +class SequenceImp : public TypeImp { + public: + ~SequenceImp() { + for (auto it = m_Sequence.begin(); it != m_Sequence.end(); it++) { + delete it->second; + } + } + + virtual const std::string &GetData() const { return g_EmptyString; } + + virtual bool SetData(const std::string &data) { return false; } + + virtual size_t GetSize() const { return m_Sequence.size(); } + + virtual Node *GetNode(const size_t index) { + auto it = m_Sequence.find(index); + if (it != m_Sequence.end()) { + return it->second; + } + return nullptr; + } + + virtual Node *GetNode(const std::string &key) { return nullptr; } + + virtual Node *Insert(const size_t index) { + if (m_Sequence.size() == 0) { + Node *pNode = new Node; + m_Sequence.insert({0, pNode}); + return pNode; + } + + if (index >= m_Sequence.size()) { + auto it = m_Sequence.end(); + --it; + Node *pNode = new Node; + m_Sequence.insert({it->first, pNode}); + return pNode; + } + + auto it = m_Sequence.cbegin(); + while (it != m_Sequence.cend()) { + m_Sequence[it->first + 1] = it->second; + + if (it->first == index) { + break; + } + } + + Node *pNode = new Node; + m_Sequence.insert({index, pNode}); + return pNode; + } + + virtual Node *PushFront() { + for (auto it = m_Sequence.cbegin(); it != m_Sequence.cend(); it++) { + m_Sequence[it->first + 1] = it->second; + } + + Node *pNode = new Node; + m_Sequence.insert({0, pNode}); + return pNode; + } + + virtual Node *PushBack() { + size_t index = 0; + if (m_Sequence.size()) { + auto it = m_Sequence.end(); + --it; + index = it->first + 1; + } + + Node *pNode = new Node; + m_Sequence.insert({index, pNode}); + return pNode; + } + + virtual void Erase(const size_t index) { + auto it = m_Sequence.find(index); + if (it == m_Sequence.end()) { + return; + } + delete it->second; + m_Sequence.erase(index); + } + + virtual void Erase(const std::string &key) {} + + std::map m_Sequence; +}; + +class MapImp : public TypeImp { + public: + ~MapImp() { + for (auto it = m_Map.begin(); it != m_Map.end(); it++) { + delete it->second; + } + } + + virtual const std::string &GetData() const { return g_EmptyString; } + + virtual bool SetData(const std::string &data) { return false; } + + virtual size_t GetSize() const { return m_Map.size(); } + + virtual Node *GetNode(const size_t index) { return nullptr; } + + virtual Node *GetNode(const std::string &key) { + auto it = m_Map.find(key); + if (it == m_Map.end()) { + Node *pNode = new Node; + m_Map.insert({key, pNode}); + return pNode; + } + return it->second; + } + + virtual Node *Insert(const size_t index) { return nullptr; } + + virtual Node *PushFront() { return nullptr; } + + virtual Node *PushBack() { return nullptr; } + + virtual void Erase(const size_t index) {} + + virtual void Erase(const std::string &key) { + auto it = m_Map.find(key); + if (it == m_Map.end()) { + return; + } + delete it->second; + m_Map.erase(key); + } + + std::map m_Map; +}; + +class ScalarImp : public TypeImp { + public: + ~ScalarImp() {} + + virtual const std::string &GetData() const { return m_Value; } + + virtual bool SetData(const std::string &data) { + m_Value = data; + return true; + } + + virtual size_t GetSize() const { return 0; } + + virtual Node *GetNode(const size_t index) { return nullptr; } + + virtual Node *GetNode(const std::string &key) { return nullptr; } + + virtual Node *Insert(const size_t index) { return nullptr; } + + virtual Node *PushFront() { return nullptr; } + + virtual Node *PushBack() { return nullptr; } + + virtual void Erase(const size_t index) {} + + virtual void Erase(const std::string &key) {} + + std::string m_Value; +}; + +// Node implementations. +class NodeImp { + public: + NodeImp() : m_Type(Node::None), m_pImp(nullptr) {} + + ~NodeImp() { Clear(); } + + void Clear() { + if (m_pImp != nullptr) { + delete m_pImp; + m_pImp = nullptr; + } + m_Type = Node::None; + } + + void InitSequence() { + if (m_Type != Node::SequenceType || m_pImp == nullptr) { + if (m_pImp) { + delete m_pImp; + } + m_pImp = new SequenceImp; + m_Type = Node::SequenceType; + } + } + + void InitMap() { + if (m_Type != Node::MapType || m_pImp == nullptr) { + if (m_pImp) { + delete m_pImp; + } + m_pImp = new MapImp; + m_Type = Node::MapType; + } + } + + void InitScalar() { + if (m_Type != Node::ScalarType || m_pImp == nullptr) { + if (m_pImp) { + delete m_pImp; + } + m_pImp = new ScalarImp; + m_Type = Node::ScalarType; + } + } + + Node::eType m_Type; ///< Type of node. + TypeImp *m_pImp; ///< Imp of type. +}; + +// Iterator implementation class +class IteratorImp { + public: + virtual ~IteratorImp() {} + + virtual Node::eType GetType() const = 0; + virtual void InitBegin(SequenceImp *pSequenceImp) = 0; + virtual void InitEnd(SequenceImp *pSequenceImp) = 0; + virtual void InitBegin(MapImp *pMapImp) = 0; + virtual void InitEnd(MapImp *pMapImp) = 0; +}; + +class SequenceIteratorImp : public IteratorImp { + public: + virtual Node::eType GetType() const { return Node::SequenceType; } + + virtual void InitBegin(SequenceImp *pSequenceImp) { + m_Iterator = pSequenceImp->m_Sequence.begin(); + } + + virtual void InitEnd(SequenceImp *pSequenceImp) { + m_Iterator = pSequenceImp->m_Sequence.end(); + } + + virtual void InitBegin(MapImp *pMapImp) {} + + virtual void InitEnd(MapImp *pMapImp) {} + + void Copy(const SequenceIteratorImp &it) { m_Iterator = it.m_Iterator; } + + std::map::iterator m_Iterator; +}; + +class MapIteratorImp : public IteratorImp { + public: + virtual Node::eType GetType() const { return Node::MapType; } + + virtual void InitBegin(SequenceImp *pSequenceImp) {} + + virtual void InitEnd(SequenceImp *pSequenceImp) {} + + virtual void InitBegin(MapImp *pMapImp) { + m_Iterator = pMapImp->m_Map.begin(); + } + + virtual void InitEnd(MapImp *pMapImp) { m_Iterator = pMapImp->m_Map.end(); } + + void Copy(const MapIteratorImp &it) { m_Iterator = it.m_Iterator; } + + std::map::iterator m_Iterator; +}; + +class SequenceConstIteratorImp : public IteratorImp { + public: + virtual Node::eType GetType() const { return Node::SequenceType; } + + virtual void InitBegin(SequenceImp *pSequenceImp) { + m_Iterator = pSequenceImp->m_Sequence.begin(); + } + + virtual void InitEnd(SequenceImp *pSequenceImp) { + m_Iterator = pSequenceImp->m_Sequence.end(); + } + + virtual void InitBegin(MapImp *pMapImp) {} + + virtual void InitEnd(MapImp *pMapImp) {} + + void Copy(const SequenceConstIteratorImp &it) { m_Iterator = it.m_Iterator; } + + std::map::const_iterator m_Iterator; +}; + +class MapConstIteratorImp : public IteratorImp { + public: + virtual Node::eType GetType() const { return Node::MapType; } + + virtual void InitBegin(SequenceImp *pSequenceImp) {} + + virtual void InitEnd(SequenceImp *pSequenceImp) {} + + virtual void InitBegin(MapImp *pMapImp) { + m_Iterator = pMapImp->m_Map.begin(); + } + + virtual void InitEnd(MapImp *pMapImp) { m_Iterator = pMapImp->m_Map.end(); } + + void Copy(const MapConstIteratorImp &it) { m_Iterator = it.m_Iterator; } + + std::map::const_iterator m_Iterator; +}; + +// Iterator class +Iterator::Iterator() : m_Type(None), m_pImp(nullptr) {} + +Iterator::~Iterator() { + if (m_pImp) { + switch (m_Type) { + case SequenceType: + delete static_cast(m_pImp); + break; + case MapType: + delete static_cast(m_pImp); + break; + default: + break; + } + } +} + +Iterator::Iterator(const Iterator &it) : m_Type(None), m_pImp(nullptr) { + *this = it; +} + +Iterator &Iterator::operator=(const Iterator &it) { + if (m_pImp) { + switch (m_Type) { + case SequenceType: + delete static_cast(m_pImp); + break; + case MapType: + delete static_cast(m_pImp); + break; + default: + break; + } + m_pImp = nullptr; + m_Type = None; + } + + IteratorImp *pNewImp = nullptr; + + switch (it.m_Type) { + case SequenceType: + m_Type = SequenceType; + pNewImp = new SequenceIteratorImp; + static_cast(pNewImp)->m_Iterator = + static_cast(it.m_pImp)->m_Iterator; + break; + case MapType: + m_Type = MapType; + pNewImp = new MapIteratorImp; + static_cast(pNewImp)->m_Iterator = + static_cast(it.m_pImp)->m_Iterator; + break; + default: + break; + } + + m_pImp = pNewImp; + return *this; +} + +std::pair Iterator::operator*() { + switch (m_Type) { + case SequenceType: + return { + g_EmptyString, + *(static_cast(m_pImp)->m_Iterator->second)}; + break; + case MapType: + return {static_cast(m_pImp)->m_Iterator->first, + *(static_cast(m_pImp)->m_Iterator->second)}; + break; + default: + break; + } + + g_NoneNode.Clear(); + return {g_EmptyString, g_NoneNode}; +} + +Iterator &Iterator::operator++(int dummy) { + switch (m_Type) { + case SequenceType: + static_cast(m_pImp)->m_Iterator++; + break; + case MapType: + static_cast(m_pImp)->m_Iterator++; + break; + default: + break; + } + return *this; +} + +Iterator &Iterator::operator--(int dummy) { + switch (m_Type) { + case SequenceType: + static_cast(m_pImp)->m_Iterator--; + break; + case MapType: + static_cast(m_pImp)->m_Iterator--; + break; + default: + break; + } + return *this; +} + +bool Iterator::operator==(const Iterator &it) { + if (m_Type != it.m_Type) { + return false; + } + + switch (m_Type) { + case SequenceType: + return static_cast(m_pImp)->m_Iterator == + static_cast(it.m_pImp)->m_Iterator; + break; + case MapType: + return static_cast(m_pImp)->m_Iterator == + static_cast(it.m_pImp)->m_Iterator; + break; + default: + break; + } + + return false; +} + +bool Iterator::operator!=(const Iterator &it) { return !(*this == it); } + +// Const Iterator class +ConstIterator::ConstIterator() : m_Type(None), m_pImp(nullptr) {} + +ConstIterator::~ConstIterator() { + if (m_pImp) { + switch (m_Type) { + case SequenceType: + delete static_cast(m_pImp); + break; + case MapType: + delete static_cast(m_pImp); + break; + default: + break; + } + } +} + +ConstIterator::ConstIterator(const ConstIterator &it) + : m_Type(None), m_pImp(nullptr) { + *this = it; +} + +ConstIterator &ConstIterator::operator=(const ConstIterator &it) { + if (m_pImp) { + switch (m_Type) { + case SequenceType: + delete static_cast(m_pImp); + break; + case MapType: + delete static_cast(m_pImp); + break; + default: + break; + } + m_pImp = nullptr; + m_Type = None; + } + + IteratorImp *pNewImp = nullptr; + + switch (it.m_Type) { + case SequenceType: + m_Type = SequenceType; + pNewImp = new SequenceConstIteratorImp; + static_cast(pNewImp)->m_Iterator = + static_cast(it.m_pImp)->m_Iterator; + break; + case MapType: + m_Type = MapType; + pNewImp = new MapConstIteratorImp; + static_cast(pNewImp)->m_Iterator = + static_cast(it.m_pImp)->m_Iterator; + break; + default: + break; + } + + m_pImp = pNewImp; + return *this; +} + +std::pair ConstIterator::operator*() { + switch (m_Type) { + case SequenceType: + return {g_EmptyString, *(static_cast(m_pImp) + ->m_Iterator->second)}; + break; + case MapType: + return { + static_cast(m_pImp)->m_Iterator->first, + *(static_cast(m_pImp)->m_Iterator->second)}; + break; + default: + break; + } + + g_NoneNode.Clear(); + return {g_EmptyString, g_NoneNode}; +} + +ConstIterator &ConstIterator::operator++(int dummy) { + switch (m_Type) { + case SequenceType: + static_cast(m_pImp)->m_Iterator++; + break; + case MapType: + static_cast(m_pImp)->m_Iterator++; + break; + default: + break; + } + return *this; +} + +ConstIterator &ConstIterator::operator--(int dummy) { + switch (m_Type) { + case SequenceType: + static_cast(m_pImp)->m_Iterator--; + break; + case MapType: + static_cast(m_pImp)->m_Iterator--; + break; + default: + break; + } + return *this; +} + +bool ConstIterator::operator==(const ConstIterator &it) { + if (m_Type != it.m_Type) { + return false; + } + + switch (m_Type) { + case SequenceType: + return static_cast(m_pImp)->m_Iterator == + static_cast(it.m_pImp)->m_Iterator; + break; + case MapType: + return static_cast(m_pImp)->m_Iterator == + static_cast(it.m_pImp)->m_Iterator; + break; + default: + break; + } + + return false; +} + +bool ConstIterator::operator!=(const ConstIterator &it) { + return !(*this == it); +} + +// Node class +Node::Node() : m_pImp(new NodeImp) {} + +Node::Node(const Node &node) : Node() { *this = node; } + +Node::Node(const std::string &value) : Node() { *this = value; } + +Node::Node(const char *value) : Node() { *this = value; } + +Node::~Node() { delete static_cast(m_pImp); } + +Node::eType Node::Type() const { return NODE_IMP->m_Type; } + +bool Node::IsNone() const { return NODE_IMP->m_Type == Node::None; } + +bool Node::IsSequence() const { return NODE_IMP->m_Type == Node::SequenceType; } + +bool Node::IsMap() const { return NODE_IMP->m_Type == Node::MapType; } + +bool Node::IsScalar() const { return NODE_IMP->m_Type == Node::ScalarType; } + +void Node::Clear() { NODE_IMP->Clear(); } + +size_t Node::Size() const { + if (TYPE_IMP == nullptr) { + return 0; + } + + return TYPE_IMP->GetSize(); +} + +Node &Node::Insert(const size_t index) { + NODE_IMP->InitSequence(); + return *TYPE_IMP->Insert(index); +} + +Node &Node::PushFront() { + NODE_IMP->InitSequence(); + return *TYPE_IMP->PushFront(); +} +Node &Node::PushBack() { + NODE_IMP->InitSequence(); + return *TYPE_IMP->PushBack(); +} + +Node &Node::operator[](const size_t index) { + NODE_IMP->InitSequence(); + Node *pNode = TYPE_IMP->GetNode(index); + if (pNode == nullptr) { + g_NoneNode.Clear(); + return g_NoneNode; + } + return *pNode; +} + +Node &Node::operator[](const std::string &key) { + NODE_IMP->InitMap(); + return *TYPE_IMP->GetNode(key); +} + +void Node::Erase(const size_t index) { + if (TYPE_IMP == nullptr || NODE_IMP->m_Type != Node::SequenceType) { + return; + } + + return TYPE_IMP->Erase(index); +} + +void Node::Erase(const std::string &key) { + if (TYPE_IMP == nullptr || NODE_IMP->m_Type != Node::MapType) { + return; + } + + return TYPE_IMP->Erase(key); +} + +Node &Node::operator=(const Node &node) { + NODE_IMP->Clear(); + CopyNode(node, *this); + return *this; +} + +Node &Node::operator=(const std::string &value) { + NODE_IMP->InitScalar(); + TYPE_IMP->SetData(value); + return *this; +} + +Node &Node::operator=(const char *value) { + NODE_IMP->InitScalar(); + TYPE_IMP->SetData(value ? std::string(value) : ""); + return *this; +} + +Iterator Node::Begin() { + Iterator it; + + if (TYPE_IMP != nullptr) { + IteratorImp *pItImp = nullptr; + + switch (NODE_IMP->m_Type) { + case Node::SequenceType: + it.m_Type = Iterator::SequenceType; + pItImp = new SequenceIteratorImp; + pItImp->InitBegin(static_cast(TYPE_IMP)); + break; + case Node::MapType: + it.m_Type = Iterator::MapType; + pItImp = new MapIteratorImp; + pItImp->InitBegin(static_cast(TYPE_IMP)); + break; + default: + break; + } + + it.m_pImp = pItImp; + } + + return it; +} + +ConstIterator Node::Begin() const { + ConstIterator it; + + if (TYPE_IMP != nullptr) { + IteratorImp *pItImp = nullptr; + + switch (NODE_IMP->m_Type) { + case Node::SequenceType: + it.m_Type = ConstIterator::SequenceType; + pItImp = new SequenceConstIteratorImp; + pItImp->InitBegin(static_cast(TYPE_IMP)); + break; + case Node::MapType: + it.m_Type = ConstIterator::MapType; + pItImp = new MapConstIteratorImp; + pItImp->InitBegin(static_cast(TYPE_IMP)); + break; + default: + break; + } + + it.m_pImp = pItImp; + } + + return it; +} + +Iterator Node::End() { + Iterator it; + + if (TYPE_IMP != nullptr) { + IteratorImp *pItImp = nullptr; + + switch (NODE_IMP->m_Type) { + case Node::SequenceType: + it.m_Type = Iterator::SequenceType; + pItImp = new SequenceIteratorImp; + pItImp->InitEnd(static_cast(TYPE_IMP)); + break; + case Node::MapType: + it.m_Type = Iterator::MapType; + pItImp = new MapIteratorImp; + pItImp->InitEnd(static_cast(TYPE_IMP)); + break; + default: + break; + } + + it.m_pImp = pItImp; + } + + return it; +} + +ConstIterator Node::End() const { + ConstIterator it; + + if (TYPE_IMP != nullptr) { + IteratorImp *pItImp = nullptr; + + switch (NODE_IMP->m_Type) { + case Node::SequenceType: + it.m_Type = ConstIterator::SequenceType; + pItImp = new SequenceConstIteratorImp; + pItImp->InitEnd(static_cast(TYPE_IMP)); + break; + case Node::MapType: + it.m_Type = ConstIterator::MapType; + pItImp = new MapConstIteratorImp; + pItImp->InitEnd(static_cast(TYPE_IMP)); + break; + default: + break; + } + + it.m_pImp = pItImp; + } + + return it; +} + +const std::string &Node::AsString() const { + if (TYPE_IMP == nullptr) { + return g_EmptyString; + } + + return TYPE_IMP->GetData(); +} + +// Reader implementations +/** + * @breif Line information structure. + * + */ +class ReaderLine { + public: + /** + * @breif Constructor. + * + */ + ReaderLine(const std::string &data = "", const size_t no = 0, + const size_t offset = 0, const Node::eType type = Node::None, + const unsigned char flags = 0) + : Data(data), + No(no), + Offset(offset), + Type(type), + Flags(flags), + NextLine(nullptr) {} + + enum eFlag { + LiteralScalarFlag, ///< Literal scalar type, defined as "|". + FoldedScalarFlag, ///< Folded scalar type, defined as "<". + ScalarNewlineFlag ///< Scalar ends with a newline. + }; + + /** + * @breif Set flag. + * + */ + void SetFlag(const eFlag flag) { + Flags |= FlagMask[static_cast(flag)]; + } + + /** + * @breif Set flags by mask value. + * + */ + void SetFlags(const unsigned char flags) { Flags |= flags; } + + /** + * @breif Unset flag. + * + */ + void UnsetFlag(const eFlag flag) { + Flags &= ~FlagMask[static_cast(flag)]; + } + + /** + * @breif Unset flags by mask value. + * + */ + void UnsetFlags(const unsigned char flags) { Flags &= ~flags; } + + /** + * @breif Get flag value. + * + */ + bool GetFlag(const eFlag flag) const { + return Flags & FlagMask[static_cast(flag)]; + } + + /** + * @breif Copy and replace scalar flags from another ReaderLine. + * + */ + void CopyScalarFlags(ReaderLine *from) { + if (from == nullptr) { + return; + } + + unsigned char newFlags = + from->Flags & (FlagMask[0] | FlagMask[1] | FlagMask[2]); + Flags |= newFlags; + } + + static const unsigned char FlagMask[3]; + + std::string Data; ///< Data of line. + size_t No; ///< Line number. + size_t Offset; ///< Offset to first character in data. + Node::eType Type; ///< Type of line. + unsigned char Flags; ///< Flags of line. + ReaderLine *NextLine; ///< Pointer to next line. +}; + +const unsigned char ReaderLine::FlagMask[3] = {0x01, 0x02, 0x04}; + +/** + * @breif Implementation class of Yaml parsing. + * Parsing incoming stream and outputs a root node. + * + */ +class ParseImp { + public: + /** + * @breif Default constructor. + * + */ + ParseImp() {} + + /** + * @breif Destructor. + * + */ + ~ParseImp() { ClearLines(); } + + /** + * @breif Run full parsing procedure. + * + */ + void Parse(Node &root, std::iostream &stream) { // NOLINT + try { + root.Clear(); + ReadLines(stream); + PostProcessLines(); + // Print(); + ParseRoot(root); + } catch (Exception e) { + root.Clear(); + throw; + } + } + + private: + /** + * @breif Copy constructor. + * + */ + ParseImp(const ParseImp ©) {} + + /** + * @breif Read all lines. + * Ignoring: + * - Empty lines. + * - Comments. + * - Document start/end. + * + */ + void ReadLines(std::iostream &stream) { + std::string line = ""; + size_t lineNo = 0; + bool documentStartFound = false; + bool foundFirstNotEmpty = false; + std::streampos streamPos = 0; + + // Read all lines, as long as the stream is ok. + while (!stream.eof() && !stream.fail()) { + // Read line + streamPos = stream.tellg(); + std::getline(stream, line); + lineNo++; + + // Remove comment + const size_t commentPos = FindNotCited(line, '#'); + if (commentPos != std::string::npos) { + line.resize(commentPos); + } + + // Start of document. + if (documentStartFound == false && line == "---") { + // Erase all lines before this line. + ClearLines(); + documentStartFound = true; + continue; + } + + // End of document. + if (line == "...") { + break; + } else if (line == "---") { + stream.seekg(streamPos); + break; + } + + // Remove trailing return. + if (line.size()) { + if (line[line.size() - 1] == '\r') { + line.resize(line.size() - 1); + } + } + + // Validate characters. + for (size_t i = 0; i < line.size(); i++) { + if (line[i] != '\t' && (line[i] < 32 || line[i] > 125)) { + throw ParsingException( + ExceptionMessage(g_ErrorInvalidCharacter, lineNo, i + 1)); + } + } + + // Validate tabs + const size_t firstTabPos = line.find_first_of('\t'); + size_t startOffset = line.find_first_not_of(" \t"); + + // Make sure no tabs are in the very front. + if (startOffset != std::string::npos) { + if (firstTabPos < startOffset) { + throw ParsingException( + ExceptionMessage(g_ErrorTabInOffset, lineNo, firstTabPos)); + } + + // Remove front spaces. + line = line.substr(startOffset); + } else { + startOffset = 0; + line = ""; + } + + // Add line. + if (foundFirstNotEmpty == false) { + if (line.size()) { + foundFirstNotEmpty = true; + } else { + continue; + } + } + + ReaderLine *pLine = new ReaderLine(line, lineNo, startOffset); + m_Lines.push_back(pLine); + } + } + + /** + * @breif Run post-processing on all lines. + * Basically split lines into multiple lines if needed, to follow the + * parsing algorithm. + * + */ + void PostProcessLines() { + for (auto it = m_Lines.begin(); it != m_Lines.end();) { + // Sequence. + if (PostProcessSequenceLine(it) == true) { + continue; + } + + // Mapping. + if (PostProcessMappingLine(it) == true) { + continue; + } + + // Scalar. + PostProcessScalarLine(it); + } + + // Set next line of all lines. + if (m_Lines.size()) { + if (m_Lines.back()->Type != Node::ScalarType) { + throw ParsingException( + ExceptionMessage(g_ErrorUnexpectedDocumentEnd, *m_Lines.back())); + } + + if (m_Lines.size() > 1) { + auto prevEnd = m_Lines.end(); + --prevEnd; + + for (auto it = m_Lines.begin(); it != prevEnd; it++) { + auto nextIt = it; + ++nextIt; + + (*it)->NextLine = *nextIt; + } + } + } + } + + /** + * @breif Run post-processing and check for sequence. + * Split line into two lines if sequence token is not on it's own line. + * + * @return true if line is sequence, else false. + * + */ + bool PostProcessSequenceLine( + std::list::iterator &it) { // NOLINT + ReaderLine *pLine = *it; + + // Sequence split + if (IsSequenceStart(pLine->Data) == false) { + return false; + } + + pLine->Type = Node::SequenceType; + + ClearTrailingEmptyLines(++it); + + const size_t valueStart = pLine->Data.find_first_not_of(" \t", 1); + if (valueStart == std::string::npos) { + return true; + } + + // Create new line and insert + std::string newLine = pLine->Data.substr(valueStart); + it = m_Lines.insert( + it, new ReaderLine(newLine, pLine->No, pLine->Offset + valueStart)); + pLine->Data = ""; + + return false; + } + + /** + * @breif Run post-processing and check for mapping. + * Split line into two lines if mapping value is not on it's own line. + * + * @return true if line is mapping, else move on to scalar parsing. + * + */ + bool PostProcessMappingLine( + std::list::iterator &it) { // NOLINT + ReaderLine *pLine = *it; + + // Find map key. + size_t preKeyQuotes = 0; + size_t tokenPos = FindNotCited(pLine->Data, ':', preKeyQuotes); + if (tokenPos == std::string::npos) { + return false; + } + if (preKeyQuotes > 1) { + throw ParsingException(ExceptionMessage(g_ErrorKeyIncorrect, *pLine)); + } + + pLine->Type = Node::MapType; + + // Get key + std::string key = pLine->Data.substr(0, tokenPos); + const size_t keyEnd = key.find_last_not_of(" \t"); + if (keyEnd == std::string::npos) { + throw ParsingException(ExceptionMessage(g_ErrorKeyMissing, *pLine)); + } + key.resize(keyEnd + 1); + + // Handle cited key. + if (preKeyQuotes == 1) { + if (key.front() != '"' || key.back() != '"') { + throw ParsingException(ExceptionMessage(g_ErrorKeyIncorrect, *pLine)); + } + + key = key.substr(1, key.size() - 2); + } + RemoveAllEscapeTokens(key); + + // Get value + std::string value = ""; + size_t valueStart = std::string::npos; + if (tokenPos + 1 != pLine->Data.size()) { + valueStart = pLine->Data.find_first_not_of(" \t", tokenPos + 1); + if (valueStart != std::string::npos) { + value = pLine->Data.substr(valueStart); + } + } + + // Make sure the value is not a sequence start. + if (IsSequenceStart(value) == true) { + throw ParsingException( + ExceptionMessage(g_ErrorBlockSequenceNotAllowed, *pLine, valueStart)); + } + + pLine->Data = key; + + // Remove all empty lines after map key. + ClearTrailingEmptyLines(++it); + + // Add new empty line? + size_t newLineOffset = valueStart; + if (newLineOffset == std::string::npos) { + if (it != m_Lines.end() && (*it)->Offset > pLine->Offset) { + return true; + } + + newLineOffset = tokenPos + 2; + } else { + newLineOffset += pLine->Offset; + } + + // Add new line with value. + unsigned char dummyBlockFlags = 0; + if (IsBlockScalar(value, pLine->No, dummyBlockFlags) == true) { + newLineOffset = pLine->Offset; + } + ReaderLine *pNewLine = + new ReaderLine(value, pLine->No, newLineOffset, Node::ScalarType); + it = m_Lines.insert(it, pNewLine); + + // Return false in order to handle next line(scalar value). + return false; + } + + /** + * @breif Run post-processing and check for scalar. + * Checking for multi-line scalars. + * + * @return true if scalar search should continue, else false. + * + */ + void PostProcessScalarLine( + std::list::iterator &it) { // NOLINT + ReaderLine *pLine = *it; + pLine->Type = Node::ScalarType; + + size_t parentOffset = pLine->Offset; + if (pLine != m_Lines.front()) { + std::list::iterator lastIt = it; + --lastIt; + parentOffset = (*lastIt)->Offset; + } + + std::list::iterator lastNotEmpty = it++; + + // Find last empty lines + while (it != m_Lines.end()) { + pLine = *it; + pLine->Type = Node::ScalarType; + if (pLine->Data.size()) { + if (pLine->Offset <= parentOffset) { + break; + } else { + lastNotEmpty = it; + } + } + ++it; + } + + ClearTrailingEmptyLines(++lastNotEmpty); + } + + /** + * @breif Process root node and start of document. + * + */ + void ParseRoot(Node &root) { // NOLINT + // Get first line and start type. + auto it = m_Lines.begin(); + if (it == m_Lines.end()) { + return; + } + Node::eType type = (*it)->Type; + ReaderLine *pLine = *it; + + // Handle next line. + switch (type) { + case Node::SequenceType: + ParseSequence(root, it); + break; + case Node::MapType: + ParseMap(root, it); + break; + case Node::ScalarType: + ParseScalar(root, it); + break; + default: + break; + } + + if (it != m_Lines.end()) { + throw InternalException( + ExceptionMessage(g_ErrorUnexpectedDocumentEnd, *pLine)); + } + } + + /** + * @breif Process sequence node. + * + */ + void ParseSequence( + Node &node, std::list::iterator &it) { // NOLINT + ReaderLine *pNextLine = nullptr; + while (it != m_Lines.end()) { + ReaderLine *pLine = *it; + Node &childNode = node.PushBack(); + + // Move to next line, error check. + ++it; + if (it == m_Lines.end()) { + throw InternalException( + ExceptionMessage(g_ErrorUnexpectedDocumentEnd, *pLine)); + } + + // Handle value of map + Node::eType valueType = (*it)->Type; + switch (valueType) { + case Node::SequenceType: + ParseSequence(childNode, it); + break; + case Node::MapType: + ParseMap(childNode, it); + break; + case Node::ScalarType: + ParseScalar(childNode, it); + break; + default: + break; + } + + // Check next line. if sequence and correct level, go on, else exit. + // If same level but but of type map = error. + if (it == m_Lines.end() || ((pNextLine = *it)->Offset < pLine->Offset)) { + break; + } + if (pNextLine->Offset > pLine->Offset) { + throw ParsingException( + ExceptionMessage(g_ErrorIncorrectOffset, *pNextLine)); + } + if (pNextLine->Type != Node::SequenceType) { + throw InternalException( + ExceptionMessage(g_ErrorDiffEntryNotAllowed, *pNextLine)); + } + } + } + + /** + * @breif Process map node. + * + */ + void ParseMap(Node &node, // NOLINT + std::list::iterator &it) { // NOLINT + ReaderLine *pNextLine = nullptr; + while (it != m_Lines.end()) { + ReaderLine *pLine = *it; + Node &childNode = node[pLine->Data]; + + // Move to next line, error check. + ++it; + if (it == m_Lines.end()) { + throw InternalException( + ExceptionMessage(g_ErrorUnexpectedDocumentEnd, *pLine)); + } + + // Handle value of map + Node::eType valueType = (*it)->Type; + switch (valueType) { + case Node::SequenceType: + ParseSequence(childNode, it); + break; + case Node::MapType: + ParseMap(childNode, it); + break; + case Node::ScalarType: + ParseScalar(childNode, it); + break; + default: + break; + } + + // Check next line. if map and correct level, go on, else exit. + // if same level but but of type map = error. + if (it == m_Lines.end() || ((pNextLine = *it)->Offset < pLine->Offset)) { + break; + } + if (pNextLine->Offset > pLine->Offset) { + throw ParsingException( + ExceptionMessage(g_ErrorIncorrectOffset, *pNextLine)); + } + if (pNextLine->Type != pLine->Type) { + throw InternalException( + ExceptionMessage(g_ErrorDiffEntryNotAllowed, *pNextLine)); + } + } + } + + /** + * @breif Process scalar node. + * + */ + void ParseScalar( + Node &node, std::list::iterator &it) { // NOLINT + std::string data = ""; + ReaderLine *pFirstLine = *it; + ReaderLine *pLine = *it; + + // Check if current line is a block scalar. + unsigned char blockFlags = 0; + bool isBlockScalar = IsBlockScalar(pLine->Data, pLine->No, blockFlags); + const bool newLineFlag = + static_cast(blockFlags & ReaderLine::FlagMask[static_cast( + ReaderLine::ScalarNewlineFlag)]); + const bool foldedFlag = + static_cast(blockFlags & ReaderLine::FlagMask[static_cast( + ReaderLine::FoldedScalarFlag)]); + const bool literalFlag = + static_cast(blockFlags & ReaderLine::FlagMask[static_cast( + ReaderLine::LiteralScalarFlag)]); + size_t parentOffset = 0; + + // Find parent offset + if (it != m_Lines.begin()) { + std::list::iterator parentIt = it; + --parentIt; + parentOffset = (*parentIt)->Offset; + } + + // Move to next iterator/line if current line is a block scalar. + if (isBlockScalar) { + ++it; + if (it == m_Lines.end() || (pLine = *it)->Type != Node::ScalarType) { + return; + } + } + + // Not a block scalar, cut end spaces/tabs + if (isBlockScalar == false) { + while (1) { + pLine = *it; + + if (parentOffset != 0 && pLine->Offset <= parentOffset) { + throw ParsingException( + ExceptionMessage(g_ErrorIncorrectOffset, *pLine)); + } + + const size_t endOffset = pLine->Data.find_last_not_of(" \t"); + if (endOffset == std::string::npos) { + data += "\n"; + } else { + data += pLine->Data.substr(0, endOffset + 1); + } + + // Move to next line + ++it; + if (it == m_Lines.end() || (*it)->Type != Node::ScalarType) { + break; + } + + data += " "; + } + + if (ValidateQuote(data) == false) { + throw ParsingException( + ExceptionMessage(g_ErrorInvalidQuote, *pFirstLine)); + } + } else { // Block scalar + pLine = *it; + size_t blockOffset = pLine->Offset; + if (blockOffset <= parentOffset) { + throw ParsingException( + ExceptionMessage(g_ErrorIncorrectOffset, *pLine)); + } + + bool addedSpace = false; + while (it != m_Lines.end() && (*it)->Type == Node::ScalarType) { + pLine = *it; + + const size_t endOffset = pLine->Data.find_last_not_of(" \t"); + if (endOffset != std::string::npos && pLine->Offset < blockOffset) { + throw ParsingException( + ExceptionMessage(g_ErrorIncorrectOffset, *pLine)); + } + + if (endOffset == std::string::npos) { + if (addedSpace) { + data[data.size() - 1] = '\n'; + addedSpace = false; + } else { + data += "\n"; + } + + ++it; + continue; + } else { + if (blockOffset != pLine->Offset && foldedFlag) { + if (addedSpace) { + data[data.size() - 1] = '\n'; + addedSpace = false; + } else { + data += "\n"; + } + } + data += std::string(pLine->Offset - blockOffset, ' '); + data += pLine->Data; + } + + // Move to next line + ++it; + if (it == m_Lines.end() || (*it)->Type != Node::ScalarType) { + if (newLineFlag) { + data += "\n"; + } + break; + } + + if (foldedFlag) { + data += " "; + addedSpace = true; + } else if (literalFlag && endOffset != std::string::npos) { + data += "\n"; + } + } + } + + if (data.size() && (data[0] == '"' || data[0] == '\'')) { + data = data.substr(1, data.size() - 2); + } + + node = data; + } + + /** + * @breif Debug printing. + * + */ + void Print() { + for (auto it = m_Lines.begin(); it != m_Lines.end(); it++) { + ReaderLine *pLine = *it; + + // Print type + if (pLine->Type == Node::SequenceType) { + std::cout << "seq "; + } else if (pLine->Type == Node::MapType) { + std::cout << "map "; + } else if (pLine->Type == Node::ScalarType) { + std::cout << "sca "; + } else { + std::cout << " "; + } + + // Print flags + if (pLine->GetFlag(ReaderLine::FoldedScalarFlag)) { + std::cout << "f"; + } else { + std::cout << "-"; + } + if (pLine->GetFlag(ReaderLine::LiteralScalarFlag)) { + std::cout << "l"; + } else { + std::cout << "-"; + } + if (pLine->GetFlag(ReaderLine::ScalarNewlineFlag)) { + std::cout << "n"; + } else { + std::cout << "-"; + } + if (pLine->NextLine == nullptr) { + std::cout << "e"; + } else { + std::cout << "-"; + } + + std::cout << "| "; + std::cout << pLine->No << " "; + std::cout << std::string(pLine->Offset, ' '); + + if (pLine->Type == Node::ScalarType) { + std::string scalarValue = pLine->Data; + for (size_t i = 0; + (i = scalarValue.find("\n", i)) != std::string::npos;) { + scalarValue.replace(i, 1, "\\n"); + i += 2; + } + std::cout << scalarValue << std::endl; + } else if (pLine->Type == Node::MapType) { + std::cout << pLine->Data + ":" << std::endl; + } else if (pLine->Type == Node::SequenceType) { + std::cout << "-" << std::endl; + } else { + std::cout << "> UNKOWN TYPE <" << std::endl; + } + } + } + + /** + * @breif Clear all read lines. + * + */ + void ClearLines() { + for (auto it = m_Lines.begin(); it != m_Lines.end(); it++) { + delete *it; + } + m_Lines.clear(); + } + + void ClearTrailingEmptyLines( + std::list::iterator &it) { // NOLINT + while (it != m_Lines.end()) { + ReaderLine *pLine = *it; + if (pLine->Data.size() == 0) { + delete *it; + it = m_Lines.erase(it); + } else { + return; + } + } + } + + static bool IsSequenceStart(const std::string &data) { + if (data.size() == 0 || data[0] != '-') { + return false; + } + + if (data.size() >= 2 && data[1] != ' ') { + return false; + } + + return true; + } + + static bool IsBlockScalar(const std::string &data, + const size_t line, unsigned char &flags) { // NOLINT + flags = 0; + if (data.size() == 0) { + return false; + } + + if (data[0] == '|') { + if (data.size() >= 2) { + if (data[1] != '-' && data[1] != ' ' && data[1] != '\t') { + throw ParsingException( + ExceptionMessage(g_ErrorInvalidBlockScalar, line, data)); + } + } else { + flags |= ReaderLine::FlagMask[static_cast( + ReaderLine::ScalarNewlineFlag)]; + } + flags |= ReaderLine::FlagMask[static_cast( + ReaderLine::LiteralScalarFlag)]; + return true; + } + + if (data[0] == '>') { + if (data.size() >= 2) { + if (data[1] != '-' && data[1] != ' ' && data[1] != '\t') { + throw ParsingException( + ExceptionMessage(g_ErrorInvalidBlockScalar, line, data)); + } + } else { + flags |= ReaderLine::FlagMask[static_cast( + ReaderLine::ScalarNewlineFlag)]; + } + flags |= ReaderLine::FlagMask[static_cast( + ReaderLine::FoldedScalarFlag)]; + return true; + } + + return false; + } + + std::list m_Lines; ///< List of lines. +}; + +// Parsing functions +void Parse(Node &root, const char *filename) { // NOLINT + std::ifstream f(filename, std::ifstream::binary); + if (f.is_open() == false) { + throw OperationException(g_ErrorCannotOpenFile); + } + + f.seekg(0, f.end); + size_t fileSize = static_cast(f.tellg()); + f.seekg(0, f.beg); + + std::unique_ptr data(new char[fileSize]); + f.read(data.get(), fileSize); + f.close(); + + Parse(root, data.get(), fileSize); +} + +void Parse(Node &root, std::iostream &stream) { // NOLINT + ParseImp *pImp = nullptr; + + try { + pImp = new ParseImp; + pImp->Parse(root, stream); + delete pImp; + } catch (const Exception e) { + delete pImp; + throw; + } +} + +void Parse(Node &root, const std::string &string) { // NOLINT + std::stringstream ss(string); + Parse(root, ss); +} + +void Parse(Node &root, const char *buffer, const size_t size) { // NOLINT + std::stringstream ss(std::string(buffer, size)); + Parse(root, ss); +} + +// Serialize configuration structure. +SerializeConfig::SerializeConfig(const size_t spaceIndentation, + const size_t scalarMaxLength, + const bool sequenceMapNewline, + const bool mapScalarNewline) + : SpaceIndentation(spaceIndentation), + ScalarMaxLength(scalarMaxLength), + SequenceMapNewline(sequenceMapNewline), + MapScalarNewline(mapScalarNewline) {} + +// Serialization functions +void Serialize(const Node &root, const char *filename, + const SerializeConfig &config) { + std::stringstream stream; + Serialize(root, stream, config); + + std::ofstream f(filename); + if (f.is_open() == false) { + throw OperationException(g_ErrorCannotOpenFile); + } + + f.write(stream.str().c_str(), stream.str().size()); + f.close(); +} + +size_t LineFolding(const std::string &input, + std::vector &folded, // NOLINT + const size_t maxLength) { + folded.clear(); + if (input.size() == 0) { + return 0; + } + + size_t currentPos = 0; + size_t lastPos = 0; + size_t spacePos = std::string::npos; + while (currentPos < input.size()) { + currentPos = lastPos + maxLength; + + if (currentPos < input.size()) { + spacePos = input.find_first_of(' ', currentPos); + } + + if (spacePos == std::string::npos || currentPos >= input.size()) { + const std::string endLine = input.substr(lastPos); + if (endLine.size()) { + folded.push_back(endLine); + } + + return folded.size(); + } + + folded.push_back(input.substr(lastPos, spacePos - lastPos)); + + lastPos = spacePos + 1; + } + + return folded.size(); +} + +static void SerializeLoop(const Node &node, std::iostream &stream, + bool useLevel, const size_t level, + const SerializeConfig &config) { + const size_t indention = config.SpaceIndentation; + + switch (node.Type()) { + case Node::SequenceType: { + for (auto it = node.Begin(); it != node.End(); it++) { + const Node &value = (*it).second; + if (value.IsNone()) { + continue; + } + stream << std::string(level, ' ') << "- "; + useLevel = false; + if (value.IsSequence() || + (value.IsMap() && config.SequenceMapNewline == true)) { + useLevel = true; + stream << "\n"; + } + + SerializeLoop(value, stream, useLevel, level + 2, config); + } + } break; + case Node::MapType: { + size_t count = 0; + for (auto it = node.Begin(); it != node.End(); it++) { + const Node &value = (*it).second; + if (value.IsNone()) { + continue; + } + + if (useLevel || count > 0) { + stream << std::string(level, ' '); + } + + std::string key = (*it).first; + AddEscapeTokens(key, "\\\""); + if (ShouldBeCited(key)) { + stream << "\"" << key << "\"" + << ": "; + } else { + stream << key << ": "; + } + + useLevel = false; + if (value.IsScalar() == false || + (value.IsScalar() && config.MapScalarNewline)) { + useLevel = true; + stream << "\n"; + } + + SerializeLoop(value, stream, useLevel, level + indention, config); + + useLevel = true; + count++; + } + } break; + case Node::ScalarType: { + const std::string value = node.As(); + + // Empty scalar + if (value.size() == 0) { + stream << "\n"; + break; + } + + // Get lines of scalar. + std::string line = ""; + std::vector lines; + std::istringstream iss(value); + while (iss.eof() == false) { + std::getline(iss, line); + lines.push_back(line); + } + + // Block scalar + const std::string &lastLine = lines.back(); + const bool endNewline = lastLine.size() == 0; + if (endNewline) { + lines.pop_back(); + } + + // Literal + if (lines.size() > 1) { + stream << "|"; + } else { // Folded/plain + const std::string frontLine = lines.front(); + if (config.ScalarMaxLength == 0 || + lines.front().size() <= config.ScalarMaxLength || + LineFolding(frontLine, lines, config.ScalarMaxLength) == 1) { + if (useLevel) { + stream << std::string(level, ' '); + } + + if (ShouldBeCited(value)) { + stream << "\"" << value << "\"\n"; + break; + } + stream << value << "\n"; + break; + } else { + stream << ">"; + } + } + + if (endNewline == false) { + stream << "-"; + } + stream << "\n"; + + for (auto it = lines.begin(); it != lines.end(); it++) { + stream << std::string(level, ' ') << (*it) << "\n"; + } + } break; + + default: + break; + } +} + +void Serialize(const Node &root, std::iostream &stream, + const SerializeConfig &config) { + if (config.SpaceIndentation < 2) { + throw OperationException(g_ErrorIndentation); + } + + SerializeLoop(root, stream, false, 0, config); +} + +void Serialize(const Node &root, + std::string &string, // NOLINT + const SerializeConfig &config) { + std::stringstream stream; + Serialize(root, stream, config); + string = stream.str(); +} + +// Static function implementations +std::string ExceptionMessage(const std::string &message, + ReaderLine &line) { // NOLINT + return message + std::string(" Line ") + std::to_string(line.No) + + std::string(": ") + line.Data; +} + +std::string ExceptionMessage(const std::string &message, + ReaderLine &line, const size_t errorPos) { // NOLINT + return message + std::string(" Line ") + std::to_string(line.No) + + std::string(" column ") + std::to_string(errorPos + 1) + + std::string(": ") + line.Data; +} + +std::string ExceptionMessage(const std::string &message, const size_t errorLine, + const size_t errorPos) { + return message + std::string(" Line ") + std::to_string(errorLine) + + std::string(" column ") + std::to_string(errorPos); +} + +std::string ExceptionMessage(const std::string &message, const size_t errorLine, + const std::string &data) { + return message + std::string(" Line ") + std::to_string(errorLine) + + std::string(": ") + data; +} + +bool FindQuote(const std::string &input, + size_t &start, size_t &end, // NOLINT + size_t searchPos) { + start = end = std::string::npos; + size_t qPos = searchPos; + bool foundStart = false; + + while (qPos != std::string::npos) { + // Find first quote. + qPos = input.find_first_of("\"'", qPos); + if (qPos == std::string::npos) { + return false; + } + + const char token = input[qPos]; + if (token == '"' && (qPos == 0 || input[qPos - 1] != '\\')) { + // Found start quote. + if (foundStart == false) { + start = qPos; + foundStart = true; + } else { + end = qPos; + return true; + } + } + + // Check if it's possible for another loop. + if (qPos + 1 == input.size()) { + return false; + } + qPos++; + } + + return false; +} + +size_t FindNotCited(const std::string &input, char token, + size_t &preQuoteCount) { // NOLINT + preQuoteCount = 0; + size_t tokenPos = input.find_first_of(token); + if (tokenPos == std::string::npos) { + return std::string::npos; + } + + // Find all quotes + std::vector> quotes; + + size_t quoteStart = 0; + size_t quoteEnd = 0; + while (FindQuote(input, quoteStart, quoteEnd, quoteEnd)) { + quotes.push_back({quoteStart, quoteEnd}); + + if (quoteEnd + 1 == input.size()) { + break; + } + quoteEnd++; + } + + if (quotes.size() == 0) { + return tokenPos; + } + + size_t currentQuoteIndex = 0; + std::pair currentQuote = {0, 0}; + + while (currentQuoteIndex < quotes.size()) { + currentQuote = quotes[currentQuoteIndex]; + + if (tokenPos < currentQuote.first) { + return tokenPos; + } + preQuoteCount++; + if (tokenPos <= currentQuote.second) { + // Find next token + if (tokenPos + 1 == input.size()) { + return std::string::npos; + } + tokenPos = input.find_first_of(token, tokenPos + 1); + if (tokenPos == std::string::npos) { + return std::string::npos; + } + } + + currentQuoteIndex++; + } + + return tokenPos; +} + +size_t FindNotCited(const std::string &input, char token) { + size_t dummy = 0; + return FindNotCited(input, token, dummy); +} + +bool ValidateQuote(const std::string &input) { + if (input.size() == 0) { + return true; + } + + char token = 0; + size_t searchPos = 0; + if (input[0] == '\"' || input[0] == '\'') { + if (input.size() == 1) { + return false; + } + token = input[0]; + searchPos = 1; + } + + while (searchPos != std::string::npos && searchPos < input.size() - 1) { + searchPos = input.find_first_of("\"'", searchPos + 1); + if (searchPos == std::string::npos) { + break; + } + + const char foundToken = input[searchPos]; + + if (input[searchPos] == '\"' || input[searchPos] == '\'') { + if (token == 0 && input[searchPos - 1] != '\\') { + return false; + } + // if(foundToken == token) + //{ + + /*if(foundToken == token && searchPos == input.size() - 1 && + input[searchPos-1] != '\\') + { + return true; + if(searchPos == input.size() - 1) + { + return true; + } + return false; + } + else */ + if (foundToken == token && input[searchPos - 1] != '\\') { + if (searchPos == input.size() - 1) { + return true; + } + return false; + } + //} + } + } + + return token == 0; +} + +void CopyNode(const Node &from, Node &to) { // NOLINT + const Node::eType type = from.Type(); + + switch (type) { + case Node::SequenceType: + for (auto it = from.Begin(); it != from.End(); it++) { + const Node ¤tNode = (*it).second; + Node &newNode = to.PushBack(); + CopyNode(currentNode, newNode); + } + break; + case Node::MapType: + for (auto it = from.Begin(); it != from.End(); it++) { + const Node ¤tNode = (*it).second; + Node &newNode = to[(*it).first]; + CopyNode(currentNode, newNode); + } + break; + case Node::ScalarType: + to = from.As(); + break; + case Node::None: + break; + } +} + +bool ShouldBeCited(const std::string &key) { + return key.find_first_of("\":{}[],&*#?|-<>=!%@") != std::string::npos; +} + +void AddEscapeTokens(std::string &input, const std::string &tokens) { // NOLINT + for (auto it = tokens.begin(); it != tokens.end(); it++) { + const char token = *it; + const std::string replace = std::string("\\") + std::string(1, token); + size_t found = input.find_first_of(token); + while (found != std::string::npos) { + input.replace(found, 1, replace); + found = input.find_first_of(token, found + 2); + } + } +} + +void RemoveAllEscapeTokens(std::string &input) { // NOLINT + size_t found = input.find_first_of("\\"); + while (found != std::string::npos) { + if (found + 1 == input.size()) { + return; + } + + std::string replace(1, input[found + 1]); + input.replace(found, 2, replace); + found = input.find_first_of("\\", found + 1); + } +} + +} // namespace Yaml diff --git a/runtime/core/utils/Yaml.hpp b/runtime/core/utils/Yaml.hpp new file mode 100644 index 0000000000..2eb3027351 --- /dev/null +++ b/runtime/core/utils/Yaml.hpp @@ -0,0 +1,605 @@ +// Copyright (c) From https://github.com/jimmiebergmann/mini-yaml +// 2022 SoundDataConverge Co.LTD (Weiliang Chong) +/* +* MIT License +* +* Copyright(c) 2018 Jimmie Bergmann +* +* Permission is hereby granted, free of charge, to any person obtaining a copy +* of this software and associated documentation files(the "Software"), to deal +* in the Software without restriction, including without limitation the rights +* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +* copies of the Software, and to permit persons to whom the Software is +* furnished to do so, subject to the following conditions : +* +* The above copyright notice and this permission notice shall be included in all +* copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +* SOFTWARE. +* +*/ + +/* +YAML documentation: +http://yaml.org/spec/1.0/index.html +https://www.codeproject.com/Articles/28720/YAML-Parser-in-C +*/ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +/** +* @breif Namespace wrapping mini-yaml classes. +* +*/ +namespace Yaml { +/** +* @breif Forward declarations. +* +*/ +class Node; + + +/** +* @breif Helper classes and functions +* +*/ +namespace impl { +/** +* @breif Helper functionality, converting string to any data type. +* Strings are left untouched. +* +*/ +template +struct StringConverter { + static T Get(const std::string & data) { + T type; + std::stringstream ss(data); + ss >> type; + return type; + } + + static T Get(const std::string & data, const T & defaultValue) { + T type; + std::stringstream ss(data); + ss >> type; + + if (ss.fail()) { + return defaultValue; + } + + return type; + } +}; +template<> +struct StringConverter { + static std::string Get(const std::string & data) { + return data; + } + + static std::string Get( + const std::string & data, const std::string & defaultValue) { + if (data.size() == 0) { + return defaultValue; + } + return data; + } +}; + +template<> +struct StringConverter { + static bool Get(const std::string & data) { + std::string tmpData = data; + std::transform(tmpData.begin(), tmpData.end(), tmpData.begin(), ::tolower); + if (tmpData == "true" || tmpData == "yes" || tmpData == "1") { + return true; + } + + return false; + } + + static bool Get(const std::string & data, const bool & defaultValue) { + if (data.size() == 0) { + return defaultValue; + } + + return Get(data); + } +}; + +} // namespace impl + + +/** +* @breif Exception class. +* +*/ +class Exception : public std::runtime_error { + public: + /** + * @breif Enumeration of exception types. + * + */ + enum eType { + InternalError, ///< Internal error. + ParsingError, ///< Invalid parsing data. + OperationError ///< User operation error. + }; + + /** + * @breif Constructor. + * + * @param message Exception message. + * @param type Type of exception. + * + */ + Exception(const std::string & message, const eType type); + + /** + * @breif Get type of exception. + * + */ + eType Type() const; + + /** + * @breif Get message of exception. + * + */ + const char * Message() const; + + private: + eType m_Type; ///< Type of exception. +}; + +/** +* @breif Internal exception class. +* +* @see Exception +* +*/ +class InternalException : public Exception { + public: + /** + * @breif Constructor. + * + * @param message Exception message. + * + */ + explicit InternalException(const std::string & message); +}; + +/** +* @breif Parsing exception class. +* +* @see Exception +* +*/ +class ParsingException : public Exception { + public: + /** + * @breif Constructor. + * + * @param message Exception message. + * + */ + explicit ParsingException(const std::string & message); +}; + +/** +* @breif Operation exception class. +* +* @see Exception +* +*/ +class OperationException : public Exception { + public: + /** + * @breif Constructor. + * + * @param message Exception message. + * + */ + explicit OperationException(const std::string & message); +}; + +/** +* @breif Iterator class. +* +*/ +class Iterator { + public: + friend class Node; + + /** + * @breif Default constructor. + * + */ + Iterator(); + + /** + * @breif Copy constructor. + * + */ + Iterator(const Iterator & it); + + /** + * @breif Assignment operator. + * + */ + Iterator & operator = (const Iterator & it); + + /** + * @breif Destructor. + * + */ + ~Iterator(); + + /** + * @breif Get node of iterator. + * First pair item is the key of map value, empty if type is sequence. + * + */ + std::pair operator *(); + + /** + * @breif Post-increment operator. + * + */ + Iterator & operator++ (int); + + /** + * @breif Post-decrement operator. + * + */ + Iterator & operator-- (int); + + /** + * @breif Check if iterator is equal to other iterator. + * + */ + bool operator == (const Iterator & it); + + /** + * @breif Check if iterator is not equal to other iterator. + * + */ + bool operator != (const Iterator & it); + + private: + enum eType { + None, + SequenceType, + MapType + }; + + eType m_Type; // Type of iterator. + void * m_pImp; // Implementation of iterator class. +}; + +/** +* @breif Constant iterator class. +* +*/ +class ConstIterator { + public: + friend class Node; + + /** + * @breif Default constructor. + * + */ + ConstIterator(); + + /** + * @breif Copy constructor. + * + */ + ConstIterator(const ConstIterator & it); + + /** + * @breif Assignment operator. + * + */ + ConstIterator & operator = (const ConstIterator & it); + + /** + * @breif Destructor. + * + */ + ~ConstIterator(); + + /** + * @breif Get node of iterator. + * First pair item is the key of map value, empty if type is sequence. + * + */ + std::pair operator *(); + + /** + * @breif Post-increment operator. + * + */ + ConstIterator & operator++ (int); + + /** + * @breif Post-decrement operator. + * + */ + ConstIterator & operator-- (int); + + /** + * @breif Check if iterator is equal to other iterator. + * + */ + bool operator == (const ConstIterator & it); + + /** + * @breif Check if iterator is not equal to other iterator. + * + */ + bool operator != (const ConstIterator & it); + + private: + enum eType { + None, + SequenceType, + MapType + }; + + eType m_Type; // Type of iterator. + void * m_pImp; // Implementation of constant iterator class. +}; + +/** +* @breif Node class. +* +*/ +class Node { + public: + friend class Iterator; + + /** + * @breif Enumeration of node types. + * + */ + enum eType { + None, + SequenceType, + MapType, + ScalarType + }; + + /** + * @breif Default constructor. + * + */ + Node(); + + /** + * @breif Copy constructor. + * + */ + Node(const Node & node); + + /** + * @breif Assignment constructors. + * Converts node to scalar type if needed. + * + */ + explicit Node(const std::string & value); + explicit Node(const char * value); + + /** + * @breif Destructor. + * + */ + ~Node(); + + /** + * @breif Functions for checking type of node. + * + */ + eType Type() const; + bool IsNone() const; + bool IsSequence() const; + bool IsMap() const; + bool IsScalar() const; + + /** + * @breif Completely clear node. + * + */ + void Clear(); + + /** + * @breif Get node as given template type. + * + */ + template + T As() const { + return impl::StringConverter::Get(AsString()); + } + + /** + * @breif Get node as given template type. + * + */ + template + T As(const T & defaultValue) const { + return impl::StringConverter::Get(AsString(), defaultValue); + } + + /** + * @breif Get size of node. + * Nodes of type None or Scalar will return 0. + * + */ + size_t Size() const; + + // Sequence operators + + /** + * @breif Insert sequence item at given index. + * Converts node to sequence type if needed. + * Adding new item to end of sequence if index is larger than sequence size. + * + */ + Node & Insert(const size_t index); + + /** + * @breif Add new sequence index to back. + * Converts node to sequence type if needed. + * + */ + Node & PushFront(); + + /** + * @breif Add new sequence index to front. + * Converts node to sequence type if needed. + * + */ + Node & PushBack(); + + /** + * @breif Get sequence/map item. + * Converts node to sequence/map type if needed. + * + * @param index Sequence index. Returns None type Node if index is unknown. + * @param key Map key. Creates a new node if key is unknown. + * + */ + Node & operator[] (const size_t index); + Node & operator[] (const std::string & key); + + /** + * @breif Erase item. + * No action if node is not a sequence or map. + * + */ + void Erase(const size_t index); + void Erase(const std::string & key); + + /** + * @breif Assignment operators. + * + */ + Node & operator = (const Node & node); + Node & operator = (const std::string & value); + Node & operator = (const char * value); + + /** + * @breif Get start iterator. + * + */ + Iterator Begin(); + ConstIterator Begin() const; + + /** + * @breif Get end iterator. + * + */ + Iterator End(); + ConstIterator End() const; + + private: + /** + * @breif Get as string. If type is scalar, else empty. + * + */ + const std::string & AsString() const; + + void * m_pImp; // Implementation of node class. +}; + + +/** +* @breif Parsing functions. +* Population given root node with deserialized data. +* +* @param root Root node to populate. +* @param filename Path of input file. +* @param stream Input stream. +* @param string String of input data. +* @param buffer Char array of input data. +* @param size Buffer size. +* +* @throw InternalException An internal error occurred. +* @throw ParsingException Invalid input YAML data. +* @throw OperationException If filename or buffer pointer is invalid. +* +*/ +void Parse(Node & root, const char * filename); // NOLINT +void Parse(Node & root, std::iostream & stream); // NOLINT +void Parse(Node & root, const std::string & string); // NOLINT +void Parse(Node & root, const char * buffer, const size_t size); // NOLINT + + +/** +* @breif Serialization configuration structure, +* describing output behavior. +* +*/ +struct SerializeConfig { + /** + * @breif Constructor. + * + * @param spaceIndentation Number of spaces per indentation. + * @param scalarMaxLength Maximum length of scalars. Serialized as folder scalars if exceeded. + * Ignored if equal to 0. + * @param sequenceMapNewline Put maps on a new line if parent node is a sequence. + * @param mapScalarNewline Put scalars on a new line if parent node is a map. + * + */ + SerializeConfig(const size_t spaceIndentation = 2, + const size_t scalarMaxLength = 64, + const bool sequenceMapNewline = false, + const bool mapScalarNewline = false); + + size_t SpaceIndentation; // Number of spaces per indentation. + // Maximum length of scalars. Serialized as folder scalars if exceeded. + size_t ScalarMaxLength; + // Put maps on a new line if parent node is a sequence. + bool SequenceMapNewline; + // Put scalars on a new line if parent node is a map. + bool MapScalarNewline; +}; + + +/** +* @breif Serialization functions. +* +* @param root Root node to serialize. +* @param filename Path of output file. +* @param stream Output stream. +* @param string String of output data. +* @param config Serialization configurations. +* +* @throw InternalException An internal error occurred. +* @throw OperationException If filename or buffer pointer is invalid. +* If config is invalid. +* +*/ +void Serialize( + const Node & root, const char * filename, + const SerializeConfig & config = {2, 64, false, false}); +void Serialize( + const Node & root, std::iostream & stream, // NOLINT + const SerializeConfig & config = {2, 64, false, false}); +void Serialize( + const Node & root, std::string & string, // NOLINT + const SerializeConfig & config = {2, 64, false, false}); + +} // namespace Yaml diff --git a/runtime/core/utils/json.h b/runtime/core/utils/json.h index bf8d94a3e4..773bed3190 100644 --- a/runtime/core/utils/json.h +++ b/runtime/core/utils/json.h @@ -488,7 +488,8 @@ class JSON { Class Type = Class::Null; }; -JSON Array() { return std::move(JSON::Make(JSON::Class::Array)); } +// fix: https://github.com/nbsdx/SimpleJSON/issues/4 (veelion) +inline JSON Array() { return std::move(JSON::Make(JSON::Class::Array)); } template JSON Array(T... args) { @@ -497,9 +498,9 @@ JSON Array(T... args) { return std::move(arr); } -JSON Object() { return std::move(JSON::Make(JSON::Class::Object)); } +inline JSON Object() { return std::move(JSON::Make(JSON::Class::Object)); } -std::ostream& operator<<(std::ostream& os, const JSON& json) { +inline std::ostream& operator<<(std::ostream& os, const JSON& json) { os << json.dump(); return os; } @@ -744,7 +745,7 @@ JSON parse_next(const string& str, size_t& offset) { // NOLINT } } // namespace -JSON JSON::Load(const string& str) { +inline JSON JSON::Load(const string& str) { size_t offset = 0; return std::move(parse_next(str, offset)); } diff --git a/runtime/core/websocket/CMakeLists.txt b/runtime/core/websocket/CMakeLists.txt index 67447c42d9..4518667103 100644 --- a/runtime/core/websocket/CMakeLists.txt +++ b/runtime/core/websocket/CMakeLists.txt @@ -2,4 +2,4 @@ add_library(websocket STATIC websocket_client.cc websocket_server.cc ) -target_link_libraries(websocket PUBLIC decoder) +target_link_libraries(websocket PUBLIC decoder kaldifeat_core) diff --git a/runtime/core/websocket/batch_connection_handler.h b/runtime/core/websocket/batch_connection_handler.h new file mode 100644 index 0000000000..f6b6225a2c --- /dev/null +++ b/runtime/core/websocket/batch_connection_handler.h @@ -0,0 +1,217 @@ +// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang) +// 2022 SoundDataConverge Co.LTD (Weiliang Chong) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef WEBSOCKET_BATCH_CONNECTION_HANDLER_H_ +#define WEBSOCKET_BATCH_CONNECTION_HANDLER_H_ + +#include +#include +#include +#include +#include +#include + +#include "boost/asio/connect.hpp" +#include "boost/asio/ip/tcp.hpp" +#include "boost/beast/core.hpp" +#include "boost/beast/websocket.hpp" +#include "boost/json/src.hpp" + +#include "decoder/asr_decoder.h" +#include "decoder/batch_asr_decoder.h" +#include "frontend/feature_pipeline.h" +#include "utils/log.h" + +namespace wenet { + +namespace beast = boost::beast; // from +namespace http = beast::http; // from +namespace websocket = beast::websocket; // from +namespace asio = boost::asio; // from +using tcp = boost::asio::ip::tcp; // from +namespace json = boost::json; + +class BatchConnectionHandler { + public: + BatchConnectionHandler( + tcp::socket&& socket, + std::shared_ptr feature_config, + std::shared_ptr decode_config, + std::shared_ptr decode_resource) + : ws_(std::move(socket)), + feature_config_(std::move(feature_config)), + decode_config_(std::move(decode_config)), + decode_resource_(std::move(decode_resource)) {} + + void operator()() { + try { + // Accept the websocket handshake + ws_.accept(); + for (;;) { + // This buffer will hold the incoming message + beast::flat_buffer buffer; + // Read a message + ws_.read(buffer); + if (ws_.got_text()) { + std::string message = beast::buffers_to_string(buffer.data()); + LOG(INFO) << message; + OnText(message); + if (got_end_tag_) { + break; + } + } else { + if (!got_start_tag_) { + OnError("Start signal is expected before binary data"); + } else { + OnSpeechData(buffer); + break; + } + } + } + ws_.close(websocket::close_code::normal); + LOG(INFO) << "ws_ is closed, bye :)"; + } catch (beast::system_error const& se) { + LOG(INFO) << se.code().message(); + // This indicates that the session was closed + if (se.code() == websocket::error::closed) { + OnSpeechEnd(); + } + } catch (std::exception const& e) { + LOG(ERROR) << e.what(); + OnError("Decoder got some exception!"); + } + } + + private: + void OnSpeechStart() { + LOG(INFO) << "Received speech start signal, start reading speech"; + got_start_tag_ = true; + json::value rv = {{"status", "ok"}, {"type", "server_ready"}}; + ws_.text(true); + ws_.write(asio::buffer(json::serialize(rv))); + decoder_ = std::make_shared( + feature_config_, decode_resource_, + *decode_config_); + } + + void OnSpeechEnd() { + LOG(INFO) << "Received speech end signal"; + got_end_tag_ = true; + } + + void OnText(const std::string& message) { + json::value v = json::parse(message); + if (v.is_object()) { + json::object obj = v.get_object(); + if (obj.find("signal") != obj.end()) { + json::string signal = obj["signal"].as_string(); + if (signal == "start") { + if (obj.find("nbest") != obj.end()) { + if (obj["nbest"].is_int64()) { + nbest_ = obj["nbest"].as_int64(); + } else { + OnError("integer is expected for nbest option"); + } + } + if (obj.find("enable_timestamp") != obj.end()) { + if (obj["enable_timestamp"].is_bool()) { + enable_timestamp_ = obj["enable_timestamp"].as_bool(); + } else { + OnError( + "boolean true or false is expected for " + "enable_timestamp option"); + } + } + if (obj.find("batch_lens") != obj.end()) { + if (obj["batch_lens"].is_array()) { + batch_lens_.clear(); + auto& batch_lens = obj["batch_lens"].as_array(); + for (size_t i = 0; i < batch_lens.size(); i++) { + int len = batch_lens[i].as_int64(); + batch_lens_.push_back(len); + } + } else { + OnError("a list of batch_lens should be given"); + } + } + OnSpeechStart(); + } else if (signal == "end") { + OnSpeechEnd(); + } else { + OnError("Unexpected signal type"); + } + } else { + OnError("Wrong message header"); + } + } else { + OnError("Wrong protocol"); + } + } + + void OnFinish() { + // Send finish tag + json::value rv = {{"status", "ok"}, {"type", "speech_end"}}; + ws_.text(true); + ws_.write(asio::buffer(json::serialize(rv))); + } + + void OnSpeechData(const beast::flat_buffer& buffer) { + // Read binary PCM data + std::vector> wavs; + size_t total = std::accumulate(batch_lens_.begin(), batch_lens_.end(), 0); + VLOG(1) << "buffer size " << buffer.size() << ", batch_lens_ sum " << total; + CHECK(buffer.size() == total); + const auto* pcm_data = static_cast(buffer.data().data()); + int offset = 0; + for (int len : batch_lens_) { + len /= 2; // lenght of int16_t data + std::vector wav(len); + for (size_t i = 0; i < len; i++) { + wav[i] = static_cast(pcm_data[offset+i]); + } + wavs.push_back(std::move(wav)); + offset += len; + } + CHECK(decoder_ != nullptr); + decoder_->Decode(wavs); + std::string result = decoder_->get_batch_result(nbest_, enable_timestamp_); + ws_.text(true); + ws_.write(asio::buffer(result)); + } + + void OnError(const std::string& message) { + json::value rv = {{"status", "failed"}, {"message", message}}; + ws_.text(true); + ws_.write(asio::buffer(json::serialize(rv))); + // Close websocket + ws_.close(websocket::close_code::normal); + } + + int nbest_ = 1; + bool enable_timestamp_ = false; + std::vector batch_lens_; + websocket::stream ws_; + std::shared_ptr feature_config_; + std::shared_ptr decode_config_; + std::shared_ptr decode_resource_; + + bool got_start_tag_ = false; + bool got_end_tag_ = false; + std::shared_ptr decoder_ = nullptr; +}; + +} // namespace wenet + +#endif // WEBSOCKET_BATCH_CONNECTION_HANDLER_H_ diff --git a/runtime/core/websocket/websocket_server.cc b/runtime/core/websocket/websocket_server.cc index 52ab088f46..708d936237 100644 --- a/runtime/core/websocket/websocket_server.cc +++ b/runtime/core/websocket/websocket_server.cc @@ -1,4 +1,5 @@ // Copyright (c) 2020 Mobvoi Inc (Binbin Zhang) +// 2022 SoundDataConverge Co.LTD (Weiliang Chong) // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -18,6 +19,7 @@ #include #include +#include "websocket/batch_connection_handler.h" #include "boost/json/src.hpp" #include "utils/log.h" @@ -244,7 +246,7 @@ void ConnectionHandler::operator()() { } } -void WebSocketServer::Start() { +void WebSocketServer::Start(bool run_batch) { try { auto const address = asio::ip::make_address("0.0.0.0"); tcp::acceptor acceptor{ioc_, {address, static_cast(port_)}}; @@ -254,10 +256,17 @@ void WebSocketServer::Start() { // Block until we get a connection acceptor.accept(socket); // Launch the session, transferring ownership of the socket - ConnectionHandler handler(std::move(socket), feature_config_, - decode_config_, decode_resource_); - std::thread t(std::move(handler)); - t.detach(); + if (run_batch) { + BatchConnectionHandler handler(std::move(socket), feature_config_, + decode_config_, decode_resource_); + std::thread t(std::move(handler)); + t.detach(); + } else { + ConnectionHandler handler(std::move(socket), feature_config_, + decode_config_, decode_resource_); + std::thread t(std::move(handler)); + t.detach(); + } } } catch (const std::exception& e) { LOG(FATAL) << e.what(); diff --git a/runtime/core/websocket/websocket_server.h b/runtime/core/websocket/websocket_server.h index a124183422..5714f32512 100644 --- a/runtime/core/websocket/websocket_server.h +++ b/runtime/core/websocket/websocket_server.h @@ -85,7 +85,7 @@ class WebSocketServer { decode_config_(std::move(decode_config)), decode_resource_(std::move(decode_resource)) {} - void Start(); + void Start(bool run_batch = false); private: int port_; diff --git a/runtime/libtorch/CMakeLists.txt b/runtime/libtorch/CMakeLists.txt index a02f37ac7d..f2dee20235 100644 --- a/runtime/libtorch/CMakeLists.txt +++ b/runtime/libtorch/CMakeLists.txt @@ -33,6 +33,8 @@ endif() # Include all dependency if(TORCH) + include(kaldifeat) + include(FetchContent) # use wenet's, disable kaldifeat's custom: cmake/Modules/FetchContent include(libtorch) endif() if(ONNX) diff --git a/runtime/libtorch/README.md b/runtime/libtorch/README.md index bccfcc4fe1..0e97cd92e8 100644 --- a/runtime/libtorch/README.md +++ b/runtime/libtorch/README.md @@ -131,6 +131,36 @@ Here is a demo for command line based websocket server/client interaction. ![Runtime server demo](../../../docs/images/runtime_server.gif) +#### run_batch (offline) mode on GPU + +When start Websocket server with the option `--run_batch`, it will work on `run_batch` mode which accept a batch of wav data (batch_size >= 1). The encoding and decoding use the advantage of GPU batch processing to improve speed. + +This mode support both of libtorch and onnxruntime, but libtorch performs better due to some GPU memory issue of onnxruntime. + +Test result: + +* hardware-1: + Platinum 8358P CPU @ 2.60GHz 15 cores + 80G memory, A5000 * 1 + 24G memory + +* hardware-2: + Platinum 8369B CPU @ 2.90GHz 32 cores + 120GB memory, A100-SXM4-80GB * 1 + 80GB memory + +* data: + 3000 wavs with different durations in range [0.6, 15] seconds. + +| hardware | websocket_server | concurrency | batch_size | RTF | CER | +| --- | --- | --- | --- | --- | --- | +| hardware-1 | libtorch(CPU) | 30 | 1 | 0.01666 | 8.90 | +| hardware-1 | libtorch(GPU) | 10 | 1 | 0.00831 | 8.90 | +| hardware-1 | libtorch(GPU+batch) | 20 | 8 | 0.00339 | 9.61 | +| hardware-2 | libtorch(CPU) | 48 | 1 | 0.00753 | 8.90 | +| hardware-2 | libtorch(GPU) | 48 | 1 | 0.00234 | 8.90 | +| hardware-2 | libtorch(GPU+batch) | 48 | 8 | 0.00110 | 9.61 | + +With same CPU, GPU is 2~3 times faster than CPU, run_batch is 2.x times faster than non run_batch mode, but the CER has a little bigger. + + + ### gRPC Why grpc? You may find your answer in https://grpc.io/. diff --git a/runtime/libtorch/README_CN.md b/runtime/libtorch/README_CN.md index ee74968bd3..097c361d5a 100644 --- a/runtime/libtorch/README_CN.md +++ b/runtime/libtorch/README_CN.md @@ -110,6 +110,35 @@ model_dir=./20210602_unified_transformer_server 上述服务启动后,会监听 10086 端口。若想使用其他端口,请修改 `--port` 对应的参数. +#### run_batch (非流式) 模式运行在GPU上 + +启动 Websocket server 时添加 `--run_batch`,既可以开启 `run_batch` 模式,它的输入是一批wav数据(batch_size >= 1),模型在编码和解码阶段都可以利用GPU的批处理能力,从而提高推理速度。 + +该模式在 libtorch 和 onnxruntime 库上都已经实现,但是libtoch的表现更好(更大的并发性能),因为 onnxruntime 目前没有办法清除显存缓存而导致并发较大时显存不足。 + +测试结果: + +* hardware-1: + Platinum 8358P CPU @ 2.60GHz 15 cores + 80G memory, A5000 * 1 + 24G memory + +* hardware-2: + Platinum 8369B CPU @ 2.90GHz 32 cores + 120GB memory, A100-SXM4-80GB * 1 + 80GB memory + +* data: + 3000 wavs with different durations in range [0.6, 15] seconds. + +| hardware | websocket_server | concurrency | batch_size | RTF | CER | +| --- | --- | --- | --- | --- | --- | +| hardware-1 | libtorch(CPU) | 30 | 1 | 0.01666 | 8.90 | +| hardware-1 | libtorch(GPU) | 10 | 1 | 0.00831 | 8.90 | +| hardware-1 | libtorch(GPU+batch) | 20 | 8 | 0.00339 | 9.61 | +| hardware-2 | libtorch(CPU) | 48 | 1 | 0.00753 | 8.90 | +| hardware-2 | libtorch(GPU) | 48 | 1 | 0.00234 | 8.90 | +| hardware-2 | libtorch(GPU+batch) | 48 | 8 | 0.00110 | 9.61 | + +可以看出,同样的CPU下,GPU(batch_size == 1) 是 CPU 速度的 2-3 倍, 而 run_batch 速度又是 GPU(batch_size==1) 的 2.x 倍,但是CER有所提高。 + + ### websocket 识别客户端 客户端按 websocket 协议去请求服务,可以用不同语言来实现客户端。我们提供了两种客户端,一种是基于 C++ 的命令行工具。一种是基于网页形式的可视化客户端。 diff --git a/tools/websocket/performance-ws.py b/tools/websocket/performance-ws.py index af77dea06b..083ac59bf9 100755 --- a/tools/websocket/performance-ws.py +++ b/tools/websocket/performance-ws.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # coding:utf-8 -# Copyright (c) 2022 SDCI Co. Ltd (author: veelion) +# Copyright (c) 2022 SoundDataConverge Co.LTD (Weiliang Chong) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/wenet/bin/export_onnx_gpu.py b/wenet/bin/export_onnx_gpu.py index 14f107d5ed..de2beb88c5 100644 --- a/wenet/bin/export_onnx_gpu.py +++ b/wenet/bin/export_onnx_gpu.py @@ -262,7 +262,7 @@ def forward(self, score = torch.sum(score, axis=1) # B2 score = torch.reshape(score, (B, bz)) + self.ctc_weight * ctc_score best_index = torch.argmax(score, dim=1) - return best_index + return best_index, score def to_numpy(tensors): @@ -332,10 +332,15 @@ def export_offline_encoder(model, configs, args, logger, encoder_onnx_path): # check encoder output test(to_numpy([o0, o1, o2, o3, o4]), ort_outs) + is_bidirectional_decoder = 1 if configs['decoder'] == 'bitransformer' else 0 logger.info("export offline onnx encoder succeed!") + reverse_weight = configs['model_conf'].get('reverse_weight', 0) onnx_config = {"beam_size": args.beam_size, - "reverse_weight": args.reverse_weight, - "ctc_weight": args.ctc_weight, + "reverse_weight": reverse_weight, + "ctc_weight": configs['model_conf']['ctc_weight'], + "sos": configs["output_dim"] - 1, + "eos": configs["output_dim"] - 1, + "is_bidirectional_decoder": is_bidirectional_decoder, "fp16": args.fp16} return onnx_config @@ -500,7 +505,7 @@ def export_rescoring_decoder(model, configs, args, logger, decoder_onnx_path): ort_outs = ort_session.run(None, ort_inputs) # check decoder output - test(to_numpy([o0]), ort_outs, rtol=1e-03, atol=1e-05) + test(to_numpy(o0), ort_outs, rtol=1e-03, atol=1e-05) logger.info("export to onnx decoder succeed!") diff --git a/wenet/transformer/asr_model.py b/wenet/transformer/asr_model.py index 367c9189a7..e3dbbcfba3 100644 --- a/wenet/transformer/asr_model.py +++ b/wenet/transformer/asr_model.py @@ -896,3 +896,97 @@ def forward_attention_decoder( # r_dccoder_out will be 0.0, if reverse_weight is 0.0 r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out, dim=-1) return decoder_out, r_decoder_out + + @torch.jit.export + def batch_forward_encoder( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + beam_size: int = 10, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ Export interface for c++ call, encode a batch of speech + + Args: + speech: padded input tensor (B, T, F) + speech_lengths: input length (B) + Returns: + encoder_out: B x T x F + encoder_out_lens: B + ctc_log_probs: B x T x V + beam_log_probs: B x T x beam_size + beam_log_probs_idx: B x T x beam_size + """ + encoder_out, encoder_mask = self.encoder( + speech, speech_lengths, -1, -1) + encoder_out_lens = encoder_mask.squeeze(1).sum(1) + encoder_out_lens = encoder_out_lens.int() + ctc_log_probs = self.ctc.log_softmax(encoder_out) + beam_log_probs, beam_log_probs_idx = torch.topk( + ctc_log_probs, beam_size, dim=2) + return encoder_out, encoder_out_lens, ctc_log_probs, \ + beam_log_probs, beam_log_probs_idx + + @torch.jit.export + def batch_forward_attention_decoder( + self, + encoder_out: torch.Tensor, + encoder_lens: torch.Tensor, + hyps_pad_sos_eos: torch.Tensor, + hyps_lens_sos: torch.Tensor, + r_hyps_pad_sos_eos: torch.Tensor, + ctc_score: torch.Tensor): + """Decoder + Args: + encoder_out: B x T x F + encoder_lens: B + hyps_pad_sos_eos: B x beam x (T2+1), + hyps with sos & eos and padded by ignore id + hyps_lens_sos: B x beam, length for each hyp with sos + r_hyps_pad_sos_eos: B x beam x (T2+1), + reversed hyps with sos & eos and padded by ignore id + ctc_score: B x beam, ctc score for each hyp + Returns: + best_index: B + score: B x beam + """ + B, T, F = encoder_out.shape + bz = hyps_pad_sos_eos.shape[1] # beam_size + B2 = B * bz + encoder_out = encoder_out.repeat(1, bz, 1).view(B2, T, F) + encoder_mask = ~make_pad_mask(encoder_lens, T).unsqueeze(1) + encoder_mask = encoder_mask.repeat(1, bz, 1).view(B2, 1, T) + T2 = hyps_pad_sos_eos.shape[2] - 1 + hyps_pad = hyps_pad_sos_eos.view(B2, T2 + 1) + hyps_lens = hyps_lens_sos.view(B2,) + hyps_pad_sos = hyps_pad[:, :-1].contiguous() + hyps_pad_eos = hyps_pad[:, 1:].contiguous() + + r_hyps_pad = r_hyps_pad_sos_eos.view(B2, T2 + 1) + r_hyps_pad_sos = r_hyps_pad[:, :-1].contiguous() + r_hyps_pad_eos = r_hyps_pad[:, 1:].contiguous() + + decoder_out, r_decoder_out, _ = self.decoder( + encoder_out, encoder_mask, hyps_pad_sos, hyps_lens, r_hyps_pad_sos, + self.reverse_weight) + decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1) + V = decoder_out.shape[-1] + decoder_out = decoder_out.view(B2, T2, V) + mask = ~make_pad_mask(hyps_lens, T2) # B2 x T2 + # mask index, remove ignore id + index = torch.unsqueeze(hyps_pad_eos * mask, 2) + score = decoder_out.gather(2, index).squeeze(2) # B2 X T2 + # mask padded part + score = score * mask + decoder_out = decoder_out.view(B, bz, T2, V) + if self.reverse_weight > 0: + r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out, dim=-1) + r_decoder_out = r_decoder_out.view(B2, T2, V) + index = torch.unsqueeze(r_hyps_pad_eos * mask, 2) + r_score = r_decoder_out.gather(2, index).squeeze(2) + r_score = r_score * mask + score = score * (1 - self.reverse_weight) + self.reverse_weight * r_score + r_decoder_out = r_decoder_out.view(B, bz, T2, V) + score = torch.sum(score, dim=1) # B2 + score = torch.reshape(score, (B, bz)) + self.ctc_weight * ctc_score + best_index = torch.argmax(score, dim=1) + return best_index, score