From 1352cc45f123df50852eeee27c0a20ca17f8072c Mon Sep 17 00:00:00 2001 From: wangqc <1160619743@qq.com> Date: Thu, 12 Jun 2025 14:09:54 +0800 Subject: [PATCH] feat: add the link of ckpt in test --- .../applications/crystalflow/test_crystalflow.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/MindChemistry/applications/crystalflow/test_crystalflow.py b/MindChemistry/applications/crystalflow/test_crystalflow.py index ffce4881f..18e0a9199 100644 --- a/MindChemistry/applications/crystalflow/test_crystalflow.py +++ b/MindChemistry/applications/crystalflow/test_crystalflow.py @@ -7,6 +7,7 @@ import mindspore.numpy as mnp from mindspore import nn, ops, Tensor, mint, load_checkpoint, load_param_into_net from mindchemistry.graph.loss import L2LossMask import numpy as np +import urllib.request from models.cspnet import CSPNet @@ -33,6 +34,13 @@ class SinusoidalTimeEmbeddings(nn.Cell): (ops.Sin()(embeddings), ops.Cos()(embeddings))) return embeddings +def download_file(url, filename): + try: + urllib.request.urlretrieve(url, filename) + print(f"File downloaded successfully: {filename}") + except Exception as e: + print(f"Failed to download file: {e}") + def test_cspnet(): """test cspnet.py""" ms.set_seed(1234) @@ -153,7 +161,8 @@ def test_loss(): cspnet = CSPNet(num_layers=6, hidden_dim=512, num_freqs=256) cspflow = CSPFlow(cspnet) - mindspore_ckpt = load_checkpoint("./torch2ms_ckpt/ms_flow.ckpt") + download_file('https://download-mindspore.osinfra.cn/mindscience/mindchemistry/crystalflow/ms_flow.ckpt', 'ms_flow.ckpt') + mindspore_ckpt = load_checkpoint("ms_flow.ckpt") load_param_into_net(cspflow, mindspore_ckpt) loss_func_mse = L2LossMask(reduction='mean') -- Gitee