diff --git a/frameworks/libs/distributeddb/sqlite_adapter/src/tokenizer_export_type.h b/frameworks/libs/distributeddb/sqlite_adapter/src/tokenizer_export_type.h index 8a7a96c1a77a4ee6804969e7bee43fa9b92cf7b1..1460cbadeee4bf32c4d5a243e2fcc918b675c11a 100644 --- a/frameworks/libs/distributeddb/sqlite_adapter/src/tokenizer_export_type.h +++ b/frameworks/libs/distributeddb/sqlite_adapter/src/tokenizer_export_type.h @@ -46,6 +46,7 @@ typedef enum GRD_CutMode { typedef enum GRD_CutScene { DEFAULT = 0, SEARCH = 1, + MULTI_WORDS = 2, SCENE_BUTT // INVALID mode } GRD_CutSceneE; diff --git a/frameworks/libs/distributeddb/sqlite_adapter/src/tokenizer_sqlite.cpp b/frameworks/libs/distributeddb/sqlite_adapter/src/tokenizer_sqlite.cpp index e0490e44333c9cbfdc829c406a3f291d0886e924..46c6bf48604506ae16ec21ae6a50b57e9373c60d 100644 --- a/frameworks/libs/distributeddb/sqlite_adapter/src/tokenizer_sqlite.cpp +++ b/frameworks/libs/distributeddb/sqlite_adapter/src/tokenizer_sqlite.cpp @@ -18,6 +18,7 @@ #include #include #include +#include #include "tokenizer_api.h" #include "tokenizer_export_type.h" @@ -38,20 +39,20 @@ static uint32_t g_refCount = 0; constexpr int FTS5_MAX_VERSION = 2; constexpr int MAGIC_CODE = 0x12345678; constexpr const char *CUT_SCENE_PARAM_NAME = "cut_mode"; -constexpr const char *CUT_SCENE_SHORT_WORDS = "short_words"; -constexpr const char *CUT_SCENE_DEFAULT = "default"; constexpr const char *CUT_CASE_SENSITIVE = "case_sensitive"; +static std::unordered_map g_cutModeMap = { + {"default", DEFAULT}, + {"short_words", SEARCH}, + {"multi_words", MULTI_WORDS} +}; int AnalyzeCutMode(std::string &value, Fts5TokenizerParamT *para) { - if (value == CUT_SCENE_SHORT_WORDS) { - para->cutScene = SEARCH; - } else if (value == CUT_SCENE_DEFAULT) { - para->cutScene = DEFAULT; - } else { + if (g_cutModeMap.find(value) == g_cutModeMap.end()) { sqlite3_log(SQLITE_ERROR, "invalid arg value of cut scene"); return SQLITE_ERROR; } + para->cutScene = g_cutModeMap[value]; return SQLITE_OK; } @@ -151,6 +152,34 @@ static char *CpyStr(const char *pText, int nText) return ptr; } +static int IterateCutResults(void *pCtx, int nText, XTokenFn xToken, GRD_CutScene cutScene, GRD_WordEntryListT *list) +{ + GRD_WordEntryT entry; + int start = 0; // 词在句子中的起始位置 + int end = 0; // 词在句子中的结束位置 + int startBefore = -1; + int ret = 0; + while ((ret = GRD_TokenizerNext(list, &entry)) == GRD_OK) { + start = static_cast(entry.offset); + end = start + static_cast(entry.length); + if (end > nText || start < 0) { + sqlite3_log(SQLITE_ERROR, "|fts5 custom tokenizer xTokenize| offset wrong"); + return SQLITE_ERROR; + } + int tokenFlag = 0; + if (start == startBefore && cutScene == MULTI_WORDS) { + tokenFlag = FTS5_TOKEN_COLOCATED; + } + startBefore = start; + ret = xToken(pCtx, tokenFlag, entry.word, entry.length, start, end); + if (ret != SQLITE_OK) { + sqlite3_log(ret, "xToken wrong"); + return ret; + } + } + return ret; +} + int fts5_customtokenizer_xTokenize( Fts5Tokenizer *tokenizer_ptr, void *pCtx, int flags, const char *pText, int nText, XTokenFn xToken) { @@ -175,26 +204,11 @@ int fts5_customtokenizer_xTokenize( free(ptr); return ret; } - GRD_WordEntryT entry; - int start = 0; // 词在句子中的起始位置 - int end = 0; // 词在句子中的结束位置 - while ((ret = GRD_TokenizerNext(entryList, &entry)) == GRD_OK) { - start = static_cast(entry.offset); - end = start + static_cast(entry.length); - if (end > nText || start < 0) { - ret = SQLITE_ERROR; - sqlite3_log(SQLITE_ERROR, "|fts5 custom tokenizer xTokenize| offset wrong"); - break; - } - ret = xToken(pCtx, 0, entry.word, entry.length, start, end); - if (ret != SQLITE_OK) { - sqlite3_log(ret, "xToken wrong"); - break; - } - } + ret = IterateCutResults(pCtx, nText, xToken, pFts5TokenizerParam->cutScene, entryList); GRD_TokenizerFreeWordEntryList(entryList); free(ptr); if (ret != GRD_OK && ret != GRD_NO_DATA) { + sqlite3_log(ret, "Iterate Cut Results wrong"); return ret; } return SQLITE_OK; diff --git a/frameworks/libs/distributeddb/test/unittest/common/tokenizer/sqlite_adapter_test.cpp b/frameworks/libs/distributeddb/test/unittest/common/tokenizer/sqlite_adapter_test.cpp index c9fd3f96a64b320dbce55668eb21dcfed5ba46a8..ceea66a81d5fe79d2827dc6d202aafd808d9395a 100644 --- a/frameworks/libs/distributeddb/test/unittest/common/tokenizer/sqlite_adapter_test.cpp +++ b/frameworks/libs/distributeddb/test/unittest/common/tokenizer/sqlite_adapter_test.cpp @@ -107,6 +107,34 @@ static void SQLTest(const char *sql) } } +static int HighlightCallback(void *data, int argc, char **argv, char **colName) +{ + const int expectedArgc = 2; + if (argc != expectedArgc) { + return SQLITE_ERROR; + } + string result = argv[1]; + string expectResult = reinterpret_cast(data); + EXPECT_EQ(expectResult, result); + return SQLITE_OK; +} + +static void HighlightTest(const char *sql, char *expectResult) +{ + if (g_needSkip) { + return; + } + int errCode = sqlite3_exec(g_sqliteDb, sql, HighlightCallback, reinterpret_cast(expectResult), &g_errMsg); + if (errCode != SQLITE_OK) { + if (g_errMsg != nullptr) { + LOGE("SQL error: %s\n", g_errMsg); + sqlite3_free(g_errMsg); + g_errMsg = nullptr; + ASSERT_TRUE(false); + } + } +} + /** * @tc.name: SqliteAdapterTest001 * @tc.desc: Get blob size over limit @@ -468,7 +496,8 @@ HWTEST_F(SqliteAdapterTest, SqliteAdapterTest009, TestSize.Level0) const char *SQLQUERY1 = "SELECT name, highlight(example, 1, '【', '】') as highlighted_content " "FROM example WHERE example MATCH '测试';"; - SQLTest(SQLQUERY1); + char expectResult[] = "这是一个【测试】文档,用于【测试】中文文本的分词和索引。"; + HighlightTest(SQLQUERY1, expectResult); EXPECT_EQ(sqlite3_close(g_sqliteDb), SQLITE_OK); } @@ -519,7 +548,8 @@ HWTEST_F(SqliteAdapterTest, SqliteAdapterTest010, TestSize.Level0) const char *SQLQUERY1 = "SELECT name, highlight(example, 1, '【', '】') as highlighted_content FROM" " example WHERE example MATCH '测试';"; - SQLTest(SQLQUERY1); + char expectResult[] = "这是一个【测试】文档,用于【测试】中文文本的分词和索引。"; + HighlightTest(SQLQUERY1, expectResult); EXPECT_EQ(sqlite3_close(g_sqliteDb), SQLITE_OK); } @@ -564,15 +594,86 @@ HWTEST_F(SqliteAdapterTest, SqliteAdapterTest011, TestSize.Level0) const char *SQLQUERY1 = "SELECT name, highlight(example, 1, '【', '】') as highlighted_content FROM" " example WHERE example MATCH '\"C++\"';"; - SQLTest(SQLQUERY1); + char expectResult1[] = "\"C语言设计【c++】C语言设计X射线哆啦A梦qqq号250G硬盘usb接口k歌【C++】卡拉ok卡拉OK\""; + HighlightTest(SQLQUERY1, expectResult1); const char *SQLQUERY2 = "SELECT name, highlight(example, 1, '【', '】') as highlighted_content FROM" " example WHERE example MATCH '\"卡拉OK\"';"; - SQLTest(SQLQUERY2); + char expectResult2[] = "\"C语言设计c++C语言设计X射线哆啦A梦qqq号250G硬盘usb接口k歌C++【卡拉ok卡拉OK】\""; + HighlightTest(SQLQUERY2, expectResult2); const char *SQLQUERY3 = "SELECT name, highlight(example, 1, '【', '】') as highlighted_content FROM" " example WHERE example MATCH '\"qq号\"';"; - SQLTest(SQLQUERY3); + char expectResult3[] = "\"C语言设计c++C语言设计X射线哆啦A梦q【qq号】250G硬盘usb接口k歌C++卡拉ok卡拉OK\""; + HighlightTest(SQLQUERY3, expectResult3); EXPECT_EQ(sqlite3_close(g_sqliteDb), SQLITE_OK); } + +/** + * @tc.name: SqliteAdapterTest012 + * @tc.desc: Test cut_mode multi_words + * @tc.type: FUNC + * @tc.require: + * @tc.author: lxl + */ +HWTEST_F(SqliteAdapterTest, SqliteAdapterTest012, TestSize.Level0) +{ + /** + * @tc.steps: step1. prepare db + * @tc.expected: step1. OK. + */ + // Save any error messages + char *zErrMsg = nullptr; + + // Save the connection result + int rc = sqlite3_open_v2(g_dbPath, &g_sqliteDb, SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE, nullptr); + HandleRc(g_sqliteDb, rc); + + rc = sqlite3_db_config(g_sqliteDb, SQLITE_DBCONFIG_ENABLE_LOAD_EXTENSION, 1, nullptr); + HandleRc(g_sqliteDb, rc); + + rc = sqlite3_load_extension(g_sqliteDb, "libcustomtokenizer.z.so", nullptr, nullptr); + HandleRc(g_sqliteDb, rc); + /** + * @tc.steps: step2. create table + * @tc.expected: step2. OK. + */ + string sql = "CREATE VIRTUAL TABLE example USING fts5(content, tokenize = 'customtokenizer cut_mode multi_words')"; + rc = sqlite3_exec(g_sqliteDb, sql.c_str(), Callback, 0, &zErrMsg); + HandleRc(g_sqliteDb, rc); + /** + * @tc.steps: step3. insert records + * @tc.expected: step3. OK. + */ + std::vector records = { + "传承认知心理学", + "中国社会科学院", + "123...321 hello--?+--world生物学家,是研究生物学的专家!HelLo wOrLD!!!" + }; + for (const auto &record : records) { + std::string insertSql = "insert into example values('" + record + "');"; + SQLTest(insertSql.c_str()); + } + /** + * @tc.steps: step4. test cut for short words + * @tc.expected: step4. OK. + */ + std::vector> expectResult = { + {"传承", 1}, {"认知", 1}, {"心理", 1}, {"心理学", 1}, {"认知心理学", 1}, {"传承认知心理学", 1}, {"承认", 0}, + {"中国", 1}, {"社会", 1}, {"社会科学", 1}, {"社会科学院", 1}, {"科学", 1}, {"科学院", 1}, {"学院", 1}, + {"中国社会科学院", 1}, {"生物", 1}, {"生物学", 1}, {"生物学家", 1}, {"专家", 1} + }; + // 蓝区没有so导致失败,直接跳过测试 + if (!g_needSkip) { + for (const auto &[word, expectMatchNum] : expectResult) { + std::string querySql = "SELECT count(*) FROM example WHERE content MATCH '" + word + "';"; + EXPECT_EQ(sqlite3_exec(g_sqliteDb, querySql.c_str(), QueryCallback, + reinterpret_cast(expectMatchNum), nullptr), SQLITE_OK); + } + } + + const char *SQLDROP = "DROP TABLE IF EXISTS example;"; + SQLTest(SQLDROP); + EXPECT_EQ(sqlite3_close(g_sqliteDb), SQLITE_OK); +}