diff --git a/models/chatglm6b/modeling_chatglm_layer.py b/models/chatglm6b/modeling_chatglm_layer.py index 80d9cd1b95a459f9f5a7a9b8875e95e0109c2d54..f1d84dd81ed56e43af120c94b165446e9ce99b8f 100644 --- a/models/chatglm6b/modeling_chatglm_layer.py +++ b/models/chatglm6b/modeling_chatglm_layer.py @@ -31,8 +31,6 @@ from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaL from .configuration_chatglm import ChatGLMConfig -import time - ACLTRANSFORMER_HOME_PATH = os.environ.get("ACLTRANSFORMER_HOME_PATH") if ACLTRANSFORMER_HOME_PATH is None: raise RuntimeError( @@ -1022,11 +1020,6 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): bias=False, dtype=torch.half ) - - self.count = 0 - self.total = 0 - self.cur_time = 0 - self.first = 0 def get_output_embeddings(self): return self.lm_head @@ -1328,22 +1321,12 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): model_inputs = self.prepare_inputs_for_generation( input_ids, **model_kwargs) # forward pass to get next token - torch.npu.synchronize() - start = time.time() outputs = self( **model_inputs, return_dict=True, output_attentions=False, output_hidden_states=False, ) - torch.npu.synchronize() - end = time.time() - self.count += 1 - self.cur_time = (end - start) * 1000 - if self.count == 1: - self.first = self.cur_time - else: - self.total += self.cur_time next_token_logits = outputs.logits[:, -1, :] diff --git a/models/chatglm6b/modeling_chatglm_layer_performance.py b/models/chatglm6b/modeling_chatglm_layer_performance.py new file mode 100644 index 0000000000000000000000000000000000000000..31f224abdb83f046ff374c242d6b294e96e4d6fe --- /dev/null +++ b/models/chatglm6b/modeling_chatglm_layer_performance.py @@ -0,0 +1,1376 @@ +""" PyTorch ChatGLM model. """ + +import math +import copy +import os +import warnings +import json + +import torch +import torch.utils.checkpoint +import torch.nn.functional as F +from torch import nn +from torch.nn import CrossEntropyLoss, LayerNorm +from torch.nn.utils import skip_init +from typing import Optional, Tuple, Union, List, Callable + +from transformers.utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + BaseModelOutputWithPastAndCrossAttentions, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging +from transformers.generation.logits_process import LogitsProcessor +from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig + +from .configuration_chatglm import ChatGLMConfig + +import time + +ACLTRANSFORMER_HOME_PATH = os.environ.get("ACLTRANSFORMER_HOME_PATH") +if ACLTRANSFORMER_HOME_PATH is None: + raise RuntimeError( + "env ACLTRANSFORMER_HOME_PATH not exist, source set_env.sh") + +LIB_PATH = os.path.join(ACLTRANSFORMER_HOME_PATH, + "examples/libacltransformer_torch.so") +torch.classes.load_library(LIB_PATH) +acl_layer = torch.classes.LayerTorch.LayerTorch("ChatGlm6BLayer") +acl_layer.set_workspace(1020 * 1024 * 1000) + +# flags required to enable jit fusion kernels +torch._C._jit_set_profiling_mode(False) +torch._C._jit_set_profiling_executor(False) +torch._C._jit_override_can_fuse_on_cpu(True) +torch._C._jit_override_can_fuse_on_gpu(True) + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "THUDM/ChatGLM-6B" +_CONFIG_FOR_DOC = "ChatGLM6BConfig" + +CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "THUDM/chatglm-6b", + # See all ChatGLM-6B models at https://huggingface.co/models?filter=chatglm +] + + +class InvalidScoreLogitsProcessor(LogitsProcessor): + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + if torch.isnan(scores).any() or torch.isinf(scores).any(): + scores.zero_() + scores[..., 20005] = 5e4 + return scores + + +def load_tf_weights_in_chatglm_6b(model, config, tf_checkpoint_path): + """Load tf checkpoints in a pytorch model.""" + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array) + + for name, array in zip(names, arrays): + name = name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any( + n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", + "AdamWeightDecayOptimizer_1", "global_step"] + for n in name + ): + logger.info(f"Skipping {'/'.join(name)}") + continue + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "kernel" or scope_names[0] == "gamma": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "output_weights": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "squad": + pointer = getattr(pointer, "classifier") + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info(f"Skipping {'/'.join(name)}") + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if m_name[-11:] == "_embeddings": + pointer = getattr(pointer, "weight") + elif m_name == "kernel": + array = np.transpose(array) + try: + assert ( + pointer.shape == array.shape + ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array) + return model + + +# @torch.jit.script +# def gelu_impl(x): +# """OpenAI's gelu implementation.""" +# return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * +# (1.0 + 0.044715 * x * x))) +def gelu_impl(x): + """OpenAI's gelu implementation.""" + return torch.fast_gelu(x) + + +def gelu(x): + return gelu_impl(x) + + +class RotaryEmbedding(torch.nn.Module): + def __init__(self, dim, base=10000, precision=torch.half, learnable=False): + super().__init__() + inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) + inv_freq = inv_freq.half() + self.learnable = learnable + if learnable: + self.inv_freq = torch.nn.Parameter(inv_freq) + self.max_seq_len_cached = None + else: + self.register_buffer('inv_freq', inv_freq) + self.max_seq_len_cached = None + self.cos_cached = None + self.sin_cached = None + self.precision = precision + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, + error_msgs): + pass + + def forward(self, x, seq_dim=1, seq_len=None): + if seq_len is None: + seq_len = x.shape[seq_dim] + if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached): + self.max_seq_len_cached = None if self.learnable else seq_len + # t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype) + t = torch.arange(seq_len, device='cpu').npu().half() + freqs = torch.einsum('i,j->ij', t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + if self.precision == torch.bfloat16: + emb = emb.float() + + # [sx, 1 (b * np), hn] + # cos_cached = emb.cos()[:, None, :] + # sin_cached = emb.sin()[:, None, :] + cos_cached = emb.cos().unsqueeze(1) + sin_cached = emb.sin().unsqueeze(1) + if self.precision == torch.bfloat16: + cos_cached = cos_cached.bfloat16() + sin_cached = sin_cached.bfloat16() + if self.learnable: + return cos_cached, sin_cached + self.cos_cached, self.sin_cached = cos_cached, sin_cached + return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...] + + +inv_freq_global = 1. / \ + (10000 ** (torch.arange(0, 64, 2).float() / 64)).npu().half() +temp_global = torch.arange(2049, device='cpu').npu().half() +freqs_global = torch.einsum('i,j->ij', temp_global, inv_freq_global) +emb_global = torch.cat((freqs_global, freqs_global), dim=-1) +cosTable = emb_global.cos().unsqueeze(1) +sinTable = emb_global.sin().unsqueeze(1) + + +def rotate_half(x): + x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] + # dim=-1 triggers a bug in earlier torch versions + return torch.cat((-x2, x1), dim=x1.ndim - 1) + + +@torch.jit.script +def apply_rotary_pos_emb_index(q, k, cos, sin, position_id): + # position_id: [sq, b], q, k: [sq, b, np, hn], cos: [sq, 1, hn] -> [sq, b, 1, hn] + cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), \ + F.embedding(position_id, sin.squeeze(1)).unsqueeze(2) + q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + \ + (rotate_half(k) * sin) + return q, k + + +def attention_fn( + self, + query_layer, + key_layer, + value_layer, + attention_mask, + hidden_size_per_partition, + layer_id, + layer_past=None, + scaling_attention_score=True, + use_cache=False, +): + if layer_past is not None: + past_key, past_value = layer_past + idScal = layer_id.item() + key_layer = torch.cat((past_key, key_layer), dim=0) + value_layer = torch.cat((past_value, value_layer), dim=0) + + # seqlen, batch, num_attention_heads, hidden_size_per_attention_head + seq_len, b, nh, hidden_size = key_layer.shape + + if use_cache: + present = (key_layer, value_layer) + else: + present = None + + query_key_layer_scaling_coeff = float(layer_id + 1) + if scaling_attention_score: + query_layer = query_layer / \ + (math.sqrt(hidden_size) * query_key_layer_scaling_coeff) + + # =================================== + # Raw attention scores. [b, np, s, s] + # =================================== + + # [b, np, sq, sk] + output_size = (query_layer.size(1), query_layer.size(2), + query_layer.size(0), key_layer.size(0)) + + # [sq, b, np, hn] -> [sq, b * np, hn] + query_layer = query_layer.view( + output_size[2], output_size[0] * output_size[1], -1) + # [sk, b, np, hn] -> [sk, b * np, hn] + key_layer = key_layer.view( + output_size[3], output_size[0] * output_size[1], -1) + + matmul_result = torch.empty( + output_size[0] * output_size[1], + output_size[2], + output_size[3], + dtype=query_layer.dtype, + device=query_layer.device, + ) + + # matmul_result = torch.baddbmm( + # matmul_result, + # query_layer.transpose(0, 1), # [b * np, sq, hn] + # key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] + # beta=0.0, + # alpha=1.0, + # ) + matmul_result = torch.bmm(query_layer.transpose( + 0, 1), key_layer.permute(1, 2, 0)) + # change view to [b, np, sq, sk] + attention_scores = matmul_result.view(*output_size) + + if self.scale_mask_softmax: + self.scale_mask_softmax.scale = query_key_layer_scaling_coeff + attention_probs = self.scale_mask_softmax( + attention_scores, attention_mask.contiguous()) + else: + if not (attention_mask == 0).all(): + # if auto-regressive, skip + attention_scores.masked_fill_(attention_mask, -10000.0) + dtype = attention_scores.type() + attention_scores = attention_scores.float() + attention_scores = attention_scores * query_key_layer_scaling_coeff + + attention_probs = F.softmax(attention_scores, dim=-1) + attention_probs = attention_probs.type(dtype) + # try: + # if dtype in ["torch.xla.HalfTensor","torch.npu.HalfTensor"]: + # attention_probs = attention_probs.to(torch.float16) + # else: + # attention_probs = attention_probs.type(dtype) + # except Exception as e: + # print(e,dtype) + + # ========================= + # Context layer. [sq, b, hp] + # ========================= + + # value_layer -> context layer. + # [sk, b, np, hn] --> [b, np, sq, hn] + + # context layer shape: [b, np, sq, hn] + output_size = (value_layer.size(1), value_layer.size(2), + query_layer.size(0), value_layer.size(3)) + + # change view [sk, b * np, hn] + value_layer = value_layer.view(value_layer.size( + 0), output_size[0] * output_size[1], -1) + + # change view [b * np, sq, sk] + attention_probs = attention_probs.view( + output_size[0] * output_size[1], output_size[2], -1) + + # matmul: [b * np, sq, hn] + context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) + + # change view [b, np, sq, hn] + context_layer = context_layer.view(*output_size) + + # [b, np, sq, hn] --> [sq, b, np, hn] + context_layer = context_layer.permute(2, 0, 1, 3).contiguous() + + # [sq, b, np, hn] --> [sq, b, hp] + new_context_layer_shape = context_layer.size( + )[:-2] + (hidden_size_per_partition,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, present, attention_probs) + + return outputs + + +class SelfAttention(torch.nn.Module): + def __init__(self, hidden_size, num_attention_heads, + layer_id, hidden_size_per_attention_head=None, bias=True, + params_dtype=torch.float, position_encoding_2d=True): + super(SelfAttention, self).__init__() + + self.layer_id = layer_id + self.hidden_size = hidden_size + self.hidden_size_per_partition = hidden_size + self.num_attention_heads = num_attention_heads + self.num_attention_heads_per_partition = num_attention_heads + self.position_encoding_2d = position_encoding_2d + self.rotary_emb = RotaryEmbedding( + self.hidden_size // (self.num_attention_heads * 2) + if position_encoding_2d + else self.hidden_size // self.num_attention_heads, + base=10000, + precision=torch.half, + learnable=False, + ) + + self.scale_mask_softmax = None + + if hidden_size_per_attention_head is None: + self.hidden_size_per_attention_head = hidden_size // num_attention_heads + else: + self.hidden_size_per_attention_head = hidden_size_per_attention_head + + self.inner_hidden_size = num_attention_heads * \ + self.hidden_size_per_attention_head + + # Strided linear layer. + self.query_key_value = skip_init( + torch.nn.Linear, + hidden_size, + 3 * self.inner_hidden_size, + bias=bias, + dtype=params_dtype, + ) + + self.dense = skip_init( + torch.nn.Linear, + self.inner_hidden_size, + hidden_size, + bias=bias, + dtype=params_dtype, + ) + + @staticmethod + def attention_mask_func(attention_scores, attention_mask): + attention_scores.masked_fill_(attention_mask, -10000.0) + return attention_scores + + def split_tensor_along_last_dim(self, tensor, num_partitions, + contiguous_split_chunks=False): + """Split a tensor along its last dimension. + Arguments: + tensor: input tensor. + num_partitions: number of partitions to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. + """ + # Get the size and dimension. + last_dim = tensor.dim() - 1 + last_dim_size = tensor.size()[last_dim] // num_partitions + # Split. + tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) + # Note: torch.split does not create contiguous tensors by default. + if contiguous_split_chunks: + return tuple(chunk.contiguous() for chunk in tensor_list) + + return tensor_list + + def forward( + self, + hidden_states: torch.Tensor, + position_ids, + attention_mask: torch.Tensor, + layer_id, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + use_cache: bool = False, + output_attentions: bool = False, + ): + """ + hidden_states: [seq_len, batch, hidden_size] + attention_mask: [(1, 1), seq_len, seq_len] + """ + + # [seq_len, batch, 3 * hidden_size] + mixed_raw_layer = self.query_key_value(hidden_states) + + # [seq_len, batch, 3 * hidden_size] --> [seq_len, batch, num_attention_heads, 3 * hidden_size_per_attention_head] + new_tensor_shape = mixed_raw_layer.size()[:-1] + ( + self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head, + ) + mixed_raw_layer = mixed_raw_layer.view(*new_tensor_shape) + + # [seq_len, batch, num_attention_heads, hidden_size_per_attention_head] + (query_layer, key_layer, value_layer) = self.split_tensor_along_last_dim( + mixed_raw_layer, 3) + + if self.position_encoding_2d: + q1, q2 = query_layer.chunk(2, dim=(query_layer.ndim - 1)) + k1, k2 = key_layer.chunk(2, dim=(key_layer.ndim - 1)) + cos, sin = self.rotary_emb(q1, seq_len=position_ids.max() + 1) + position_ids, block_position_ids = position_ids[:, 0, :].transpose(0, 1).contiguous(), \ + position_ids[:, 1, :].transpose(0, 1).contiguous() + q1, k1 = apply_rotary_pos_emb_index(q1, k1, cos, sin, position_ids) + q2, k2 = apply_rotary_pos_emb_index( + q2, k2, cos, sin, block_position_ids) + query_layer = torch.concat([q1, q2], dim=(q1.ndim - 1)) + key_layer = torch.concat([k1, k2], dim=(k1.ndim - 1)) + else: + position_ids = position_ids.transpose(0, 1) + cos, sin = self.rotary_emb( + value_layer, seq_len=position_ids.max() + 1) + # [seq_len, batch, num_attention_heads, hidden_size_per_attention_head] + query_layer, key_layer = apply_rotary_pos_emb_index( + query_layer, key_layer, cos, sin, position_ids) + + # [seq_len, batch, hidden_size] + context_layer, present, attention_probs = attention_fn( + self=self, + query_layer=query_layer, + key_layer=key_layer, + value_layer=value_layer, + attention_mask=attention_mask, + hidden_size_per_partition=self.hidden_size_per_partition, + layer_id=layer_id, + layer_past=layer_past, + use_cache=use_cache + ) + + output = self.dense(context_layer) + + outputs = (output, present) + + if output_attentions: + outputs += (attention_probs,) + + return outputs # output, present, attention_probs + + +class GEGLU(torch.nn.Module): + def __init__(self): + super().__init__() + self.activation_fn = F.gelu + + def forward(self, x): + # dim=-1 breaks in jit for pt<1.10 + x1, x2 = x.chunk(2, dim=(x.ndim - 1)) + return x1 * self.activation_fn(x2) + + +class GLU(torch.nn.Module): + def __init__(self, hidden_size, inner_hidden_size=None, + layer_id=None, bias=True, activation_func=gelu, params_dtype=torch.float): + super(GLU, self).__init__() + self.layer_id = layer_id + self.activation_func = activation_func + + # Project to 4h. + self.hidden_size = hidden_size + if inner_hidden_size is None: + inner_hidden_size = 4 * hidden_size + self.inner_hidden_size = inner_hidden_size + self.dense_h_to_4h = skip_init( + torch.nn.Linear, + self.hidden_size, + self.inner_hidden_size, + bias=bias, + dtype=params_dtype, + ) + # Project back to h. + self.dense_4h_to_h = skip_init( + torch.nn.Linear, + self.inner_hidden_size, + self.hidden_size, + bias=bias, + dtype=params_dtype, + ) + + def forward(self, hidden_states): + """ + hidden_states: [seq_len, batch, hidden_size] + """ + + # [seq_len, batch, inner_hidden_size] + intermediate_parallel = self.dense_h_to_4h(hidden_states) + + intermediate_parallel = self.activation_func(intermediate_parallel) + + output = self.dense_4h_to_h(intermediate_parallel) + + return output + + +class GLMBlock(torch.nn.Module): + def __init__( + self, + hidden_size, + num_attention_heads, + layernorm_epsilon, + layer_id, + inner_hidden_size=None, + hidden_size_per_attention_head=None, + layernorm=LayerNorm, + use_bias=True, + params_dtype=torch.float, + num_layers=28, + position_encoding_2d=True + ): + super(GLMBlock, self).__init__() + # Set output layer initialization if not provided. + + self.layer_id = layer_id + + # Layernorm on the input data. + self.input_layernorm = layernorm(hidden_size, eps=layernorm_epsilon) + self.layernorm_epsilon = layernorm_epsilon + + self.position_encoding_2d = position_encoding_2d + + # Self attention. + self.attention = SelfAttention( + hidden_size, + num_attention_heads, + layer_id, + hidden_size_per_attention_head=hidden_size_per_attention_head, + bias=use_bias, + params_dtype=params_dtype, + position_encoding_2d=self.position_encoding_2d + ) + + # Layernorm on the input data. + self.post_attention_layernorm = layernorm( + hidden_size, eps=layernorm_epsilon) + + self.num_layers = num_layers + + # GLU + self.mlp = GLU( + hidden_size, + inner_hidden_size=inner_hidden_size, + bias=use_bias, + layer_id=layer_id, + params_dtype=params_dtype, + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_ids, + attention_mask: torch.Tensor, + layer_id, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + use_cache: bool = False, + output_attentions: bool = False, + ): + """ + hidden_states: [seq_len, batch, hidden_size] + attention_mask: [(1, 1), seq_len, seq_len] + """ + + test_glmBlockOut = None + test_presentKey = None + test_presentValue = None + + outputs = None + + # Layer norm at the begining of the transformer layer. + # [seq_len, batch, hidden_size] + if layer_past is None: + attention_input = self.input_layernorm(hidden_states) + + # Self attention. + attention_outputs = self.attention( + attention_input, + position_ids, + attention_mask=attention_mask, + layer_id=layer_id, + layer_past=layer_past, + use_cache=use_cache, + output_attentions=output_attentions + ) + + attention_output = attention_outputs[0] + + outputs = attention_outputs[1:] + + # Residual connection. + alpha = (2 * self.num_layers) ** 0.5 + hidden_states = attention_input * alpha + attention_output + + mlp_input = self.post_attention_layernorm(hidden_states) + + # MLP. + mlp_output = self.mlp(mlp_input) + + # Second residual connection. + output = mlp_input * alpha + mlp_output + + if use_cache: + outputs = (output,) + outputs + else: + outputs = (output,) + outputs[1:] + + else: + global cosTable + global sinTable + pastKey, pastValue = layer_past + inputs = [hidden_states] + weights = list(self.state_dict().values()) + del weights[2] + inputs.extend(weights) + inputs.append(position_ids) + inputs.append(cosTable) + inputs.append(sinTable) + inputs.append(attention_mask) + inputs.append(pastKey) + inputs.append(pastValue) + acl_layer.set_param(json.dumps({"transKey": True, "dk": 128, "headNum": 32, "layerId": self.layer_id, + "layerNormEps": self.layernorm_epsilon, "ResidualAddScale": math.sqrt(2 * self.num_layers)})) + + test_glmBlockOut, test_presentKey, test_presentValue = acl_layer.execute( + inputs) + + outputs = (test_glmBlockOut, (test_presentKey, test_presentValue)) + + + + return outputs # hidden_states, present, attentions + + +class ChatGLMPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and + a simple interface for downloading and loading pretrained models. + """ + + is_parallelizable = False + supports_gradient_checkpointing = False + config_class = ChatGLMConfig + base_model_prefix = "transformer" + _no_split_modules = ["GLM6BBlock"] + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights(self, module: nn.Module): + """Initialize the weights.""" + return + + +CHATGLM_6B_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general + usage and behavior. + + Parameters: + config ([`~ChatGLM6BConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the configuration. + Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +CHATGLM_6B_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`ChatGLM6BTokenizer`]. + See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. + Selected in the range `[0, config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert *input_ids* indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare ChatGLM-6B Model transformer outputting raw hidden-states without any specific head on top.", + CHATGLM_6B_START_DOCSTRING, +) +class ChatGLMModel(ChatGLMPreTrainedModel): + """ + + The model can behave as an encoder (with only self-attention) as well + as a decoder, in which case a layer of cross-attention is added between + the self-attention layers, following the architecture described in [Attention is + all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, + Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + + To behave as an decoder the model needs to be initialized with the + `is_decoder` argument of the configuration set to `True`. + To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` + argument and `add_cross_attention` set to `True`; an + `encoder_hidden_states` is then expected as an input to the forward pass. + """ + + def __init__(self, config: ChatGLMConfig): + super().__init__(config) + + # recording parameters + self.max_sequence_length = config.max_sequence_length + self.hidden_size = config.hidden_size + self.params_dtype = torch.half + self.num_attention_heads = config.num_attention_heads + self.vocab_size = config.vocab_size + self.num_layers = config.num_layers + self.layernorm_epsilon = config.layernorm_epsilon + self.inner_hidden_size = config.inner_hidden_size + self.hidden_size_per_attention_head = self.hidden_size // self.num_attention_heads + self.position_encoding_2d = config.position_encoding_2d + + self.word_embeddings = skip_init( + torch.nn.Embedding, + num_embeddings=self.vocab_size, embedding_dim=self.hidden_size, + dtype=self.params_dtype + ) + + def get_layer(layer_id): + return GLMBlock( + self.hidden_size, + self.num_attention_heads, + self.layernorm_epsilon, + layer_id, + inner_hidden_size=self.inner_hidden_size, + hidden_size_per_attention_head=self.hidden_size_per_attention_head, + layernorm=LayerNorm, + use_bias=True, + params_dtype=self.params_dtype, + position_encoding_2d=self.position_encoding_2d, + ) + + self.layers = torch.nn.ModuleList( + [get_layer(layer_id) for layer_id in range(self.num_layers)] + ) + + # Final layer norm before output. + self.final_layernorm = LayerNorm( + self.hidden_size, eps=self.layernorm_epsilon) + + def get_input_embeddings(self): + return self.word_embeddings + + def set_input_embeddings(self, new_embeddings: torch.Tensor): + self.word_embeddings = new_embeddings + + def get_masks(self, seq, device): + context_length = seq.index(self.config.bos_token_id) + 1 + + attention_mask = torch.ones((1, len(seq), len(seq)), device=device) + attention_mask.tril_() + attention_mask[..., :context_length - 1] = 1 + attention_mask.unsqueeze_(1) + attention_mask = (attention_mask < 0.5).bool() + + return attention_mask + + def get_position_ids(self, seq, mask_position, device, gmask=False): + context_length = seq.index(self.config.bos_token_id) + 1 + if self.position_encoding_2d: + seq_length = seq.index(self.config.bos_token_id) + position_ids = torch.arange( + context_length, dtype=torch.long, device=device) + if not gmask: + position_ids[seq_length:] = mask_position + block_position_ids = torch.cat(( + torch.zeros(seq_length, dtype=torch.long, device=device), + torch.arange(context_length - seq_length, + dtype=torch.long, device=device) + 1 + )) + position_ids = torch.stack( + (position_ids, block_position_ids), dim=0) + else: + position_ids = torch.arange( + context_length, dtype=torch.long, device=device) + if not gmask: + position_ids[context_length - 1:] = mask_position + + position_ids = position_ids.unsqueeze(0) + + return position_ids + + @add_start_docstrings_to_model_forward(CHATGLM_6B_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPastAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, + torch.Tensor], ...]] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPast]: + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape[:2] + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape[:2] + else: + raise ValueError( + "You have to specify either input_ids or inputs_embeds") + + if past_key_values is None: + past_key_values = tuple([None] * len(self.layers)) + seq = input_ids[0].tolist() + + if attention_mask is None: + attention_mask = self.get_masks( + seq=seq, + device=input_ids.device + ) + + if position_ids is None: + MASK, gMASK = 150000, 150001 + mask_token = MASK if MASK in input_ids else gMASK + use_gmask = False if MASK in input_ids else gMASK + + mask_position = seq.index(mask_token) + position_ids = self.get_position_ids( + seq=seq, + mask_position=mask_position, + device=input_ids.device, + gmask=use_gmask + ) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + # [seq_len, batch, hidden_size] + hidden_states = inputs_embeds.transpose(0, 1) + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + seq_length_with_past = seq_length + past_key_values_length = 0 + if past_key_values[0] is not None: + past_key_values_length = past_key_values[0][0].shape[0] + seq_length_with_past = seq_length_with_past + past_key_values_length + if attention_mask is None: + attention_mask = torch.zeros(1, 1, device=input_ids.device).bool() + + else: + attention_mask = attention_mask.to(input_ids.device) + + for i, layer in enumerate(self.layers): + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_ret = layer( + hidden_states, + position_ids=position_ids, + attention_mask=attention_mask, + layer_id=torch.tensor(i), + layer_past=past_key_values[i], + use_cache=use_cache, + output_attentions=output_attentions + ) + + hidden_states = layer_ret[0] + + if use_cache: + presents = presents + (layer_ret[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + \ + (layer_ret[2 if use_cache else 1],) + + # Final layer norm. + hidden_states = self.final_layernorm(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + # self.hidden_size = config.hidden_size + # self.params_dtype = torch.half + # self.vocab_size = config.vocab_size + self.max_sequence_length = config.max_sequence_length + + self.position_encoding_2d = config.position_encoding_2d + + self.transformer = ChatGLMModel(config) + + self.lm_head = skip_init( + nn.Linear, + config.hidden_size, + config.vocab_size, + bias=False, + dtype=torch.half + ) + + self.count = 0 + self.total = 0 + self.cur_time = 0 + self.first = 0 + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_masks_and_position_ids(self, seq, mask_position, context_length, device, gmask=False): + attention_mask = torch.ones( + (1, context_length, context_length), device=device) + attention_mask.tril_() + attention_mask[..., :context_length - 1] = 1 + attention_mask.unsqueeze_(1) + attention_mask = (attention_mask < 0.5).bool() + + if self.position_encoding_2d: + seq_length = seq.index(self.config.bos_token_id) + position_ids = torch.arange( + context_length, dtype=torch.long, device=device) + if not gmask: + position_ids[seq_length:] = mask_position + block_position_ids = torch.cat(( + torch.zeros(seq_length, dtype=torch.long, device=device), + torch.arange(context_length - seq_length, + dtype=torch.long, device=device) + 1 + )) + position_ids = torch.stack( + (position_ids, block_position_ids), dim=0) + else: + position_ids = torch.arange( + context_length, dtype=torch.long, device=device) + if not gmask: + position_ids[context_length - 1:] = mask_position + + position_ids = position_ids.unsqueeze(0) + + return attention_mask, position_ids + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor, + past: Optional[torch.Tensor] = None, + past_key_values: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **kwargs + ) -> dict: + + MASK, gMASK = 150000, 150001 + mask_token = MASK if MASK in input_ids else gMASK + use_gmask = False if MASK in input_ids else gMASK + seq = input_ids[0].tolist() + mask_position = seq.index(mask_token) + + if mask_token not in seq: + raise ValueError( + "You have to add either [MASK] or [gMASK] in your input") + + # only last token for input_ids if past is not None + if past is not None or past_key_values is not None: + context_length = seq.index(self.config.bos_token_id) + last_token = input_ids[:, -1].unsqueeze(-1) + if self.position_encoding_2d: + position_ids = torch.tensor([[[mask_position], [len(seq) - context_length]]], dtype=torch.long, + device=input_ids.device) + else: + position_ids = torch.tensor( + [[mask_position]], dtype=torch.long, device=input_ids.device) + + if past is None: + past = past_key_values + return { + "input_ids": last_token, + "past_key_values": past, + "position_ids": position_ids, + } + else: + attention_mask, position_ids = self.get_masks_and_position_ids( + seq=seq, + mask_position=mask_position, + context_length=len(seq), + device=input_ids.device, + gmask=use_gmask + ) + + return { + "input_ids": input_ids, + "past_key_values": past, + "position_ids": position_ids, + "attention_mask": attention_mask + } + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + + lm_logits = self.lm_head(hidden_states).permute(1, 0, 2).contiguous() + + loss = None + if labels is not None: + lm_logits = lm_logits.to(torch.float32) + + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct( + shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + lm_logits = lm_logits.to(hidden_states.dtype) + loss = loss.to(hidden_states.dtype) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + @staticmethod + def _reorder_cache( + past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor + ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct + beam_idx at every generation step. + + Output shares the same memory storage as `past`. + """ + return tuple( + ( + layer_past[0].index_select( + 1, beam_idx.to(layer_past[0].device)), + layer_past[1].index_select( + 1, beam_idx.to(layer_past[1].device)), + ) + for layer_past in past + ) + + @torch.no_grad() + def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1, + do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs): + if history is None: + history = [] + if logits_processor is None: + logits_processor = LogitsProcessorList() + logits_processor.append(InvalidScoreLogitsProcessor()) + gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p, + "temperature": temperature, "logits_processor": logits_processor, **kwargs} + if not history: + prompt = query + else: + prompt = "" + for i, (old_query, response) in enumerate(history): + prompt += "[Round {}]\n问:{}\n答:{}\n".format( + i, old_query, response) + prompt += "[Round {}]\n问:{}\n答:".format(len(history), query) + input_ids = tokenizer([prompt], return_tensors="pt", padding=True) + # print(input_ids) + input_ids = input_ids.to(self.device) + outputs = self.generate(**input_ids, **gen_kwargs) + outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):] + # print(outputs) + response = tokenizer.decode(outputs) + response = response.strip() + response = response.replace("[[训练时间]]", "2023年") + history = history + [(query, response)] + return response, history + + @torch.no_grad() + def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, + do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs): + if history is None: + history = [] + if logits_processor is None: + logits_processor = LogitsProcessorList() + logits_processor.append(InvalidScoreLogitsProcessor()) + gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p, + "temperature": temperature, "logits_processor": logits_processor, **kwargs} + if not history: + prompt = query + else: + prompt = "" + for i, (old_query, response) in enumerate(history): + prompt += "用户:{}\n机器人:{}\n".format(old_query, response) + prompt += "用户:{}\n机器人:".format(query) + input_ids = tokenizer([prompt], return_tensors="pt", padding=True) + input_ids = input_ids.to(self.device) + for outputs in self.stream_generate(**input_ids, **gen_kwargs): + outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):] + response = tokenizer.decode(outputs) + # response = response.strip() + # response = response.replace("[[训练时间]]", "2023年") + new_history = history + [(query, response)] + yield response, new_history + + @torch.no_grad() + def stream_generate( + self, + input_ids, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[Callable[[ + int, torch.Tensor], List[int]]] = None, + **kwargs, + ): + batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] + + if generation_config is None: + generation_config = self.generation_config + generation_config = copy.deepcopy(generation_config) + model_kwargs = generation_config.update(**kwargs) + bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id + + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + + has_default_max_length = kwargs.get( + "max_length") is None and generation_config.max_length is not None + if has_default_max_length and generation_config.max_new_tokens is None: + warnings.warn( + f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " + "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" + " recommend using `max_new_tokens` to control the maximum length of the generation.", + UserWarning, + ) + elif generation_config.max_new_tokens is not None: + generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length + if not has_default_max_length: + logger.warn( + f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" + f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " + "Please refer to the documentation for more information. " + "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)", + UserWarning, + ) + + if input_ids_seq_length >= generation_config.max_length: + input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" + logger.warning( + f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" + f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" + " increasing `max_new_tokens`." + ) + + # 2. Set generation parameters if not already defined + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + + logits_processor = self._get_logits_processor( + generation_config=generation_config, + input_ids_seq_length=input_ids_seq_length, + encoder_input_ids=input_ids, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + logits_processor=logits_processor, + ) + + stopping_criteria = self._get_stopping_criteria( + generation_config=generation_config, stopping_criteria=stopping_criteria + ) + logits_warper = self._get_logits_warper(generation_config) + + unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) + scores = None + while True: + model_inputs = self.prepare_inputs_for_generation( + input_ids, **model_kwargs) + # forward pass to get next token + torch.npu.synchronize() + start = time.time() + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=False, + output_hidden_states=False, + ) + torch.npu.synchronize() + end = time.time() + self.count += 1 + self.cur_time = (end - start) * 1000 + if self.count == 1: + self.first = self.cur_time + else: + self.total += self.cur_time + + next_token_logits = outputs.logits[:, -1, :] + + # pre-process distribution + next_token_scores = logits_processor(input_ids, next_token_logits) + next_token_scores = logits_warper(input_ids, next_token_scores) + + # sample + probs = nn.functional.softmax(next_token_scores, dim=-1) + if generation_config.do_sample: + next_tokens = torch.multinomial( + probs, num_samples=1).squeeze(1) + else: + next_tokens = torch.argmax(probs, dim=-1) + + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + ) + unfinished_sequences = unfinished_sequences.mul( + (sum(next_tokens != i for i in eos_token_id)).long()) + + # stop when each sentence is finished, or if we exceed the maximum length + if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): + break + yield input_ids + + def quantize(self, bits: int): + from .quantization import quantize + self.transformer = quantize(self.transformer, bits) + return self