From 83b5927e6797204ae831c4e2281ec728ecc42722 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B0=8F=E5=A5=B6=E5=A5=B6=E8=8A=B1=E7=94=9F=E7=B1=B3?= <279094122@qq.com> Date: Tue, 3 Feb 2026 22:27:38 +0800 Subject: [PATCH] =?UTF-8?q?feat(weaviate):=20=E6=B7=BB=E5=8A=A0=20Weaviate?= =?UTF-8?q?=20=E7=9F=A2=E9=87=8F=E5=AD=98=E5=82=A8=E7=9F=A5=E8=AF=86?= =?UTF-8?q?=E5=BA=93=E5=AE=9E=E7=8E=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 实现 Weaviate 矢量存储知识库的核心功能,包括: - 文档的存储、检索和删除 - 支持元数据字段定义和过滤查询 - 集成外部 Embedding 模型生成向量 - 提供 GraphQL 查询接口封装 - 添加完整的单元测试和集成测试 --- .../solon-ai-repo-weaviate/docker-compose.yml | 22 + .../solon-ai-repo-weaviate/pom.xml | 37 ++ .../ai/rag/repository/WeaviateRepository.java | 611 ++++++++++++++++++ .../rag/repository/weaviate/Additional.java | 37 ++ .../ai/rag/repository/weaviate/ClassInfo.java | 22 + .../ai/rag/repository/weaviate/Data.java | 22 + .../rag/repository/weaviate/DocumentData.java | 60 ++ .../ai/rag/repository/weaviate/FieldType.java | 14 + .../weaviate/FilterTransformer.java | 284 ++++++++ .../repository/weaviate/GraphQLResponse.java | 34 + .../repository/weaviate/MetadataField.java | 41 ++ .../repository/weaviate/SchemaResponse.java | 33 + .../repository/weaviate/WeaviateClient.java | 201 ++++++ .../repo/weaviate/WeaviateRepositoryTest.java | 340 ++++++++++ 14 files changed, 1758 insertions(+) create mode 100644 solon-ai-rag-repositorys/solon-ai-repo-weaviate/docker-compose.yml create mode 100644 solon-ai-rag-repositorys/solon-ai-repo-weaviate/pom.xml create mode 100644 solon-ai-rag-repositorys/solon-ai-repo-weaviate/src/main/java/org/noear/solon/ai/rag/repository/WeaviateRepository.java create mode 100644 solon-ai-rag-repositorys/solon-ai-repo-weaviate/src/main/java/org/noear/solon/ai/rag/repository/weaviate/Additional.java create mode 100644 solon-ai-rag-repositorys/solon-ai-repo-weaviate/src/main/java/org/noear/solon/ai/rag/repository/weaviate/ClassInfo.java create mode 100644 solon-ai-rag-repositorys/solon-ai-repo-weaviate/src/main/java/org/noear/solon/ai/rag/repository/weaviate/Data.java create mode 100644 solon-ai-rag-repositorys/solon-ai-repo-weaviate/src/main/java/org/noear/solon/ai/rag/repository/weaviate/DocumentData.java create mode 100644 solon-ai-rag-repositorys/solon-ai-repo-weaviate/src/main/java/org/noear/solon/ai/rag/repository/weaviate/FieldType.java create mode 100644 solon-ai-rag-repositorys/solon-ai-repo-weaviate/src/main/java/org/noear/solon/ai/rag/repository/weaviate/FilterTransformer.java create mode 100644 solon-ai-rag-repositorys/solon-ai-repo-weaviate/src/main/java/org/noear/solon/ai/rag/repository/weaviate/GraphQLResponse.java create mode 100644 solon-ai-rag-repositorys/solon-ai-repo-weaviate/src/main/java/org/noear/solon/ai/rag/repository/weaviate/MetadataField.java create mode 100644 solon-ai-rag-repositorys/solon-ai-repo-weaviate/src/main/java/org/noear/solon/ai/rag/repository/weaviate/SchemaResponse.java create mode 100644 solon-ai-rag-repositorys/solon-ai-repo-weaviate/src/main/java/org/noear/solon/ai/rag/repository/weaviate/WeaviateClient.java create mode 100644 solon-ai-rag-repositorys/solon-ai-repo-weaviate/src/test/java/features/ai/repo/weaviate/WeaviateRepositoryTest.java diff --git a/solon-ai-rag-repositorys/solon-ai-repo-weaviate/docker-compose.yml b/solon-ai-rag-repositorys/solon-ai-repo-weaviate/docker-compose.yml new file mode 100644 index 00000000..696a2b69 --- /dev/null +++ b/solon-ai-rag-repositorys/solon-ai-repo-weaviate/docker-compose.yml @@ -0,0 +1,22 @@ +version: '3.5' + +networks: + net: + driver: bridge +services: + weaviate: + image: semitechnologies/weaviate:1.35.6 + volumes: + - ./weaviate:/var/lib/weaviate + environment: + - QUERY_DEFAULTS_LIMIT=25 + - AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED=true + - PERSISTENCE_DATA_PATH=/var/lib/weaviate + - DEFAULT_VECTORIZER_MODULE=none + - ENABLE_MODULES=text2vec-openai,generative-openai + - CLUSTER_HOSTNAME=node1 + ports: + - 8080:8080 + - 50051:50051 + networks: + - net diff --git a/solon-ai-rag-repositorys/solon-ai-repo-weaviate/pom.xml b/solon-ai-rag-repositorys/solon-ai-repo-weaviate/pom.xml new file mode 100644 index 00000000..cad62b82 --- /dev/null +++ b/solon-ai-rag-repositorys/solon-ai-repo-weaviate/pom.xml @@ -0,0 +1,37 @@ + + + 4.0.0 + + + org.noear + solon-ai-parent + 3.9.1-SNAPSHOT + ../../solon-ai-parent/pom.xml + + + solon-ai-repo-weaviate + ${project.artifactId} + jar + + + + org.noear + solon-ai + + + + org.noear + solon-logging-simple + test + + + + org.noear + solon-test + test + + + + diff --git a/solon-ai-rag-repositorys/solon-ai-repo-weaviate/src/main/java/org/noear/solon/ai/rag/repository/WeaviateRepository.java b/solon-ai-rag-repositorys/solon-ai-repo-weaviate/src/main/java/org/noear/solon/ai/rag/repository/WeaviateRepository.java new file mode 100644 index 00000000..8463cb68 --- /dev/null +++ b/solon-ai-rag-repositorys/solon-ai-repo-weaviate/src/main/java/org/noear/solon/ai/rag/repository/WeaviateRepository.java @@ -0,0 +1,611 @@ + +package org.noear.solon.ai.rag.repository; + +import org.noear.solon.Utils; +import org.noear.solon.ai.embedding.EmbeddingModel; +import org.noear.solon.ai.rag.Document; +import org.noear.solon.ai.rag.RepositoryLifecycle; +import org.noear.solon.ai.rag.RepositoryStorable; +import org.noear.solon.ai.rag.repository.weaviate.*; +import org.noear.solon.ai.rag.repository.weaviate.MetadataField; +import org.noear.solon.ai.rag.util.ListUtil; +import org.noear.solon.ai.rag.util.QueryCondition; +import org.noear.solon.ai.rag.util.SimilarityUtil; +import org.noear.solon.lang.Preview; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.BiConsumer; + +/** + * Weaviate 矢量存储知识库 + * 基于 Weaviate REST / GraphQL 接口 + * + * 说明: + * 1. 使用外部 EmbeddingModel 生成向量,通过 nearVector GraphQL 查询实现向量搜索; + * 2. 文档内容、元数据以属性形式写入 Weaviate,对应属性名为: + * - content: 文本内容 + * - url: 文档原始地址(如果有) + * - 其他 metadata: 直接展开为属性; + * 3. Document.id 对应 Weaviate 对象的 uuid(使用我们生成的 uuid,传给 Weaviate)。 + * + * @author 小奶奶花生米 + * @since 3.1 + */ +@Preview("3.1") +public class WeaviateRepository implements RepositoryStorable, RepositoryLifecycle { + + private static final String DEFAULT_COLLECTION_NAME = "solon_ai"; + + /** + * 基础配置 + */ + private final Builder config; + private final WeaviateClient client; + + private WeaviateRepository(Builder config) { + this.config = config; + if (config.token != null) { + this.client = new WeaviateClient(config.baseUrl, config.token); + } else if (config.username != null && config.password != null) { + this.client = new WeaviateClient(config.baseUrl, config.username, config.password); + } else { + this.client = new WeaviateClient(config.baseUrl); + } + } + + /** + * 初始化仓库:创建 collection(class)Schema(如果不存在) + */ + @Override + public void initRepository() throws Exception { + String className = getClassName(); + + // 检查 class 是否存在,不存在则创建 + SchemaResponse schemaResponse = client.getSchema(); + if (!schemaResponse.hasClass(className)) { + List> properties = new ArrayList<>(); + + Map contentProp = new HashMap<>(); + contentProp.put("name", "content"); + contentProp.put("dataType", new String[]{"text"}); + properties.add(contentProp); + + Map urlProp = new HashMap<>(); + urlProp.put("name", "url"); + urlProp.put("dataType", new String[]{"string"}); + properties.add(urlProp); + + // 添加元数据字段 + if (config.metadataFields != null && !config.metadataFields.isEmpty()) { + for (MetadataField field : config.metadataFields) { + Map metadataProp = new HashMap<>(); + metadataProp.put("name", field.getName()); + // 根据字段类型设置 dataType + switch (field.getFieldType()) { + case STRING: + metadataProp.put("dataType", new String[]{"string"}); + break; + case INTEGER: + metadataProp.put("dataType", new String[]{"int"}); + break; + case FLOAT: + metadataProp.put("dataType", new String[]{"number"}); + break; + case BOOLEAN: + metadataProp.put("dataType", new String[]{"boolean"}); + break; + default: + metadataProp.put("dataType", new String[]{"string"}); + break; + } + properties.add(metadataProp); + } + } + + client.createClass(className, properties); + } + } + + /** + * 删除整个集合(删除 class) + */ + @Override + public void dropRepository() throws Exception { + String className = getClassName(); + client.deleteClass(className); + } + + private String getClassName() { + String className = (config.collectionName != null ? config.collectionName : DEFAULT_COLLECTION_NAME); + // Weaviate 类名需要首字母大写 + if (className != null && !className.isEmpty()) { + return Character.toUpperCase(className.charAt(0)) + className.substring(1); + } + return className; + } + + /** + * 批量保存文档(无进度回调) + */ + public void save(List documents) throws IOException { + save(documents, null); + } + + /** + * 批量保存文档(支持进度回调) + */ + @Override + public void save(List documents, BiConsumer progressCallback) throws IOException { + if (Utils.isEmpty(documents)) { + if (progressCallback != null) { + progressCallback.accept(0, 0); + } + return; + } + + try { + initRepository(); + } catch (Exception e) { + throw new IOException("Failed to initialize Weaviate repository before saving: " + e.getMessage(), e); + } + + // 确保所有文档有 id + for (Document doc : documents) { + if (Utils.isEmpty(doc.getId())) { + doc.id(Utils.uuid()); + } + } + + // 通过 embeddingModel 生成向量(如果有配置) + if (config.embeddingModel != null) { + List> batchList = ListUtil.partition(documents, config.embeddingModel.batchSize()); + int batchIndex = 0; + + for (List batch : batchList) { + config.embeddingModel.embed(batch); + batchSaveDo(batch); + + if (progressCallback != null) { + progressCallback.accept(++batchIndex, batchList.size()); + } + } + } else { + List> batchList = ListUtil.partition(documents, 64); + int batchIndex = 0; + + for (List batch : batchList) { + batchSaveDo(batch); + + if (progressCallback != null) { + progressCallback.accept(++batchIndex, batchList.size()); + } + } + } + } + + /** + * 单批次写入 Weaviate(使用 /v1/batch/objects) + */ + private void batchSaveDo(List batch) throws IOException { + if (batch == null || batch.isEmpty()) { + return; + } + + String className = getClassName(); + List> objects = new ArrayList<>(batch.size()); + + for (Document doc : batch) { + Map obj = new HashMap<>(); + obj.put("id", doc.getId()); + obj.put("class", className); + + Map props = new HashMap<>(doc.getMetadata()); + props.put("content", doc.getContent()); + if (!Utils.isEmpty(doc.getUrl())) { + props.put("url", doc.getUrl()); + } + + obj.put("properties", props); + + // 写入向量(如果有) + if (doc.getEmbedding() != null) { + float[] emb = doc.getEmbedding(); + double[] vec = new double[emb.length]; + for (int i = 0; i < emb.length; i++) { + vec[i] = emb[i]; + } + obj.put("vector", vec); + } + + objects.add(obj); + } + + client.batchSaveObjects(objects); + } + + /** + * 按 Weaviate 对象 uuid 删除 + */ + @Override + public void deleteById(String... ids) throws IOException { + if (ids == null || ids.length == 0) { + return; + } + + String className = getClassName(); + for (String id : ids) { + if (Utils.isEmpty(id)) { + continue; + } + client.deleteObject(className, id); + } + } + + /** + * 检查对象是否存在(按 uuid) + */ + @Override + public boolean existsById(String id) throws IOException { + if (Utils.isEmpty(id)) { + return false; + } + + String className = getClassName(); + return client.objectExists(className, id); + } + + /** + * 向量搜索(使用 GraphQL nearVector) + */ + @Override + public List search(QueryCondition condition) throws IOException { + if (condition == null || condition.getQuery() == null) { + return new ArrayList<>(); + } + + if (config.embeddingModel == null) { + throw new IOException("EmbeddingModel is required for WeaviateRepository.search (nearVector)"); + } + + // 使用 EmbeddingModel 生成查询向量 + float[] embedding = config.embeddingModel.embed(condition.getQuery()); + double[] queryVec = new double[embedding.length]; + for (int i = 0; i < embedding.length; i++) { + queryVec[i] = embedding[i]; + } + + String className = getClassName(); + + // 构造 GraphQL 查询 + StringBuilder sb = new StringBuilder(); + sb.append("{\n") + .append(" Get {\n") + .append(" ").append(className).append("(\n") + .append(" nearVector: {\n") + .append(" vector: ["); + for (int i = 0; i < queryVec.length; i++) { + if (i > 0) { + sb.append(","); + } + if (i % 10 == 0) { + sb.append("\n "); + } + sb.append(queryVec[i]); + } + sb.append("\n ],\n") + .append(" certainty: 0.7\n") + .append(" }"); + + if (condition.getLimit() > 0) { + sb.append(",\n limit: ").append(condition.getLimit()); + } + + if (condition.getFilterExpression() != null) { + Map filter = FilterTransformer.getInstance().transform(condition.getFilterExpression()); + if (filter != null && !filter.isEmpty()) { + // 将过滤表达式转换为 Weaviate 的 where 参数格式 + Map where = convertToWeaviateWhere(filter); + if (where != null && !where.isEmpty()) { + // 将 where 参数添加到 GraphQL 查询中,使用GraphQL对象字面量格式 + sb.append(",\n where: ").append(FilterTransformer.getInstance().convertToGraphQLObject(where)); + } + } + } + + sb.append("\n ) {\n") + .append(" content\n") + .append(" url\n"); + + // 动态添加metadata字段 + if (config.metadataFields != null && !config.metadataFields.isEmpty()) { + for (MetadataField field : config.metadataFields) { + sb.append(" ").append(field.getName()).append("\n"); + } + } + + sb.append(" _additional {\n") + .append(" id\n") + .append(" certainty\n") + .append(" }\n") + .append(" }\n") + .append(" }\n") + .append("}"); + + // 移除多余的逗号 + String query = sb.toString(); + // 移除 limit 前的多余逗号 + query = query.replaceAll(",\\s*limit:", " limit:"); + // 移除 nearVector 后的多余逗号 + query = query.replaceAll("certainty: 0.7\\s*},\\s*", "certainty: 0.7\\n },"); + + // 执行 GraphQL 查询 + GraphQLResponse response = client.executeGraphQL(query, GraphQLResponse.class); + + List docs = new ArrayList<>(); + + if (response == null || response.getData() == null) { + return docs; + } + + Map> getResult = response.getData().getGet(); + List documentDataList = getResult.get(className); + + if (documentDataList == null || documentDataList.isEmpty()) { + return docs; + } + + for (DocumentData item : documentDataList) { + String content = item.getContent(); + String urlVal = item.getUrl(); + double score = 0.0; + String id = null; + if (item.getAdditional() != null) { + score = item.getAdditional().getCertainty(); + id = item.getAdditional().getId(); + } + // 使用 certainty 作为 score,不需要转换,因为它已经是 0-1 之间的值 + + Map metadata = new HashMap<>(); + if (!Utils.isEmpty(urlVal)) { + metadata.put("url", urlVal); + } + // 动态添加metadata字段 + if (item.getMetadata() != null && !item.getMetadata().isEmpty()) { + metadata.putAll(item.getMetadata()); + } + + Document doc = new Document(id, content, metadata, score); + if (!Utils.isEmpty(urlVal)) { + doc.url(urlVal); + } + docs.add(doc); + } + + // 再次基于 Vector/文本条件做过滤与排序 + return SimilarityUtil.refilter(docs.stream(), condition); + } + + /** + * 将过滤表达式转换为 Weaviate 的 where 参数格式 + */ + private Map convertToWeaviateWhere(Map filter) { + if (filter == null || filter.isEmpty()) { + return null; + } + + // 处理 AND 操作 + if (filter.containsKey("$and")) { + List operands = new ArrayList<>(); + List> andConditions = (List>) filter.get("$and"); + for (Map condition : andConditions) { + Map converted = convertToWeaviateWhere(condition); + if (converted != null) { + operands.add(converted); + } + } + if (!operands.isEmpty()) { + Map result = new HashMap<>(); + result.put("operator", "And"); + result.put("operands", operands); + return result; + } + } + + // 处理 OR 操作 + if (filter.containsKey("$or")) { + List operands = new ArrayList<>(); + List> orConditions = (List>) filter.get("$or"); + for (Map condition : orConditions) { + Map converted = convertToWeaviateWhere(condition); + if (converted != null) { + operands.add(converted); + } + } + if (!operands.isEmpty()) { + Map result = new HashMap<>(); + result.put("operator", "Or"); + result.put("operands", operands); + return result; + } + } + + // 处理 NOT 操作 + if (filter.containsKey("$not")) { + Map notCondition = (Map) filter.get("$not"); + Map converted = convertToWeaviateWhere(notCondition); + if (converted != null) { + Map result = new HashMap<>(); + result.put("operator", "Not"); + result.put("operands", new ArrayList<>(Collections.singletonList(converted))); + return result; + } + } + + // 处理基本比较操作 + for (Map.Entry entry : filter.entrySet()) { + String key = entry.getKey(); + if (key.startsWith("$")) { + continue; // 跳过特殊操作符 + } + + Object value = entry.getValue(); + if (value instanceof Map) { + Map valueMap = (Map) value; + // 处理 $eq, $ne, $gt, $gte, $lt, $lte, $in, $nin 操作 + for (Map.Entry opEntry : valueMap.entrySet()) { + String op = opEntry.getKey(); + Object opValue = opEntry.getValue(); + Map condition = createCondition(key, op, opValue); + if (condition != null) { + return condition; + } + } + } else { + // 默认为等于操作 + Map condition = createCondition(key, "$eq", value); + if (condition != null) { + return condition; + } + } + } + + return null; + } + + /** + * 创建单个条件 + */ + private Map createCondition(String field, String operator, Object value) { + Map condition = new HashMap<>(); + condition.put("path", Collections.singletonList(field)); + + switch (operator) { + case "$eq": + condition.put("operator", "Equal"); + setValueByType(condition, value); + break; + case "$ne": + condition.put("operator", "NotEqual"); + setValueByType(condition, value); + break; + case "$gt": + condition.put("operator", "GreaterThan"); + setValueByType(condition, value); + break; + case "$gte": + condition.put("operator", "GreaterThanEqual"); + setValueByType(condition, value); + break; + case "$lt": + condition.put("operator", "LessThan"); + setValueByType(condition, value); + break; + case "$lte": + condition.put("operator", "LessThanEqual"); + setValueByType(condition, value); + break; + case "$in": + condition.put("operator", "ContainsAny"); + if (value instanceof List) { + condition.put("valueTextArray", value); + } + break; + case "$nin": + condition.put("operator", "NotContainsAny"); + if (value instanceof List) { + condition.put("valueTextArray", value); + } + break; + default: + return null; + } + + return condition; + } + + /** + * 根据值的类型设置相应的字段 + */ + private void setValueByType(Map condition, Object value) { + if (value == null) { + return; + } + + if (value instanceof String) { + condition.put("valueString", value); + } else if (value instanceof Integer) { + condition.put("valueInt", value); + } else if (value instanceof Long) { + condition.put("valueInt", value); + } else if (value instanceof Double) { + condition.put("valueNumber", value); + } else if (value instanceof Float) { + condition.put("valueNumber", value); + } else if (value instanceof Boolean) { + condition.put("valueBoolean", value); + } + } + + /** + * 创建 Builder + */ + public static Builder builder(EmbeddingModel embeddingModel, String baseUrl) { + return new Builder(embeddingModel, baseUrl); + } + + public static class Builder { + private final EmbeddingModel embeddingModel; + private final String baseUrl; + private String username; + private String password; + private String token; + + private String collectionName = DEFAULT_COLLECTION_NAME; + private List metadataFields = new ArrayList<>(); + + private Builder(EmbeddingModel embeddingModel, String baseUrl) { + this.embeddingModel = embeddingModel; + this.baseUrl = baseUrl; + } + + public Builder collectionName(String collectionName) { + this.collectionName = collectionName; + return this; + } + + public Builder username(String username) { + this.username = username; + return this; + } + + public Builder password(String password) { + this.password = password; + return this; + } + + public Builder token(String token) { + this.token = token; + return this; + } + + public Builder metadataFields(List metadataFields) { + this.metadataFields = metadataFields; + return this; + } + + public Builder addMetadataField(MetadataField metadataField) { + this.metadataFields.add(metadataField); + return this; + } + + public WeaviateRepository build() { + return new WeaviateRepository(this); + } + } +} + diff --git a/solon-ai-rag-repositorys/solon-ai-repo-weaviate/src/main/java/org/noear/solon/ai/rag/repository/weaviate/Additional.java b/solon-ai-rag-repositorys/solon-ai-repo-weaviate/src/main/java/org/noear/solon/ai/rag/repository/weaviate/Additional.java new file mode 100644 index 00000000..2b15897a --- /dev/null +++ b/solon-ai-rag-repositorys/solon-ai-repo-weaviate/src/main/java/org/noear/solon/ai/rag/repository/weaviate/Additional.java @@ -0,0 +1,37 @@ +package org.noear.solon.ai.rag.repository.weaviate; + +/** + * 附加信息 + * + * @author 小奶奶花生米 + * @since 3.1 + */ +public class Additional { + private String id; + private double distance; + private double certainty; + + public String getId() { + return id; + } + + public void setId(String id) { + this.id = id; + } + + public double getDistance() { + return distance; + } + + public void setDistance(double distance) { + this.distance = distance; + } + + public double getCertainty() { + return certainty; + } + + public void setCertainty(double certainty) { + this.certainty = certainty; + } +} diff --git a/solon-ai-rag-repositorys/solon-ai-repo-weaviate/src/main/java/org/noear/solon/ai/rag/repository/weaviate/ClassInfo.java b/solon-ai-rag-repositorys/solon-ai-repo-weaviate/src/main/java/org/noear/solon/ai/rag/repository/weaviate/ClassInfo.java new file mode 100644 index 00000000..72e08b1c --- /dev/null +++ b/solon-ai-rag-repositorys/solon-ai-repo-weaviate/src/main/java/org/noear/solon/ai/rag/repository/weaviate/ClassInfo.java @@ -0,0 +1,22 @@ +package org.noear.solon.ai.rag.repository.weaviate; + +import org.noear.snack4.annotation.ONodeAttr; + +/** + * Class 信息 + * + * @author 小奶奶花生米 + * @since 3.1 + */ +public class ClassInfo { + @ONodeAttr(name = "class") + private String clazz; + + public String getClazz() { + return clazz; + } + + public void setClazz(String clazz) { + this.clazz = clazz; + } +} diff --git a/solon-ai-rag-repositorys/solon-ai-repo-weaviate/src/main/java/org/noear/solon/ai/rag/repository/weaviate/Data.java b/solon-ai-rag-repositorys/solon-ai-repo-weaviate/src/main/java/org/noear/solon/ai/rag/repository/weaviate/Data.java new file mode 100644 index 00000000..3c08b884 --- /dev/null +++ b/solon-ai-rag-repositorys/solon-ai-repo-weaviate/src/main/java/org/noear/solon/ai/rag/repository/weaviate/Data.java @@ -0,0 +1,22 @@ +package org.noear.solon.ai.rag.repository.weaviate; + +import java.util.List; +import java.util.Map; + +/** + * Data 部分 + * + * @author 小奶奶花生米 + * @since 3.1 + */ +public class Data { + private Map> Get; + + public Map> getGet() { + return Get; + } + + public void setGet(Map> get) { + Get = get; + } +} diff --git a/solon-ai-rag-repositorys/solon-ai-repo-weaviate/src/main/java/org/noear/solon/ai/rag/repository/weaviate/DocumentData.java b/solon-ai-rag-repositorys/solon-ai-repo-weaviate/src/main/java/org/noear/solon/ai/rag/repository/weaviate/DocumentData.java new file mode 100644 index 00000000..3b0e6414 --- /dev/null +++ b/solon-ai-rag-repositorys/solon-ai-repo-weaviate/src/main/java/org/noear/solon/ai/rag/repository/weaviate/DocumentData.java @@ -0,0 +1,60 @@ +package org.noear.solon.ai.rag.repository.weaviate; + + +import org.noear.snack4.annotation.ONodeAttr; +import java.util.HashMap; +import java.util.Map; + +/** + * 文档数据 + * + * @author 小奶奶花生米 + * @since 3.1 + */ +public class DocumentData { + @ONodeAttr(name = "_additional") + private Additional additional; + private String content; + private String url; + private Map metadata = new HashMap<>(); + + public String getContent() { + return content; + } + + public void setContent(String content) { + this.content = content; + } + + public String getUrl() { + return url; + } + + public void setUrl(String url) { + this.url = url; + } + + public Additional getAdditional() { + return additional; + } + + public void setAdditional(Additional additional) { + this.additional = additional; + } + + public Map getMetadata() { + return metadata; + } + + public void setMetadata(Map metadata) { + this.metadata = metadata; + } + + // 动态处理字段 + @ONodeAttr(ignore = true) + public void set(String key, Object value) { + if (!"content".equals(key) && !"url".equals(key) && !"_additional".equals(key)) { + metadata.put(key, value); + } + } +} diff --git a/solon-ai-rag-repositorys/solon-ai-repo-weaviate/src/main/java/org/noear/solon/ai/rag/repository/weaviate/FieldType.java b/solon-ai-rag-repositorys/solon-ai-repo-weaviate/src/main/java/org/noear/solon/ai/rag/repository/weaviate/FieldType.java new file mode 100644 index 00000000..0df5f2f6 --- /dev/null +++ b/solon-ai-rag-repositorys/solon-ai-repo-weaviate/src/main/java/org/noear/solon/ai/rag/repository/weaviate/FieldType.java @@ -0,0 +1,14 @@ +package org.noear.solon.ai.rag.repository.weaviate; + +/** + * 元数据字段类型枚举 + * + * @author 小奶奶花生米 + * @since 3.1 + */ +public enum FieldType { + STRING, + INTEGER, + FLOAT, + BOOLEAN +} diff --git a/solon-ai-rag-repositorys/solon-ai-repo-weaviate/src/main/java/org/noear/solon/ai/rag/repository/weaviate/FilterTransformer.java b/solon-ai-rag-repositorys/solon-ai-repo-weaviate/src/main/java/org/noear/solon/ai/rag/repository/weaviate/FilterTransformer.java new file mode 100644 index 00000000..36f4fee1 --- /dev/null +++ b/solon-ai-rag-repositorys/solon-ai-repo-weaviate/src/main/java/org/noear/solon/ai/rag/repository/weaviate/FilterTransformer.java @@ -0,0 +1,284 @@ +package org.noear.solon.ai.rag.repository.weaviate; + +import org.noear.solon.expression.Expression; +import org.noear.solon.expression.Transformer; +import org.noear.solon.expression.snel.*; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * 过滤转换器 + * + * @author 小奶奶花生米 + * @since 3.1 + */ +public class FilterTransformer implements Transformer> { + private static FilterTransformer instance = new FilterTransformer(); + + public static FilterTransformer getInstance() { + return instance; + } + + /** + * 解析QueryCondition中的过滤表达式,转换为Weaviate支持的过滤条件格式 + * + * @param filterExpr 查询条件 + * @return Weaviate过滤表达式,格式为Map + */ + @Override + public Map transform(Expression filterExpr) { + if (filterExpr == null) { + return null; + } + + try { + // 处理Expression对象形式的过滤表达式 + Map filter = new HashMap<>(); + parseFilterExpression(filterExpr, filter); + + return filter.isEmpty() ? null : filter; + } catch (Exception e) { + System.err.println("Error processing filter expression: " + e.getMessage()); + e.printStackTrace(); + return null; + } + } + + /** + * 解析Expression对象形式的过滤表达式 + * @param filterExpression 过滤表达式对象 + * @param result 存储解析结果的Map + */ + private void parseFilterExpression(Expression filterExpression, Map result) { + if (filterExpression == null) { + return; + } + + //用类型识别(用字符串,未来可能会变) + if (filterExpression instanceof VariableNode) { + // 获取字段名 + String fieldName = ((VariableNode) filterExpression).getName(); + result.put("$field", fieldName); + } else if (filterExpression instanceof ConstantNode) { + // 获取值 + Object value = ((ConstantNode) filterExpression).getValue(); + Boolean isCollection = ((ConstantNode) filterExpression).isCollection(); + if (isCollection) { + result.put("$value", value); + result.put("$isCollection", true); + } else { + result.put("$value", value); + } + } else if (filterExpression instanceof ComparisonNode) { + // 获取比较操作符和左右子节点 + ComparisonOp operator = ((ComparisonNode) filterExpression).getOperator(); + Expression left = ((ComparisonNode) filterExpression).getLeft(); + Expression right = ((ComparisonNode) filterExpression).getRight(); + + // 解析左右子节点 + Map leftMap = new HashMap<>(); + parseFilterExpression(left, leftMap); + + Map rightMap = new HashMap<>(); + parseFilterExpression(right, rightMap); + + // 提取字段名和值 + String fieldName2 = (String) leftMap.get("$field"); + Object value2 = rightMap.get("$value"); + + if (fieldName2 != null && value2 != null) { + switch (operator) { + case eq: + // 等于操作 - 直接设置字段值 + result.put(fieldName2, value2); + break; + case neq: + // 不等于操作 - 使用$ne操作符 + Map neMap = new HashMap<>(); + neMap.put("$ne", value2); + result.put(fieldName2, neMap); + break; + case gt: + // 大于操作 - 使用$gt操作符 + Map gtMap = new HashMap<>(); + gtMap.put("$gt", value2); + result.put(fieldName2, gtMap); + break; + case gte: + // 大于等于操作 - 使用$gte操作符 + Map gteMap = new HashMap<>(); + gteMap.put("$gte", value2); + result.put(fieldName2, gteMap); + break; + case lt: + // 小于操作 - 使用$lt操作符 + Map ltMap = new HashMap<>(); + ltMap.put("$lt", value2); + result.put(fieldName2, ltMap); + break; + case lte: + // 小于等于操作 - 使用$lte操作符 + Map lteMap = new HashMap<>(); + lteMap.put("$lte", value2); + result.put(fieldName2, lteMap); + break; + case in: + // 包含操作 - 使用$in操作符 + Map inMap = new HashMap<>(); + inMap.put("$in", value2); + result.put(fieldName2, inMap); + break; + case nin: + // 不包含操作 - 使用$nin操作符 + Map ninMap = new HashMap<>(); + ninMap.put("$nin", value2); + result.put(fieldName2, ninMap); + break; + default: + // 未识别的操作符,忽略 + break; + } + } + } else if (filterExpression instanceof LogicalNode) { + // 获取逻辑操作符和左右子节点 + LogicalOp logicalOp = ((LogicalNode) filterExpression).getOperator(); + Expression leftExpr = ((LogicalNode) filterExpression).getLeft(); + Expression rightExpr = ((LogicalNode) filterExpression).getRight(); + + if (rightExpr != null) { + // 二元逻辑操作符 (AND, OR) + Map leftMap2 = new HashMap<>(); + parseFilterExpression(leftExpr, leftMap2); + + Map rightMap2 = new HashMap<>(); + parseFilterExpression(rightExpr, rightMap2); + + switch (logicalOp) { + case AND: + // AND操作 - 使用$and操作符 + result.put("$and", Arrays.asList(leftMap2, rightMap2)); + break; + case OR: + // OR操作 - 使用$or操作符 + result.put("$or", Arrays.asList(leftMap2, rightMap2)); + break; + default: + // 未识别的操作符,忽略 + break; + } + } else if (leftExpr != null) { + // 一元逻辑操作符 (NOT) + Map leftMap2 = new HashMap<>(); + parseFilterExpression(leftExpr, leftMap2); + + switch (logicalOp) { + case NOT: + // NOT操作 - 使用$not操作符 + result.put("$not", leftMap2); + break; + default: + // 未识别的操作符,忽略 + break; + } + } + } + } + + /** + * 将Map转换为GraphQL对象字面量格式 + */ + public String convertToGraphQLObject(Map map) { + if (map == null) { + return "null"; + } + + StringBuilder sb = new StringBuilder(); + sb.append("{"); + + boolean first = true; + for (Map.Entry entry : map.entrySet()) { + if (!first) { + sb.append(", "); + } + first = false; + + String key = entry.getKey(); + Object value = entry.getValue(); + + // 特殊处理操作符字段,确保它被视为枚举类型而不是字符串 + if (key.equals("operator") && value instanceof String) { + // 对于操作符,直接输出值,不添加引号,确保GraphQL将其视为枚举类型 + sb.append(key).append(": ").append(value); + } else { + sb.append(key).append(": ").append(convertToGraphQLValue(value)); + } + } + + sb.append("}"); + return sb.toString(); + } + + /** + * 将对象转换为GraphQL值格式 + */ + private String convertToGraphQLValue(Object value) { + if (value == null) { + return "null"; + } else if (value instanceof String) { + return "\"" + escapeString((String) value) + "\""; + } else if (value instanceof Number) { + return value.toString(); + } else if (value instanceof Boolean) { + return value.toString(); + } else if (value instanceof Map) { + return convertToGraphQLObject((Map) value); + } else if (value instanceof List) { + return convertToGraphQLArray((List) value); + } else { + return "\"" + escapeString(value.toString()) + "\""; + } + } + + /** + * 将List转换为GraphQL数组格式 + */ + private String convertToGraphQLArray(List list) { + if (list == null) { + return "null"; + } + + StringBuilder sb = new StringBuilder(); + sb.append("["); + + boolean first = true; + for (Object item : list) { + if (!first) { + sb.append(", "); + } + first = false; + + sb.append(convertToGraphQLValue(item)); + } + + sb.append("]"); + return sb.toString(); + } + + /** + * 转义字符串中的特殊字符 + */ + private String escapeString(String str) { + if (str == null) { + return ""; + } + + return str.replace("\"", "\\\"") + .replace("\\", "\\\\") + .replace("\n", "\\n") + .replace("\r", "\\r") + .replace("\t", "\\t"); + } +} diff --git a/solon-ai-rag-repositorys/solon-ai-repo-weaviate/src/main/java/org/noear/solon/ai/rag/repository/weaviate/GraphQLResponse.java b/solon-ai-rag-repositorys/solon-ai-repo-weaviate/src/main/java/org/noear/solon/ai/rag/repository/weaviate/GraphQLResponse.java new file mode 100644 index 00000000..36cc8251 --- /dev/null +++ b/solon-ai-rag-repositorys/solon-ai-repo-weaviate/src/main/java/org/noear/solon/ai/rag/repository/weaviate/GraphQLResponse.java @@ -0,0 +1,34 @@ +package org.noear.solon.ai.rag.repository.weaviate; + +import java.util.Map; + +/** + * GraphQL 响应 + * + * @author 小奶奶花生米 + * @since 3.1 + */ +public class GraphQLResponse { + private Data data; + private Map errors; + + public Data getData() { + return data; + } + + public void setData(Data data) { + this.data = data; + } + + public Map getErrors() { + return errors; + } + + public void setErrors(Map errors) { + this.errors = errors; + } + + public boolean hasError() { + return errors != null && !errors.isEmpty(); + } +} diff --git a/solon-ai-rag-repositorys/solon-ai-repo-weaviate/src/main/java/org/noear/solon/ai/rag/repository/weaviate/MetadataField.java b/solon-ai-rag-repositorys/solon-ai-repo-weaviate/src/main/java/org/noear/solon/ai/rag/repository/weaviate/MetadataField.java new file mode 100644 index 00000000..0a03f965 --- /dev/null +++ b/solon-ai-rag-repositorys/solon-ai-repo-weaviate/src/main/java/org/noear/solon/ai/rag/repository/weaviate/MetadataField.java @@ -0,0 +1,41 @@ +package org.noear.solon.ai.rag.repository.weaviate; + +/** + * 元数据字段,用于定义向量库索引字段 + * + * @author 小奶奶花生米 + * @since 3.1 + */ +public class MetadataField { + private final String name; + private final FieldType fieldType; + + /** + * 创建元数据字段 + * + * @param name 字段名 + * @param fieldType 字段类型 + */ + public MetadataField(String name, FieldType fieldType) { + this.name = name; + this.fieldType = fieldType; + } + + /** + * 获取字段名 + * + * @return 字段名 + */ + public String getName() { + return name; + } + + /** + * 获取字段类型 + * + * @return 字段类型 + */ + public FieldType getFieldType() { + return fieldType; + } +} diff --git a/solon-ai-rag-repositorys/solon-ai-repo-weaviate/src/main/java/org/noear/solon/ai/rag/repository/weaviate/SchemaResponse.java b/solon-ai-rag-repositorys/solon-ai-repo-weaviate/src/main/java/org/noear/solon/ai/rag/repository/weaviate/SchemaResponse.java new file mode 100644 index 00000000..8bd2c6f3 --- /dev/null +++ b/solon-ai-rag-repositorys/solon-ai-repo-weaviate/src/main/java/org/noear/solon/ai/rag/repository/weaviate/SchemaResponse.java @@ -0,0 +1,33 @@ +package org.noear.solon.ai.rag.repository.weaviate; + +import java.util.List; + +/** + * Schema 响应 + * + * @author 小奶奶花生米 + * @since 3.1 + */ +public class SchemaResponse { + private List classes; + + public List getClasses() { + return classes; + } + + public void setClasses(List classes) { + this.classes = classes; + } + + public boolean hasClass(String className) { + if (classes == null) { + return false; + } + for (ClassInfo cls : classes) { + if (className.equals(cls.getClazz())) { + return true; + } + } + return false; + } +} diff --git a/solon-ai-rag-repositorys/solon-ai-repo-weaviate/src/main/java/org/noear/solon/ai/rag/repository/weaviate/WeaviateClient.java b/solon-ai-rag-repositorys/solon-ai-repo-weaviate/src/main/java/org/noear/solon/ai/rag/repository/weaviate/WeaviateClient.java new file mode 100644 index 00000000..d0fc828d --- /dev/null +++ b/solon-ai-rag-repositorys/solon-ai-repo-weaviate/src/main/java/org/noear/solon/ai/rag/repository/weaviate/WeaviateClient.java @@ -0,0 +1,201 @@ +package org.noear.solon.ai.rag.repository.weaviate; + +import org.noear.snack4.ONode; +import org.noear.solon.Utils; +import org.noear.solon.core.util.MultiMap; +import org.noear.solon.net.http.HttpUtils; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Weaviate API 客户端 + * 封装 Weaviate REST / GraphQL 接口调用 + * + * @author 小奶奶花生米 + * @since 3.1 + */ +public class WeaviateClient { + + + private final String baseUrl; + private final MultiMap headers = new MultiMap<>(); + + public WeaviateClient(String baseUrl) { + if (Utils.isEmpty(baseUrl)) { + throw new IllegalArgumentException("The baseurl cannot be empty."); + } + + this.baseUrl = baseUrl.endsWith("/") ? baseUrl : baseUrl + "/"; + } + + /** + * 带基本认证的构造方法 + */ + public WeaviateClient(String baseUrl, String username, String password) { + this(baseUrl); + setBasicAuth(username, password); + } + + /** + * 带令牌认证的构造方法 + */ + public WeaviateClient(String baseUrl, String token) { + this(baseUrl); + setBearerAuth(token); + } + + /** + * 设置基础鉴权 + */ + public void setBasicAuth(String username, String password) { + String plainCredentials = username + ":" + password; + String base64Credentials = java.util.Base64.getEncoder().encodeToString(plainCredentials.getBytes()); + headers.put("Authorization", "Basic " + base64Credentials); + } + + /** + * 设置令牌鉴权 + */ + public void setBearerAuth(String token) { + headers.put("Authorization", "Bearer " + token); + } + + /** + * 构建 HTTP 工具 + */ + private HttpUtils http(String endpoint) { + return HttpUtils.http(endpoint) + .header("Accept", "application/json") + .header("Content-Type", "application/json") + .headers(headers); + } + + /** + * 构建 API 端点路径 + */ + private String buildEndpoint(String... pathParts) { + return baseUrl + String.join("/", pathParts); + } + + /** + * 获取 schema + */ + public SchemaResponse getSchema() throws IOException { + try { + String endpoint = buildEndpoint("v1", "schema"); + String response = http(endpoint).get(); + return ONode.deserialize(response, SchemaResponse.class); + } catch (Exception e) { + throw new IOException("Failed to get schema: " + e.getMessage(), e); + } + } + + /** + * 创建 class + */ + public void createClass(String className, List> properties) throws IOException { + try { + String endpoint = buildEndpoint("v1", "schema"); + + Map body = new HashMap<>(); + body.put("class", className); + body.put("properties", properties); + + // 关闭 Weaviate 内置向量化,使用手动向量 + body.put("vectorizer", "none"); + + // 配置向量索引 + Map vectorIndexConfig = new HashMap<>(); + vectorIndexConfig.put("skip", false); + vectorIndexConfig.put("type", "hnsw"); + Map hnswConfig = new HashMap<>(); + hnswConfig.put("distance", "cosine"); + vectorIndexConfig.put("hnsw", hnswConfig); + body.put("vectorIndexConfig", vectorIndexConfig); + + http(endpoint).bodyOfJson(ONode.serialize(body)).post(); + } catch (Exception e) { + throw new IOException("Failed to create class: " + e.getMessage(), e); + } + } + + /** + * 删除 class + */ + public void deleteClass(String className) throws IOException { + try { + String endpoint = buildEndpoint("v1", "schema", className); + http(endpoint).delete(); + } catch (Exception e) { + throw new IOException("Failed to delete class: " + e.getMessage(), e); + } + } + + /** + * 批量保存对象 + */ + public void batchSaveObjects(List> objects) throws IOException { + try { + String endpoint = buildEndpoint("v1", "batch", "objects"); + + Map body = new HashMap<>(); + body.put("objects", objects); + + String response = http(endpoint).bodyOfJson(ONode.serialize(body)).post(); + + // 简单检查是否有 errors 字段 + if (response != null && response.contains("\"errors\"")) { + throw new IOException("Weaviate batch insert has errors: " + response); + } + } catch (Exception e) { + throw new IOException("Failed to batch save objects: " + e.getMessage(), e); + } + } + + /** + * 删除对象 + */ + public void deleteObject(String className, String id) throws IOException { + try { + String endpoint = buildEndpoint("v1", "objects", className, id); + http(endpoint).delete(); + } catch (Exception e) { + throw new IOException("Failed to delete object: " + e.getMessage(), e); + } + } + + /** + * 检查对象是否存在 + */ + public boolean objectExists(String className, String id) throws IOException { + try { + String endpoint = buildEndpoint("v1", "objects", className, id); + String response = http(endpoint).get(); + return response != null && !response.isEmpty(); + } catch (Exception e) { + return false; + } + } + + /** + * 执行 GraphQL 查询 + */ + public T executeGraphQL(String query, Class responseType) throws IOException { + try { + String endpoint = buildEndpoint("v1", "graphql"); + + Map body = new HashMap<>(); + body.put("query", query); + + String requestBody = ONode.serialize(body); + String response = http(endpoint).bodyOfJson(requestBody).post(); + return ONode.deserialize(response, responseType); + } catch (Exception e) { + throw new IOException("Failed to execute GraphQL query: " + e.getMessage(), e); + } + } +} + diff --git a/solon-ai-rag-repositorys/solon-ai-repo-weaviate/src/test/java/features/ai/repo/weaviate/WeaviateRepositoryTest.java b/solon-ai-rag-repositorys/solon-ai-repo-weaviate/src/test/java/features/ai/repo/weaviate/WeaviateRepositoryTest.java new file mode 100644 index 00000000..81bb9c48 --- /dev/null +++ b/solon-ai-rag-repositorys/solon-ai-repo-weaviate/src/test/java/features/ai/repo/weaviate/WeaviateRepositoryTest.java @@ -0,0 +1,340 @@ +package features.ai.repo.weaviate; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.noear.solon.ai.embedding.EmbeddingModel; +import org.noear.solon.ai.rag.Document; +import org.noear.solon.ai.rag.repository.WeaviateRepository; +import org.noear.solon.ai.rag.repository.weaviate.MetadataField; +import org.noear.solon.ai.rag.repository.weaviate.FieldType; +import org.noear.solon.ai.rag.splitter.TokenSizeTextSplitter; +import org.noear.solon.ai.rag.util.QueryCondition; +import org.noear.solon.net.http.HttpUtils; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.stream.Collectors; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * WeaviateRepository 集成测试 + * + * 说明: + * - 依赖本地或远程 Weaviate 实例; + * - 默认连接 http://localhost:8080,gRPC 端口 50051; + * - 默认使用 ollama 的 bge-m3 模型做外部向量。 + * + * 建议在本地先通过 docker-compose 启动 weaviate 再运行本测试。 + * + * @author 小奶奶花生米 + * @since 3.1 + */ +@EnabledIfEnvironmentVariable(named = "WEAVIATE_TEST_ENABLED", matches = "true") +public class WeaviateRepositoryTest { + + private WeaviateRepository repository; + + private static final String SERVER_BASE_URL = "http://localhost:8080"; + + private static final String COLLECTION_NAME = "solon_ai_test"; + + private static final String EMBEDDING_API_URL = "http://localhost:11434/api/embed"; + private static final String EMBEDDING_PROVIDER = "ollama"; + private static final String EMBEDDING_MODEL = "bge-m3"; + + @BeforeEach + public void setup() { + try { + // 创建外部向量模型 + EmbeddingModel embeddingModel = EmbeddingModel.of(EMBEDDING_API_URL) + .provider(EMBEDDING_PROVIDER) + .model(EMBEDDING_MODEL) + .build(); + + // 通过 WeaviateRepository 的 Builder 配置连接参数(REST 调用) + repository = WeaviateRepository.builder(embeddingModel, SERVER_BASE_URL) + .collectionName(COLLECTION_NAME+"_"+System.currentTimeMillis()) + .addMetadataField(new MetadataField("title", FieldType.STRING)) + .addMetadataField(new MetadataField("category", FieldType.STRING)) + .addMetadataField(new MetadataField("price", FieldType.INTEGER)) + .addMetadataField(new MetadataField("stock", FieldType.INTEGER)) + .build(); + + // 预加载两篇文档,便于后续检索测试 + load(repository, "https://solon.noear.org/article/about?format=md"); + load(repository, "https://h5.noear.org/readme.htm"); + + Thread.sleep(1000L); + } catch (Exception e) { + System.err.println("Failed to setup WeaviateRepositoryTest: " + e.getMessage()); + e.printStackTrace(); + } + } + + private void assumeRepositoryReady() { + org.junit.jupiter.api.Assumptions.assumeTrue(repository != null, "Weaviate repository is not initialized"); + } + + /** + * 简单的端到端搜索、保存、删除测试 + */ + @Test + public void testSearchAndCrud() throws IOException { + assumeRepositoryReady(); + + // 基本搜索 + List list = repository.search(new QueryCondition("solon")); + assertNotNull(list); + + // 插入文档并检查存在性 + Document doc = new Document("Test content for weaviate"); + repository.save(Collections.singletonList(doc)); + String key = doc.getId(); + + assertTrue(repository.existsById(key), "Document should exist after storing"); + + // 删除并再次检查 + repository.deleteById(key); + assertFalse(repository.existsById(key), "Document should not exist after removal"); + } + + /** + * 删除测试 + */ + @Test + public void testRemove() { + assumeRepositoryReady(); + + // 准备并存储测试数据 + List documents = new ArrayList<>(); + Document doc = new Document("Document to be removed", new HashMap<>()); + documents.add(doc); + + try { + repository.save(documents); + Thread.sleep(1000); + // 删除文档 + repository.deleteById(doc.getId()); + + Thread.sleep(1000); + // 验证文档已被删除 + assertFalse(repository.existsById(doc.getId()), "文档应该已被删除"); + + } catch (Exception e) { + e.printStackTrace(); + fail("测试过程中发生异常: " + e.getMessage()); + } + } + + /** + * 评分输出测试:保证 search 返回的文档都有 score 字段 + */ + @Test + public void testScoreOutput() throws IOException { + assumeRepositoryReady(); + + try { + QueryCondition condition = new QueryCondition("solon").disableRefilter(true); + List results = repository.search(condition); + + assertFalse(results.isEmpty(), "搜索结果不应为空"); + + for (Document doc : results) { + assertTrue(doc.getScore() >= 0, "文档评分应该是非负数"); + } + + if (results.size() > 1) { + double firstScore = results.get(0).getScore(); + double secondScore = results.get(1).getScore(); + assertTrue(firstScore >= secondScore, "结果应该按评分降序排序"); + } + } catch (Exception e) { + fail("测试过程中发生异常: " + e.getMessage()); + } + } + + /** + * 表达式过滤测试 + */ + @Test + public void testExpressionFilter() throws IOException { + assumeRepositoryReady(); + + // 新增带有元数据的文档 + Document doc1 = new Document("Document about Solon framework"); + doc1.getMetadata().put("title", "solon"); + doc1.getMetadata().put("category", "framework"); + + Document doc2 = new Document("Document about Java settings"); + doc2.getMetadata().put("title", "设置"); + doc2.getMetadata().put("category", "tutorial"); + + Document doc3 = new Document("Document about Spring framework"); + doc3.getMetadata().put("title", "spring"); + doc3.getMetadata().put("category", "framework"); + + List documents = new ArrayList<>(); + documents.add(doc1); + documents.add(doc2); + documents.add(doc3); + repository.save(documents); + + try { + // 1. 使用OR表达式过滤进行搜索 + String orExpression = "title == 'solon' OR title == '设置'"; + List orResults = repository.search(new QueryCondition("framework").filterExpression(orExpression).disableRefilter(true)); + + System.out.println("Found " + orResults.size() + " documents with OR filter expression: " + orExpression); + + // 验证结果包含2个文档 + assert orResults.size() == 2; + + // 2. 使用AND表达式过滤 + String andExpression = "title == 'solon' AND category == 'framework'"; + List andResults = repository.search(new QueryCondition("framework").filterExpression(andExpression).disableRefilter(true)); + + System.out.println("Found " + andResults.size() + " documents with AND filter expression: " + andExpression); + + // 验证结果只包含1个文档 + assertEquals(1, andResults.size()); + + // 3. 使用category过滤 + String categoryExpression = "category == 'framework'"; + List categoryResults = repository.search(new QueryCondition("framework").filterExpression(categoryExpression).disableRefilter(true)); + + System.out.println("Found " + categoryResults.size() + " documents with category filter: " + categoryExpression); + + // 验证结果包含2个framework类别的文档 + assertEquals(2, categoryResults.size()); + } finally { + // 清理测试数据 + repository.deleteById(doc1.getId(), doc2.getId(), doc3.getId()); + } + } + + /** + * 高级表达式过滤测试 + */ + @Test + public void testAdvancedExpressionFilter() throws IOException { + assumeRepositoryReady(); + + // 创建测试文档 + Document doc1 = new Document("Document with numeric properties"); + doc1.getMetadata().put("price", 100); + doc1.getMetadata().put("stock", 50); + doc1.getMetadata().put("category", "electronics"); + + Document doc2 = new Document("Document with different price"); + doc2.getMetadata().put("price", 200); + doc2.getMetadata().put("stock", 10); + doc2.getMetadata().put("category", "electronics"); + + Document doc3 = new Document("Document with different category"); + doc3.getMetadata().put("price", 150); + doc3.getMetadata().put("stock", 25); + doc3.getMetadata().put("category", "books"); + + List documents = new ArrayList<>(); + documents.add(doc1); + documents.add(doc2); + documents.add(doc3); + + try { + // 插入测试文档 + repository.save(documents); + + // 等待索引更新 + Thread.sleep(1000); + + // 1. 测试数值比较 (大于) + String gtExpression = "price > 120"; + QueryCondition gtCondition = new QueryCondition("document") + .filterExpression(gtExpression) + .disableRefilter(true); + + List gtResults = repository.search(gtCondition); + System.out.println("找到 " + gtResults.size() + " 个文档,使用大于表达式: " + gtExpression); + + // 验证结果 - 应该至少找到一个文档 + assertTrue(gtResults.size() > 0, "大于表达式应该找到文档"); + + // 2. 测试数值比较 (小于等于) + String lteExpression = "stock <= 25"; + QueryCondition lteCondition = new QueryCondition("document") + .filterExpression(lteExpression) + .disableRefilter(true); + + List lteResults = repository.search(lteCondition); + System.out.println("找到 " + lteResults.size() + " 个文档,使用小于等于表达式: " + lteExpression); + + // 验证结果 - 应该至少找到一个文档 + assertTrue(lteResults.size() > 0, "小于等于表达式应该找到文档"); + + // 3. 测试复合表达式 (价格区间和类别) + String complexExpression = "(price >= 100 AND price <= 180) AND category == 'electronics'"; + QueryCondition complexCondition = new QueryCondition("document") + .filterExpression(complexExpression) + .disableRefilter(true); + + List complexResults = repository.search(complexCondition); + System.out.println("找到 " + complexResults.size() + " 个文档,使用复合表达式: " + complexExpression); + + // 验证结果 - 应该至少找到一个文档 + assertTrue(complexResults.size() > 0, "复合表达式应该找到文档"); + + // 打印结果 + System.out.println("\n=== 高级表达式过滤测试结果 ==="); + System.out.println("大于表达式结果数量: " + gtResults.size()); + System.out.println("小于等于表达式结果数量: " + lteResults.size()); + System.out.println("复合表达式结果数量: " + complexResults.size()); + + } catch (Exception e) { + e.printStackTrace(); + fail("测试过程中发生异常: " + e.getMessage()); + } finally { + // 清理测试文档 + try { + repository.deleteById(doc1.getId(), doc2.getId(), doc3.getId()); + } catch (Exception e) { + System.err.println("清理测试文档失败: " + e.getMessage()); + } + } + } + + private void load(WeaviateRepository repository, String url) throws IOException { + System.out.println("Loading documents from: " + url); + String text = HttpUtils.http(url).get(); // 加载文档 + System.out.println("Loaded text with length: " + text.length()); + + // 分割文档 + List documents = new TokenSizeTextSplitter(200).split(text).stream() + .map(doc -> { + doc.url(url); + return doc; + }) + .collect(Collectors.toList()); + + System.out.println("Split into " + documents.size() + " documents"); + + // 存储文档 + repository.save(documents); + System.out.println("Inserted documents into repository"); + + // 验证文档是否成功插入 + try { + if (!documents.isEmpty()) { + boolean exists = repository.existsById(documents.get(0).getId()); + System.out.println("Verified document exists: " + exists); + } + } catch (Exception e) { + System.err.println("Failed to verify document: " + e.getMessage()); + } + } +} + -- Gitee