diff --git a/models/vision-language-understanding/Intern_VL/vllm/README.md b/models/vision-language-understanding/Intern_VL/vllm/README.md new file mode 100644 index 0000000000000000000000000000000000000000..0b09f06eaef8da47b0f9fced467f6d34770191e7 --- /dev/null +++ b/models/vision-language-understanding/Intern_VL/vllm/README.md @@ -0,0 +1,40 @@ +# InternVL2-4B + +## Description + +InternVL2-4B is a large-scale multimodal model developed by WeTab AI, designed to handle a wide range of tasks involving both text and visual data. With 4 billion parameters, it is capable of understanding and generating complex patterns in data, making it suitable for applications such as image recognition, natural language processing, and multimodal learning. + +## Setup + +### Instal + +In order to run the model smoothly, you need to get the sdk from [resource center](https://support.iluvatar.com/#/ProductLine?id=2) of Iluvatar CoreX official website. + +```bash +# Install libGL +## CentOS +yum install -y mesa-libGL +## Ubuntu +apt install -y libgl1-mesa-dev + +# Contact the iluvatar manager to get adapted install packages of vllm, triton, and ixformer +pip3 install vllm +pip3 install triton +pip3 install ixformer +``` +### Download + +-Model: + +```bash +cd ${DeepSparkInference}/models/vision-language-understanding/Intern_VL/vllm +mkdir -p data/intern_vl +ln -s /path/to/InternVL2-4B ./data/intern_vl +``` + +## Inference + +```bash +export CUDA_VISIBLE_DEVICES=0,1 +python3 offline_inference_vision_language.py --model ./data/intern_vl/InternVL2-4B --max-tokens 256 -tp 2 --temperature 0.0 --max-model-len 2048 +``` diff --git a/models/vision-language-understanding/Intern_VL/vllm/offline_inference_vision_language.py b/models/vision-language-understanding/Intern_VL/vllm/offline_inference_vision_language.py new file mode 100644 index 0000000000000000000000000000000000000000..b82c8ec0cffd95b30175162cee61c59bf8ff7083 --- /dev/null +++ b/models/vision-language-understanding/Intern_VL/vllm/offline_inference_vision_language.py @@ -0,0 +1,157 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +This example shows how to use vLLM for running offline inference +with the correct prompt format on vision language models. + +For most models, the prompt format should follow corresponding examples +on HuggingFace model repository. +""" +import sys +from pathlib import Path +import os +sys.path.append(str(Path(__file__).resolve().parent.parent.parent)) +import argparse +import dataclasses +import inspect +from vllm.assets.image import ImageAsset +from vllm.assets.video import VideoAsset + +from transformers import AutoTokenizer + +from vllm import LLM, EngineArgs, SamplingParams +from utils import sampling_add_cli_args + +# InternVL +def run_internvl(question, engine_params,model,modality): + assert modality == "image" + llm = LLM(**engine_params) + + tokenizer = AutoTokenizer.from_pretrained(model, + trust_remote_code=True) + messages = [{'role': 'user', 'content': f"\n{question}"}] + prompt = tokenizer.apply_chat_template(messages, + tokenize=False, + add_generation_prompt=True) + + # Stop tokens for InternVL + # models variants may have different stop tokens + # please refer to the model card for the correct "stop words": + # https://huggingface.co/OpenGVLab/InternVL2-2B#service + stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"] + stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens] + return llm, prompt, stop_token_ids + + + +def get_multi_modal_input(args): + """ + return { + "data": image or video, + "question": question, + } + """ + if args.modality == "image": + # Input image and question + image = ImageAsset("cherry_blossom").pil_image.convert("RGB") + img_question = "What is the content of this image?" + + return { + "data": image, + "question": img_question, + } + + if args.modality == "video": + # Input video and question + video = VideoAsset(name="sample_demo_1.mp4", + num_frames=args.num_frames).np_ndarrays + vid_question = "Why is this video funny?" + + return { + "data": video, + "question": vid_question, + } + + msg = f"Modality {args.modality} is not supported." + raise ValueError(msg) + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--num-prompts', + type=int, + default=1, + help='Number of prompts to run.') + parser.add_argument('--modality', + type=str, + default="image", + help='Modality of the input.') + parser.add_argument('--num-frames', + type=int, + default=16, + help='Number of frames to extract from the video.') + parser = EngineArgs.add_cli_args(parser) + parser = sampling_add_cli_args(parser) + args = parser.parse_args() + engine_args = [attr.name for attr in dataclasses.fields(EngineArgs)] + sampling_args = [ + param.name + for param in list( + inspect.signature(SamplingParams).parameters.values() + ) + ] + engine_params = {attr: getattr(args, attr) for attr in engine_args} + sampling_params = { + attr: getattr(args, attr) for attr in sampling_args if args.__contains__(attr) + } + + modality = args.modality + mm_input = get_multi_modal_input(args) + data = mm_input["data"] + question = mm_input["question"] + + llm, prompt, stop_token_ids = run_internvl(question,engine_params,args.model,args.modality) + sampling_params['stop_token_ids'] = stop_token_ids + + # We set temperature to 0.2 so that outputs can be different + # even when all prompts are identical when running batch inference. + sampling_params = SamplingParams(**sampling_params) + + assert args.num_prompts > 0 + if args.num_prompts == 1: + # Single inference + inputs = { + "prompt": prompt, + "multi_modal_data": { + modality: data + }, + } + + else: + # Batch inference + inputs = [{ + "prompt": prompt, + "multi_modal_data": { + modality: data + }, + } for _ in range(args.num_prompts)] + + outputs = llm.generate(inputs, sampling_params=sampling_params) + + for o in outputs: + generated_text = o.outputs[0].text + print(generated_text) \ No newline at end of file diff --git a/models/vision-language-understanding/Intern_VL/vllm/utils.py b/models/vision-language-understanding/Intern_VL/vllm/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c6def85dedc08ef9c3a489ce9dc5b1ff4a5e48b0 --- /dev/null +++ b/models/vision-language-understanding/Intern_VL/vllm/utils.py @@ -0,0 +1,173 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import codecs +import logging +import argparse + + +def sampling_add_cli_args(args: argparse.ArgumentParser) -> argparse.ArgumentParser: + args.add_argument( + '--n', + type=int, + default=1, + help="Number of output sequences to return for the given prompt.") + args.add_argument( + '--best-of', + type=int, + default=None, + help="Number of output sequences that are generated from the prompt. " + "From these `best_of` sequences, the top `n` sequences are returned. " + "`best_of` must be greater than or equal to `n`. This is treated as " + "the beam width when `use_beam_search` is True. By default, `best_of`" + "is set to `n`.") + args.add_argument( + '--presence-penalty', + type=float, + default=0.0, + help="Float that penalizes new tokens based on whether they " + "appear in the generated text so far. Values > 0 encourage the model " + "to use new tokens, while values < 0 encourage the model to repeat " + "tokens.") + args.add_argument( + '--frequency-penalty', + type=float, + default=0.0, + help="Float that penalizes new tokens based on their " + " frequency in the generated text so far. Values > 0 encourage the " + " model to use new tokens, while values < 0 encourage the model to " + "repeat tokens.") + args.add_argument( + '--repetition-penalty', + type=float, + default=1.0, + help="Float that penalizes new tokens based on whether " + "they appear in the prompt and the generated text so far. Values > 1 " + "encourage the model to use new tokens, while values < 1 encourage " + "the model to repeat tokens.") + args.add_argument( + '--temperature', + type=float, + default=1.0, + help="Float that controls the randomness of the sampling. Lower " + "values make the model more deterministic, while higher values make " + "the model more random. Zero means greedy sampling.") + args.add_argument( + '--top-p', + type=float, + default=1.0, + help="Float that controls the cumulative probability of the top tokens " + "to consider. Must be in (0, 1]. Set to 1 to consider all tokens.") + args.add_argument( + '--top-k', + type=int, + default=-1, + help="Integer that controls the number of top tokens to consider. Set " + "to -1 to consider all tokens.") + args.add_argument( + '--min-p', + type=float, + default=0.0, + help="Float that represents the minimum probability for a token to be " + "considered, relative to the probability of the most likely token. " + "Must be in [0, 1]. Set to 0 to disable this.") + args.add_argument( + '--use-beam-search', + default=False, + action="store_true", + help="Whether to use beam search instead of sampling.") + args.add_argument( + '--length-penalty', + type=float, + default=1.0, + help="Float that penalizes sequences based on their length. Used in beam search.") + args.add_argument( + '--stop', + type=str, + default=None, + help="List of strings that stop the generation when they are generated. " + "The returned output will not contain the stop strings.") + args.add_argument( + '--stop-token-ids', + type=int, + default=None, + help="List of tokens that stop the generation when they are " + "generated. The returned output will contain the stop tokens unless " + "the stop tokens are special tokens.") + args.add_argument( + '--include-stop-str-in-output', + default=False, + action="store_true", + help="Whether to include the stop strings in output text. Defaults to False.") + args.add_argument( + '--ignore-eos', + default=False, + action="store_true", + help="Whether to ignore the EOS token and continue generating tokens after the EOS token is generated.") + args.add_argument( + '--max-tokens', + type=int, + default=16, + help="Maximum number of tokens to generate per output sequence.") + args.add_argument( + '--logprobs', + type=int, + default=None, + help="NNumber of log probabilities to return per output token. " + "Note that the implementation follows the OpenAI API: The return " + "result includes the log probabilities on the `logprobs` most likely " + "tokens, as well the chosen tokens. The API will always return the " + "log probability of the sampled token, so there may be up to " + "`logprobs+1` elements in the response.") + args.add_argument( + '--prompt-logprobs', + type=int, + default=None, + help="Number of log probabilities to return per prompt token.") + args.add_argument( + '--skip-special-tokens', + default=True, + action="store_false", + help="Whether to skip special tokens in the output.") + args.add_argument( + '--spaces-between-special-tokens', + default=True, + action="store_false", + help="Whether to add spaces between special tokens in the output. Defaults to True.") + # early_stopping logits_processors seed + return args + + +def load_chat_template(tokenizer, chat_template): + if chat_template is not None: + try: + with open(chat_template, "r") as f: + tokenizer.chat_template = f.read() + except OSError: + # If opening a file fails, set chat template to be args to + # ensure we decode so our escape are interpreted correctly + tokenizer.chat_template = codecs.decode( + chat_template, "unicode_escape") + + logging.info( + f"Using supplied chat template:\n{tokenizer.chat_template}" + ) + elif tokenizer.chat_template is not None: + logging.info( + f"Using default chat template:\n{tokenizer.chat_template}. This May lead to unsatisfactory results. You can provide a template.jinja file for vllm." + ) + else: + logging.warning( + "No chat template provided. Chat API will not work. This May lead to unsatisfactory results. You can provide a template.jinja file for vllm.") diff --git a/models/vision-language-understanding/Intern_VL/vllm/vllm_public_assets/cherry_blossom.jpg b/models/vision-language-understanding/Intern_VL/vllm/vllm_public_assets/cherry_blossom.jpg new file mode 100644 index 0000000000000000000000000000000000000000..63173db0da7687d7841fe4d85239d8e277d81259 Binary files /dev/null and b/models/vision-language-understanding/Intern_VL/vllm/vllm_public_assets/cherry_blossom.jpg differ