diff --git a/src/main/java/com/github/javpower/javavision/detect/translator/AbstractDjlTranslator.java b/src/main/java/com/github/javpower/javavision/detect/translator/AbstractDjlTranslator.java index 7ba00c200910e87342a72b1e595a4d1b524b1c51..853c76b124d135c59319a785f495152a51fd808b 100644 --- a/src/main/java/com/github/javpower/javavision/detect/translator/AbstractDjlTranslator.java +++ b/src/main/java/com/github/javpower/javavision/detect/translator/AbstractDjlTranslator.java @@ -4,13 +4,18 @@ import ai.djl.modality.cv.Image; import ai.djl.repository.zoo.Criteria; import ai.djl.training.util.ProgressBar; import ai.djl.translate.Translator; +import com.github.javpower.javavision.util.DjlHandlerUtil; import com.github.javpower.javavision.util.JarFileUtils; import com.github.javpower.javavision.util.PathConstants; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; +import java.io.File; import java.io.IOException; import java.util.Map; +@Slf4j public abstract class AbstractDjlTranslator { public String modelName; @@ -35,11 +40,15 @@ public abstract class AbstractDjlTranslator { } catch (IOException e) { throw new RuntimeException(e); } - String model_path = PathConstants.TEMP_DIR + PathConstants.ONNX + "/" + modelName; +// String model_path = PathConstants.TEMP_DIR + PathConstants.ONNX + "/" + modelName; + String modelPath = PathConstants.TEMP_DIR + File.separator+PathConstants.ONNX_NAME+ File.separator + modelName; + log.info("路径修改前:{}",modelPath); + modelPath= DjlHandlerUtil.getFixModelPath(modelPath); + log.info("路径修改后:{}",modelPath); Criteria criteria = Criteria.builder() .setTypes(Image.class, getClassOfT()) - .optModelUrls(model_path) + .optModelUrls(modelPath) .optTranslator(translator) .optEngine(getEngine()) // Use PyTorch engine .optProgress(new ProgressBar()) @@ -53,5 +62,6 @@ public abstract class AbstractDjlTranslator { protected abstract Class getClassOfT(); protected abstract String getEngine(); + } diff --git a/src/main/java/com/github/javpower/javavision/util/DjlHandlerUtil.java b/src/main/java/com/github/javpower/javavision/util/DjlHandlerUtil.java new file mode 100644 index 0000000000000000000000000000000000000000..88128cbde24b083ef49c5749ccd7ca0f714a4bc4 --- /dev/null +++ b/src/main/java/com/github/javpower/javavision/util/DjlHandlerUtil.java @@ -0,0 +1,23 @@ +package com.github.javpower.javavision.util; + +import org.apache.commons.lang3.StringUtils; + +public class DjlHandlerUtil { + + /** + * 获取修复后的模型路径 + * @param modelPath 如 C:\Users\wosui\AppData\Local\Temp\ocrJava\onnx\image_feature.zip + * @return file:///C:/Users/wosui/AppData/Local/Temp/ocrJava/onnx/image_feature.zip + */ + public static String getFixModelPath(String modelPath){ + if(StringUtils.isBlank(modelPath)){ + return ""; + } + StringBuffer path=new StringBuffer(); + if(!modelPath.startsWith("http")){ + modelPath=modelPath.replace("\\","/"); + path.append("file:///").append(modelPath); + } + return path.toString(); + } +} diff --git a/src/main/java/com/github/javpower/javavision/util/PathConstants.java b/src/main/java/com/github/javpower/javavision/util/PathConstants.java index e4b09df2394b187f1c40cc71ba45125c9712242d..b8d2d101403888c88a7f42cfddc012b7a0b8a754 100644 --- a/src/main/java/com/github/javpower/javavision/util/PathConstants.java +++ b/src/main/java/com/github/javpower/javavision/util/PathConstants.java @@ -16,7 +16,7 @@ public class PathConstants { */ public static final String NCNN = "/ncnn"; public static final String ONNX = "/onnx"; - + public static final String ONNX_NAME = "onnx"; /** * 模型相关 **/