From 36e1773ee467fe55f6b9c151d6675508a7bd273b Mon Sep 17 00:00:00 2001 From: xumingqian Date: Mon, 12 Aug 2024 20:02:43 +0800 Subject: [PATCH] aicpu max table 128 --- .../npu_bridge/embedding/embedding_service.py | 116 ++++++++++-------- 1 file changed, 68 insertions(+), 48 deletions(-) diff --git a/tf_adapter/python/npu_bridge/embedding/embedding_service.py b/tf_adapter/python/npu_bridge/embedding/embedding_service.py index cc1292f77..362fd7f51 100644 --- a/tf_adapter/python/npu_bridge/embedding/embedding_service.py +++ b/tf_adapter/python/npu_bridge/embedding/embedding_service.py @@ -1479,53 +1479,73 @@ class ESWorker: return output_slots def _call_feature_mapping_export_op(self, path, export_value): - table_name_list = [] - offset_list = [] - for i in range(len(self._small_table_variable_list)): - table_name_list.append(self._small_table_variable_list[i][:-2]) - offset_list.append(0) - table_name_tensor = ops.convert_to_tensor(table_name_list) - feature_size = gen_npu_cpu_ops.embedding_feature_mapping_table_size(table_name=table_name_tensor) - feature_id, offset_id = gen_npu_cpu_ops.embedding_feature_mapping_find(table_name=table_name_tensor, - feature_size=feature_size, - num=len(table_name_list)) - if export_value: - tvar = tf.trainable_variables() - for x in tvar: - if x.name[3:] in self._small_table_variable_list: - idx = self._small_table_variable_list.index(x.name[3:]) - offset_list[idx] = tf.reshape(tf.gather(x, offset_id[idx]), [-1]) - values = tf.concat(offset_list, axis=0) - else: - values = 0 - feature_mapping_export = gen_npu_cpu_ops.embedding_feature_mapping_export(file_path=path, - table_name=table_name_tensor, - feature_id=feature_id, - offset_id=offset_id, - values=values, - embedding_dim= - self._small_table_variable_dim_list) - return tf.group([feature_mapping_export]) + feature_mapping_export_list = [] + num = len(self._small_table_variable_list) + index = 0 + # aicpu only support handle 128 tables at one time + while index < num: + iter_max = min(index+128, num) + table_name_list = [] + offset_list = [] + embedding_dim_list = [] + while index < iter_max: + table_name_list.append(self._small_table_variable_list[index][:-2]) + embedding_dim_list.append(self._small_table_variable_dim_list[index]) + offset_list.append(0) + index += 1 + table_name_tensor = ops.convert_to_tensor(table_name_list) + feature_size = gen_npu_cpu_ops.embedding_feature_mapping_table_size(table_name=table_name_tensor) + feature_id, offset_id = gen_npu_cpu_ops.embedding_feature_mapping_find(table_name=table_name_tensor, + feature_size=feature_size, + num=len(table_name_list)) + if export_value: + tvar = tf.trainable_variables() + for x in tvar: + if x.name[3:-2] in table_name_list: + idx = table_name_list.index(x.name[3:-2]) + offset_list[idx] = tf.reshape(tf.gather(x, offset_id[idx]), [-1]) + values = tf.concat(offset_list, axis=0) + else: + values = 0 + feature_mapping_export = gen_npu_cpu_ops.embedding_feature_mapping_export(file_path=path, + table_name=table_name_tensor, + feature_id=feature_id, + offset_id=offset_id, + values=values, + embedding_dim= + embedding_dim_list) + feature_mapping_export_list.append(feature_mapping_export) + return feature_mapping_export_list def _call_feature_mapping_import_op(self, path): - table_name_list = [] - for i in range(len(self._small_table_variable_list)): - table_name_list.append(self._small_table_variable_list[i][:-2]) - - feature_size = \ - gen_npu_cpu_ops.embedding_feature_mapping_file_size(file_path=path, - table_name=ops.convert_to_tensor(table_name_list), - embedding_dim=self._small_table_variable_dim_list, - only_offset_flag=True) - feature_id, offset_id, values = \ - gen_npu_cpu_ops.embedding_feature_mapping_import(file_path=path, - table_name=ops.convert_to_tensor(table_name_list), - feature_size=feature_size, - embedding_dim=self._small_table_variable_dim_list, - only_offset_flag=True, - num=len(self._small_table_variable_list)) - feature_mapping_insert = \ - gen_npu_cpu_ops.embedding_feature_mapping_insert(table_name=ops.convert_to_tensor(table_name_list), - feature_id=feature_id, - offset_id=offset_id) - return tf.group([feature_mapping_insert]) + feature_mapping_import_list = [] + num = len(self._small_table_variable_list) + index = 0 + # aicpu only support handle 128 tables at one time + while index < num: + iter_max = min(index+128, num) + table_name_list = [] + embedding_dim_list = [] + while index < iter_max: + table_name_list.append(self._small_table_variable_list[index][:-2]) + embedding_dim_list.append(self._small_table_variable_dim_list[index]) + index += 1 + + feature_size = \ + gen_npu_cpu_ops.embedding_feature_mapping_file_size(file_path=path, + table_name=ops.convert_to_tensor(table_name_list), + embedding_dim=embedding_dim_list, + only_offset_flag=True) + feature_id, offset_id, values = \ + gen_npu_cpu_ops.embedding_feature_mapping_import(file_path=path, + table_name=ops.convert_to_tensor(table_name_list), + feature_size=feature_size, + embedding_dim=embedding_dim_list, + only_offset_flag=True, + num=len(table_name_list)) + feature_mapping_insert = \ + gen_npu_cpu_ops.embedding_feature_mapping_insert(table_name=ops.convert_to_tensor(table_name_list), + feature_id=feature_id, + offset_id=offset_id) + feature_mapping_import_list.append(feature_mapping_insert) + return feature_mapping_import_list -- Gitee