diff --git a/examples/src/main/java/smartai/examples/nlp/TextTranslationCPU.java b/examples/src/main/java/smartai/examples/nlp/TextTranslationCPU.java new file mode 100644 index 0000000000000000000000000000000000000000..67d5710b468f95d91eba523eaa0fb452fc817d71 --- /dev/null +++ b/examples/src/main/java/smartai/examples/nlp/TextTranslationCPU.java @@ -0,0 +1,57 @@ +package smartai.examples.nlb; + +import ai.djl.Device; +import ai.djl.ModelException; +import ai.djl.translate.TranslateException; +import smartai.examples.nlb.generate.SearchConfig; +import smartai.examples.nlb.model.NllbModel; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; + +/** + * 文本翻译,支持202种语言互译 + * + * @author Calvin + * @mail 179209347@qq.com + * @website www.aias.top + */ +public final class TextTranslationCPU { + + private static final Logger logger = LoggerFactory.getLogger(TextTranslationCPU.class); + + private TextTranslationCPU() { + } + + public static void main(String[] args) throws ModelException, IOException, + TranslateException { + + SearchConfig config = new SearchConfig(); + // 设置输出文字的最大长度 + config.setMaxSeqLength(128); + // 设置源语言:中文 "zho_Hans": 256200 + config.setSrcLangId(256200); + // 设置目标语言:英文 "eng_Latn": 256047 + config.setForcedBosTokenId(256047); + config.setForcedBosTokenId(256201); + + // 输入文字 + String input = "智利北部的丘基卡马塔矿是世界上最大的露天矿之一,长约4公里,宽3公里,深1公里。"; + + String modelPath = "E:\\ai\\models\\nlp\\"; + String cpuModelName = "traced_translation_cpu.pt"; + String gpuModelName = "traced_translation_gpu.pt"; + try (NllbModel nllbModel = new NllbModel(config, modelPath, cpuModelName, Device.cpu())) { + + System.setProperty("ai.djl.pytorch.graph_optimizer", "false"); + + // 运行模型,获取翻译结果 + String result = nllbModel.translate(input); + + logger.info("result========={}", result); + } finally { + System.clearProperty("ai.djl.pytorch.graph_optimizer"); + } + } +} \ No newline at end of file diff --git a/examples/src/main/java/smartai/examples/nlp/generate/BatchTensorList.java b/examples/src/main/java/smartai/examples/nlp/generate/BatchTensorList.java new file mode 100644 index 0000000000000000000000000000000000000000..03ed9576f91928d7a8fcbc47f9d97895e83fc772 --- /dev/null +++ b/examples/src/main/java/smartai/examples/nlp/generate/BatchTensorList.java @@ -0,0 +1,174 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package smartai.examples.nlb.generate; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; + +/** + * BatchTensorList represents a search state, and the NDArrays inside are updated in each iteration + * of the autoregressive loop. + * + *

It is a struct consisting of NDArrays, whose first dimension is batch, and also contains + * sequence dimension (whose position in tensor's shape is specified by seqDimOrder). The SeqBatcher + * batch operations will operate on these two dimensions. + */ +public abstract class BatchTensorList { + // [batch, seq_past]. seq-dim-size == |past_seq| + |inputIds|. Will grow. + private NDArray pastOutputIds; + + // [batch, seq_past] + // The cache of past attentionMask. seq-dim-size == |past_seq| + |inputIds|. Will grow. + private NDArray pastAttentionMask; + + // (k, v) * numLayer, + // kv: [batch, heads, seq_past, kvfeature] + // The cache of past sequence. seq-dim-size == |past_seq| + |inputIds|. Will grow. + private NDList pastKeyValues; + + // Sequence dimension order among all dimensions for each element in the batch list. + private long[] seqDimOrder; + + BatchTensorList() {} + + /** + * Constructs a new {@code BatchTensorList} instance. + * + * @param list the NDList that contains the serialized version of the batch tensors + * @param seqDimOrder the sequence dimension order that specifies where the sequence dimension + * is in a tensor's shape + */ + BatchTensorList(NDList list, long[] seqDimOrder) { + this.seqDimOrder = seqDimOrder; + pastOutputIds = list.get(0); + pastAttentionMask = list.get(1); + pastKeyValues = list.subNDList(2); + } + + /** + * Constructs a new {@code BatchTensorList} instance. + * + * @param pastOutputIds past output token ids + * @param pastAttentionMask past attention mask + * @param pastKeyValues past kv cache + * @param seqDimOrder the sequence dimension order that specifies where the sequence dimension + * is in a tensor's shape + */ + BatchTensorList( + NDArray pastOutputIds, + NDArray pastAttentionMask, + NDList pastKeyValues, + long[] seqDimOrder) { + this.pastKeyValues = pastKeyValues; + this.pastOutputIds = pastOutputIds; + this.pastAttentionMask = pastAttentionMask; + this.seqDimOrder = seqDimOrder; + } + + /** + * Constructs a new {@code BatchTensorList} instance from the serialized version of the batch + * tensors. + * + *

The pastOutputIds has to be the first in the output list. + * + * @param inputList the serialized version of the batch tensors + * @param seqDimOrder the sequence dimension order that specifies where the sequence dimension + * is in a tensor's shape + * @return BatchTensorList + */ + public abstract BatchTensorList fromList(NDList inputList, long[] seqDimOrder); + + /** + * Returns the serialized version of the BatchTensorList. The pastOutputIds has to be the first + * in the output list. + * + * @return the NDList that contains the serialized BatchTensorList + */ + public abstract NDList getList(); + + /** + * Returns the sequence dimension order which specifies where the sequence dimension is in a + * tensor's shape. + * + * @return the sequence dimension order which specifies where the sequence dimension is in a + * tensor's shape + */ + public long[] getSeqDimOrder() { + return seqDimOrder; + } + + /** + * Returns the value of the pastOutputIds. + * + * @return the value of pastOutputIds + */ + public NDArray getPastOutputIds() { + return pastOutputIds; + } + + /** + * Sets the past output token ids. + * + * @param pastOutputIds the past output token ids + */ + public void setPastOutputIds(NDArray pastOutputIds) { + this.pastOutputIds = pastOutputIds; + } + + /** + * Returns the value of the pastAttentionMask. + * + * @return the value of pastAttentionMask + */ + public NDArray getPastAttentionMask() { + return pastAttentionMask; + } + + /** + * Sets the attention mask. + * + * @param pastAttentionMask the attention mask + */ + public void setPastAttentionMask(NDArray pastAttentionMask) { + this.pastAttentionMask = pastAttentionMask; + } + + /** + * Returns the value of the pastKeyValues. + * + * @return the value of pastKeyValues + */ + public NDList getPastKeyValues() { + return pastKeyValues; + } + + /** + * Sets the kv cache. + * + * @param pastKeyValues the kv cache + */ + public void setPastKeyValues(NDList pastKeyValues) { + this.pastKeyValues = pastKeyValues; + } + + /** + * Sets the sequence dimension order which specifies where the sequence dimension is in a + * tensor's shape. + * + * @param seqDimOrder the sequence dimension order which specifies where the sequence dimension + * is in a tensor's shape + */ + public void setSeqDimOrder(long[] seqDimOrder) { + this.seqDimOrder = seqDimOrder; + } +} \ No newline at end of file diff --git a/examples/src/main/java/smartai/examples/nlp/generate/CausalLMOutput.java b/examples/src/main/java/smartai/examples/nlp/generate/CausalLMOutput.java new file mode 100644 index 0000000000000000000000000000000000000000..80b9aa4757849f0e9238b3721f791f72accdd4bc --- /dev/null +++ b/examples/src/main/java/smartai/examples/nlp/generate/CausalLMOutput.java @@ -0,0 +1,32 @@ +package smartai.examples.nlb.generate; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +/** + * 解码输出对象 + * + * @author Calvin + * @mail 179209347@qq.com + * @website www.aias.top + */ +public class CausalLMOutput { + private NDArray logits; + private NDList pastKeyValuesList; + + public CausalLMOutput(NDArray logits, NDList pastKeyValues) { + this.logits = logits; + this.pastKeyValuesList = pastKeyValues; + } + + public NDArray getLogits() { + return logits; + } + + public void setLogits(NDArray logits) { + this.logits = logits; + } + + public NDList getPastKeyValuesList() { + return pastKeyValuesList; + } +} \ No newline at end of file diff --git a/examples/src/main/java/smartai/examples/nlp/generate/GreedyBatchTensorList.java b/examples/src/main/java/smartai/examples/nlp/generate/GreedyBatchTensorList.java new file mode 100644 index 0000000000000000000000000000000000000000..388de1455e0f304cdaea932568341ddfb11ea321 --- /dev/null +++ b/examples/src/main/java/smartai/examples/nlp/generate/GreedyBatchTensorList.java @@ -0,0 +1,84 @@ +package smartai.examples.nlb.generate; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; + +/** + * 贪婪搜索张量对象列表 + * + * @author Calvin + * @mail 179209347@qq.com + * @website www.aias.top + */ +public class GreedyBatchTensorList extends BatchTensorList { + // [batch, 1] + private NDArray nextInputIds; + + private NDArray pastOutputIds; + + private NDArray encoderHiddenStates; + private NDArray attentionMask; + private NDList pastKeyValues; + + public GreedyBatchTensorList( + NDArray nextInputIds, + NDArray pastOutputIds, + NDList pastKeyValues, + NDArray encoderHiddenStates, + NDArray attentionMask) { + this.nextInputIds = nextInputIds; + this.pastKeyValues = pastKeyValues; + this.pastOutputIds = pastOutputIds; + this.attentionMask = attentionMask; + this.encoderHiddenStates = encoderHiddenStates; + } + + public GreedyBatchTensorList() {} + + public BatchTensorList fromList(NDList inputList, long[] seqDimOrder) { + return new GreedyBatchTensorList(); + } + + public NDList getList() { + return new NDList(); + } + + public NDArray getNextInputIds() { + return nextInputIds; + } + + public void setNextInputIds(NDArray nextInputIds) { + this.nextInputIds = nextInputIds; + } + public NDArray getPastOutputIds() { + return pastOutputIds; + } + + public void setPastOutputIds(NDArray pastOutputIds) { + this.pastOutputIds = pastOutputIds; + } + + public NDList getPastKeyValues() { + return pastKeyValues; + } + + public void setPastKeyValues(NDList pastKeyValues) { + this.pastKeyValues = pastKeyValues; + } + + public NDArray getEncoderHiddenStates() { + return encoderHiddenStates; + } + + public void setEncoderHiddenStates(NDArray encoderHiddenStates) { + this.encoderHiddenStates = encoderHiddenStates; + } + + public NDArray getAttentionMask() { + return attentionMask; + } + + public void setAttentionMask(NDArray attentionMask) { + this.attentionMask = attentionMask; + } +} \ No newline at end of file diff --git a/examples/src/main/java/smartai/examples/nlp/generate/SearchConfig.java b/examples/src/main/java/smartai/examples/nlp/generate/SearchConfig.java new file mode 100644 index 0000000000000000000000000000000000000000..a040b504641110ca2f653909da37c8909c2f4cc3 --- /dev/null +++ b/examples/src/main/java/smartai/examples/nlp/generate/SearchConfig.java @@ -0,0 +1,104 @@ +package smartai.examples.nlb.generate; +/** + * 配置信息 + * + * @author Calvin + * @mail 179209347@qq.com + * @website www.aias.top + */ +public class SearchConfig { + + private int maxSeqLength; + private long padTokenId; + private long eosTokenId; + private long bosTokenId; + private long decoderStartTokenId; + private float encoderRepetitionPenalty; + private long forcedBosTokenId; + private long srcLangId; + private float lengthPenalty; + public SearchConfig() { + this.maxSeqLength = 512; + this.eosTokenId = 2; + this.bosTokenId = 0; + this.padTokenId = 1; + this.decoderStartTokenId = 2; + this.encoderRepetitionPenalty = 1.0f; + this.srcLangId = 0; + this.forcedBosTokenId = 0; + this.lengthPenalty = 1.0f; + + } + + public long getSrcLangId() { + return srcLangId; + } + + public void setSrcLangId(long srcLangId) { + this.srcLangId = srcLangId; + } + + public void setEosTokenId(long eosTokenId) { + this.eosTokenId = eosTokenId; + } + + public int getMaxSeqLength() { + return maxSeqLength; + } + + public void setMaxSeqLength(int maxSeqLength) { + this.maxSeqLength = maxSeqLength; + } + + public long getPadTokenId() { + return padTokenId; + } + + public void setPadTokenId(long padTokenId) { + this.padTokenId = padTokenId; + } + + public long getEosTokenId() { + return eosTokenId; + } + + public long getDecoderStartTokenId() { + return decoderStartTokenId; + } + + public void setDecoderStartTokenId(long decoderStartTokenId) { + this.decoderStartTokenId = decoderStartTokenId; + } + + public float getEncoderRepetitionPenalty() { + return encoderRepetitionPenalty; + } + + public void setEncoderRepetitionPenalty(float encoderRepetitionPenalty) { + this.encoderRepetitionPenalty = encoderRepetitionPenalty; + } + + public long getForcedBosTokenId() { + return forcedBosTokenId; + } + + public void setForcedBosTokenId(long forcedBosTokenId) { + this.forcedBosTokenId = forcedBosTokenId; + } + + public float getLengthPenalty() { + return lengthPenalty; + } + + public void setLengthPenalty(float lengthPenalty) { + this.lengthPenalty = lengthPenalty; + } + + public long getBosTokenId() { + return bosTokenId; + } + + public void setBosTokenId(long bosTokenId) { + this.bosTokenId = bosTokenId; + } +} \ No newline at end of file diff --git a/examples/src/main/java/smartai/examples/nlp/model/Decoder2Translator.java b/examples/src/main/java/smartai/examples/nlp/model/Decoder2Translator.java new file mode 100644 index 0000000000000000000000000000000000000000..26255477f5ec3384d88b4f6cef3705e66d919f9b --- /dev/null +++ b/examples/src/main/java/smartai/examples/nlp/model/Decoder2Translator.java @@ -0,0 +1,45 @@ +package smartai.examples.nlb.model; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.translate.NoBatchifyTranslator; +import ai.djl.translate.TranslatorContext; +import smartai.examples.nlb.generate.CausalLMOutput; + +/** + * 解碼器,參數支持 pastKeyValues + * + * @author Calvin + * @mail 179209347@qq.com + * @website www.aias.top + */ +public class Decoder2Translator implements NoBatchifyTranslator { + private String tupleName; + + public Decoder2Translator() { + tupleName = "past_key_values(" + 12 + ',' + 4 + ')'; + } + + @Override + public NDList processInput(TranslatorContext ctx, NDList input) { + + NDArray placeholder = ctx.getNDManager().create(0); + placeholder.setName("module_method:decoder2"); + + input.add(placeholder); + + return input; + } + + @Override + public CausalLMOutput processOutput(TranslatorContext ctx, NDList output) { + NDArray logitsOutput = output.get(0); + NDList pastKeyValuesOutput = output.subNDList(1, 12 * 4 + 1); + + for (NDArray array : pastKeyValuesOutput) { + array.setName(tupleName); + } + + return new CausalLMOutput(logitsOutput, pastKeyValuesOutput); + } +} \ No newline at end of file diff --git a/examples/src/main/java/smartai/examples/nlp/model/DecoderTranslator.java b/examples/src/main/java/smartai/examples/nlp/model/DecoderTranslator.java new file mode 100644 index 0000000000000000000000000000000000000000..9211a8282eac8d3fbed6be051cfcf4e64dce9069 --- /dev/null +++ b/examples/src/main/java/smartai/examples/nlp/model/DecoderTranslator.java @@ -0,0 +1,44 @@ +package smartai.examples.nlb.model; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.translate.NoBatchifyTranslator; +import ai.djl.translate.TranslatorContext; +import smartai.examples.nlb.generate.CausalLMOutput; +/** + * 解碼器,參數沒有 pastKeyValues + * + * @author Calvin + * @mail 179209347@qq.com + * @website www.aias.top + */ +public class DecoderTranslator implements NoBatchifyTranslator { + private String tupleName; + + public DecoderTranslator() { + tupleName = "past_key_values(" + 12 + ',' + 4 + ')'; + } + + @Override + public NDList processInput(TranslatorContext ctx, NDList input) { + + NDArray placeholder = ctx.getNDManager().create(0); + placeholder.setName("module_method:decoder"); + + input.add(placeholder); + + return input; + } + + @Override + public CausalLMOutput processOutput(TranslatorContext ctx, NDList output) { + NDArray logitsOutput = output.get(0); + NDList pastKeyValuesOutput = output.subNDList(1, 12 * 4 + 1); + + for (NDArray array : pastKeyValuesOutput) { + array.setName(tupleName); + } + + return new CausalLMOutput(logitsOutput, pastKeyValuesOutput); + } +} \ No newline at end of file diff --git a/examples/src/main/java/smartai/examples/nlp/model/EncoderTranslator.java b/examples/src/main/java/smartai/examples/nlp/model/EncoderTranslator.java new file mode 100644 index 0000000000000000000000000000000000000000..8ba292ee71311e5af70c4fc7f65eea72f7a12a72 --- /dev/null +++ b/examples/src/main/java/smartai/examples/nlp/model/EncoderTranslator.java @@ -0,0 +1,49 @@ +package smartai.examples.nlb.model; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.translate.NoBatchifyTranslator; +import ai.djl.translate.TranslatorContext; + +import java.util.Arrays; + +/** + * 编码器前后处理 + * + * @author Calvin + * @mail 179209347@qq.com + * @website www.aias.top + */ +public class EncoderTranslator implements NoBatchifyTranslator { + + + public EncoderTranslator() { + } + + @Override + public NDList processInput(TranslatorContext ctx, long[] input) throws Exception { + NDManager manager = ctx.getNDManager(); + + NDArray inputIdArray = manager.create(input).expandDims(0); + inputIdArray.setName("input_ids"); + + long[] attentionMask = new long[input.length]; + Arrays.fill(attentionMask, 1); + NDArray attentionMaskArray = manager.create(attentionMask).expandDims(0); + attentionMaskArray.setName("attention_mask"); + + NDArray placeholder = ctx.getNDManager().create(0); + placeholder.setName("module_method:encoder"); + + return new NDList(inputIdArray, attentionMaskArray, placeholder); + } + + @Override + public NDArray processOutput(TranslatorContext ctx, NDList list) { + NDArray encoderHiddenStates = list.get(0); + encoderHiddenStates.detach(); + return encoderHiddenStates; + } + +} \ No newline at end of file diff --git a/examples/src/main/java/smartai/examples/nlp/model/NllbModel.java b/examples/src/main/java/smartai/examples/nlp/model/NllbModel.java new file mode 100644 index 0000000000000000000000000000000000000000..a987ab26fb334c2bf9c1cff0a7178ee7d4a0f392 --- /dev/null +++ b/examples/src/main/java/smartai/examples/nlp/model/NllbModel.java @@ -0,0 +1,183 @@ +package smartai.examples.nlb.model; + +import ai.djl.Device; +import ai.djl.ModelException; +import ai.djl.huggingface.tokenizers.Encoding; +import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; +import ai.djl.inference.Predictor; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.index.NDIndex; +import ai.djl.repository.zoo.Criteria; +import ai.djl.repository.zoo.ZooModel; +import ai.djl.translate.NoopTranslator; +import ai.djl.translate.TranslateException; +import smartai.examples.nlb.generate.CausalLMOutput; +import smartai.examples.nlb.generate.GreedyBatchTensorList; +import smartai.examples.nlb.generate.SearchConfig; +import smartai.examples.nlb.tokenizer.TokenUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.nio.file.Paths; +import java.util.Arrays; +/** + * 模型载入及推理 + * + * @author Calvin + * @mail 179209347@qq.com + * @website www.aias.top + */ +public class NllbModel implements AutoCloseable { + private static final Logger logger = LoggerFactory.getLogger(NllbModel.class); + private SearchConfig config; + private ZooModel nllbModel; + private HuggingFaceTokenizer tokenizer; + private Predictor encoderPredictor; + private Predictor decoderPredictor; + private Predictor decoder2Predictor; + private NDManager manager; + + public NllbModel(SearchConfig config, String modelPath, String modelName, Device device) throws ModelException, IOException { + this.config = config; + Criteria criteria = + Criteria.builder() + .setTypes(NDList.class, NDList.class) + .optModelPath(Paths.get(modelPath + modelName)) + .optEngine("PyTorch") + .optDevice(device) + .optTranslator(new NoopTranslator()) + .build(); + + manager = NDManager.newBaseManager(device); + nllbModel = criteria.loadModel(); + tokenizer = HuggingFaceTokenizer.newInstance(Paths.get(modelPath + "tokenizer.json")); + encoderPredictor = nllbModel.newPredictor(new EncoderTranslator()); + decoderPredictor = nllbModel.newPredictor(new DecoderTranslator()); + decoder2Predictor = nllbModel.newPredictor(new Decoder2Translator()); + } + + public NDArray encoder(long[] ids) throws TranslateException { + return encoderPredictor.predict(ids); + } + + public CausalLMOutput decoder(NDList input) throws TranslateException { + return decoderPredictor.predict(input); + } + + public CausalLMOutput decoder2(NDList input) throws TranslateException { + return decoder2Predictor.predict(input); + } + + @Override + public void close() { + encoderPredictor.close(); + decoderPredictor.close(); + decoder2Predictor.close(); + nllbModel.close(); + manager.close(); + tokenizer.close(); + } + + public String translate(String input) throws TranslateException { + + Encoding encoding = tokenizer.encode(input); + long[] ids = encoding.getIds(); + // 1. Encoder + long[] inputIds = new long[ids.length]; + // 设置源语言编码 + inputIds[0] = config.getSrcLangId(); + for (int i = 0; i < ids.length - 1; i++) { + inputIds[i + 1] = ids[i]; + } + logger.info("inputIds: " + Arrays.toString(inputIds)); + long[] attentionMask = encoding.getAttentionMask(); + NDArray attentionMaskArray = manager.create(attentionMask).expandDims(0); + + NDArray encoderHiddenStates = encoder(inputIds); + + NDArray decoder_input_ids = manager.create(new long[]{config.getDecoderStartTokenId()}).reshape(1, 1); + NDList decoderInput = new NDList(decoder_input_ids, encoderHiddenStates, attentionMaskArray); + + // 2. Initial Decoder + CausalLMOutput modelOutput = decoder(decoderInput); + modelOutput.getLogits().attach(manager); + modelOutput.getPastKeyValuesList().attach(manager); + + GreedyBatchTensorList searchState = + new GreedyBatchTensorList(null, decoder_input_ids, modelOutput.getPastKeyValuesList(), encoderHiddenStates, attentionMaskArray); + + while (true) { +// try (NDScope ignore = new NDScope()) { + NDArray pastOutputIds = searchState.getPastOutputIds(); + + if (searchState.getNextInputIds() != null) { + decoderInput = new NDList(searchState.getNextInputIds(), searchState.getEncoderHiddenStates(), searchState.getAttentionMask()); + decoderInput.addAll(searchState.getPastKeyValues()); + // 3. Decoder loop + modelOutput = decoder2(decoderInput); + } + + NDArray outputIds = greedyStepGen(config, pastOutputIds, modelOutput.getLogits()); + + searchState.setNextInputIds(outputIds); + pastOutputIds = pastOutputIds.concat(outputIds, 1); + searchState.setPastOutputIds(pastOutputIds); + + searchState.setPastKeyValues(modelOutput.getPastKeyValuesList()); + + // memory management +// NDScope.unregister(outputIds, pastOutputIds); +// } + + // Termination Criteria + long id = searchState.getNextInputIds().toLongArray()[0]; + if (config.getEosTokenId() == id) { + searchState.setNextInputIds(null); + break; + } + if (searchState.getPastOutputIds() != null && searchState.getPastOutputIds().getShape().get(1) + 1 >= config.getMaxSeqLength()) { + break; + } + } + + if (searchState.getNextInputIds() == null) { + NDArray resultIds = searchState.getPastOutputIds(); + String result = TokenUtils.decode(config, tokenizer, resultIds); + return result; + } else { + NDArray resultIds = searchState.getPastOutputIds(); // .concat(searchState.getNextInputIds(), 1) + String result = TokenUtils.decode(config, tokenizer, resultIds); + return result; + } + + } + + public NDArray greedyStepGen(SearchConfig config, NDArray pastOutputIds, NDArray next_token_scores) { + next_token_scores = next_token_scores.get(":, -1, :"); + + NDArray new_next_token_scores = manager.create(next_token_scores.getShape(), next_token_scores.getDataType()); + next_token_scores.copyTo(new_next_token_scores); + + // LogitsProcessor 1. ForcedBOSTokenLogitsProcessor + // 设置目标语言 + long cur_len = pastOutputIds.getShape().getLastDimension(); + if (cur_len == 1) { + long num_tokens = new_next_token_scores.getShape().getLastDimension(); + for (long i = 0; i < num_tokens; i++) { + if (i != config.getForcedBosTokenId()) { + new_next_token_scores.set(new NDIndex(":," + i), Float.NEGATIVE_INFINITY); + } + } + new_next_token_scores.set(new NDIndex(":," + config.getForcedBosTokenId()), 0); + } + + NDArray probs = new_next_token_scores.softmax(-1); + NDArray next_tokens = probs.argMax(-1); + + return next_tokens.expandDims(0); + } + +} diff --git a/examples/src/main/java/smartai/examples/objectdetection/ObjectDetection.java b/examples/src/main/java/smartai/examples/objectdetection/ObjectDetection.java index 4fc321a95c5ce989abb4f51662bb79a8156d65f0..f8ee492c954e24bd5feadd29f58e285ad2e7a504 100644 --- a/examples/src/main/java/smartai/examples/objectdetection/ObjectDetection.java +++ b/examples/src/main/java/smartai/examples/objectdetection/ObjectDetection.java @@ -111,18 +111,21 @@ public class ObjectDetection { /** - * 使用yolo官方模型检测 + * 使用yolo官方模型检测物品识别 */ @Test public void objectDetectionWithOfficialModel(){ DetectorModelConfig config = new DetectorModelConfig(); + config.setThreshold(0.3f); //也支持YoloV8:YOLOV8_OFFICIAL 模型可以从文档中提供的地址下载 config.setModelEnum(DetectorModelEnum.YOLOV12_OFFICIAL);//检测模型,目前支持19种模型 // 指定模型路径,需要更改为自己的模型路径 - config.setModelPath("/Users/xxx/Documents/develop/face_model/yolov12n.onnx"); + config.setModelPath("E:\\ai\\models\\yolo12m\\yolov12m.onnx"); DetectorModel detectorModel = ObjectDetectionModelFactory.getInstance().getModel(config); //一定要将yolo官方的类别文件:synset.txt(文档中下载)放在模型同目录下,否则报错 - detectorModel.detectAndDraw("src/main/resources/object_detection.jpg","output/object_detection_detected.png"); + DetectionResponse detect = detectorModel.detect("E:\\ai\\testimage\\1.jpg"); + log.info("目标检测结果:{}", JSONObject.toJSONString(detect)); + detectorModel.detectAndDraw("E:\\ai\\testimage\\1.jpg","E:\\ai\\outimage\\11.png"); } /** diff --git a/pom.xml b/pom.xml index c5aeed1bda50fd6b8bd405d446456b24a346e86d..ef044b7ed566a7d07cadd80bc4e930c6f92e58e4 100644 --- a/pom.xml +++ b/pom.xml @@ -11,6 +11,7 @@ SmartJavaAI smartjavaai-face + smartjavaai-translate smartjavaai-common smartjavaai-objectdetection smartjavaai-all @@ -52,14 +53,16 @@ ai.djl model-zoo + + ai.djl.huggingface + tokenizers + ${djl.version} + - - - - + diff --git a/smartjavaai-common/src/main/java/cn/smartjavaai/common/pool/ZooModelFactory.java b/smartjavaai-common/src/main/java/cn/smartjavaai/common/pool/ZooModelFactory.java new file mode 100644 index 0000000000000000000000000000000000000000..4c8bdaf6e870ed494f1df830b0bee8597d73eef4 --- /dev/null +++ b/smartjavaai-common/src/main/java/cn/smartjavaai/common/pool/ZooModelFactory.java @@ -0,0 +1,35 @@ +package cn.smartjavaai.common.pool; + +import ai.djl.inference.Predictor; +import ai.djl.repository.zoo.ZooModel; +import org.apache.commons.pool2.BasePooledObjectFactory; +import org.apache.commons.pool2.PooledObject; +import org.apache.commons.pool2.impl.DefaultPooledObject; + +/** + * ZooModel 工厂类 + * @author lwx + * @date 2025/6/06 + */ +public class ZooModelFactory extends BasePooledObjectFactory> { + private final ZooModel model; + + public ZooModelFactory(ZooModel model) { + this.model = model; + } + + @Override + public ZooModel create() { + return model; + } + + @Override + public PooledObject> wrap(ZooModel predictor) { + return new DefaultPooledObject<>(predictor); + } + + @Override + public void destroyObject(PooledObject> p) { + p.getObject().close(); + } +} diff --git a/smartjavaai-objectdetection/src/main/java/cn/smartjavaai/objectdetection/enums/DetectorModelEnum.java b/smartjavaai-objectdetection/src/main/java/cn/smartjavaai/objectdetection/enums/DetectorModelEnum.java index ad69fb4dcf8cbe5edce0e03f90528ed8e223d118..4fb6d8a282272fc3afcb5cfe6d7650696b798553 100644 --- a/smartjavaai-objectdetection/src/main/java/cn/smartjavaai/objectdetection/enums/DetectorModelEnum.java +++ b/smartjavaai-objectdetection/src/main/java/cn/smartjavaai/objectdetection/enums/DetectorModelEnum.java @@ -9,7 +9,7 @@ public enum DetectorModelEnum { // resnet50 系列 SSD_300_RESNET50("ai.djl.pytorch/ssd/0.0.1/ssd_300_resnet50"), - SSD_512_RESNET50_V1_VOC("ai.djl.mxnet/ssd/0.0.1/ssd_512_resnet50_v1_voc"), + SSD_512_RESNET50_V1_VOC("ai.djl./ssd/0.0.1/ssd_512_resnet50_v1_voc"), // vgg16 系列 SSD_512_VGG16_ATROUS_COCO("ai.djl.mxnet/ssd/0.0.1/ssd_512_vgg16_atrous_coco"), diff --git a/smartjavaai-translate/pom.xml b/smartjavaai-translate/pom.xml new file mode 100644 index 0000000000000000000000000000000000000000..a957198a2d21562a38ff8194ffaa9000624b28e0 --- /dev/null +++ b/smartjavaai-translate/pom.xml @@ -0,0 +1,134 @@ + + + 4.0.0 + + cn.smartjavaai + smartjavaai-parent + 1.0.15 + + + smartjavaai-translate + + + + + cn.smartjavaai + smartjavaai-common + ${project.version} + + + junit + junit + 4.13.1 + compile + + + + 1.0.15 + smartjavaai-ocr + SmartJavaAI + https://github.com/geekwenjie/SmartJavaAI + + + MIT License + https://opensource.org/licenses/MIT + + + + + + + org.sonatype.central + central-publishing-maven-plugin + 0.4.0 + true + + dengwenjie + true + ${project.groupId}:${project.artifactId}:${project.version} + + + + + org.apache.maven.plugins + maven-source-plugin + 3.1.0 + + + attach-sources + + jar-no-fork + + + + + + org.apache.maven.plugins + maven-javadoc-plugin + 3.1.0 + + + none + + -Xdoclint:none + + + + + attach-javadocs + + jar + + + + + + org.apache.maven.plugins + maven-gpg-plugin + 3.1.0 + + + sign-artifacts + verify + + sign + + + + + + + + + + scm:git:git://github.com/geekwenjie/SmartJavaAI.git + scm:git:ssh://github.com/geekwenjie/SmartJavaAI.git + http://github.com/geekwenjie/SmartJavaAI/tree/master + + + + + + dengwenjie + https://s01.oss.sonatype.org/content/repositories/snapshots + + + dengwenjie + https://s01.oss.sonatype.org/service/local/staging/deploy/maven2/ + + + + + + dengwenjie + 775747758@qq.com + + Project Manager + Architect + + + + + diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/config/MachineTranslationModelConfig.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/config/MachineTranslationModelConfig.java new file mode 100644 index 0000000000000000000000000000000000000000..1b9ee4065c044289b797bda0a3418f59fb76799b --- /dev/null +++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/config/MachineTranslationModelConfig.java @@ -0,0 +1,40 @@ +package cn.smartjavaai.translation.config; + +import cn.smartjavaai.common.enums.DeviceEnum; + +import cn.smartjavaai.translation.enums.MachineTranslationModeEnum; +import lombok.Data; + +/** + * 机器翻译模型配置 + * @author lwx + * @date 2025/6/05 + */ +@Data +public class MachineTranslationModelConfig { + /** + * 翻译模型 + */ + private MachineTranslationModeEnum modelEnum; + + /** + * 设备类型 + */ + private DeviceEnum device; + + /** + * 翻译模型路径 + */ + private String modelPath; + /** + * 翻译模型路径 + */ + private String modelName; + /** + * 翻译模型配置 + */ + private SearchConfig searchConfig; + + + +} diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/config/SearchConfig.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/config/SearchConfig.java new file mode 100644 index 0000000000000000000000000000000000000000..06799af45ecf730a6e23edd77cf7f39e44714d38 --- /dev/null +++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/config/SearchConfig.java @@ -0,0 +1,102 @@ +package cn.smartjavaai.translation.config; +/** + * 配置信息 + * @author lwx + * @date 2025/6/05 + */ +public class SearchConfig { + + private int maxSeqLength; + private long padTokenId; + private long eosTokenId; + private long bosTokenId; + private long decoderStartTokenId; + private float encoderRepetitionPenalty; + private long forcedBosTokenId; + private long srcLangId; + private float lengthPenalty; + public SearchConfig() { + this.maxSeqLength = 512; + this.eosTokenId = 2; + this.bosTokenId = 0; + this.padTokenId = 1; + this.decoderStartTokenId = 2; + this.encoderRepetitionPenalty = 1.0f; + this.srcLangId = 0; + this.forcedBosTokenId = 0; + this.lengthPenalty = 1.0f; + + } + + public long getSrcLangId() { + return srcLangId; + } + + public void setSrcLangId(long srcLangId) { + this.srcLangId = srcLangId; + } + + public void setEosTokenId(long eosTokenId) { + this.eosTokenId = eosTokenId; + } + + public int getMaxSeqLength() { + return maxSeqLength; + } + + public void setMaxSeqLength(int maxSeqLength) { + this.maxSeqLength = maxSeqLength; + } + + public long getPadTokenId() { + return padTokenId; + } + + public void setPadTokenId(long padTokenId) { + this.padTokenId = padTokenId; + } + + public long getEosTokenId() { + return eosTokenId; + } + + public long getDecoderStartTokenId() { + return decoderStartTokenId; + } + + public void setDecoderStartTokenId(long decoderStartTokenId) { + this.decoderStartTokenId = decoderStartTokenId; + } + + public float getEncoderRepetitionPenalty() { + return encoderRepetitionPenalty; + } + + public void setEncoderRepetitionPenalty(float encoderRepetitionPenalty) { + this.encoderRepetitionPenalty = encoderRepetitionPenalty; + } + + public long getForcedBosTokenId() { + return forcedBosTokenId; + } + + public void setForcedBosTokenId(long forcedBosTokenId) { + this.forcedBosTokenId = forcedBosTokenId; + } + + public float getLengthPenalty() { + return lengthPenalty; + } + + public void setLengthPenalty(float lengthPenalty) { + this.lengthPenalty = lengthPenalty; + } + + public long getBosTokenId() { + return bosTokenId; + } + + public void setBosTokenId(long bosTokenId) { + this.bosTokenId = bosTokenId; + } +} \ No newline at end of file diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/entity/BatchTensorList.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/entity/BatchTensorList.java new file mode 100644 index 0000000000000000000000000000000000000000..87bea9b63a9977b7aa6e897bfaf0ce67957fe44b --- /dev/null +++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/entity/BatchTensorList.java @@ -0,0 +1,174 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package cn.smartjavaai.translation.entity; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; + +/** + * BatchTensorList represents a search state, and the NDArrays inside are updated in each iteration + * of the autoregressive loop. + * + *

It is a struct consisting of NDArrays, whose first dimension is batch, and also contains + * sequence dimension (whose position in tensor's shape is specified by seqDimOrder). The SeqBatcher + * batch operations will operate on these two dimensions. + */ +public abstract class BatchTensorList { + // [batch, seq_past]. seq-dim-size == |past_seq| + |inputIds|. Will grow. + private NDArray pastOutputIds; + + // [batch, seq_past] + // The cache of past attentionMask. seq-dim-size == |past_seq| + |inputIds|. Will grow. + private NDArray pastAttentionMask; + + // (k, v) * numLayer, + // kv: [batch, heads, seq_past, kvfeature] + // The cache of past sequence. seq-dim-size == |past_seq| + |inputIds|. Will grow. + private NDList pastKeyValues; + + // Sequence dimension order among all dimensions for each element in the batch list. + private long[] seqDimOrder; + + BatchTensorList() {} + + /** + * Constructs a new {@code BatchTensorList} instance. + * + * @param list the NDList that contains the serialized version of the batch tensors + * @param seqDimOrder the sequence dimension order that specifies where the sequence dimension + * is in a tensor's shape + */ + BatchTensorList(NDList list, long[] seqDimOrder) { + this.seqDimOrder = seqDimOrder; + pastOutputIds = list.get(0); + pastAttentionMask = list.get(1); + pastKeyValues = list.subNDList(2); + } + + /** + * Constructs a new {@code BatchTensorList} instance. + * + * @param pastOutputIds past output token ids + * @param pastAttentionMask past attention mask + * @param pastKeyValues past kv cache + * @param seqDimOrder the sequence dimension order that specifies where the sequence dimension + * is in a tensor's shape + */ + BatchTensorList( + NDArray pastOutputIds, + NDArray pastAttentionMask, + NDList pastKeyValues, + long[] seqDimOrder) { + this.pastKeyValues = pastKeyValues; + this.pastOutputIds = pastOutputIds; + this.pastAttentionMask = pastAttentionMask; + this.seqDimOrder = seqDimOrder; + } + + /** + * Constructs a new {@code BatchTensorList} instance from the serialized version of the batch + * tensors. + * + *

The pastOutputIds has to be the first in the output list. + * + * @param inputList the serialized version of the batch tensors + * @param seqDimOrder the sequence dimension order that specifies where the sequence dimension + * is in a tensor's shape + * @return BatchTensorList + */ + public abstract BatchTensorList fromList(NDList inputList, long[] seqDimOrder); + + /** + * Returns the serialized version of the BatchTensorList. The pastOutputIds has to be the first + * in the output list. + * + * @return the NDList that contains the serialized BatchTensorList + */ + public abstract NDList getList(); + + /** + * Returns the sequence dimension order which specifies where the sequence dimension is in a + * tensor's shape. + * + * @return the sequence dimension order which specifies where the sequence dimension is in a + * tensor's shape + */ + public long[] getSeqDimOrder() { + return seqDimOrder; + } + + /** + * Returns the value of the pastOutputIds. + * + * @return the value of pastOutputIds + */ + public NDArray getPastOutputIds() { + return pastOutputIds; + } + + /** + * Sets the past output token ids. + * + * @param pastOutputIds the past output token ids + */ + public void setPastOutputIds(NDArray pastOutputIds) { + this.pastOutputIds = pastOutputIds; + } + + /** + * Returns the value of the pastAttentionMask. + * + * @return the value of pastAttentionMask + */ + public NDArray getPastAttentionMask() { + return pastAttentionMask; + } + + /** + * Sets the attention mask. + * + * @param pastAttentionMask the attention mask + */ + public void setPastAttentionMask(NDArray pastAttentionMask) { + this.pastAttentionMask = pastAttentionMask; + } + + /** + * Returns the value of the pastKeyValues. + * + * @return the value of pastKeyValues + */ + public NDList getPastKeyValues() { + return pastKeyValues; + } + + /** + * Sets the kv cache. + * + * @param pastKeyValues the kv cache + */ + public void setPastKeyValues(NDList pastKeyValues) { + this.pastKeyValues = pastKeyValues; + } + + /** + * Sets the sequence dimension order which specifies where the sequence dimension is in a + * tensor's shape. + * + * @param seqDimOrder the sequence dimension order which specifies where the sequence dimension + * is in a tensor's shape + */ + public void setSeqDimOrder(long[] seqDimOrder) { + this.seqDimOrder = seqDimOrder; + } +} \ No newline at end of file diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/entity/CausalLMOutput.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/entity/CausalLMOutput.java new file mode 100644 index 0000000000000000000000000000000000000000..acbc11de829def6de43c967ead4dc982b0267636 --- /dev/null +++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/entity/CausalLMOutput.java @@ -0,0 +1,33 @@ +package cn.smartjavaai.translation.entity; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; + +/** + * 解码输出对象 + * + * @author Calvin + * @mail 179209347@qq.com + * @website www.aias.top + */ +public class CausalLMOutput { + private NDArray logits; + private NDList pastKeyValuesList; + + public CausalLMOutput(NDArray logits, NDList pastKeyValues) { + this.logits = logits; + this.pastKeyValuesList = pastKeyValues; + } + + public NDArray getLogits() { + return logits; + } + + public void setLogits(NDArray logits) { + this.logits = logits; + } + + public NDList getPastKeyValuesList() { + return pastKeyValuesList; + } +} \ No newline at end of file diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/entity/DirectionInfo.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/entity/DirectionInfo.java new file mode 100644 index 0000000000000000000000000000000000000000..c0d7be281cb1138098fbda2209bf4fb95f7dd16a --- /dev/null +++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/entity/DirectionInfo.java @@ -0,0 +1,41 @@ +package cn.smartjavaai.translation.entity; + +/** + * 方向检测结果 + * @author Calvin + * @mail 179209347@qq.com + * @website www.aias.top + */ +public class DirectionInfo { + + /** + * 方向 0 90 180 270 + */ + private String name; + + /** + * 置信度 + */ + private Double prob; + + public DirectionInfo(String name, Double prob) { + this.name = name; + this.prob = prob; + } + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public Double getProb() { + return prob; + } + + public void setProb(Double prob) { + this.prob = prob; + } +} diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/entity/GreedyBatchTensorList.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/entity/GreedyBatchTensorList.java new file mode 100644 index 0000000000000000000000000000000000000000..2745c151c34f87f10b9c14a15fc07e166e33ab1a --- /dev/null +++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/entity/GreedyBatchTensorList.java @@ -0,0 +1,84 @@ +package cn.smartjavaai.translation.entity; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; + +/** + * 贪婪搜索张量对象列表 + * + * @author Calvin + * @mail 179209347@qq.com + * @website www.aias.top + */ +public class GreedyBatchTensorList extends BatchTensorList { + // [batch, 1] + private NDArray nextInputIds; + + private NDArray pastOutputIds; + + private NDArray encoderHiddenStates; + private NDArray attentionMask; + private NDList pastKeyValues; + + public GreedyBatchTensorList( + NDArray nextInputIds, + NDArray pastOutputIds, + NDList pastKeyValues, + NDArray encoderHiddenStates, + NDArray attentionMask) { + this.nextInputIds = nextInputIds; + this.pastKeyValues = pastKeyValues; + this.pastOutputIds = pastOutputIds; + this.attentionMask = attentionMask; + this.encoderHiddenStates = encoderHiddenStates; + } + + public GreedyBatchTensorList() {} + + public BatchTensorList fromList(NDList inputList, long[] seqDimOrder) { + return new GreedyBatchTensorList(); + } + + public NDList getList() { + return new NDList(); + } + + public NDArray getNextInputIds() { + return nextInputIds; + } + + public void setNextInputIds(NDArray nextInputIds) { + this.nextInputIds = nextInputIds; + } + public NDArray getPastOutputIds() { + return pastOutputIds; + } + + public void setPastOutputIds(NDArray pastOutputIds) { + this.pastOutputIds = pastOutputIds; + } + + public NDList getPastKeyValues() { + return pastKeyValues; + } + + public void setPastKeyValues(NDList pastKeyValues) { + this.pastKeyValues = pastKeyValues; + } + + public NDArray getEncoderHiddenStates() { + return encoderHiddenStates; + } + + public void setEncoderHiddenStates(NDArray encoderHiddenStates) { + this.encoderHiddenStates = encoderHiddenStates; + } + + public NDArray getAttentionMask() { + return attentionMask; + } + + public void setAttentionMask(NDArray attentionMask) { + this.attentionMask = attentionMask; + } +} \ No newline at end of file diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/entity/IdCardInfo.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/entity/IdCardInfo.java new file mode 100644 index 0000000000000000000000000000000000000000..8eae0b930efa2cd271fdad4836156b5af134144b --- /dev/null +++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/entity/IdCardInfo.java @@ -0,0 +1,12 @@ +package cn.smartjavaai.translation.entity; + +/** + * 身份证信息 + * @author dwj + * @date 2025/5/22 + */ +public class IdCardInfo { + + + +} diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/enums/MachineTranslationModeEnum.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/enums/MachineTranslationModeEnum.java new file mode 100644 index 0000000000000000000000000000000000000000..1ffa2fdd1d04248322a754bf1fdd7b71d0378262 --- /dev/null +++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/enums/MachineTranslationModeEnum.java @@ -0,0 +1,23 @@ +package cn.smartjavaai.translation.enums; +/** + * 机器翻译模型枚举 + * @author lwx + * @date 2025/6/05 + */ +public enum MachineTranslationModeEnum { + + TRACED_TRANSLATION_CPU; + + /** + * 根据名称获取枚举 (忽略大小写和下划线变体) + */ + public static MachineTranslationModeEnum fromName(String name) { + String formatted = name.trim().toUpperCase().replaceAll("[-_]", ""); + for (MachineTranslationModeEnum model : values()) { + if (model.name().replaceAll("_", "").equals(formatted)) { + return model; + } + } + throw new IllegalArgumentException("未知模型名称: " + name); + } +} diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/exception/TranslationException.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/exception/TranslationException.java new file mode 100644 index 0000000000000000000000000000000000000000..f9fab822eef1cb472433b028c31c7f3c4504ddef --- /dev/null +++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/exception/TranslationException.java @@ -0,0 +1,30 @@ +package cn.smartjavaai.translation.exception; + +/** + * 翻译异常 + * @author lwx + * @date 2025/6/5 + */ +public class TranslationException extends RuntimeException{ + + public TranslationException() { + super(); + } + + public TranslationException(String message, Throwable cause, boolean enableSuppression, boolean writableStackTrace) { + super(message, cause, enableSuppression, writableStackTrace); + } + + public TranslationException(String message, Throwable cause) { + super(message, cause); + } + + public TranslationException(String message) { + super(message); + } + + public TranslationException(Throwable cause) { + super(cause); + } + +} diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/factory/TranslationModelFactory.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/factory/TranslationModelFactory.java new file mode 100644 index 0000000000000000000000000000000000000000..1cab02bd394a10e677b96e1d045953fbf028f4af --- /dev/null +++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/factory/TranslationModelFactory.java @@ -0,0 +1,102 @@ +package cn.smartjavaai.translation.factory; + +import cn.smartjavaai.common.config.Config; + + +import cn.smartjavaai.translation.config.MachineTranslationModelConfig; +import cn.smartjavaai.translation.exception.TranslationException; +import cn.smartjavaai.translation.model.common.TracedTranslationModel; +import cn.smartjavaai.translation.model.common.TranslationCommonModel; +import lombok.extern.slf4j.Slf4j; + +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.ConcurrentHashMap; + +/** + * 机器翻译模型工厂 + * @author dwj + */ +@Slf4j +public class TranslationModelFactory { + + // 使用 volatile 和双重检查锁定来确保线程安全的单例模式 + private static volatile TranslationModelFactory instance; + + private static final ConcurrentHashMap commonDetModelMap = new ConcurrentHashMap<>(); + + + + /** + * 检测模型注册表 + */ + private static final Map> commonDetRegistry = + new ConcurrentHashMap<>(); + + + public static TranslationModelFactory getInstance() { + if (instance == null) { + synchronized (TranslationModelFactory.class) { + if (instance == null) { + instance = new TranslationModelFactory(); + } + } + } + return instance; + } + + + + /** + * 注册通用检测模型 + * @param name + * @param clazz + */ + private static void registerCommonDetModel(String name, Class clazz) { + commonDetRegistry.put(name.toLowerCase(), clazz); + } + + /** + * 获取检测模型(通过配置) + * @param config + * @return + */ + public TranslationCommonModel getDetModel(MachineTranslationModelConfig config) { + if(Objects.isNull(config) || Objects.isNull(config.getModelEnum())){ + throw new TranslationException("未配置OCR模型"); + } + return commonDetModelMap.computeIfAbsent(config.getModelEnum().name(), k -> { + return createCommonDetModel(config); + }); + } + + + /** + * 创建OCR通用检测模型 + * @param config + * @return + */ + private TranslationCommonModel createCommonDetModel(MachineTranslationModelConfig config) { + Class clazz = commonDetRegistry.get(config.getModelEnum().name().toLowerCase()); + if(clazz == null){ + throw new TranslationException("Unsupported model"); + } + TranslationCommonModel model = null; + try { + model = (TranslationCommonModel) clazz.newInstance(); + } catch (InstantiationException | IllegalAccessException e) { + throw new TranslationException(e); + } + model.loadModel(config); + return model; + } + + + // 初始化默认算法 + static { + registerCommonDetModel("TRACED_TRANSLATION_CPU", TracedTranslationModel.class); + + log.info("缓存目录:{}", Config.getCachePath()); + } + +} diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/common/TracedTranslationModel.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/common/TracedTranslationModel.java new file mode 100644 index 0000000000000000000000000000000000000000..af01b8f6d06bc288425e67269b27b17e791f5e51 --- /dev/null +++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/common/TracedTranslationModel.java @@ -0,0 +1,254 @@ +package cn.smartjavaai.translation.model.common; + +import ai.djl.Device; +import ai.djl.MalformedModelException; +import ai.djl.engine.Engine; +import ai.djl.huggingface.tokenizers.Encoding; +import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; +import ai.djl.inference.Predictor; +import ai.djl.modality.cv.Image; +import ai.djl.modality.cv.output.DetectedObjects; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.index.NDIndex; +import ai.djl.repository.zoo.Criteria; +import ai.djl.repository.zoo.ModelNotFoundException; +import ai.djl.repository.zoo.ModelZoo; +import ai.djl.repository.zoo.ZooModel; +import ai.djl.translate.NoopTranslator; +import ai.djl.translate.TranslateException; +import cn.smartjavaai.common.config.Config; +import cn.smartjavaai.common.enums.DeviceEnum; +import cn.smartjavaai.common.pool.PredictorFactory; +import cn.smartjavaai.common.pool.ZooModelFactory; +import cn.smartjavaai.translation.config.MachineTranslationModelConfig; +import cn.smartjavaai.translation.config.SearchConfig; +import cn.smartjavaai.translation.entity.CausalLMOutput; +import cn.smartjavaai.translation.entity.GreedyBatchTensorList; +import cn.smartjavaai.translation.enums.MachineTranslationModeEnum; +import cn.smartjavaai.translation.exception.TranslationException; +import cn.smartjavaai.translation.factory.TranslationModelFactory; +import cn.smartjavaai.translation.model.common.translator.Decoder2Translator; +import cn.smartjavaai.translation.model.common.translator.DecoderTranslator; +import cn.smartjavaai.translation.model.common.translator.EncoderTranslator; +import cn.smartjavaai.translation.utils.TokenUtils; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; +import org.apache.commons.pool2.ObjectPool; +import org.apache.commons.pool2.impl.GenericObjectPool; +import org.junit.Test; + +import java.io.IOException; +import java.nio.file.Paths; +import java.util.Objects; + +/** + * 机器翻译通用检测模型 + * + * @author lwx + * @date 2025/6/05 + */ +@Slf4j +public class TracedTranslationModel implements TranslationCommonModel { + + private ObjectPool> detPredictorPool; + private ZooModel nllbModel; + private HuggingFaceTokenizer tokenizer; + private Predictor encoderPredictor; + private Predictor decoderPredictor; + private Predictor decoder2Predictor; + private SearchConfig searchConfig; + private MachineTranslationModelConfig config; + private NDManager manager; + + @Test + public void detect(){ + Config.setCachePath("E:\\ai\\models\\libs"); + MachineTranslationModelConfig config = new MachineTranslationModelConfig(); + SearchConfig searchConfig = new SearchConfig(); + // 设置输出文字的最大长度 + searchConfig.setMaxSeqLength(128); + // 设置源语言:中文 "zho_Hans": 256200 + searchConfig.setSrcLangId(256200); + // 设置目标语言:英文 "eng_Latn": 256047 + searchConfig.setForcedBosTokenId(256047); + config.setSearchConfig(searchConfig); + config.setDevice(DeviceEnum.CPU); + config.setModelEnum(MachineTranslationModeEnum.TRACED_TRANSLATION_CPU); + config.setModelPath("E:\\ai\\models\\nlp\\"); + config.setModelName("traced_translation_cpu.pt"); + // 输入文字 + String input2 = "智利北部的丘基卡马塔矿是世界上最大的露天矿之一,长约4公里,宽3公里,深1公里。"; + String input = "你好,欢迎使用SmartJavaAI!"; + TranslationCommonModel detModel = TranslationModelFactory.getInstance().getDetModel(config); + // detModel.loadModel(config); + String translate = detModel.translate(input); + System.out.println("识别结果 translate"+translate); + } + @Override + public void loadModel(MachineTranslationModelConfig config) { + if (StringUtils.isBlank(config.getModelPath())) { + throw new TranslationException("modelPath is null"); + } + Device device = null; + if (!Objects.isNull(config.getDevice())) { + device = config.getDevice() == DeviceEnum.CPU ? Device.cpu() : Device.gpu(); + } + this.config = config; + this.searchConfig = config.getSearchConfig(); + Criteria criteria = + Criteria.builder() + .setTypes(NDList.class, NDList.class) + .optModelPath(Paths.get(config.getModelPath()+config.getModelName())) + .optEngine("PyTorch") + .optDevice(device) + .optTranslator(new NoopTranslator()) + .build(); + try { + nllbModel = ModelZoo.loadModel(criteria); + // 创建池子:每个线程独享 Predictor + this.detPredictorPool = new GenericObjectPool<>(new ZooModelFactory<>(nllbModel)); + log.info("当前设备: " + nllbModel.getNDManager().getDevice()); + log.info("当前引擎: " + Engine.getInstance().getEngineName()); + } catch (IOException | ModelNotFoundException | MalformedModelException e) { + throw new TranslationException("模型加载失败", e); + } + } + + @Override + public String translate(String input) throws TranslationException { + ZooModel zooModel = null; + try (NDManager manager = NDManager.newBaseManager()) { + zooModel = detPredictorPool.borrowObject(); + tokenizer = HuggingFaceTokenizer.newInstance(Paths.get(config.getModelPath() + "tokenizer.json")); + encoderPredictor = zooModel.newPredictor(new EncoderTranslator()); + decoderPredictor = zooModel.newPredictor(new DecoderTranslator()); + decoder2Predictor = zooModel.newPredictor(new Decoder2Translator()); + this.manager=manager; + return translateLanguage(input); + } catch (Exception e) { + throw new TranslationException("翻译错误", e); + } finally { + if (zooModel != null) { + try { + detPredictorPool.returnObject(zooModel); //归还 + } catch (Exception e) { + log.warn("归还Predictor失败", e); + try { + zooModel.close(); // 归还失败才销毁 + } catch (Exception ex) { + log.error("关闭Predictor失败", ex); + } + } + } + + } + } + + private String translateLanguage(String input) throws TranslateException { + + Encoding encoding = tokenizer.encode(input); + long[] ids = encoding.getIds(); + // 1. Encoder + long[] inputIds = new long[ids.length]; + // 设置源语言编码 + inputIds[0] = searchConfig.getSrcLangId(); + for (int i = 0; i < ids.length - 1; i++) { + inputIds[i + 1] = ids[i]; + } + + long[] attentionMask = encoding.getAttentionMask(); + NDArray attentionMaskArray = manager.create(attentionMask).expandDims(0); + + NDArray encoderHiddenStates = encoder(inputIds); + + NDArray decoder_input_ids = manager.create(new long[]{searchConfig.getDecoderStartTokenId()}).reshape(1, 1); + NDList decoderInput = new NDList(decoder_input_ids, encoderHiddenStates, attentionMaskArray); + + // 2. Initial Decoder + CausalLMOutput modelOutput = decoder(decoderInput); + modelOutput.getLogits().attach(manager); + modelOutput.getPastKeyValuesList().attach(manager); + + GreedyBatchTensorList searchState = + new GreedyBatchTensorList(null, decoder_input_ids, modelOutput.getPastKeyValuesList(), encoderHiddenStates, attentionMaskArray); + + while (true) { +// try (NDScope ignore = new NDScope()) { + NDArray pastOutputIds = searchState.getPastOutputIds(); + + if (searchState.getNextInputIds() != null) { + decoderInput = new NDList(searchState.getNextInputIds(), searchState.getEncoderHiddenStates(), searchState.getAttentionMask()); + decoderInput.addAll(searchState.getPastKeyValues()); + // 3. Decoder loop + modelOutput = decoder2(decoderInput); + } + + NDArray outputIds = greedyStepGen(searchConfig, pastOutputIds, modelOutput.getLogits()); + + searchState.setNextInputIds(outputIds); + pastOutputIds = pastOutputIds.concat(outputIds, 1); + searchState.setPastOutputIds(pastOutputIds); + + searchState.setPastKeyValues(modelOutput.getPastKeyValuesList()); + + long id = searchState.getNextInputIds().toLongArray()[0]; + if (searchConfig.getEosTokenId() == id) { + searchState.setNextInputIds(null); + break; + } + if (searchState.getPastOutputIds() != null && searchState.getPastOutputIds().getShape().get(1) + 1 >= searchConfig.getMaxSeqLength()) { + break; + } + } + + if (searchState.getNextInputIds() == null) { + NDArray resultIds = searchState.getPastOutputIds(); + String result = TokenUtils.decode(searchConfig, tokenizer, resultIds); + return result; + } else { + NDArray resultIds = searchState.getPastOutputIds(); // .concat(searchState.getNextInputIds(), 1) + String result = TokenUtils.decode(searchConfig, tokenizer, resultIds); + return result; + } + + } + + public NDArray greedyStepGen(SearchConfig config, NDArray pastOutputIds, NDArray next_token_scores) { + next_token_scores = next_token_scores.get(":, -1, :"); + + NDArray new_next_token_scores = manager.create(next_token_scores.getShape(), next_token_scores.getDataType()); + next_token_scores.copyTo(new_next_token_scores); + + // LogitsProcessor 1. ForcedBOSTokenLogitsProcessor + // 设置目标语言 + long cur_len = pastOutputIds.getShape().getLastDimension(); + if (cur_len == 1) { + long num_tokens = new_next_token_scores.getShape().getLastDimension(); + for (long i = 0; i < num_tokens; i++) { + if (i != config.getForcedBosTokenId()) { + new_next_token_scores.set(new NDIndex(":," + i), Float.NEGATIVE_INFINITY); + } + } + new_next_token_scores.set(new NDIndex(":," + config.getForcedBosTokenId()), 0); + } + + NDArray probs = new_next_token_scores.softmax(-1); + NDArray next_tokens = probs.argMax(-1); + + return next_tokens.expandDims(0); + } + + public NDArray encoder(long[] ids) throws TranslateException { + return encoderPredictor.predict(ids); + } + + public CausalLMOutput decoder(NDList input) throws TranslateException { + return decoderPredictor.predict(input); + } + + public CausalLMOutput decoder2(NDList input) throws TranslateException { + return decoder2Predictor.predict(input); + } +} diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/common/TranslationCommonModel.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/common/TranslationCommonModel.java new file mode 100644 index 0000000000000000000000000000000000000000..513331d0181e776e3a177e56208375fc6dc4bb34 --- /dev/null +++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/common/TranslationCommonModel.java @@ -0,0 +1,32 @@ +package cn.smartjavaai.translation.model.common; + +import cn.smartjavaai.translation.config.MachineTranslationModelConfig; + +/** + * 机器翻译通用检测模型 + * @author lwx + * @date 2025/6/05 + */ +public interface TranslationCommonModel { + + /** + * 加载模型 + * @param config + */ + void loadModel(MachineTranslationModelConfig config); // 加载模型 + + /** + * 机器翻译 + * @param input 翻译内容 + * @return + */ + default String translate(String input) { + throw new UnsupportedOperationException("默认不支持该功能"); + } + + + + + + +} diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/common/translator/Decoder2Translator.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/common/translator/Decoder2Translator.java new file mode 100644 index 0000000000000000000000000000000000000000..2c32518841e998a0cf0c0efcf8498cf0578bb9d4 --- /dev/null +++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/common/translator/Decoder2Translator.java @@ -0,0 +1,46 @@ +package cn.smartjavaai.translation.model.common.translator; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.translate.NoBatchifyTranslator; +import ai.djl.translate.TranslatorContext; +import cn.smartjavaai.translation.entity.CausalLMOutput; + + +/** + * 解碼器,參數支持 pastKeyValues + * + * @author Calvin + * @mail 179209347@qq.com + * @website www.aias.top + */ +public class Decoder2Translator implements NoBatchifyTranslator { + private String tupleName; + + public Decoder2Translator() { + tupleName = "past_key_values(" + 12 + ',' + 4 + ')'; + } + + @Override + public NDList processInput(TranslatorContext ctx, NDList input) { + + NDArray placeholder = ctx.getNDManager().create(0); + placeholder.setName("module_method:decoder2"); + + input.add(placeholder); + + return input; + } + + @Override + public CausalLMOutput processOutput(TranslatorContext ctx, NDList output) { + NDArray logitsOutput = output.get(0); + NDList pastKeyValuesOutput = output.subNDList(1, 12 * 4 + 1); + + for (NDArray array : pastKeyValuesOutput) { + array.setName(tupleName); + } + + return new CausalLMOutput(logitsOutput, pastKeyValuesOutput); + } +} \ No newline at end of file diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/common/translator/DecoderTranslator.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/common/translator/DecoderTranslator.java new file mode 100644 index 0000000000000000000000000000000000000000..981b70823e9057a50633fc2b022ea3f4cd894d59 --- /dev/null +++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/common/translator/DecoderTranslator.java @@ -0,0 +1,45 @@ +package cn.smartjavaai.translation.model.common.translator; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.translate.NoBatchifyTranslator; +import ai.djl.translate.TranslatorContext; +import cn.smartjavaai.translation.entity.CausalLMOutput; + +/** + * 解碼器,參數沒有 pastKeyValues + * + * @author Calvin + * @mail 179209347@qq.com + * @website www.aias.top + */ +public class DecoderTranslator implements NoBatchifyTranslator { + private String tupleName; + + public DecoderTranslator() { + tupleName = "past_key_values(" + 12 + ',' + 4 + ')'; + } + + @Override + public NDList processInput(TranslatorContext ctx, NDList input) { + + NDArray placeholder = ctx.getNDManager().create(0); + placeholder.setName("module_method:decoder"); + + input.add(placeholder); + + return input; + } + + @Override + public CausalLMOutput processOutput(TranslatorContext ctx, NDList output) { + NDArray logitsOutput = output.get(0); + NDList pastKeyValuesOutput = output.subNDList(1, 12 * 4 + 1); + + for (NDArray array : pastKeyValuesOutput) { + array.setName(tupleName); + } + + return new CausalLMOutput(logitsOutput, pastKeyValuesOutput); + } +} \ No newline at end of file diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/common/translator/EncoderTranslator.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/common/translator/EncoderTranslator.java new file mode 100644 index 0000000000000000000000000000000000000000..8c9961963cb5f4af64e05ffd7216a40f098afbd0 --- /dev/null +++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/common/translator/EncoderTranslator.java @@ -0,0 +1,49 @@ +package cn.smartjavaai.translation.model.common.translator; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.translate.NoBatchifyTranslator; +import ai.djl.translate.TranslatorContext; + +import java.util.Arrays; + +/** + * 编码器前后处理 + * + * @author Calvin + * @mail 179209347@qq.com + * @website www.aias.top + */ +public class EncoderTranslator implements NoBatchifyTranslator { + + + public EncoderTranslator() { + } + + @Override + public NDList processInput(TranslatorContext ctx, long[] input) throws Exception { + NDManager manager = ctx.getNDManager(); + + NDArray inputIdArray = manager.create(input).expandDims(0); + inputIdArray.setName("input_ids"); + + long[] attentionMask = new long[input.length]; + Arrays.fill(attentionMask, 1); + NDArray attentionMaskArray = manager.create(attentionMask).expandDims(0); + attentionMaskArray.setName("attention_mask"); + + NDArray placeholder = ctx.getNDManager().create(0); + placeholder.setName("module_method:encoder"); + + return new NDList(inputIdArray, attentionMaskArray, placeholder); + } + + @Override + public NDArray processOutput(TranslatorContext ctx, NDList list) { + NDArray encoderHiddenStates = list.get(0); + encoderHiddenStates.detach(); + return encoderHiddenStates; + } + +} \ No newline at end of file diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/utils/TokenUtils.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/utils/TokenUtils.java new file mode 100644 index 0000000000000000000000000000000000000000..582804ee9f8f78d17d3defea8b16c4c1cbcf87a8 --- /dev/null +++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/utils/TokenUtils.java @@ -0,0 +1,48 @@ +package cn.smartjavaai.translation.utils; + +import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; +import ai.djl.ndarray.NDArray; +import cn.smartjavaai.translation.config.SearchConfig; + + +import java.util.ArrayList; + +/** + * + * + * @author lwx + * @date 2025/4/22 + */ +public final class TokenUtils { + + private TokenUtils() { + } + + /** + * 语言解码 + * + * @param tokenizer + * @param output + * @return + */ + public static String decode(SearchConfig config, HuggingFaceTokenizer tokenizer, NDArray output) { + long[] outputIds = output.toLongArray(); + ArrayList outputIdsList = new ArrayList<>(); + + for (long id : outputIds) { + if (id == config.getEosTokenId() || id==config.getSrcLangId() || id==config.getForcedBosTokenId()) { + continue; + } + outputIdsList.add(id); + } + + Long[] objArr = outputIdsList.toArray(new Long[0]); + long[] ids = new long[objArr.length]; + for (int i = 0; i < objArr.length; i++) { + ids[i] = objArr[i]; + } + String text = tokenizer.decode(ids); + + return text; + } +} \ No newline at end of file