diff --git a/examples/src/main/java/smartai/examples/nlp/TextTranslationCPU.java b/examples/src/main/java/smartai/examples/nlp/TextTranslationCPU.java
new file mode 100644
index 0000000000000000000000000000000000000000..67d5710b468f95d91eba523eaa0fb452fc817d71
--- /dev/null
+++ b/examples/src/main/java/smartai/examples/nlp/TextTranslationCPU.java
@@ -0,0 +1,57 @@
+package smartai.examples.nlb;
+
+import ai.djl.Device;
+import ai.djl.ModelException;
+import ai.djl.translate.TranslateException;
+import smartai.examples.nlb.generate.SearchConfig;
+import smartai.examples.nlb.model.NllbModel;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+
+/**
+ * 文本翻译,支持202种语言互译
+ *
+ * @author Calvin
+ * @mail 179209347@qq.com
+ * @website www.aias.top
+ */
+public final class TextTranslationCPU {
+
+ private static final Logger logger = LoggerFactory.getLogger(TextTranslationCPU.class);
+
+ private TextTranslationCPU() {
+ }
+
+ public static void main(String[] args) throws ModelException, IOException,
+ TranslateException {
+
+ SearchConfig config = new SearchConfig();
+ // 设置输出文字的最大长度
+ config.setMaxSeqLength(128);
+ // 设置源语言:中文 "zho_Hans": 256200
+ config.setSrcLangId(256200);
+ // 设置目标语言:英文 "eng_Latn": 256047
+ config.setForcedBosTokenId(256047);
+ config.setForcedBosTokenId(256201);
+
+ // 输入文字
+ String input = "智利北部的丘基卡马塔矿是世界上最大的露天矿之一,长约4公里,宽3公里,深1公里。";
+
+ String modelPath = "E:\\ai\\models\\nlp\\";
+ String cpuModelName = "traced_translation_cpu.pt";
+ String gpuModelName = "traced_translation_gpu.pt";
+ try (NllbModel nllbModel = new NllbModel(config, modelPath, cpuModelName, Device.cpu())) {
+
+ System.setProperty("ai.djl.pytorch.graph_optimizer", "false");
+
+ // 运行模型,获取翻译结果
+ String result = nllbModel.translate(input);
+
+ logger.info("result========={}", result);
+ } finally {
+ System.clearProperty("ai.djl.pytorch.graph_optimizer");
+ }
+ }
+}
\ No newline at end of file
diff --git a/examples/src/main/java/smartai/examples/nlp/generate/BatchTensorList.java b/examples/src/main/java/smartai/examples/nlp/generate/BatchTensorList.java
new file mode 100644
index 0000000000000000000000000000000000000000..03ed9576f91928d7a8fcbc47f9d97895e83fc772
--- /dev/null
+++ b/examples/src/main/java/smartai/examples/nlp/generate/BatchTensorList.java
@@ -0,0 +1,174 @@
+/*
+ * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+package smartai.examples.nlb.generate;
+
+import ai.djl.ndarray.NDArray;
+import ai.djl.ndarray.NDList;
+
+/**
+ * BatchTensorList represents a search state, and the NDArrays inside are updated in each iteration
+ * of the autoregressive loop.
+ *
+ *
It is a struct consisting of NDArrays, whose first dimension is batch, and also contains
+ * sequence dimension (whose position in tensor's shape is specified by seqDimOrder). The SeqBatcher
+ * batch operations will operate on these two dimensions.
+ */
+public abstract class BatchTensorList {
+ // [batch, seq_past]. seq-dim-size == |past_seq| + |inputIds|. Will grow.
+ private NDArray pastOutputIds;
+
+ // [batch, seq_past]
+ // The cache of past attentionMask. seq-dim-size == |past_seq| + |inputIds|. Will grow.
+ private NDArray pastAttentionMask;
+
+ // (k, v) * numLayer,
+ // kv: [batch, heads, seq_past, kvfeature]
+ // The cache of past sequence. seq-dim-size == |past_seq| + |inputIds|. Will grow.
+ private NDList pastKeyValues;
+
+ // Sequence dimension order among all dimensions for each element in the batch list.
+ private long[] seqDimOrder;
+
+ BatchTensorList() {}
+
+ /**
+ * Constructs a new {@code BatchTensorList} instance.
+ *
+ * @param list the NDList that contains the serialized version of the batch tensors
+ * @param seqDimOrder the sequence dimension order that specifies where the sequence dimension
+ * is in a tensor's shape
+ */
+ BatchTensorList(NDList list, long[] seqDimOrder) {
+ this.seqDimOrder = seqDimOrder;
+ pastOutputIds = list.get(0);
+ pastAttentionMask = list.get(1);
+ pastKeyValues = list.subNDList(2);
+ }
+
+ /**
+ * Constructs a new {@code BatchTensorList} instance.
+ *
+ * @param pastOutputIds past output token ids
+ * @param pastAttentionMask past attention mask
+ * @param pastKeyValues past kv cache
+ * @param seqDimOrder the sequence dimension order that specifies where the sequence dimension
+ * is in a tensor's shape
+ */
+ BatchTensorList(
+ NDArray pastOutputIds,
+ NDArray pastAttentionMask,
+ NDList pastKeyValues,
+ long[] seqDimOrder) {
+ this.pastKeyValues = pastKeyValues;
+ this.pastOutputIds = pastOutputIds;
+ this.pastAttentionMask = pastAttentionMask;
+ this.seqDimOrder = seqDimOrder;
+ }
+
+ /**
+ * Constructs a new {@code BatchTensorList} instance from the serialized version of the batch
+ * tensors.
+ *
+ *
The pastOutputIds has to be the first in the output list.
+ *
+ * @param inputList the serialized version of the batch tensors
+ * @param seqDimOrder the sequence dimension order that specifies where the sequence dimension
+ * is in a tensor's shape
+ * @return BatchTensorList
+ */
+ public abstract BatchTensorList fromList(NDList inputList, long[] seqDimOrder);
+
+ /**
+ * Returns the serialized version of the BatchTensorList. The pastOutputIds has to be the first
+ * in the output list.
+ *
+ * @return the NDList that contains the serialized BatchTensorList
+ */
+ public abstract NDList getList();
+
+ /**
+ * Returns the sequence dimension order which specifies where the sequence dimension is in a
+ * tensor's shape.
+ *
+ * @return the sequence dimension order which specifies where the sequence dimension is in a
+ * tensor's shape
+ */
+ public long[] getSeqDimOrder() {
+ return seqDimOrder;
+ }
+
+ /**
+ * Returns the value of the pastOutputIds.
+ *
+ * @return the value of pastOutputIds
+ */
+ public NDArray getPastOutputIds() {
+ return pastOutputIds;
+ }
+
+ /**
+ * Sets the past output token ids.
+ *
+ * @param pastOutputIds the past output token ids
+ */
+ public void setPastOutputIds(NDArray pastOutputIds) {
+ this.pastOutputIds = pastOutputIds;
+ }
+
+ /**
+ * Returns the value of the pastAttentionMask.
+ *
+ * @return the value of pastAttentionMask
+ */
+ public NDArray getPastAttentionMask() {
+ return pastAttentionMask;
+ }
+
+ /**
+ * Sets the attention mask.
+ *
+ * @param pastAttentionMask the attention mask
+ */
+ public void setPastAttentionMask(NDArray pastAttentionMask) {
+ this.pastAttentionMask = pastAttentionMask;
+ }
+
+ /**
+ * Returns the value of the pastKeyValues.
+ *
+ * @return the value of pastKeyValues
+ */
+ public NDList getPastKeyValues() {
+ return pastKeyValues;
+ }
+
+ /**
+ * Sets the kv cache.
+ *
+ * @param pastKeyValues the kv cache
+ */
+ public void setPastKeyValues(NDList pastKeyValues) {
+ this.pastKeyValues = pastKeyValues;
+ }
+
+ /**
+ * Sets the sequence dimension order which specifies where the sequence dimension is in a
+ * tensor's shape.
+ *
+ * @param seqDimOrder the sequence dimension order which specifies where the sequence dimension
+ * is in a tensor's shape
+ */
+ public void setSeqDimOrder(long[] seqDimOrder) {
+ this.seqDimOrder = seqDimOrder;
+ }
+}
\ No newline at end of file
diff --git a/examples/src/main/java/smartai/examples/nlp/generate/CausalLMOutput.java b/examples/src/main/java/smartai/examples/nlp/generate/CausalLMOutput.java
new file mode 100644
index 0000000000000000000000000000000000000000..80b9aa4757849f0e9238b3721f791f72accdd4bc
--- /dev/null
+++ b/examples/src/main/java/smartai/examples/nlp/generate/CausalLMOutput.java
@@ -0,0 +1,32 @@
+package smartai.examples.nlb.generate;
+
+import ai.djl.ndarray.NDArray;
+import ai.djl.ndarray.NDList;
+/**
+ * 解码输出对象
+ *
+ * @author Calvin
+ * @mail 179209347@qq.com
+ * @website www.aias.top
+ */
+public class CausalLMOutput {
+ private NDArray logits;
+ private NDList pastKeyValuesList;
+
+ public CausalLMOutput(NDArray logits, NDList pastKeyValues) {
+ this.logits = logits;
+ this.pastKeyValuesList = pastKeyValues;
+ }
+
+ public NDArray getLogits() {
+ return logits;
+ }
+
+ public void setLogits(NDArray logits) {
+ this.logits = logits;
+ }
+
+ public NDList getPastKeyValuesList() {
+ return pastKeyValuesList;
+ }
+}
\ No newline at end of file
diff --git a/examples/src/main/java/smartai/examples/nlp/generate/GreedyBatchTensorList.java b/examples/src/main/java/smartai/examples/nlp/generate/GreedyBatchTensorList.java
new file mode 100644
index 0000000000000000000000000000000000000000..388de1455e0f304cdaea932568341ddfb11ea321
--- /dev/null
+++ b/examples/src/main/java/smartai/examples/nlp/generate/GreedyBatchTensorList.java
@@ -0,0 +1,84 @@
+package smartai.examples.nlb.generate;
+
+import ai.djl.ndarray.NDArray;
+import ai.djl.ndarray.NDList;
+
+/**
+ * 贪婪搜索张量对象列表
+ *
+ * @author Calvin
+ * @mail 179209347@qq.com
+ * @website www.aias.top
+ */
+public class GreedyBatchTensorList extends BatchTensorList {
+ // [batch, 1]
+ private NDArray nextInputIds;
+
+ private NDArray pastOutputIds;
+
+ private NDArray encoderHiddenStates;
+ private NDArray attentionMask;
+ private NDList pastKeyValues;
+
+ public GreedyBatchTensorList(
+ NDArray nextInputIds,
+ NDArray pastOutputIds,
+ NDList pastKeyValues,
+ NDArray encoderHiddenStates,
+ NDArray attentionMask) {
+ this.nextInputIds = nextInputIds;
+ this.pastKeyValues = pastKeyValues;
+ this.pastOutputIds = pastOutputIds;
+ this.attentionMask = attentionMask;
+ this.encoderHiddenStates = encoderHiddenStates;
+ }
+
+ public GreedyBatchTensorList() {}
+
+ public BatchTensorList fromList(NDList inputList, long[] seqDimOrder) {
+ return new GreedyBatchTensorList();
+ }
+
+ public NDList getList() {
+ return new NDList();
+ }
+
+ public NDArray getNextInputIds() {
+ return nextInputIds;
+ }
+
+ public void setNextInputIds(NDArray nextInputIds) {
+ this.nextInputIds = nextInputIds;
+ }
+ public NDArray getPastOutputIds() {
+ return pastOutputIds;
+ }
+
+ public void setPastOutputIds(NDArray pastOutputIds) {
+ this.pastOutputIds = pastOutputIds;
+ }
+
+ public NDList getPastKeyValues() {
+ return pastKeyValues;
+ }
+
+ public void setPastKeyValues(NDList pastKeyValues) {
+ this.pastKeyValues = pastKeyValues;
+ }
+
+ public NDArray getEncoderHiddenStates() {
+ return encoderHiddenStates;
+ }
+
+ public void setEncoderHiddenStates(NDArray encoderHiddenStates) {
+ this.encoderHiddenStates = encoderHiddenStates;
+ }
+
+ public NDArray getAttentionMask() {
+ return attentionMask;
+ }
+
+ public void setAttentionMask(NDArray attentionMask) {
+ this.attentionMask = attentionMask;
+ }
+}
\ No newline at end of file
diff --git a/examples/src/main/java/smartai/examples/nlp/generate/SearchConfig.java b/examples/src/main/java/smartai/examples/nlp/generate/SearchConfig.java
new file mode 100644
index 0000000000000000000000000000000000000000..a040b504641110ca2f653909da37c8909c2f4cc3
--- /dev/null
+++ b/examples/src/main/java/smartai/examples/nlp/generate/SearchConfig.java
@@ -0,0 +1,104 @@
+package smartai.examples.nlb.generate;
+/**
+ * 配置信息
+ *
+ * @author Calvin
+ * @mail 179209347@qq.com
+ * @website www.aias.top
+ */
+public class SearchConfig {
+
+ private int maxSeqLength;
+ private long padTokenId;
+ private long eosTokenId;
+ private long bosTokenId;
+ private long decoderStartTokenId;
+ private float encoderRepetitionPenalty;
+ private long forcedBosTokenId;
+ private long srcLangId;
+ private float lengthPenalty;
+ public SearchConfig() {
+ this.maxSeqLength = 512;
+ this.eosTokenId = 2;
+ this.bosTokenId = 0;
+ this.padTokenId = 1;
+ this.decoderStartTokenId = 2;
+ this.encoderRepetitionPenalty = 1.0f;
+ this.srcLangId = 0;
+ this.forcedBosTokenId = 0;
+ this.lengthPenalty = 1.0f;
+
+ }
+
+ public long getSrcLangId() {
+ return srcLangId;
+ }
+
+ public void setSrcLangId(long srcLangId) {
+ this.srcLangId = srcLangId;
+ }
+
+ public void setEosTokenId(long eosTokenId) {
+ this.eosTokenId = eosTokenId;
+ }
+
+ public int getMaxSeqLength() {
+ return maxSeqLength;
+ }
+
+ public void setMaxSeqLength(int maxSeqLength) {
+ this.maxSeqLength = maxSeqLength;
+ }
+
+ public long getPadTokenId() {
+ return padTokenId;
+ }
+
+ public void setPadTokenId(long padTokenId) {
+ this.padTokenId = padTokenId;
+ }
+
+ public long getEosTokenId() {
+ return eosTokenId;
+ }
+
+ public long getDecoderStartTokenId() {
+ return decoderStartTokenId;
+ }
+
+ public void setDecoderStartTokenId(long decoderStartTokenId) {
+ this.decoderStartTokenId = decoderStartTokenId;
+ }
+
+ public float getEncoderRepetitionPenalty() {
+ return encoderRepetitionPenalty;
+ }
+
+ public void setEncoderRepetitionPenalty(float encoderRepetitionPenalty) {
+ this.encoderRepetitionPenalty = encoderRepetitionPenalty;
+ }
+
+ public long getForcedBosTokenId() {
+ return forcedBosTokenId;
+ }
+
+ public void setForcedBosTokenId(long forcedBosTokenId) {
+ this.forcedBosTokenId = forcedBosTokenId;
+ }
+
+ public float getLengthPenalty() {
+ return lengthPenalty;
+ }
+
+ public void setLengthPenalty(float lengthPenalty) {
+ this.lengthPenalty = lengthPenalty;
+ }
+
+ public long getBosTokenId() {
+ return bosTokenId;
+ }
+
+ public void setBosTokenId(long bosTokenId) {
+ this.bosTokenId = bosTokenId;
+ }
+}
\ No newline at end of file
diff --git a/examples/src/main/java/smartai/examples/nlp/model/Decoder2Translator.java b/examples/src/main/java/smartai/examples/nlp/model/Decoder2Translator.java
new file mode 100644
index 0000000000000000000000000000000000000000..26255477f5ec3384d88b4f6cef3705e66d919f9b
--- /dev/null
+++ b/examples/src/main/java/smartai/examples/nlp/model/Decoder2Translator.java
@@ -0,0 +1,45 @@
+package smartai.examples.nlb.model;
+
+import ai.djl.ndarray.NDArray;
+import ai.djl.ndarray.NDList;
+import ai.djl.translate.NoBatchifyTranslator;
+import ai.djl.translate.TranslatorContext;
+import smartai.examples.nlb.generate.CausalLMOutput;
+
+/**
+ * 解碼器,參數支持 pastKeyValues
+ *
+ * @author Calvin
+ * @mail 179209347@qq.com
+ * @website www.aias.top
+ */
+public class Decoder2Translator implements NoBatchifyTranslator {
+ private String tupleName;
+
+ public Decoder2Translator() {
+ tupleName = "past_key_values(" + 12 + ',' + 4 + ')';
+ }
+
+ @Override
+ public NDList processInput(TranslatorContext ctx, NDList input) {
+
+ NDArray placeholder = ctx.getNDManager().create(0);
+ placeholder.setName("module_method:decoder2");
+
+ input.add(placeholder);
+
+ return input;
+ }
+
+ @Override
+ public CausalLMOutput processOutput(TranslatorContext ctx, NDList output) {
+ NDArray logitsOutput = output.get(0);
+ NDList pastKeyValuesOutput = output.subNDList(1, 12 * 4 + 1);
+
+ for (NDArray array : pastKeyValuesOutput) {
+ array.setName(tupleName);
+ }
+
+ return new CausalLMOutput(logitsOutput, pastKeyValuesOutput);
+ }
+}
\ No newline at end of file
diff --git a/examples/src/main/java/smartai/examples/nlp/model/DecoderTranslator.java b/examples/src/main/java/smartai/examples/nlp/model/DecoderTranslator.java
new file mode 100644
index 0000000000000000000000000000000000000000..9211a8282eac8d3fbed6be051cfcf4e64dce9069
--- /dev/null
+++ b/examples/src/main/java/smartai/examples/nlp/model/DecoderTranslator.java
@@ -0,0 +1,44 @@
+package smartai.examples.nlb.model;
+
+import ai.djl.ndarray.NDArray;
+import ai.djl.ndarray.NDList;
+import ai.djl.translate.NoBatchifyTranslator;
+import ai.djl.translate.TranslatorContext;
+import smartai.examples.nlb.generate.CausalLMOutput;
+/**
+ * 解碼器,參數沒有 pastKeyValues
+ *
+ * @author Calvin
+ * @mail 179209347@qq.com
+ * @website www.aias.top
+ */
+public class DecoderTranslator implements NoBatchifyTranslator {
+ private String tupleName;
+
+ public DecoderTranslator() {
+ tupleName = "past_key_values(" + 12 + ',' + 4 + ')';
+ }
+
+ @Override
+ public NDList processInput(TranslatorContext ctx, NDList input) {
+
+ NDArray placeholder = ctx.getNDManager().create(0);
+ placeholder.setName("module_method:decoder");
+
+ input.add(placeholder);
+
+ return input;
+ }
+
+ @Override
+ public CausalLMOutput processOutput(TranslatorContext ctx, NDList output) {
+ NDArray logitsOutput = output.get(0);
+ NDList pastKeyValuesOutput = output.subNDList(1, 12 * 4 + 1);
+
+ for (NDArray array : pastKeyValuesOutput) {
+ array.setName(tupleName);
+ }
+
+ return new CausalLMOutput(logitsOutput, pastKeyValuesOutput);
+ }
+}
\ No newline at end of file
diff --git a/examples/src/main/java/smartai/examples/nlp/model/EncoderTranslator.java b/examples/src/main/java/smartai/examples/nlp/model/EncoderTranslator.java
new file mode 100644
index 0000000000000000000000000000000000000000..8ba292ee71311e5af70c4fc7f65eea72f7a12a72
--- /dev/null
+++ b/examples/src/main/java/smartai/examples/nlp/model/EncoderTranslator.java
@@ -0,0 +1,49 @@
+package smartai.examples.nlb.model;
+
+import ai.djl.ndarray.NDArray;
+import ai.djl.ndarray.NDList;
+import ai.djl.ndarray.NDManager;
+import ai.djl.translate.NoBatchifyTranslator;
+import ai.djl.translate.TranslatorContext;
+
+import java.util.Arrays;
+
+/**
+ * 编码器前后处理
+ *
+ * @author Calvin
+ * @mail 179209347@qq.com
+ * @website www.aias.top
+ */
+public class EncoderTranslator implements NoBatchifyTranslator {
+
+
+ public EncoderTranslator() {
+ }
+
+ @Override
+ public NDList processInput(TranslatorContext ctx, long[] input) throws Exception {
+ NDManager manager = ctx.getNDManager();
+
+ NDArray inputIdArray = manager.create(input).expandDims(0);
+ inputIdArray.setName("input_ids");
+
+ long[] attentionMask = new long[input.length];
+ Arrays.fill(attentionMask, 1);
+ NDArray attentionMaskArray = manager.create(attentionMask).expandDims(0);
+ attentionMaskArray.setName("attention_mask");
+
+ NDArray placeholder = ctx.getNDManager().create(0);
+ placeholder.setName("module_method:encoder");
+
+ return new NDList(inputIdArray, attentionMaskArray, placeholder);
+ }
+
+ @Override
+ public NDArray processOutput(TranslatorContext ctx, NDList list) {
+ NDArray encoderHiddenStates = list.get(0);
+ encoderHiddenStates.detach();
+ return encoderHiddenStates;
+ }
+
+}
\ No newline at end of file
diff --git a/examples/src/main/java/smartai/examples/nlp/model/NllbModel.java b/examples/src/main/java/smartai/examples/nlp/model/NllbModel.java
new file mode 100644
index 0000000000000000000000000000000000000000..a987ab26fb334c2bf9c1cff0a7178ee7d4a0f392
--- /dev/null
+++ b/examples/src/main/java/smartai/examples/nlp/model/NllbModel.java
@@ -0,0 +1,183 @@
+package smartai.examples.nlb.model;
+
+import ai.djl.Device;
+import ai.djl.ModelException;
+import ai.djl.huggingface.tokenizers.Encoding;
+import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
+import ai.djl.inference.Predictor;
+import ai.djl.ndarray.NDArray;
+import ai.djl.ndarray.NDList;
+import ai.djl.ndarray.NDManager;
+import ai.djl.ndarray.index.NDIndex;
+import ai.djl.repository.zoo.Criteria;
+import ai.djl.repository.zoo.ZooModel;
+import ai.djl.translate.NoopTranslator;
+import ai.djl.translate.TranslateException;
+import smartai.examples.nlb.generate.CausalLMOutput;
+import smartai.examples.nlb.generate.GreedyBatchTensorList;
+import smartai.examples.nlb.generate.SearchConfig;
+import smartai.examples.nlb.tokenizer.TokenUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.nio.file.Paths;
+import java.util.Arrays;
+/**
+ * 模型载入及推理
+ *
+ * @author Calvin
+ * @mail 179209347@qq.com
+ * @website www.aias.top
+ */
+public class NllbModel implements AutoCloseable {
+ private static final Logger logger = LoggerFactory.getLogger(NllbModel.class);
+ private SearchConfig config;
+ private ZooModel nllbModel;
+ private HuggingFaceTokenizer tokenizer;
+ private Predictor encoderPredictor;
+ private Predictor decoderPredictor;
+ private Predictor decoder2Predictor;
+ private NDManager manager;
+
+ public NllbModel(SearchConfig config, String modelPath, String modelName, Device device) throws ModelException, IOException {
+ this.config = config;
+ Criteria criteria =
+ Criteria.builder()
+ .setTypes(NDList.class, NDList.class)
+ .optModelPath(Paths.get(modelPath + modelName))
+ .optEngine("PyTorch")
+ .optDevice(device)
+ .optTranslator(new NoopTranslator())
+ .build();
+
+ manager = NDManager.newBaseManager(device);
+ nllbModel = criteria.loadModel();
+ tokenizer = HuggingFaceTokenizer.newInstance(Paths.get(modelPath + "tokenizer.json"));
+ encoderPredictor = nllbModel.newPredictor(new EncoderTranslator());
+ decoderPredictor = nllbModel.newPredictor(new DecoderTranslator());
+ decoder2Predictor = nllbModel.newPredictor(new Decoder2Translator());
+ }
+
+ public NDArray encoder(long[] ids) throws TranslateException {
+ return encoderPredictor.predict(ids);
+ }
+
+ public CausalLMOutput decoder(NDList input) throws TranslateException {
+ return decoderPredictor.predict(input);
+ }
+
+ public CausalLMOutput decoder2(NDList input) throws TranslateException {
+ return decoder2Predictor.predict(input);
+ }
+
+ @Override
+ public void close() {
+ encoderPredictor.close();
+ decoderPredictor.close();
+ decoder2Predictor.close();
+ nllbModel.close();
+ manager.close();
+ tokenizer.close();
+ }
+
+ public String translate(String input) throws TranslateException {
+
+ Encoding encoding = tokenizer.encode(input);
+ long[] ids = encoding.getIds();
+ // 1. Encoder
+ long[] inputIds = new long[ids.length];
+ // 设置源语言编码
+ inputIds[0] = config.getSrcLangId();
+ for (int i = 0; i < ids.length - 1; i++) {
+ inputIds[i + 1] = ids[i];
+ }
+ logger.info("inputIds: " + Arrays.toString(inputIds));
+ long[] attentionMask = encoding.getAttentionMask();
+ NDArray attentionMaskArray = manager.create(attentionMask).expandDims(0);
+
+ NDArray encoderHiddenStates = encoder(inputIds);
+
+ NDArray decoder_input_ids = manager.create(new long[]{config.getDecoderStartTokenId()}).reshape(1, 1);
+ NDList decoderInput = new NDList(decoder_input_ids, encoderHiddenStates, attentionMaskArray);
+
+ // 2. Initial Decoder
+ CausalLMOutput modelOutput = decoder(decoderInput);
+ modelOutput.getLogits().attach(manager);
+ modelOutput.getPastKeyValuesList().attach(manager);
+
+ GreedyBatchTensorList searchState =
+ new GreedyBatchTensorList(null, decoder_input_ids, modelOutput.getPastKeyValuesList(), encoderHiddenStates, attentionMaskArray);
+
+ while (true) {
+// try (NDScope ignore = new NDScope()) {
+ NDArray pastOutputIds = searchState.getPastOutputIds();
+
+ if (searchState.getNextInputIds() != null) {
+ decoderInput = new NDList(searchState.getNextInputIds(), searchState.getEncoderHiddenStates(), searchState.getAttentionMask());
+ decoderInput.addAll(searchState.getPastKeyValues());
+ // 3. Decoder loop
+ modelOutput = decoder2(decoderInput);
+ }
+
+ NDArray outputIds = greedyStepGen(config, pastOutputIds, modelOutput.getLogits());
+
+ searchState.setNextInputIds(outputIds);
+ pastOutputIds = pastOutputIds.concat(outputIds, 1);
+ searchState.setPastOutputIds(pastOutputIds);
+
+ searchState.setPastKeyValues(modelOutput.getPastKeyValuesList());
+
+ // memory management
+// NDScope.unregister(outputIds, pastOutputIds);
+// }
+
+ // Termination Criteria
+ long id = searchState.getNextInputIds().toLongArray()[0];
+ if (config.getEosTokenId() == id) {
+ searchState.setNextInputIds(null);
+ break;
+ }
+ if (searchState.getPastOutputIds() != null && searchState.getPastOutputIds().getShape().get(1) + 1 >= config.getMaxSeqLength()) {
+ break;
+ }
+ }
+
+ if (searchState.getNextInputIds() == null) {
+ NDArray resultIds = searchState.getPastOutputIds();
+ String result = TokenUtils.decode(config, tokenizer, resultIds);
+ return result;
+ } else {
+ NDArray resultIds = searchState.getPastOutputIds(); // .concat(searchState.getNextInputIds(), 1)
+ String result = TokenUtils.decode(config, tokenizer, resultIds);
+ return result;
+ }
+
+ }
+
+ public NDArray greedyStepGen(SearchConfig config, NDArray pastOutputIds, NDArray next_token_scores) {
+ next_token_scores = next_token_scores.get(":, -1, :");
+
+ NDArray new_next_token_scores = manager.create(next_token_scores.getShape(), next_token_scores.getDataType());
+ next_token_scores.copyTo(new_next_token_scores);
+
+ // LogitsProcessor 1. ForcedBOSTokenLogitsProcessor
+ // 设置目标语言
+ long cur_len = pastOutputIds.getShape().getLastDimension();
+ if (cur_len == 1) {
+ long num_tokens = new_next_token_scores.getShape().getLastDimension();
+ for (long i = 0; i < num_tokens; i++) {
+ if (i != config.getForcedBosTokenId()) {
+ new_next_token_scores.set(new NDIndex(":," + i), Float.NEGATIVE_INFINITY);
+ }
+ }
+ new_next_token_scores.set(new NDIndex(":," + config.getForcedBosTokenId()), 0);
+ }
+
+ NDArray probs = new_next_token_scores.softmax(-1);
+ NDArray next_tokens = probs.argMax(-1);
+
+ return next_tokens.expandDims(0);
+ }
+
+}
diff --git a/examples/src/main/java/smartai/examples/objectdetection/ObjectDetection.java b/examples/src/main/java/smartai/examples/objectdetection/ObjectDetection.java
index 4fc321a95c5ce989abb4f51662bb79a8156d65f0..f8ee492c954e24bd5feadd29f58e285ad2e7a504 100644
--- a/examples/src/main/java/smartai/examples/objectdetection/ObjectDetection.java
+++ b/examples/src/main/java/smartai/examples/objectdetection/ObjectDetection.java
@@ -111,18 +111,21 @@ public class ObjectDetection {
/**
- * 使用yolo官方模型检测
+ * 使用yolo官方模型检测物品识别
*/
@Test
public void objectDetectionWithOfficialModel(){
DetectorModelConfig config = new DetectorModelConfig();
+ config.setThreshold(0.3f);
//也支持YoloV8:YOLOV8_OFFICIAL 模型可以从文档中提供的地址下载
config.setModelEnum(DetectorModelEnum.YOLOV12_OFFICIAL);//检测模型,目前支持19种模型
// 指定模型路径,需要更改为自己的模型路径
- config.setModelPath("/Users/xxx/Documents/develop/face_model/yolov12n.onnx");
+ config.setModelPath("E:\\ai\\models\\yolo12m\\yolov12m.onnx");
DetectorModel detectorModel = ObjectDetectionModelFactory.getInstance().getModel(config);
//一定要将yolo官方的类别文件:synset.txt(文档中下载)放在模型同目录下,否则报错
- detectorModel.detectAndDraw("src/main/resources/object_detection.jpg","output/object_detection_detected.png");
+ DetectionResponse detect = detectorModel.detect("E:\\ai\\testimage\\1.jpg");
+ log.info("目标检测结果:{}", JSONObject.toJSONString(detect));
+ detectorModel.detectAndDraw("E:\\ai\\testimage\\1.jpg","E:\\ai\\outimage\\11.png");
}
/**
diff --git a/pom.xml b/pom.xml
index c5aeed1bda50fd6b8bd405d446456b24a346e86d..ef044b7ed566a7d07cadd80bc4e930c6f92e58e4 100644
--- a/pom.xml
+++ b/pom.xml
@@ -11,6 +11,7 @@
SmartJavaAI
smartjavaai-face
+ smartjavaai-translate
smartjavaai-common
smartjavaai-objectdetection
smartjavaai-all
@@ -52,14 +53,16 @@
ai.djl
model-zoo
+
+ ai.djl.huggingface
+ tokenizers
+ ${djl.version}
+
-
-
-
-
+
diff --git a/smartjavaai-common/src/main/java/cn/smartjavaai/common/pool/ZooModelFactory.java b/smartjavaai-common/src/main/java/cn/smartjavaai/common/pool/ZooModelFactory.java
new file mode 100644
index 0000000000000000000000000000000000000000..4c8bdaf6e870ed494f1df830b0bee8597d73eef4
--- /dev/null
+++ b/smartjavaai-common/src/main/java/cn/smartjavaai/common/pool/ZooModelFactory.java
@@ -0,0 +1,35 @@
+package cn.smartjavaai.common.pool;
+
+import ai.djl.inference.Predictor;
+import ai.djl.repository.zoo.ZooModel;
+import org.apache.commons.pool2.BasePooledObjectFactory;
+import org.apache.commons.pool2.PooledObject;
+import org.apache.commons.pool2.impl.DefaultPooledObject;
+
+/**
+ * ZooModel 工厂类
+ * @author lwx
+ * @date 2025/6/06
+ */
+public class ZooModelFactory extends BasePooledObjectFactory> {
+ private final ZooModel model;
+
+ public ZooModelFactory(ZooModel model) {
+ this.model = model;
+ }
+
+ @Override
+ public ZooModel create() {
+ return model;
+ }
+
+ @Override
+ public PooledObject> wrap(ZooModel predictor) {
+ return new DefaultPooledObject<>(predictor);
+ }
+
+ @Override
+ public void destroyObject(PooledObject> p) {
+ p.getObject().close();
+ }
+}
diff --git a/smartjavaai-objectdetection/src/main/java/cn/smartjavaai/objectdetection/enums/DetectorModelEnum.java b/smartjavaai-objectdetection/src/main/java/cn/smartjavaai/objectdetection/enums/DetectorModelEnum.java
index ad69fb4dcf8cbe5edce0e03f90528ed8e223d118..4fb6d8a282272fc3afcb5cfe6d7650696b798553 100644
--- a/smartjavaai-objectdetection/src/main/java/cn/smartjavaai/objectdetection/enums/DetectorModelEnum.java
+++ b/smartjavaai-objectdetection/src/main/java/cn/smartjavaai/objectdetection/enums/DetectorModelEnum.java
@@ -9,7 +9,7 @@ public enum DetectorModelEnum {
// resnet50 系列
SSD_300_RESNET50("ai.djl.pytorch/ssd/0.0.1/ssd_300_resnet50"),
- SSD_512_RESNET50_V1_VOC("ai.djl.mxnet/ssd/0.0.1/ssd_512_resnet50_v1_voc"),
+ SSD_512_RESNET50_V1_VOC("ai.djl./ssd/0.0.1/ssd_512_resnet50_v1_voc"),
// vgg16 系列
SSD_512_VGG16_ATROUS_COCO("ai.djl.mxnet/ssd/0.0.1/ssd_512_vgg16_atrous_coco"),
diff --git a/smartjavaai-translate/pom.xml b/smartjavaai-translate/pom.xml
new file mode 100644
index 0000000000000000000000000000000000000000..a957198a2d21562a38ff8194ffaa9000624b28e0
--- /dev/null
+++ b/smartjavaai-translate/pom.xml
@@ -0,0 +1,134 @@
+
+
+ 4.0.0
+
+ cn.smartjavaai
+ smartjavaai-parent
+ 1.0.15
+
+
+ smartjavaai-translate
+
+
+
+
+ cn.smartjavaai
+ smartjavaai-common
+ ${project.version}
+
+
+ junit
+ junit
+ 4.13.1
+ compile
+
+
+
+ 1.0.15
+ smartjavaai-ocr
+ SmartJavaAI
+ https://github.com/geekwenjie/SmartJavaAI
+
+
+ MIT License
+ https://opensource.org/licenses/MIT
+
+
+
+
+
+
+ org.sonatype.central
+ central-publishing-maven-plugin
+ 0.4.0
+ true
+
+ dengwenjie
+ true
+ ${project.groupId}:${project.artifactId}:${project.version}
+
+
+
+
+ org.apache.maven.plugins
+ maven-source-plugin
+ 3.1.0
+
+
+ attach-sources
+
+ jar-no-fork
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-javadoc-plugin
+ 3.1.0
+
+
+ none
+
+ -Xdoclint:none
+
+
+
+
+ attach-javadocs
+
+ jar
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-gpg-plugin
+ 3.1.0
+
+
+ sign-artifacts
+ verify
+
+ sign
+
+
+
+
+
+
+
+
+
+ scm:git:git://github.com/geekwenjie/SmartJavaAI.git
+ scm:git:ssh://github.com/geekwenjie/SmartJavaAI.git
+ http://github.com/geekwenjie/SmartJavaAI/tree/master
+
+
+
+
+
+ dengwenjie
+ https://s01.oss.sonatype.org/content/repositories/snapshots
+
+
+ dengwenjie
+ https://s01.oss.sonatype.org/service/local/staging/deploy/maven2/
+
+
+
+
+
+ dengwenjie
+ 775747758@qq.com
+
+ Project Manager
+ Architect
+
+
+
+
+
diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/config/MachineTranslationModelConfig.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/config/MachineTranslationModelConfig.java
new file mode 100644
index 0000000000000000000000000000000000000000..1b9ee4065c044289b797bda0a3418f59fb76799b
--- /dev/null
+++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/config/MachineTranslationModelConfig.java
@@ -0,0 +1,40 @@
+package cn.smartjavaai.translation.config;
+
+import cn.smartjavaai.common.enums.DeviceEnum;
+
+import cn.smartjavaai.translation.enums.MachineTranslationModeEnum;
+import lombok.Data;
+
+/**
+ * 机器翻译模型配置
+ * @author lwx
+ * @date 2025/6/05
+ */
+@Data
+public class MachineTranslationModelConfig {
+ /**
+ * 翻译模型
+ */
+ private MachineTranslationModeEnum modelEnum;
+
+ /**
+ * 设备类型
+ */
+ private DeviceEnum device;
+
+ /**
+ * 翻译模型路径
+ */
+ private String modelPath;
+ /**
+ * 翻译模型路径
+ */
+ private String modelName;
+ /**
+ * 翻译模型配置
+ */
+ private SearchConfig searchConfig;
+
+
+
+}
diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/config/SearchConfig.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/config/SearchConfig.java
new file mode 100644
index 0000000000000000000000000000000000000000..06799af45ecf730a6e23edd77cf7f39e44714d38
--- /dev/null
+++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/config/SearchConfig.java
@@ -0,0 +1,102 @@
+package cn.smartjavaai.translation.config;
+/**
+ * 配置信息
+ * @author lwx
+ * @date 2025/6/05
+ */
+public class SearchConfig {
+
+ private int maxSeqLength;
+ private long padTokenId;
+ private long eosTokenId;
+ private long bosTokenId;
+ private long decoderStartTokenId;
+ private float encoderRepetitionPenalty;
+ private long forcedBosTokenId;
+ private long srcLangId;
+ private float lengthPenalty;
+ public SearchConfig() {
+ this.maxSeqLength = 512;
+ this.eosTokenId = 2;
+ this.bosTokenId = 0;
+ this.padTokenId = 1;
+ this.decoderStartTokenId = 2;
+ this.encoderRepetitionPenalty = 1.0f;
+ this.srcLangId = 0;
+ this.forcedBosTokenId = 0;
+ this.lengthPenalty = 1.0f;
+
+ }
+
+ public long getSrcLangId() {
+ return srcLangId;
+ }
+
+ public void setSrcLangId(long srcLangId) {
+ this.srcLangId = srcLangId;
+ }
+
+ public void setEosTokenId(long eosTokenId) {
+ this.eosTokenId = eosTokenId;
+ }
+
+ public int getMaxSeqLength() {
+ return maxSeqLength;
+ }
+
+ public void setMaxSeqLength(int maxSeqLength) {
+ this.maxSeqLength = maxSeqLength;
+ }
+
+ public long getPadTokenId() {
+ return padTokenId;
+ }
+
+ public void setPadTokenId(long padTokenId) {
+ this.padTokenId = padTokenId;
+ }
+
+ public long getEosTokenId() {
+ return eosTokenId;
+ }
+
+ public long getDecoderStartTokenId() {
+ return decoderStartTokenId;
+ }
+
+ public void setDecoderStartTokenId(long decoderStartTokenId) {
+ this.decoderStartTokenId = decoderStartTokenId;
+ }
+
+ public float getEncoderRepetitionPenalty() {
+ return encoderRepetitionPenalty;
+ }
+
+ public void setEncoderRepetitionPenalty(float encoderRepetitionPenalty) {
+ this.encoderRepetitionPenalty = encoderRepetitionPenalty;
+ }
+
+ public long getForcedBosTokenId() {
+ return forcedBosTokenId;
+ }
+
+ public void setForcedBosTokenId(long forcedBosTokenId) {
+ this.forcedBosTokenId = forcedBosTokenId;
+ }
+
+ public float getLengthPenalty() {
+ return lengthPenalty;
+ }
+
+ public void setLengthPenalty(float lengthPenalty) {
+ this.lengthPenalty = lengthPenalty;
+ }
+
+ public long getBosTokenId() {
+ return bosTokenId;
+ }
+
+ public void setBosTokenId(long bosTokenId) {
+ this.bosTokenId = bosTokenId;
+ }
+}
\ No newline at end of file
diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/entity/BatchTensorList.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/entity/BatchTensorList.java
new file mode 100644
index 0000000000000000000000000000000000000000..87bea9b63a9977b7aa6e897bfaf0ce67957fe44b
--- /dev/null
+++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/entity/BatchTensorList.java
@@ -0,0 +1,174 @@
+/*
+ * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+package cn.smartjavaai.translation.entity;
+
+import ai.djl.ndarray.NDArray;
+import ai.djl.ndarray.NDList;
+
+/**
+ * BatchTensorList represents a search state, and the NDArrays inside are updated in each iteration
+ * of the autoregressive loop.
+ *
+ * It is a struct consisting of NDArrays, whose first dimension is batch, and also contains
+ * sequence dimension (whose position in tensor's shape is specified by seqDimOrder). The SeqBatcher
+ * batch operations will operate on these two dimensions.
+ */
+public abstract class BatchTensorList {
+ // [batch, seq_past]. seq-dim-size == |past_seq| + |inputIds|. Will grow.
+ private NDArray pastOutputIds;
+
+ // [batch, seq_past]
+ // The cache of past attentionMask. seq-dim-size == |past_seq| + |inputIds|. Will grow.
+ private NDArray pastAttentionMask;
+
+ // (k, v) * numLayer,
+ // kv: [batch, heads, seq_past, kvfeature]
+ // The cache of past sequence. seq-dim-size == |past_seq| + |inputIds|. Will grow.
+ private NDList pastKeyValues;
+
+ // Sequence dimension order among all dimensions for each element in the batch list.
+ private long[] seqDimOrder;
+
+ BatchTensorList() {}
+
+ /**
+ * Constructs a new {@code BatchTensorList} instance.
+ *
+ * @param list the NDList that contains the serialized version of the batch tensors
+ * @param seqDimOrder the sequence dimension order that specifies where the sequence dimension
+ * is in a tensor's shape
+ */
+ BatchTensorList(NDList list, long[] seqDimOrder) {
+ this.seqDimOrder = seqDimOrder;
+ pastOutputIds = list.get(0);
+ pastAttentionMask = list.get(1);
+ pastKeyValues = list.subNDList(2);
+ }
+
+ /**
+ * Constructs a new {@code BatchTensorList} instance.
+ *
+ * @param pastOutputIds past output token ids
+ * @param pastAttentionMask past attention mask
+ * @param pastKeyValues past kv cache
+ * @param seqDimOrder the sequence dimension order that specifies where the sequence dimension
+ * is in a tensor's shape
+ */
+ BatchTensorList(
+ NDArray pastOutputIds,
+ NDArray pastAttentionMask,
+ NDList pastKeyValues,
+ long[] seqDimOrder) {
+ this.pastKeyValues = pastKeyValues;
+ this.pastOutputIds = pastOutputIds;
+ this.pastAttentionMask = pastAttentionMask;
+ this.seqDimOrder = seqDimOrder;
+ }
+
+ /**
+ * Constructs a new {@code BatchTensorList} instance from the serialized version of the batch
+ * tensors.
+ *
+ *
The pastOutputIds has to be the first in the output list.
+ *
+ * @param inputList the serialized version of the batch tensors
+ * @param seqDimOrder the sequence dimension order that specifies where the sequence dimension
+ * is in a tensor's shape
+ * @return BatchTensorList
+ */
+ public abstract BatchTensorList fromList(NDList inputList, long[] seqDimOrder);
+
+ /**
+ * Returns the serialized version of the BatchTensorList. The pastOutputIds has to be the first
+ * in the output list.
+ *
+ * @return the NDList that contains the serialized BatchTensorList
+ */
+ public abstract NDList getList();
+
+ /**
+ * Returns the sequence dimension order which specifies where the sequence dimension is in a
+ * tensor's shape.
+ *
+ * @return the sequence dimension order which specifies where the sequence dimension is in a
+ * tensor's shape
+ */
+ public long[] getSeqDimOrder() {
+ return seqDimOrder;
+ }
+
+ /**
+ * Returns the value of the pastOutputIds.
+ *
+ * @return the value of pastOutputIds
+ */
+ public NDArray getPastOutputIds() {
+ return pastOutputIds;
+ }
+
+ /**
+ * Sets the past output token ids.
+ *
+ * @param pastOutputIds the past output token ids
+ */
+ public void setPastOutputIds(NDArray pastOutputIds) {
+ this.pastOutputIds = pastOutputIds;
+ }
+
+ /**
+ * Returns the value of the pastAttentionMask.
+ *
+ * @return the value of pastAttentionMask
+ */
+ public NDArray getPastAttentionMask() {
+ return pastAttentionMask;
+ }
+
+ /**
+ * Sets the attention mask.
+ *
+ * @param pastAttentionMask the attention mask
+ */
+ public void setPastAttentionMask(NDArray pastAttentionMask) {
+ this.pastAttentionMask = pastAttentionMask;
+ }
+
+ /**
+ * Returns the value of the pastKeyValues.
+ *
+ * @return the value of pastKeyValues
+ */
+ public NDList getPastKeyValues() {
+ return pastKeyValues;
+ }
+
+ /**
+ * Sets the kv cache.
+ *
+ * @param pastKeyValues the kv cache
+ */
+ public void setPastKeyValues(NDList pastKeyValues) {
+ this.pastKeyValues = pastKeyValues;
+ }
+
+ /**
+ * Sets the sequence dimension order which specifies where the sequence dimension is in a
+ * tensor's shape.
+ *
+ * @param seqDimOrder the sequence dimension order which specifies where the sequence dimension
+ * is in a tensor's shape
+ */
+ public void setSeqDimOrder(long[] seqDimOrder) {
+ this.seqDimOrder = seqDimOrder;
+ }
+}
\ No newline at end of file
diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/entity/CausalLMOutput.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/entity/CausalLMOutput.java
new file mode 100644
index 0000000000000000000000000000000000000000..acbc11de829def6de43c967ead4dc982b0267636
--- /dev/null
+++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/entity/CausalLMOutput.java
@@ -0,0 +1,33 @@
+package cn.smartjavaai.translation.entity;
+
+import ai.djl.ndarray.NDArray;
+import ai.djl.ndarray.NDList;
+
+/**
+ * 解码输出对象
+ *
+ * @author Calvin
+ * @mail 179209347@qq.com
+ * @website www.aias.top
+ */
+public class CausalLMOutput {
+ private NDArray logits;
+ private NDList pastKeyValuesList;
+
+ public CausalLMOutput(NDArray logits, NDList pastKeyValues) {
+ this.logits = logits;
+ this.pastKeyValuesList = pastKeyValues;
+ }
+
+ public NDArray getLogits() {
+ return logits;
+ }
+
+ public void setLogits(NDArray logits) {
+ this.logits = logits;
+ }
+
+ public NDList getPastKeyValuesList() {
+ return pastKeyValuesList;
+ }
+}
\ No newline at end of file
diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/entity/DirectionInfo.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/entity/DirectionInfo.java
new file mode 100644
index 0000000000000000000000000000000000000000..c0d7be281cb1138098fbda2209bf4fb95f7dd16a
--- /dev/null
+++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/entity/DirectionInfo.java
@@ -0,0 +1,41 @@
+package cn.smartjavaai.translation.entity;
+
+/**
+ * 方向检测结果
+ * @author Calvin
+ * @mail 179209347@qq.com
+ * @website www.aias.top
+ */
+public class DirectionInfo {
+
+ /**
+ * 方向 0 90 180 270
+ */
+ private String name;
+
+ /**
+ * 置信度
+ */
+ private Double prob;
+
+ public DirectionInfo(String name, Double prob) {
+ this.name = name;
+ this.prob = prob;
+ }
+
+ public String getName() {
+ return name;
+ }
+
+ public void setName(String name) {
+ this.name = name;
+ }
+
+ public Double getProb() {
+ return prob;
+ }
+
+ public void setProb(Double prob) {
+ this.prob = prob;
+ }
+}
diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/entity/GreedyBatchTensorList.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/entity/GreedyBatchTensorList.java
new file mode 100644
index 0000000000000000000000000000000000000000..2745c151c34f87f10b9c14a15fc07e166e33ab1a
--- /dev/null
+++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/entity/GreedyBatchTensorList.java
@@ -0,0 +1,84 @@
+package cn.smartjavaai.translation.entity;
+
+import ai.djl.ndarray.NDArray;
+import ai.djl.ndarray.NDList;
+
+/**
+ * 贪婪搜索张量对象列表
+ *
+ * @author Calvin
+ * @mail 179209347@qq.com
+ * @website www.aias.top
+ */
+public class GreedyBatchTensorList extends BatchTensorList {
+ // [batch, 1]
+ private NDArray nextInputIds;
+
+ private NDArray pastOutputIds;
+
+ private NDArray encoderHiddenStates;
+ private NDArray attentionMask;
+ private NDList pastKeyValues;
+
+ public GreedyBatchTensorList(
+ NDArray nextInputIds,
+ NDArray pastOutputIds,
+ NDList pastKeyValues,
+ NDArray encoderHiddenStates,
+ NDArray attentionMask) {
+ this.nextInputIds = nextInputIds;
+ this.pastKeyValues = pastKeyValues;
+ this.pastOutputIds = pastOutputIds;
+ this.attentionMask = attentionMask;
+ this.encoderHiddenStates = encoderHiddenStates;
+ }
+
+ public GreedyBatchTensorList() {}
+
+ public BatchTensorList fromList(NDList inputList, long[] seqDimOrder) {
+ return new GreedyBatchTensorList();
+ }
+
+ public NDList getList() {
+ return new NDList();
+ }
+
+ public NDArray getNextInputIds() {
+ return nextInputIds;
+ }
+
+ public void setNextInputIds(NDArray nextInputIds) {
+ this.nextInputIds = nextInputIds;
+ }
+ public NDArray getPastOutputIds() {
+ return pastOutputIds;
+ }
+
+ public void setPastOutputIds(NDArray pastOutputIds) {
+ this.pastOutputIds = pastOutputIds;
+ }
+
+ public NDList getPastKeyValues() {
+ return pastKeyValues;
+ }
+
+ public void setPastKeyValues(NDList pastKeyValues) {
+ this.pastKeyValues = pastKeyValues;
+ }
+
+ public NDArray getEncoderHiddenStates() {
+ return encoderHiddenStates;
+ }
+
+ public void setEncoderHiddenStates(NDArray encoderHiddenStates) {
+ this.encoderHiddenStates = encoderHiddenStates;
+ }
+
+ public NDArray getAttentionMask() {
+ return attentionMask;
+ }
+
+ public void setAttentionMask(NDArray attentionMask) {
+ this.attentionMask = attentionMask;
+ }
+}
\ No newline at end of file
diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/entity/IdCardInfo.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/entity/IdCardInfo.java
new file mode 100644
index 0000000000000000000000000000000000000000..8eae0b930efa2cd271fdad4836156b5af134144b
--- /dev/null
+++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/entity/IdCardInfo.java
@@ -0,0 +1,12 @@
+package cn.smartjavaai.translation.entity;
+
+/**
+ * 身份证信息
+ * @author dwj
+ * @date 2025/5/22
+ */
+public class IdCardInfo {
+
+
+
+}
diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/enums/MachineTranslationModeEnum.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/enums/MachineTranslationModeEnum.java
new file mode 100644
index 0000000000000000000000000000000000000000..1ffa2fdd1d04248322a754bf1fdd7b71d0378262
--- /dev/null
+++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/enums/MachineTranslationModeEnum.java
@@ -0,0 +1,23 @@
+package cn.smartjavaai.translation.enums;
+/**
+ * 机器翻译模型枚举
+ * @author lwx
+ * @date 2025/6/05
+ */
+public enum MachineTranslationModeEnum {
+
+ TRACED_TRANSLATION_CPU;
+
+ /**
+ * 根据名称获取枚举 (忽略大小写和下划线变体)
+ */
+ public static MachineTranslationModeEnum fromName(String name) {
+ String formatted = name.trim().toUpperCase().replaceAll("[-_]", "");
+ for (MachineTranslationModeEnum model : values()) {
+ if (model.name().replaceAll("_", "").equals(formatted)) {
+ return model;
+ }
+ }
+ throw new IllegalArgumentException("未知模型名称: " + name);
+ }
+}
diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/exception/TranslationException.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/exception/TranslationException.java
new file mode 100644
index 0000000000000000000000000000000000000000..f9fab822eef1cb472433b028c31c7f3c4504ddef
--- /dev/null
+++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/exception/TranslationException.java
@@ -0,0 +1,30 @@
+package cn.smartjavaai.translation.exception;
+
+/**
+ * 翻译异常
+ * @author lwx
+ * @date 2025/6/5
+ */
+public class TranslationException extends RuntimeException{
+
+ public TranslationException() {
+ super();
+ }
+
+ public TranslationException(String message, Throwable cause, boolean enableSuppression, boolean writableStackTrace) {
+ super(message, cause, enableSuppression, writableStackTrace);
+ }
+
+ public TranslationException(String message, Throwable cause) {
+ super(message, cause);
+ }
+
+ public TranslationException(String message) {
+ super(message);
+ }
+
+ public TranslationException(Throwable cause) {
+ super(cause);
+ }
+
+}
diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/factory/TranslationModelFactory.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/factory/TranslationModelFactory.java
new file mode 100644
index 0000000000000000000000000000000000000000..1cab02bd394a10e677b96e1d045953fbf028f4af
--- /dev/null
+++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/factory/TranslationModelFactory.java
@@ -0,0 +1,102 @@
+package cn.smartjavaai.translation.factory;
+
+import cn.smartjavaai.common.config.Config;
+
+
+import cn.smartjavaai.translation.config.MachineTranslationModelConfig;
+import cn.smartjavaai.translation.exception.TranslationException;
+import cn.smartjavaai.translation.model.common.TracedTranslationModel;
+import cn.smartjavaai.translation.model.common.TranslationCommonModel;
+import lombok.extern.slf4j.Slf4j;
+
+import java.util.Map;
+import java.util.Objects;
+import java.util.concurrent.ConcurrentHashMap;
+
+/**
+ * 机器翻译模型工厂
+ * @author dwj
+ */
+@Slf4j
+public class TranslationModelFactory {
+
+ // 使用 volatile 和双重检查锁定来确保线程安全的单例模式
+ private static volatile TranslationModelFactory instance;
+
+ private static final ConcurrentHashMap commonDetModelMap = new ConcurrentHashMap<>();
+
+
+
+ /**
+ * 检测模型注册表
+ */
+ private static final Map> commonDetRegistry =
+ new ConcurrentHashMap<>();
+
+
+ public static TranslationModelFactory getInstance() {
+ if (instance == null) {
+ synchronized (TranslationModelFactory.class) {
+ if (instance == null) {
+ instance = new TranslationModelFactory();
+ }
+ }
+ }
+ return instance;
+ }
+
+
+
+ /**
+ * 注册通用检测模型
+ * @param name
+ * @param clazz
+ */
+ private static void registerCommonDetModel(String name, Class extends TranslationCommonModel> clazz) {
+ commonDetRegistry.put(name.toLowerCase(), clazz);
+ }
+
+ /**
+ * 获取检测模型(通过配置)
+ * @param config
+ * @return
+ */
+ public TranslationCommonModel getDetModel(MachineTranslationModelConfig config) {
+ if(Objects.isNull(config) || Objects.isNull(config.getModelEnum())){
+ throw new TranslationException("未配置OCR模型");
+ }
+ return commonDetModelMap.computeIfAbsent(config.getModelEnum().name(), k -> {
+ return createCommonDetModel(config);
+ });
+ }
+
+
+ /**
+ * 创建OCR通用检测模型
+ * @param config
+ * @return
+ */
+ private TranslationCommonModel createCommonDetModel(MachineTranslationModelConfig config) {
+ Class> clazz = commonDetRegistry.get(config.getModelEnum().name().toLowerCase());
+ if(clazz == null){
+ throw new TranslationException("Unsupported model");
+ }
+ TranslationCommonModel model = null;
+ try {
+ model = (TranslationCommonModel) clazz.newInstance();
+ } catch (InstantiationException | IllegalAccessException e) {
+ throw new TranslationException(e);
+ }
+ model.loadModel(config);
+ return model;
+ }
+
+
+ // 初始化默认算法
+ static {
+ registerCommonDetModel("TRACED_TRANSLATION_CPU", TracedTranslationModel.class);
+
+ log.info("缓存目录:{}", Config.getCachePath());
+ }
+
+}
diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/common/TracedTranslationModel.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/common/TracedTranslationModel.java
new file mode 100644
index 0000000000000000000000000000000000000000..af01b8f6d06bc288425e67269b27b17e791f5e51
--- /dev/null
+++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/common/TracedTranslationModel.java
@@ -0,0 +1,254 @@
+package cn.smartjavaai.translation.model.common;
+
+import ai.djl.Device;
+import ai.djl.MalformedModelException;
+import ai.djl.engine.Engine;
+import ai.djl.huggingface.tokenizers.Encoding;
+import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
+import ai.djl.inference.Predictor;
+import ai.djl.modality.cv.Image;
+import ai.djl.modality.cv.output.DetectedObjects;
+import ai.djl.ndarray.NDArray;
+import ai.djl.ndarray.NDList;
+import ai.djl.ndarray.NDManager;
+import ai.djl.ndarray.index.NDIndex;
+import ai.djl.repository.zoo.Criteria;
+import ai.djl.repository.zoo.ModelNotFoundException;
+import ai.djl.repository.zoo.ModelZoo;
+import ai.djl.repository.zoo.ZooModel;
+import ai.djl.translate.NoopTranslator;
+import ai.djl.translate.TranslateException;
+import cn.smartjavaai.common.config.Config;
+import cn.smartjavaai.common.enums.DeviceEnum;
+import cn.smartjavaai.common.pool.PredictorFactory;
+import cn.smartjavaai.common.pool.ZooModelFactory;
+import cn.smartjavaai.translation.config.MachineTranslationModelConfig;
+import cn.smartjavaai.translation.config.SearchConfig;
+import cn.smartjavaai.translation.entity.CausalLMOutput;
+import cn.smartjavaai.translation.entity.GreedyBatchTensorList;
+import cn.smartjavaai.translation.enums.MachineTranslationModeEnum;
+import cn.smartjavaai.translation.exception.TranslationException;
+import cn.smartjavaai.translation.factory.TranslationModelFactory;
+import cn.smartjavaai.translation.model.common.translator.Decoder2Translator;
+import cn.smartjavaai.translation.model.common.translator.DecoderTranslator;
+import cn.smartjavaai.translation.model.common.translator.EncoderTranslator;
+import cn.smartjavaai.translation.utils.TokenUtils;
+import lombok.extern.slf4j.Slf4j;
+import org.apache.commons.lang3.StringUtils;
+import org.apache.commons.pool2.ObjectPool;
+import org.apache.commons.pool2.impl.GenericObjectPool;
+import org.junit.Test;
+
+import java.io.IOException;
+import java.nio.file.Paths;
+import java.util.Objects;
+
+/**
+ * 机器翻译通用检测模型
+ *
+ * @author lwx
+ * @date 2025/6/05
+ */
+@Slf4j
+public class TracedTranslationModel implements TranslationCommonModel {
+
+ private ObjectPool> detPredictorPool;
+ private ZooModel nllbModel;
+ private HuggingFaceTokenizer tokenizer;
+ private Predictor encoderPredictor;
+ private Predictor decoderPredictor;
+ private Predictor decoder2Predictor;
+ private SearchConfig searchConfig;
+ private MachineTranslationModelConfig config;
+ private NDManager manager;
+
+ @Test
+ public void detect(){
+ Config.setCachePath("E:\\ai\\models\\libs");
+ MachineTranslationModelConfig config = new MachineTranslationModelConfig();
+ SearchConfig searchConfig = new SearchConfig();
+ // 设置输出文字的最大长度
+ searchConfig.setMaxSeqLength(128);
+ // 设置源语言:中文 "zho_Hans": 256200
+ searchConfig.setSrcLangId(256200);
+ // 设置目标语言:英文 "eng_Latn": 256047
+ searchConfig.setForcedBosTokenId(256047);
+ config.setSearchConfig(searchConfig);
+ config.setDevice(DeviceEnum.CPU);
+ config.setModelEnum(MachineTranslationModeEnum.TRACED_TRANSLATION_CPU);
+ config.setModelPath("E:\\ai\\models\\nlp\\");
+ config.setModelName("traced_translation_cpu.pt");
+ // 输入文字
+ String input2 = "智利北部的丘基卡马塔矿是世界上最大的露天矿之一,长约4公里,宽3公里,深1公里。";
+ String input = "你好,欢迎使用SmartJavaAI!";
+ TranslationCommonModel detModel = TranslationModelFactory.getInstance().getDetModel(config);
+ // detModel.loadModel(config);
+ String translate = detModel.translate(input);
+ System.out.println("识别结果 translate"+translate);
+ }
+ @Override
+ public void loadModel(MachineTranslationModelConfig config) {
+ if (StringUtils.isBlank(config.getModelPath())) {
+ throw new TranslationException("modelPath is null");
+ }
+ Device device = null;
+ if (!Objects.isNull(config.getDevice())) {
+ device = config.getDevice() == DeviceEnum.CPU ? Device.cpu() : Device.gpu();
+ }
+ this.config = config;
+ this.searchConfig = config.getSearchConfig();
+ Criteria criteria =
+ Criteria.builder()
+ .setTypes(NDList.class, NDList.class)
+ .optModelPath(Paths.get(config.getModelPath()+config.getModelName()))
+ .optEngine("PyTorch")
+ .optDevice(device)
+ .optTranslator(new NoopTranslator())
+ .build();
+ try {
+ nllbModel = ModelZoo.loadModel(criteria);
+ // 创建池子:每个线程独享 Predictor
+ this.detPredictorPool = new GenericObjectPool<>(new ZooModelFactory<>(nllbModel));
+ log.info("当前设备: " + nllbModel.getNDManager().getDevice());
+ log.info("当前引擎: " + Engine.getInstance().getEngineName());
+ } catch (IOException | ModelNotFoundException | MalformedModelException e) {
+ throw new TranslationException("模型加载失败", e);
+ }
+ }
+
+ @Override
+ public String translate(String input) throws TranslationException {
+ ZooModel zooModel = null;
+ try (NDManager manager = NDManager.newBaseManager()) {
+ zooModel = detPredictorPool.borrowObject();
+ tokenizer = HuggingFaceTokenizer.newInstance(Paths.get(config.getModelPath() + "tokenizer.json"));
+ encoderPredictor = zooModel.newPredictor(new EncoderTranslator());
+ decoderPredictor = zooModel.newPredictor(new DecoderTranslator());
+ decoder2Predictor = zooModel.newPredictor(new Decoder2Translator());
+ this.manager=manager;
+ return translateLanguage(input);
+ } catch (Exception e) {
+ throw new TranslationException("翻译错误", e);
+ } finally {
+ if (zooModel != null) {
+ try {
+ detPredictorPool.returnObject(zooModel); //归还
+ } catch (Exception e) {
+ log.warn("归还Predictor失败", e);
+ try {
+ zooModel.close(); // 归还失败才销毁
+ } catch (Exception ex) {
+ log.error("关闭Predictor失败", ex);
+ }
+ }
+ }
+
+ }
+ }
+
+ private String translateLanguage(String input) throws TranslateException {
+
+ Encoding encoding = tokenizer.encode(input);
+ long[] ids = encoding.getIds();
+ // 1. Encoder
+ long[] inputIds = new long[ids.length];
+ // 设置源语言编码
+ inputIds[0] = searchConfig.getSrcLangId();
+ for (int i = 0; i < ids.length - 1; i++) {
+ inputIds[i + 1] = ids[i];
+ }
+
+ long[] attentionMask = encoding.getAttentionMask();
+ NDArray attentionMaskArray = manager.create(attentionMask).expandDims(0);
+
+ NDArray encoderHiddenStates = encoder(inputIds);
+
+ NDArray decoder_input_ids = manager.create(new long[]{searchConfig.getDecoderStartTokenId()}).reshape(1, 1);
+ NDList decoderInput = new NDList(decoder_input_ids, encoderHiddenStates, attentionMaskArray);
+
+ // 2. Initial Decoder
+ CausalLMOutput modelOutput = decoder(decoderInput);
+ modelOutput.getLogits().attach(manager);
+ modelOutput.getPastKeyValuesList().attach(manager);
+
+ GreedyBatchTensorList searchState =
+ new GreedyBatchTensorList(null, decoder_input_ids, modelOutput.getPastKeyValuesList(), encoderHiddenStates, attentionMaskArray);
+
+ while (true) {
+// try (NDScope ignore = new NDScope()) {
+ NDArray pastOutputIds = searchState.getPastOutputIds();
+
+ if (searchState.getNextInputIds() != null) {
+ decoderInput = new NDList(searchState.getNextInputIds(), searchState.getEncoderHiddenStates(), searchState.getAttentionMask());
+ decoderInput.addAll(searchState.getPastKeyValues());
+ // 3. Decoder loop
+ modelOutput = decoder2(decoderInput);
+ }
+
+ NDArray outputIds = greedyStepGen(searchConfig, pastOutputIds, modelOutput.getLogits());
+
+ searchState.setNextInputIds(outputIds);
+ pastOutputIds = pastOutputIds.concat(outputIds, 1);
+ searchState.setPastOutputIds(pastOutputIds);
+
+ searchState.setPastKeyValues(modelOutput.getPastKeyValuesList());
+
+ long id = searchState.getNextInputIds().toLongArray()[0];
+ if (searchConfig.getEosTokenId() == id) {
+ searchState.setNextInputIds(null);
+ break;
+ }
+ if (searchState.getPastOutputIds() != null && searchState.getPastOutputIds().getShape().get(1) + 1 >= searchConfig.getMaxSeqLength()) {
+ break;
+ }
+ }
+
+ if (searchState.getNextInputIds() == null) {
+ NDArray resultIds = searchState.getPastOutputIds();
+ String result = TokenUtils.decode(searchConfig, tokenizer, resultIds);
+ return result;
+ } else {
+ NDArray resultIds = searchState.getPastOutputIds(); // .concat(searchState.getNextInputIds(), 1)
+ String result = TokenUtils.decode(searchConfig, tokenizer, resultIds);
+ return result;
+ }
+
+ }
+
+ public NDArray greedyStepGen(SearchConfig config, NDArray pastOutputIds, NDArray next_token_scores) {
+ next_token_scores = next_token_scores.get(":, -1, :");
+
+ NDArray new_next_token_scores = manager.create(next_token_scores.getShape(), next_token_scores.getDataType());
+ next_token_scores.copyTo(new_next_token_scores);
+
+ // LogitsProcessor 1. ForcedBOSTokenLogitsProcessor
+ // 设置目标语言
+ long cur_len = pastOutputIds.getShape().getLastDimension();
+ if (cur_len == 1) {
+ long num_tokens = new_next_token_scores.getShape().getLastDimension();
+ for (long i = 0; i < num_tokens; i++) {
+ if (i != config.getForcedBosTokenId()) {
+ new_next_token_scores.set(new NDIndex(":," + i), Float.NEGATIVE_INFINITY);
+ }
+ }
+ new_next_token_scores.set(new NDIndex(":," + config.getForcedBosTokenId()), 0);
+ }
+
+ NDArray probs = new_next_token_scores.softmax(-1);
+ NDArray next_tokens = probs.argMax(-1);
+
+ return next_tokens.expandDims(0);
+ }
+
+ public NDArray encoder(long[] ids) throws TranslateException {
+ return encoderPredictor.predict(ids);
+ }
+
+ public CausalLMOutput decoder(NDList input) throws TranslateException {
+ return decoderPredictor.predict(input);
+ }
+
+ public CausalLMOutput decoder2(NDList input) throws TranslateException {
+ return decoder2Predictor.predict(input);
+ }
+}
diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/common/TranslationCommonModel.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/common/TranslationCommonModel.java
new file mode 100644
index 0000000000000000000000000000000000000000..513331d0181e776e3a177e56208375fc6dc4bb34
--- /dev/null
+++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/common/TranslationCommonModel.java
@@ -0,0 +1,32 @@
+package cn.smartjavaai.translation.model.common;
+
+import cn.smartjavaai.translation.config.MachineTranslationModelConfig;
+
+/**
+ * 机器翻译通用检测模型
+ * @author lwx
+ * @date 2025/6/05
+ */
+public interface TranslationCommonModel {
+
+ /**
+ * 加载模型
+ * @param config
+ */
+ void loadModel(MachineTranslationModelConfig config); // 加载模型
+
+ /**
+ * 机器翻译
+ * @param input 翻译内容
+ * @return
+ */
+ default String translate(String input) {
+ throw new UnsupportedOperationException("默认不支持该功能");
+ }
+
+
+
+
+
+
+}
diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/common/translator/Decoder2Translator.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/common/translator/Decoder2Translator.java
new file mode 100644
index 0000000000000000000000000000000000000000..2c32518841e998a0cf0c0efcf8498cf0578bb9d4
--- /dev/null
+++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/common/translator/Decoder2Translator.java
@@ -0,0 +1,46 @@
+package cn.smartjavaai.translation.model.common.translator;
+
+import ai.djl.ndarray.NDArray;
+import ai.djl.ndarray.NDList;
+import ai.djl.translate.NoBatchifyTranslator;
+import ai.djl.translate.TranslatorContext;
+import cn.smartjavaai.translation.entity.CausalLMOutput;
+
+
+/**
+ * 解碼器,參數支持 pastKeyValues
+ *
+ * @author Calvin
+ * @mail 179209347@qq.com
+ * @website www.aias.top
+ */
+public class Decoder2Translator implements NoBatchifyTranslator {
+ private String tupleName;
+
+ public Decoder2Translator() {
+ tupleName = "past_key_values(" + 12 + ',' + 4 + ')';
+ }
+
+ @Override
+ public NDList processInput(TranslatorContext ctx, NDList input) {
+
+ NDArray placeholder = ctx.getNDManager().create(0);
+ placeholder.setName("module_method:decoder2");
+
+ input.add(placeholder);
+
+ return input;
+ }
+
+ @Override
+ public CausalLMOutput processOutput(TranslatorContext ctx, NDList output) {
+ NDArray logitsOutput = output.get(0);
+ NDList pastKeyValuesOutput = output.subNDList(1, 12 * 4 + 1);
+
+ for (NDArray array : pastKeyValuesOutput) {
+ array.setName(tupleName);
+ }
+
+ return new CausalLMOutput(logitsOutput, pastKeyValuesOutput);
+ }
+}
\ No newline at end of file
diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/common/translator/DecoderTranslator.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/common/translator/DecoderTranslator.java
new file mode 100644
index 0000000000000000000000000000000000000000..981b70823e9057a50633fc2b022ea3f4cd894d59
--- /dev/null
+++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/common/translator/DecoderTranslator.java
@@ -0,0 +1,45 @@
+package cn.smartjavaai.translation.model.common.translator;
+
+import ai.djl.ndarray.NDArray;
+import ai.djl.ndarray.NDList;
+import ai.djl.translate.NoBatchifyTranslator;
+import ai.djl.translate.TranslatorContext;
+import cn.smartjavaai.translation.entity.CausalLMOutput;
+
+/**
+ * 解碼器,參數沒有 pastKeyValues
+ *
+ * @author Calvin
+ * @mail 179209347@qq.com
+ * @website www.aias.top
+ */
+public class DecoderTranslator implements NoBatchifyTranslator {
+ private String tupleName;
+
+ public DecoderTranslator() {
+ tupleName = "past_key_values(" + 12 + ',' + 4 + ')';
+ }
+
+ @Override
+ public NDList processInput(TranslatorContext ctx, NDList input) {
+
+ NDArray placeholder = ctx.getNDManager().create(0);
+ placeholder.setName("module_method:decoder");
+
+ input.add(placeholder);
+
+ return input;
+ }
+
+ @Override
+ public CausalLMOutput processOutput(TranslatorContext ctx, NDList output) {
+ NDArray logitsOutput = output.get(0);
+ NDList pastKeyValuesOutput = output.subNDList(1, 12 * 4 + 1);
+
+ for (NDArray array : pastKeyValuesOutput) {
+ array.setName(tupleName);
+ }
+
+ return new CausalLMOutput(logitsOutput, pastKeyValuesOutput);
+ }
+}
\ No newline at end of file
diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/common/translator/EncoderTranslator.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/common/translator/EncoderTranslator.java
new file mode 100644
index 0000000000000000000000000000000000000000..8c9961963cb5f4af64e05ffd7216a40f098afbd0
--- /dev/null
+++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/common/translator/EncoderTranslator.java
@@ -0,0 +1,49 @@
+package cn.smartjavaai.translation.model.common.translator;
+
+import ai.djl.ndarray.NDArray;
+import ai.djl.ndarray.NDList;
+import ai.djl.ndarray.NDManager;
+import ai.djl.translate.NoBatchifyTranslator;
+import ai.djl.translate.TranslatorContext;
+
+import java.util.Arrays;
+
+/**
+ * 编码器前后处理
+ *
+ * @author Calvin
+ * @mail 179209347@qq.com
+ * @website www.aias.top
+ */
+public class EncoderTranslator implements NoBatchifyTranslator {
+
+
+ public EncoderTranslator() {
+ }
+
+ @Override
+ public NDList processInput(TranslatorContext ctx, long[] input) throws Exception {
+ NDManager manager = ctx.getNDManager();
+
+ NDArray inputIdArray = manager.create(input).expandDims(0);
+ inputIdArray.setName("input_ids");
+
+ long[] attentionMask = new long[input.length];
+ Arrays.fill(attentionMask, 1);
+ NDArray attentionMaskArray = manager.create(attentionMask).expandDims(0);
+ attentionMaskArray.setName("attention_mask");
+
+ NDArray placeholder = ctx.getNDManager().create(0);
+ placeholder.setName("module_method:encoder");
+
+ return new NDList(inputIdArray, attentionMaskArray, placeholder);
+ }
+
+ @Override
+ public NDArray processOutput(TranslatorContext ctx, NDList list) {
+ NDArray encoderHiddenStates = list.get(0);
+ encoderHiddenStates.detach();
+ return encoderHiddenStates;
+ }
+
+}
\ No newline at end of file
diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/utils/TokenUtils.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/utils/TokenUtils.java
new file mode 100644
index 0000000000000000000000000000000000000000..582804ee9f8f78d17d3defea8b16c4c1cbcf87a8
--- /dev/null
+++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/utils/TokenUtils.java
@@ -0,0 +1,48 @@
+package cn.smartjavaai.translation.utils;
+
+import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
+import ai.djl.ndarray.NDArray;
+import cn.smartjavaai.translation.config.SearchConfig;
+
+
+import java.util.ArrayList;
+
+/**
+ *
+ *
+ * @author lwx
+ * @date 2025/4/22
+ */
+public final class TokenUtils {
+
+ private TokenUtils() {
+ }
+
+ /**
+ * 语言解码
+ *
+ * @param tokenizer
+ * @param output
+ * @return
+ */
+ public static String decode(SearchConfig config, HuggingFaceTokenizer tokenizer, NDArray output) {
+ long[] outputIds = output.toLongArray();
+ ArrayList outputIdsList = new ArrayList<>();
+
+ for (long id : outputIds) {
+ if (id == config.getEosTokenId() || id==config.getSrcLangId() || id==config.getForcedBosTokenId()) {
+ continue;
+ }
+ outputIdsList.add(id);
+ }
+
+ Long[] objArr = outputIdsList.toArray(new Long[0]);
+ long[] ids = new long[objArr.length];
+ for (int i = 0; i < objArr.length; i++) {
+ ids[i] = objArr[i];
+ }
+ String text = tokenizer.decode(ids);
+
+ return text;
+ }
+}
\ No newline at end of file