Source code for tad_dftd4.dispersion.threebody

# This file is part of tad-dftd4.
#
# SPDX-Identifier: Apache-2.0
# Copyright (C) 2024 Grimme Group
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""
Dispersion: 3-body terms
========================

Implementation of the 3-body Axilrod-Teller-Muto dispersion terms.

.. math::

    E_\text{disp}^{(3), \text{ATM}} &=
    \sum_\text{ABC} E^{\text{ABC}} f_\text{damp}\left(\overline{R}_\text{ABC}\right) \\
    E^{\text{ABC}} &=
    \dfrac{C^{\text{ABC}}_9
    \left(3 \cos\theta_\text{A} \cos\theta_\text{B} \cos\theta_\text{C} + 1 \right)}
    {\left(r_\text{AB} r_\text{BC} r_\text{AC} \right)^3} \\
    f_\text{damp} &=
    \dfrac{1}{1+ 6 \left(\overline{R}_\text{ABC}\right)^{-16}}
"""

from __future__ import annotations

from abc import ABC, abstractmethod

import torch
from tad_mctc import storch
from tad_mctc.batch import real_pairs, real_triples
from tad_mctc.convert import any_to_tensor
from tad_mctc.typing import DD, Tensor

from .. import defaults
from ..cutoff import Cutoff
from ..damping import Damping, Param, ZeroDamping
from ..model import ModelInst
from .base import DispTerm

__all__ = ["ATM", "get_atm_dispersion"]


[docs] def get_atm_dispersion( numbers: Tensor, positions: Tensor, c9: Tensor, radii: Tensor, cutoff: Tensor, damping_function: Damping = ZeroDamping(), s9: Tensor | float | int = defaults.S9, alp: Tensor | float | int = defaults.ALP, ) -> Tensor: """ Axilrod-Teller-Muto dispersion term. Parameters ---------- numbers : Tensor Atomic numbers for all atoms in the system of shape ``(..., nat)``. positions : Tensor Cartesian coordinates of all atoms (shape: ``(..., nat, 3)``). c9 : Tensor Atomic C9 dispersion coefficients. radii : Tensor Pairwise critical radii for all atom pairs (shape: ``(..., nat, nat)``). cutoff : Tensor Real-space cutoff. damping_function : Damping, optional Damping function to use. Defaults to :func:`zero_damping`. s9 : Tensor, optional Scaling for dispersion coefficients. Defaults to ``1.0``. alp : Tensor, optional Exponent of zero damping function. Defaults to ``14.0``. Returns ------- Tensor Atom-resolved ATM dispersion energy. """ dd: DD = {"device": positions.device, "dtype": positions.dtype} s9 = any_to_tensor(s9, **dd) alp = any_to_tensor(alp, **dd) cutoff2 = cutoff * cutoff mask_pairs = real_pairs(numbers, mask_diagonal=True) mask_triples = real_triples(numbers, mask_diagonal=True, mask_self=True) # filler values for masks eps = torch.tensor(torch.finfo(positions.dtype).eps, **dd) zero = torch.tensor(0.0, **dd) one = torch.tensor(1.0, **dd) r0ij = radii.unsqueeze(-1) r0ik = radii.unsqueeze(-2) r0jk = radii.unsqueeze(-3) r0 = r0ij * r0ik * r0jk # actually faster than other alternatives # very slow: (pos.unsqueeze(-2) - pos.unsqueeze(-3)).pow(2).sum(-1) distances = torch.pow( torch.where( mask_pairs, storch.cdist(positions, positions, p=2), eps, ), 2.0, ) r2ij = distances.unsqueeze(-1) r2ik = distances.unsqueeze(-2) r2jk = distances.unsqueeze(-3) r2 = r2ij * r2ik * r2jk r1 = torch.sqrt(r2) # add epsilon to avoid zero division later r3 = torch.where(mask_triples, r1 * r2, eps) r5 = torch.where(mask_triples, r2 * r3, eps) # to fix the previous mask, we mask again (not strictly necessary because # `ang` is also masked and we later multiply with `ang`) fdamp = torch.where( mask_triples, damping_function( r0, # dividing by tiny numbers leads to huge numbers, which result in # NaN's upon exponentiation in the subsequent step torch.where(mask_triples, r1, one), order=9, alp=alp / 3.0, only_damping=True, ), zero, ) s = torch.where( mask_triples, (r2ij + r2jk - r2ik) * (r2ij - r2jk + r2ik) * (-r2ij + r2jk + r2ik), zero, ) ang = torch.where( mask_triples * (r2ij <= cutoff2) * (r2jk <= cutoff2) * (r2jk <= cutoff2), 0.375 * s / r5 + 1.0 / r3, torch.tensor(0.0, **dd), ) energy = ang * fdamp * s9 * c9 return torch.sum(energy, dim=(-2, -1)) / 6.0
class ThreeBodyTerm(DispTerm): """Base class for three-body dispersion terms.""" def __init__( self, *, damping_fn: Damping = ZeroDamping(), charge_dependent: bool = False, ): super().__init__(damping_fn, charge_dependent)
[docs] class ATM(ThreeBodyTerm, ABC): r""" D4's Axilrod-Teller-Muto dispersion term. - C9 coefficients are approximated from C6 coefficients - Becke--Johnson cutoff radii (`a1 * sqrt(3.0 * r4r2) + a2`) - zero damping .. math:: E_\text{disp}^{(3), \text{ATM}} &= \sum_\text{ABC} E^{\text{ABC}} f_\text{damp}\left(\overline{R}_\text{ABC}\right) \\ E^{\text{ABC}} &= \dfrac{C^{\text{ABC}}_9 \left(3 \cos\theta_\text{A} \cos\theta_\text{B} \cos\theta_\text{C} + 1 \right)} {\left(r_\text{AB} r_\text{BC} r_\text{AC} \right)^3} \\ f_\text{damp} &= \dfrac{1}{1+ 6 \left(\overline{R}_\text{ABC}\right)^{-16}} """
[docs] @abstractmethod def get_c9(self, model: ModelInst, cn: Tensor, q: Tensor | None) -> Tensor: """Approximate or exact C9 coefficients."""
[docs] @abstractmethod def get_radii( self, param: Param, r4r2: Tensor, rvdw: Tensor, ) -> Tensor: """Compute critical radii used in damping function."""
[docs] def calculate( self, numbers: Tensor, positions: Tensor, param: Param, cn: Tensor, model: ModelInst, q: Tensor | None, r4r2: Tensor, rvdw: Tensor, cutoff: Cutoff, ): # ATM-specific C9 coefficients and radii c9 = self.get_c9(model, cn, q) radii = self.get_radii(param, r4r2, rvdw) return get_atm_dispersion( numbers, positions, c9, radii, damping_function=self.damping_fn, cutoff=cutoff.disp3, s9=param.get("s9", torch.tensor(defaults.S9, **self.dd)), alp=param.get("alp", torch.tensor(defaults.ALP, **self.dd)), )
# Radii mixins class RadiiBJMixin: """Becke--Johnson critical radii: a1 * sqrt(3 * r4r2_i * r4r2_j) + a2.""" def get_radii( self, param: Param, r4r2: Tensor, rvdw: Tensor, ) -> Tensor: dd = self.dd # type: ignore a1 = param.get("a1", torch.tensor(defaults.A1, **dd)) a2 = param.get("a2", torch.tensor(defaults.A2, **dd)) return ( a1 * storch.sqrt(3.0 * r4r2.unsqueeze(-1) * r4r2.unsqueeze(-2)) + a2 ) class RadiiVDWMixin: """Scaled pair-wise VDW radii: rs9 * rvdw.""" def get_radii( self, param: Param, r4r2: Tensor, rvdw: Tensor, ) -> Tensor: dd = self.dd # type: ignore rs9 = param.get("rs9", torch.tensor(defaults.RS9, **dd)) return rs9 * rvdw # C9 mixins class C9ExactMixin: """Exact C9 via Casimir–Polder integration of polarizabilities.""" charge_dependent: bool """Whether the C9 coefficients depend on atomic charges.""" def get_c9(self, model: ModelInst, cn: Tensor, q: Tensor | None) -> Tensor: r""" Approximate C9 coefficients from C6 coefficients. .. math:: C_9 = \sqrt{|C_{6}^{AB} C_{6}^{AC} C_{6}^{BC}|} """ # pylint: disable=import-outside-toplevel from ..utils import trapzd_atm weights = model.weight_references( cn, q if self.charge_dependent else None ) aiw = model.get_weighted_pols(weights) # C9_ABC = integral (aiw_A * aiw_B * aiw_C) aiwi = aiw.unsqueeze(-2).unsqueeze(-2) aiwj = aiw.unsqueeze(-3).unsqueeze(-2) aiwk = aiw.unsqueeze(-3).unsqueeze(-3) return trapzd_atm(aiwi * aiwj * aiwk) class C9ApproxMixin: """Approximate C9 coefficients from C6 coefficients.""" charge_dependent: bool """Whether the C9 coefficients depend on atomic charges.""" def get_c9(self, model: ModelInst, cn: Tensor, q: Tensor | None) -> Tensor: weights = model.weight_references( cn, q if self.charge_dependent else None ) c6 = model.get_atomic_c6(weights) # C9_ABC = sqrt(|C6_AB * C6_AC * C6_BC|) return storch.sqrt( torch.abs(c6.unsqueeze(-1) * c6.unsqueeze(-2) * c6.unsqueeze(-3)), )