From e0bb00179141d57636df41d4b5aa0460e211e9b9 Mon Sep 17 00:00:00 2001 From: majunwang <844234020@qq.com> Date: Mon, 25 Aug 2025 15:19:17 +0000 Subject: [PATCH 1/8] =?UTF-8?q?=E6=96=B0=E5=BB=BA=20behaviour=5Fand=5Fmult?= =?UTF-8?q?i=5Ftask?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- behaviour_and_multi_task/.keep | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 behaviour_and_multi_task/.keep diff --git a/behaviour_and_multi_task/.keep b/behaviour_and_multi_task/.keep new file mode 100644 index 000000000..e69de29bb -- Gitee From 9fa6c4ee7ab6be0dfa7c4b95ae35b3b0b4c86ef3 Mon Sep 17 00:00:00 2001 From: majunwang <844234020@qq.com> Date: Mon, 25 Aug 2025 15:19:56 +0000 Subject: [PATCH 2/8] add behaviour_and_multi_task/utils. Signed-off-by: majunwang <844234020@qq.com> --- behaviour_and_multi_task/utils | 144 +++++++++++++++++++++++++++++++++ 1 file changed, 144 insertions(+) create mode 100644 behaviour_and_multi_task/utils diff --git a/behaviour_and_multi_task/utils b/behaviour_and_multi_task/utils new file mode 100644 index 000000000..ed2ab3bde --- /dev/null +++ b/behaviour_and_multi_task/utils @@ -0,0 +1,144 @@ +import os +import json +import stat +import glob +import tensorflow as tf +from torch.utils.data import Dataset, DataLoader +import torch +import torch.nn as nn + +from sklearn.metrics import roc_auc_score + +class TFRecordDataset(Dataset): + def __init__(self, params, filepath, feature_description, spec, mode, shuffle=False): + self.params = params + self.spec = spec + self.filepath = filepath + self.feature_description = feature_description + self.dataset = tf.data.TFRecordDataset(self.filepath) + if shuffle: + self.dataset.shuffle(buffer_size=500000) + self.dataset = self.dataset.repeat(params.num.epochs).batch(self.params.batch_size, drop_remainder=True).map( + self.parse_example, num_parallel_calls=tf.data.AUTOTUNE).prefetch(100) + self.iterator = tf.compat.v1.data.make_one_shot_iterator(self.dataset) + self.length = spec['dataset_size'][mode] // self.params.batch_size + + def parse_example(self, example): + parsed_example = tf.io.parse_example(example, self.feature_description) + input_data = {} + target = {"y": parsed_example["y"], "z": parsed_example["z"]} + for index,key in enumerate(self.spec["one_hot_fielfs"]): + input_data[key] = parsed_example["one_hot_fields"][:,index] + for key in self.spec["multi_hot_fileds"]: + input_data[key] = parsed_example[key] + for key in self.spec["special_fileds"]: + input_data[key] = parsed_example[key] + + def __len__(self): + return self.length + + def __getitem__(self, idx): + batch_features, batch_labels = self.iterator.get_next() + for key, val in batch_features.items(): + batch_features[key] = torch.tensor(val.numpy(), dtype=torch.long).to(self.params.device) + for key, val in batch_labels.items(): + batch_labels[key] = torch.tensor(val.numpy(), dtype=torch.float32).to(self.params.device) + return batch_features, batch_labels + +def cal_auc(pred, labels): + return roc_auc_score(labels,pred) + +def build_feature_descriptions(params): + if params.mode == 'infer': + root_path = os.path.abspath(__file__) + root_path = os.path.sep.join(root_path.split(os.path.sep)[:-2]) + spec_json_path = os.path.join(root_path, "spec.json") + else: + spec_json_path = os.path.join(params.data_dir, "spec.json") + local_spec = json_file_load("spec", spec_json_path) + + local_feature_descriptions = {} + mode = ['train', 'val', 'test'] + for mode_type in mode: + feature_description = { + 'y': tf.io.FixedLenFeature([], tf.float32), + 'z': tf.io.FixedLenFeature([], tf.float32), + 'one_hot_fields': tf.io.FixedLenFeature([len(local_spec["one_hot_fields"])], tf.int64) + } + for mul_fields in local_spec["multi_hot_fields"]: + feature_description[mul_fields] = tf.io.FixedLenFeature( + [local_spec.get(f"{mode_type}_max_length").get(mul_fields)], + tf.int64) + for mul_fields in local_spec["special_fields"]: + feature_description[mul_fields] = tf.io.FixedLenFeature( + [local_spec.get(f"{mode_type}_max_length").get(mul_fields)], + tf.int64) + local_feature_descriptions[mode_type] = feature_description + + return local_spec, local_feature_descriptions + +def json_file_load(json_name: str, json_path: str) -> dict: + """ + Load a JSON file from the specified path. + """ + flags = os.O_RDONLY + modes = stat.S_IRUSR | stat.S_IWUSR | stat.S_IRGRP | stat.S_IROTH + try: + with os.fdopen(os.open(json_path, flags, modes), "r") as fp: + json_re = json.load(fp) + except FileNotFoundError as e: + raise FileNotFoundError(f"{json_name} file not found: {e}") from e + except Exception as e: + raise RuntimeError(f"Error loading {json_name} file: {e}") from e + + return json_re + +def load_data(params): + spec, feature_desc = build_feature_descriptions(params) + root_path = os.path.abspath(__file__) + root_path = os.path.sep.join(root_path.split(os.path.sep)[:-2]) + train_order = json_file_load("order", os.path.join(root_path, "order.json")) + tr_files = [] + for index in train_order["reading_order"]: + tr_files.append(os.path.join(params.data_dir, 'train', 'data_train.csv.tfrecord.{}'.format(index))) + va_files=glob.glob(os.path.join(params.data_dir, 'val', 'data_val.csv.tfrecord.*')) + te_files=glob.glob(os.path.join(params.data_dir, 'test', 'data_test.csv.tfrecord.*')) + + train_dataset = TFRecordDataset(params, tr_files, feature_description=feature_desc.get('train'), + spec=spec, mode = 'train',shuffle=True) + test_dataset = TFRecordDataset(params, te_files, feature_description=feature_desc.get('test'), + spec=spec, mode = 'test') + val_dataset = TFRecordDataset(params, va_files, feature_description=feature_desc.get('val'), + spec=spec, mode = 'val') + + # batch_sizeand num_worker由内部处理 + collect_fn = lambda x:x[0] + train_loader = DataLoader(train_dataset, batch_size=1, collect_fn=collect_fn) + test_loader = DataLoader(test_dataset, batch_size=1, collect_fn=collect_fn) + val_loader = DataLoader(val_dataset, batch_size=1, collect_fn=collect_fn) + return train_loader, test_loader, val_loader, spec + +def load_generate_data(params): + spec_json_path = os.path.join(params.data_dir, "spec.json") + local_spec = json_file_load("spec", spec_json_path) + return local_spec + +def generate_data(params, device,spec): + features = {} + for key in ["101", "121", "122", "124", "125", "126", "127","128", "129", + "205", "206", "207", "216", "508", "509", "702", "301"]: + features[key] = torch.randint(low=0,high=spec['vocab_length'][key], size=(params.batch_size, 1)).to(device) + + for key in ["109_14", "110_14", "127_14", "150_14"]: + features[key] = torch.randint(low=0,high=spec['vocab_length'][key], size=(params.batch_size, 50)).to(device) + + for key in ["210", "853"]: + features[key] = torch.randint(low=0,high=spec['vocab_length'][key], size=(params.batch_size, 38)).to(device) + +def infer_with_generate_data(params, model, spec): + model.eval() + features = generate_data(params, params.device, spec) + pred = model(features,spec) + print(f'pred result:{pred}') + return pred + \ No newline at end of file -- Gitee From b2a74de14e40f4b06c7161a13143120fe9de704d Mon Sep 17 00:00:00 2001 From: majunwang <844234020@qq.com> Date: Mon, 25 Aug 2025 15:20:04 +0000 Subject: [PATCH 3/8] =?UTF-8?q?=E5=88=A0=E9=99=A4=E6=96=87=E4=BB=B6=20beha?= =?UTF-8?q?viour=5Fand=5Fmulti=5Ftask/.keep?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- behaviour_and_multi_task/.keep | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 behaviour_and_multi_task/.keep diff --git a/behaviour_and_multi_task/.keep b/behaviour_and_multi_task/.keep deleted file mode 100644 index e69de29bb..000000000 -- Gitee From 2fe7f09b44ded3477d3a568ea204887d090ab62a Mon Sep 17 00:00:00 2001 From: majunwang <844234020@qq.com> Date: Mon, 25 Aug 2025 15:20:25 +0000 Subject: [PATCH 4/8] =?UTF-8?q?=E6=96=B0=E5=BB=BA=20feature=5Finteration?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- feature_interation/.keep | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 feature_interation/.keep diff --git a/feature_interation/.keep b/feature_interation/.keep new file mode 100644 index 000000000..e69de29bb -- Gitee From 3190ed4748de0fe8a5ba394ba249e885f1dbee5b Mon Sep 17 00:00:00 2001 From: majunwang <844234020@qq.com> Date: Mon, 25 Aug 2025 15:20:57 +0000 Subject: [PATCH 5/8] add feature_interation/utils.py. Signed-off-by: majunwang <844234020@qq.com> --- feature_interation/utils.py | 134 ++++++++++++++++++++++++++++++++++++ 1 file changed, 134 insertions(+) create mode 100644 feature_interation/utils.py diff --git a/feature_interation/utils.py b/feature_interation/utils.py new file mode 100644 index 000000000..b61fb2396 --- /dev/null +++ b/feature_interation/utils.py @@ -0,0 +1,134 @@ +import os +import random +import json +import stat +import glob +import tensorflow as tf +from torch.utils.data import Dataset, DataLoader +import torch + +from sklearn.metrics import roc_auc_score + +CRITEO_NUM = { + 'train': 33003326, + 'val': 8250124, + 'test': 4587176 +} + +class TFRecordDataset(Dataset): + def __init__(self, params, filepath, mode, shuffle=False): + self.params = params + self.filepath = filepath + self.feature_description = { + 'label': tf.io.FixedLenFeature(shape=(), dtype=tf.float32), + 'ids': tf.io.FixedLenFeature(shape=(self.params.field_size,), dtype=tf.int64), + 'values': tf.io.FixedLenFeature(shape=(self.params.field_size,), dtype=tf.float32), + } + self.dataset = tf.data.TFRecordDataset(self.filepath) + if shuffle: + self.dataset.shuffle(buffer_size=500000) + self.dataset = self.dataset.repeat(params.num.epochs).batch(self.params.batch_size, drop_remainder=True).map( + self.parse_example, num_parallel_calls=tf.data.AUTOTUNE).prefetch(100) + self.iterator = tf.compat.v1.data.make_one_shot_iterator(self.dataset) + self.length = CRITEO_NUM[mode] // self.params.batch_size + + def parse_example(self, example): + sample = tf.io.parse_example(example, self.feature_description) + sample['ids'] = tf.cast(sample['ids'], dtype=tf.int32) + return {"feat_ids": sample['ids'], "feat_vals": sample['values']}, sample['label'] + + def __len__(self): + return self.length + + def __getitem__(self, idx): + batch_features, batch_labels = self.iterator.get_next() + batch_features['feat_ids'] = torch.tensor(batch_features['feat_ids'].numpy(), dtype=torch.int32).to(self.params.device) + batch_features['feat_vals'] = torch.tensor(batch_features['feat_vals'].numpy(), dtype=torch.int32).to(self.params.device) + + batch_labels = torch.tensor(batch_labels.numpy(), dtype=torch.float32).to(self.params.device) + return batch_features, batch_labels + +def cal_auc(pred, labels): + return roc_auc_score(labels,pred) + +def json_file_load(json_name: str, json_path: str) -> dict: + """ + Load a JSON file from the specified path. + """ + flags = os.O_RDONLY + modes = stat.S_IRUSR | stat.S_IWUSR | stat.S_IRGRP | stat.S_IROTH + try: + with os.fdopen(os.open(json_path, flags, modes), "r") as fp: + json_re = json.load(fp) + except FileNotFoundError as e: + raise FileNotFoundError(f"{json_name} file not found: {e}") from e + except Exception as e: + raise RuntimeError(f"Error loading {json_name} file: {e}") from e + + return json_re + +def load_data(params): + tr_files=glob.glob(os.path.join(params.data_dir, "tr*tfrecords")) + random.shuffle(tr_files) + va_files=glob.glob(os.path.join(params.data_dir, "va*tfrecords")) + te_files=glob.glob(os.path.join(params.data_dir, "te*tfrecords")) + + train_dataset = TFRecordDataset(params, tr_files, model ='train',shuffle=True) + test_dataset = TFRecordDataset(params, te_files, mode = 'test') + val_dataset = TFRecordDataset(params, va_files, mode = 'val') + + # batch_sizeand num_worker由内部处理 + collect_fn = lambda x:x[0] + train_loader = DataLoader(train_dataset, batch_size=1, num_workers=0, collect_fn=collect_fn) + test_loader = DataLoader(test_dataset, batch_size=1, num_workers=0, collect_fn=collect_fn) + val_loader = DataLoader(val_dataset, batch_size=1, num_workers=0, collect_fn=collect_fn) + return train_loader, test_loader, val_loader + +def generate_data(params, device,spec): + features = {} + features['feat_ids'] = torch.randint(0, 32, (params.batch_size, params.field_size)).to(device) + features['feat_vals'] = torch.randint((params.batch_size, params.field_size)).to(device) + labels = torch.randint(0, 2, (params.batch_size, 1)).to(device) + return features, labels + +def infer_with_generate_data(params, model, spec): + model.eval() + features = generate_data(params, params.device, spec) + pred = model(features,spec) + print(f'pred result:{pred}') + return pred + +if __name__== '__main__': + import os + from tqdm import tqdm + data_path = "D://dataset/criteo" + # ------init Envs------------- + tr_files=glob.glob(os.path.join(data_path, "tr*tfrecords")) + va_files=glob.glob(os.path.join(data_path, "va*tfrecords")) + te_files=glob.glob(os.path.join(data_path, "te*tfrecords")) + files_dict = { + 'train': tr_files, + 'val': va_files, + 'test': te_files, + } + def parse_example(example): + features = { + # extract features using the keys set during creation + 'label': tf.io.FixedLenFeature(shape=(), dtype=tf.float32), + 'ids': tf.io.FixedLenFeature(shape=(39,), dtype=tf.int64), + 'valuse': tf.io.FixedLenFeature(shape=(39,), dtype=tf.float32), + } + sample = tf.io.parse_example(example, features) + sample['ids'] = tf.cast(sample['ids'], dtype=tf.int32) + return {"feat_ids": sample['ids'], "feat_val": sample['values']}, sample['label'] + + for key, files in files_dict.items(): + dataset = tf.data.TFRecordDataset(files) + dataset = dataset.batch(4096).map(parse_example).prefetch(100) + iterator = tf.compat.v1.data.make_one_shot_iterator(dataset): + for batch_features, batch_labels in tqdm(iterator): + batch_features['feat_ids'] = torch.tensor(batch_features['feat_ids'].numpy(), dtype=torch.int32) + batch_features['geat_vals'] = torch.tensor(batch_features['feat_vals'].numpy(), dtype=torch.float32) + batch_labels = torch.tensor(batch_labels.numpy(), dtype=torch.float32) + print('over') + \ No newline at end of file -- Gitee From 79bbd3db054fa15138eb621bf51cad5379e05e23 Mon Sep 17 00:00:00 2001 From: majunwang <844234020@qq.com> Date: Mon, 25 Aug 2025 15:21:08 +0000 Subject: [PATCH 6/8] =?UTF-8?q?=E5=88=A0=E9=99=A4=E6=96=87=E4=BB=B6=20feat?= =?UTF-8?q?ure=5Finteration/.keep?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- feature_interation/.keep | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 feature_interation/.keep diff --git a/feature_interation/.keep b/feature_interation/.keep deleted file mode 100644 index e69de29bb..000000000 -- Gitee From b1cc73462cbaa386eb516649ff06c67a75de26d6 Mon Sep 17 00:00:00 2001 From: majunwang <844234020@qq.com> Date: Mon, 25 Aug 2025 15:27:59 +0000 Subject: [PATCH 7/8] update feature_interation/utils.py. Signed-off-by: majunwang <844234020@qq.com> --- feature_interation/utils.py | 47 ++++++++++++++++++------------------- 1 file changed, 23 insertions(+), 24 deletions(-) diff --git a/feature_interation/utils.py b/feature_interation/utils.py index b61fb2396..f6598c830 100644 --- a/feature_interation/utils.py +++ b/feature_interation/utils.py @@ -27,7 +27,7 @@ class TFRecordDataset(Dataset): self.dataset = tf.data.TFRecordDataset(self.filepath) if shuffle: self.dataset.shuffle(buffer_size=500000) - self.dataset = self.dataset.repeat(params.num.epochs).batch(self.params.batch_size, drop_remainder=True).map( + self.dataset = self.dataset.repeat(params.num_epochs).batch(self.params.batch_size, drop_remainder=True).map( self.parse_example, num_parallel_calls=tf.data.AUTOTUNE).prefetch(100) self.iterator = tf.compat.v1.data.make_one_shot_iterator(self.dataset) self.length = CRITEO_NUM[mode] // self.params.batch_size @@ -43,13 +43,12 @@ class TFRecordDataset(Dataset): def __getitem__(self, idx): batch_features, batch_labels = self.iterator.get_next() batch_features['feat_ids'] = torch.tensor(batch_features['feat_ids'].numpy(), dtype=torch.int32).to(self.params.device) - batch_features['feat_vals'] = torch.tensor(batch_features['feat_vals'].numpy(), dtype=torch.int32).to(self.params.device) - + batch_features['feat_vals'] = torch.tensor(batch_features['feat_vals'].numpy(), dtype=torch.float32).to(self.params.device) batch_labels = torch.tensor(batch_labels.numpy(), dtype=torch.float32).to(self.params.device) return batch_features, batch_labels def cal_auc(pred, labels): - return roc_auc_score(labels,pred) + return roc_auc_score(labels, pred) def json_file_load(json_name: str, json_path: str) -> dict: """ @@ -68,33 +67,33 @@ def json_file_load(json_name: str, json_path: str) -> dict: return json_re def load_data(params): - tr_files=glob.glob(os.path.join(params.data_dir, "tr*tfrecords")) + tr_files = glob.glob(os.path.join(params.data_dir, "tr*tfrecords")) random.shuffle(tr_files) - va_files=glob.glob(os.path.join(params.data_dir, "va*tfrecords")) - te_files=glob.glob(os.path.join(params.data_dir, "te*tfrecords")) + va_files = glob.glob(os.path.join(params.data_dir, "va*tfrecords")) + te_files = glob.glob(os.path.join(params.data_dir, "te*tfrecords")) - train_dataset = TFRecordDataset(params, tr_files, model ='train',shuffle=True) + train_dataset = TFRecordDataset(params, tr_files, mode ='train',shuffle=True) test_dataset = TFRecordDataset(params, te_files, mode = 'test') val_dataset = TFRecordDataset(params, va_files, mode = 'val') # batch_sizeand num_worker由内部处理 - collect_fn = lambda x:x[0] - train_loader = DataLoader(train_dataset, batch_size=1, num_workers=0, collect_fn=collect_fn) - test_loader = DataLoader(test_dataset, batch_size=1, num_workers=0, collect_fn=collect_fn) - val_loader = DataLoader(val_dataset, batch_size=1, num_workers=0, collect_fn=collect_fn) + collect_fn = lambda x: x[0] + train_loader = DataLoader(train_dataset, batch_size=1, num_workers=0, collate_fn=collect_fn) + test_loader = DataLoader(test_dataset, batch_size=1, num_workers=0, collate_fn=collect_fn) + val_loader = DataLoader(val_dataset, batch_size=1, num_workers=0, collate_fn=collect_fn) return train_loader, test_loader, val_loader -def generate_data(params, device,spec): +def generate_data(params, device): features = {} features['feat_ids'] = torch.randint(0, 32, (params.batch_size, params.field_size)).to(device) - features['feat_vals'] = torch.randint((params.batch_size, params.field_size)).to(device) + features['feat_vals'] = torch.rand((params.batch_size, params.field_size)).to(device) labels = torch.randint(0, 2, (params.batch_size, 1)).to(device) return features, labels -def infer_with_generate_data(params, model, spec): +def infer_with_generate_data(params, model): model.eval() - features = generate_data(params, params.device, spec) - pred = model(features,spec) + features = generate_data(params, params.device) + pred = model(features) print(f'pred result:{pred}') return pred @@ -103,9 +102,9 @@ if __name__== '__main__': from tqdm import tqdm data_path = "D://dataset/criteo" # ------init Envs------------- - tr_files=glob.glob(os.path.join(data_path, "tr*tfrecords")) - va_files=glob.glob(os.path.join(data_path, "va*tfrecords")) - te_files=glob.glob(os.path.join(data_path, "te*tfrecords")) + tr_files = glob.glob(os.path.join(data_path, "tr*tfrecords")) + va_files = glob.glob(os.path.join(data_path, "va*tfrecords")) + te_files = glob.glob(os.path.join(data_path, "te*tfrecords")) files_dict = { 'train': tr_files, 'val': va_files, @@ -116,19 +115,19 @@ if __name__== '__main__': # extract features using the keys set during creation 'label': tf.io.FixedLenFeature(shape=(), dtype=tf.float32), 'ids': tf.io.FixedLenFeature(shape=(39,), dtype=tf.int64), - 'valuse': tf.io.FixedLenFeature(shape=(39,), dtype=tf.float32), + 'values': tf.io.FixedLenFeature(shape=(39,), dtype=tf.float32), } sample = tf.io.parse_example(example, features) sample['ids'] = tf.cast(sample['ids'], dtype=tf.int32) - return {"feat_ids": sample['ids'], "feat_val": sample['values']}, sample['label'] + return {"feat_ids": sample['ids'], "feat_vals": sample['values']}, sample['label'] for key, files in files_dict.items(): dataset = tf.data.TFRecordDataset(files) dataset = dataset.batch(4096).map(parse_example).prefetch(100) - iterator = tf.compat.v1.data.make_one_shot_iterator(dataset): + iterator = tf.compat.v1.data.make_one_shot_iterator(dataset) for batch_features, batch_labels in tqdm(iterator): batch_features['feat_ids'] = torch.tensor(batch_features['feat_ids'].numpy(), dtype=torch.int32) - batch_features['geat_vals'] = torch.tensor(batch_features['feat_vals'].numpy(), dtype=torch.float32) + batch_features['feat_vals'] = torch.tensor(batch_features['feat_vals'].numpy(), dtype=torch.float32) batch_labels = torch.tensor(batch_labels.numpy(), dtype=torch.float32) print('over') \ No newline at end of file -- Gitee From 8ad84c4b0e7ee552719e2d2849140fd7d22b3660 Mon Sep 17 00:00:00 2001 From: majunwang <844234020@qq.com> Date: Mon, 25 Aug 2025 15:30:59 +0000 Subject: [PATCH 8/8] update behaviour_and_multi_task/utils. Signed-off-by: majunwang <844234020@qq.com> --- behaviour_and_multi_task/utils | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/behaviour_and_multi_task/utils b/behaviour_and_multi_task/utils index ed2ab3bde..64bb28b6b 100644 --- a/behaviour_and_multi_task/utils +++ b/behaviour_and_multi_task/utils @@ -18,7 +18,7 @@ class TFRecordDataset(Dataset): self.dataset = tf.data.TFRecordDataset(self.filepath) if shuffle: self.dataset.shuffle(buffer_size=500000) - self.dataset = self.dataset.repeat(params.num.epochs).batch(self.params.batch_size, drop_remainder=True).map( + self.dataset = self.dataset.repeat(params.num_epochs).batch(self.params.batch_size, drop_remainder=True).map( self.parse_example, num_parallel_calls=tf.data.AUTOTUNE).prefetch(100) self.iterator = tf.compat.v1.data.make_one_shot_iterator(self.dataset) self.length = spec['dataset_size'][mode] // self.params.batch_size @@ -27,11 +27,11 @@ class TFRecordDataset(Dataset): parsed_example = tf.io.parse_example(example, self.feature_description) input_data = {} target = {"y": parsed_example["y"], "z": parsed_example["z"]} - for index,key in enumerate(self.spec["one_hot_fielfs"]): + for index,key in enumerate(self.spec["one_hot_fields"]): input_data[key] = parsed_example["one_hot_fields"][:,index] - for key in self.spec["multi_hot_fileds"]: + for key in self.spec["multi_hot_fields"]: input_data[key] = parsed_example[key] - for key in self.spec["special_fileds"]: + for key in self.spec["special_fields"]: input_data[key] = parsed_example[key] def __len__(self): @@ -46,7 +46,7 @@ class TFRecordDataset(Dataset): return batch_features, batch_labels def cal_auc(pred, labels): - return roc_auc_score(labels,pred) + return roc_auc_score(labels, pred) def build_feature_descriptions(params): if params.mode == 'infer': @@ -77,6 +77,7 @@ def build_feature_descriptions(params): return local_spec, local_feature_descriptions + def json_file_load(json_name: str, json_path: str) -> dict: """ Load a JSON file from the specified path. @@ -95,6 +96,7 @@ def json_file_load(json_name: str, json_path: str) -> dict: def load_data(params): spec, feature_desc = build_feature_descriptions(params) + root_path = os.path.abspath(__file__) root_path = os.path.sep.join(root_path.split(os.path.sep)[:-2]) train_order = json_file_load("order", os.path.join(root_path, "order.json")) @@ -113,9 +115,9 @@ def load_data(params): # batch_sizeand num_worker由内部处理 collect_fn = lambda x:x[0] - train_loader = DataLoader(train_dataset, batch_size=1, collect_fn=collect_fn) - test_loader = DataLoader(test_dataset, batch_size=1, collect_fn=collect_fn) - val_loader = DataLoader(val_dataset, batch_size=1, collect_fn=collect_fn) + train_loader = DataLoader(train_dataset, batch_size=1, collate_fn=collect_fn) + test_loader = DataLoader(test_dataset, batch_size=1, collate_fn=collect_fn) + val_loader = DataLoader(val_dataset, batch_size=1, collate_fn=collect_fn) return train_loader, test_loader, val_loader, spec def load_generate_data(params): @@ -125,7 +127,7 @@ def load_generate_data(params): def generate_data(params, device,spec): features = {} - for key in ["101", "121", "122", "124", "125", "126", "127","128", "129", + for key in ["101", "121", "122", "124", "125", "126", "127", "128", "129", "205", "206", "207", "216", "508", "509", "702", "301"]: features[key] = torch.randint(low=0,high=spec['vocab_length'][key], size=(params.batch_size, 1)).to(device) @@ -138,7 +140,7 @@ def generate_data(params, device,spec): def infer_with_generate_data(params, model, spec): model.eval() features = generate_data(params, params.device, spec) - pred = model(features,spec) + pred = model(features, spec) print(f'pred result:{pred}') return pred \ No newline at end of file -- Gitee