diff --git a/MindSPONGE/src/sponge/potential/energy/lj.py b/MindSPONGE/src/sponge/potential/energy/lj.py index 0eeb5d35c6838f257b70ef0b661494f87a424794..5993bef3fcf8469942102d7fd2e5fbaab248295a 100644 --- a/MindSPONGE/src/sponge/potential/energy/lj.py +++ b/MindSPONGE/src/sponge/potential/energy/lj.py @@ -55,6 +55,7 @@ class LennardJonesEnergy(NonbondEnergy): Args: + epsilon (Union[Tensor, ndarray, List[float]]): Parameter :math:`\epsilon` for LJ potential. The shape of array is `(B, A)`, and the data type is float. @@ -84,6 +85,7 @@ class LennardJonesEnergy(NonbondEnergy): name (str): Name of the energy. Default: 'vdw' Supported Platforms: + ``Ascend`` ``GPU`` Note: @@ -120,7 +122,7 @@ class LennardJonesEnergy(NonbondEnergy): energy_unit=energy_unit, ) self._kwargs = get_arguments(locals(), kwargs) - if 'exclude_index' in self._kwargs.keys(): + if 'exclude_index' in self._kwargs: self._kwargs.pop('exclude_index') if parameters is not None: @@ -131,7 +133,7 @@ class LennardJonesEnergy(NonbondEnergy): self.units.set_units(length_unit, energy_unit) self._use_pbc = system.use_pbc - epsilon, sigma, mean_c6 = self.get_parameters(system, parameters) + epsilon, sigma, mean_c6 = self._get_parameters(system, parameters) sigma = get_ms_array(sigma, ms.float32) epsilon = get_ms_array(epsilon, ms.float32) @@ -168,7 +170,7 @@ class LennardJonesEnergy(NonbondEnergy): self.disp_corr = self._calc_disp_corr() @staticmethod - def get_parameters(system: Molecule, parameters: dict) -> Tuple[ndarray]: + def _get_parameters(system: Molecule, parameters: dict) -> Tuple[ndarray]: r"""get the force field parameters for the system ['H','HO','HS','HC','H1','H2','H3','HP','HA','H4', @@ -203,11 +205,15 @@ class LennardJonesEnergy(NonbondEnergy): type_set = list(set(type_list)) count = np.array([type_list.count(i) for i in type_set], np.int32) - sigma_set = [] - eps_set = [] - for params in itemgetter(*type_set)(vdw_params): - sigma_set.append(params[sigma_index]) - eps_set.append(params[eps_index]) + if len(type_set) == 1: + sigma_set = [vdw_params[type_set[0]][sigma_index]] + eps_set = [vdw_params[type_set[0]][eps_index]] + else: + sigma_set = [] + eps_set = [] + for params in itemgetter(*type_set)(vdw_params): + sigma_set.append(params[sigma_index]) + eps_set.append(params[eps_index]) sigma_set = np.array(sigma_set) eps_set = np.array(eps_set) @@ -264,8 +270,6 @@ class LennardJonesEnergy(NonbondEnergy): """ inv_neigh_dis = msnp.reciprocal(neighbour_distance * self.input_unit_scale) - if neighbour_mask is not None: - inv_neigh_dis = msnp.where(neighbour_mask, inv_neigh_dis, inv_neigh_dis) epsilon = self.identity(self.epsilon) sigma = self.identity(self.sigma) @@ -296,6 +300,11 @@ class LennardJonesEnergy(NonbondEnergy): energy = ene_acoeff - ene_bcoeff # (B,A) + energy = ms.ops.select( + neighbour_mask, + energy, + msnp.zeros_like(energy) + ) energy = F.reduce_sum(energy, -1) # (B,1) energy = func.keepdims_sum(energy, -1) * 0.5