diff --git a/MindChemistry/applications/crystalflow/test_crystalflow.py b/MindChemistry/applications/crystalflow/test_crystalflow.py index ffce4881fb9075c96dc7b1eb0afdd56a5a0a8779..18e0a919913db4ffdaa32bf8e77f5b8936385dbc 100644 --- a/MindChemistry/applications/crystalflow/test_crystalflow.py +++ b/MindChemistry/applications/crystalflow/test_crystalflow.py @@ -7,6 +7,7 @@ import mindspore.numpy as mnp from mindspore import nn, ops, Tensor, mint, load_checkpoint, load_param_into_net from mindchemistry.graph.loss import L2LossMask import numpy as np +import urllib.request from models.cspnet import CSPNet @@ -33,6 +34,13 @@ class SinusoidalTimeEmbeddings(nn.Cell): (ops.Sin()(embeddings), ops.Cos()(embeddings))) return embeddings +def download_file(url, filename): + try: + urllib.request.urlretrieve(url, filename) + print(f"File downloaded successfully: {filename}") + except Exception as e: + print(f"Failed to download file: {e}") + def test_cspnet(): """test cspnet.py""" ms.set_seed(1234) @@ -153,7 +161,8 @@ def test_loss(): cspnet = CSPNet(num_layers=6, hidden_dim=512, num_freqs=256) cspflow = CSPFlow(cspnet) - mindspore_ckpt = load_checkpoint("./torch2ms_ckpt/ms_flow.ckpt") + download_file('https://download-mindspore.osinfra.cn/mindscience/mindchemistry/crystalflow/ms_flow.ckpt', 'ms_flow.ckpt') + mindspore_ckpt = load_checkpoint("ms_flow.ckpt") load_param_into_net(cspflow, mindspore_ckpt) loss_func_mse = L2LossMask(reduction='mean')