Source code for tad_dftd4.model.base

# This file is part of tad-dftd4.
#
# SPDX-Identifier: Apache-2.0
# Copyright (C) 2024 Grimme Group
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Model: Base
===========

This module contains the definition of the base dispersion model for the
evaluation of C6 coefficients.

Upon instantiation, the reference polarizabilities are calculated for the
all atoms of the molecule(s) and stored in the model class.


Example
-------
>>> import torch
>>> import tad_dftd4 as d4
>>>
>>> numbers = torch.tensor([14, 1, 1, 1, 1]) # SiH4
>>> model = d4.D4Model(numbers)
>>>
>>> # calculate Gaussian weights, optionally pass CN and partial charges
>>> gw = model.weight_references()
>>> c6 = model.get_atomic_c6(gw)
"""

from __future__ import annotations

from abc import abstractmethod

import torch
from tad_mctc import storch
from tad_mctc.convert import any_to_tensor
from tad_mctc.math import einsum
from tad_mctc.typing import Literal, Tensor, TensorLike, overload

from .. import data
from ..utils import trapzd

__all__ = ["BaseModel", "WF_DEFAULT"]


GA_DEFAULT = 3.0
GC_DEFAULT = 2.0
WF_DEFAULT = 6.0


[docs] class BaseModel(TensorLike): """ The D4 dispersion model. """ numbers: Tensor """Atomic numbers of all atoms in the system.""" ga: float """ Maximum charge scaling height for partial charge extrapolation. :default: :data:`.GA_DEFAULT` """ gc: float """ Charge scaling steepness for partial charge extrapolation. :default: :data:`.GC_DEFAULT` """ wf: Tensor """ Weighting factor for coordination number interpolation. :default: ``None`` (model-dependent, set upon instantiation) """ ref_charges: Literal["eeq", "gfn2"] """ Reference charges to use for the model. :default: ``"eeq"`` """ rc6: Tensor """ Reference C6 coefficients of all atoms. :default: ``None`` (calculated upon instantiation) """ __slots__ = ("numbers", "ga", "gc", "wf", "ref_charges", "rc6") def __init__( self, numbers: Tensor, ga: float = GA_DEFAULT, gc: float = GC_DEFAULT, wf: Tensor | float | None = None, ref_charges: Literal["eeq", "gfn2"] = "eeq", rc6: Tensor | None = None, device: torch.device | None = None, dtype: torch.dtype | None = None, ) -> None: """ Instantiate `D4Model`. Parameters ---------- numbers : Tensor Atomic numbers of all atoms in the system. ga : float, optional Maximum charge scaling height for partial charge extrapolation. Defaults to `GA_DEFAULT`. gc : float, optional Charge scaling steepness for partial charge extrapolation. Defaults to `GC_DEFAULT`. wf : float, optional Weighting factor for coordination number interpolation. Defaults to `WF_DEFAULT`. ref_charges : Literal["eeq", "gfn2"], optional Reference charges to use for the model. Defaults to `"eeq"`. rc6 : Tensor | None, optional Reference C6 coefficients of all atoms. Defaults to `None`. device : torch.device | None, optional Pytorch device for calculations. Defaults to `None`. dtype : torch.dtype | None, optional Pytorch dtype for calculations. Defaults to `None`. """ super().__init__(device, dtype) self.numbers = numbers self.ga = ga self.gc = gc self.ref_charges = ref_charges self.wf = self._get_wf() if wf is None else any_to_tensor(wf, **self.dd) self.rc6 = self._get_refc6() if rc6 is None else rc6 #################### # Abstract methods # #################### @abstractmethod def _get_wf(self) -> Tensor: """ Get the weighting factor for the Gaussian weights. Returns ------- Tensor Weighting factor for the Gaussian weights. """
[docs] @abstractmethod def get_atomic_c6(self, gw: Tensor) -> Tensor: """ Calculate atomic C6 dispersion coefficients. Parameters ---------- gw : Tensor Weights for the atomic reference systems of shape `(..., nat, nref)`. Returns ------- Tensor C6 coefficients for all atom pairs of shape `(..., nat, nat)`. """
@overload @abstractmethod def weight_references( self, cn: Tensor | None = None, q: Tensor | None = None, *, with_dgwdq: Literal[False] = ..., with_dgwdcn: Literal[False] = ..., ) -> Tensor: ... @overload @abstractmethod def weight_references( self, cn: Tensor | None = None, q: Tensor | None = None, *, with_dgwdq: Literal[True], with_dgwdcn: Literal[False] = ..., ) -> tuple[Tensor, Tensor]: ... @overload @abstractmethod def weight_references( self, cn: Tensor | None = None, q: Tensor | None = None, *, with_dgwdq: Literal[False] = ..., with_dgwdcn: Literal[True], ) -> tuple[Tensor, Tensor]: ... @overload @abstractmethod def weight_references( self, cn: Tensor | None = None, q: Tensor | None = None, *, with_dgwdq: Literal[True], with_dgwdcn: Literal[True], ) -> tuple[Tensor, Tensor, Tensor]: ...
[docs] @abstractmethod def weight_references( self, cn: Tensor | None = None, q: Tensor | None = None, *, with_dgwdq: bool = False, with_dgwdcn: bool = False, ) -> Tensor | tuple[Tensor, Tensor] | tuple[Tensor, Tensor, Tensor]: """ Calculate the weights of the reference system (shape: ``(..., nat, nref)``). Parameters ---------- cn : Tensor | None, optional Coordination number of every atom. Defaults to `None` (0). q : Tensor | None, optional Partial charge of every atom. Defaults to `None` (0). with_dgwdq : bool, optional Whether to also calculate the derivative of the weights with respect to the partial charges. Defaults to `False`. with_dgwdcn : bool, optional Whether to also calculate the derivative of the weights with respect to the coordination numbers. Defaults to `False`. Returns ------- Tensor | tuple[Tensor, Tensor] | tuple[Tensor, Tensor, Tensor] Weights for the atomic reference systems (shape: ``(..., nat, ref)``). If ``with_dgwdq`` is ``True``, also returns the derivative of the weights with respect to the partial charges. If ``with_dgwdcn`` is ``True``, also returns the derivative of the weights with respect to the coordination numbers. """
[docs] @abstractmethod def get_weighted_pols(self, gw: Tensor) -> Tensor: """ Calculate the weighted polarizabilities for each atom and frequency. Parameters ---------- gw : Tensor Weights for the atomic reference systems of shape ``(..., nat, nref)``. Returns ------- Tensor Weighted polarizabilities of shape ``(..., nat, 23)``. """
################## # Public methods # ##################
[docs] def get_polarizabilities(self, weights: Tensor) -> Tensor: """ Calculate static polarizabilities for all atoms. Parameters ---------- weights : Tensor Weights for the atomic reference systems of shape ``(..., nat, nref)``. Returns ------- Tensor Polarizabilities of shape ``(..., nat)``. """ # (..., n, r) * (..., n, r) -> (..., n) return einsum("...nr,...nr->...n", weights, self._get_alpha()[..., 0])
################### # Private methods # ################### def _zeta(self, gam: Tensor, qref: Tensor, qmod: Tensor) -> Tensor: """ Charge scaling function. Parameters ---------- gam : Tensor Chemical hardness. qref : Tensor Reference charges. qmod : Tensor Modified charges. Returns ------- Tensor Scaled charges. """ eps = torch.tensor(torch.finfo(self.dtype).eps, **self.dd) ga = torch.tensor(self.ga, **self.dd) scale = torch.exp(gam * (1.0 - qref / (qmod - eps))) return torch.where( qmod > 0.0, torch.exp(ga * (1.0 - scale)), torch.exp(ga), ) def _dzeta(self, gam: Tensor, qref: Tensor, qmod: Tensor) -> Tensor: """ Derivative of charge scaling function with respect to `qmod`. Parameters ---------- gam : Tensor Chemical hardness. qref : Tensor Reference charges. qmod : Tensor Modified charges. Returns ------- Tensor Derivative of charges. """ eps = torch.tensor(torch.finfo(self.dtype).eps, **self.dd) ga = torch.tensor(self.ga, **self.dd) scale = torch.exp(gam * (1.0 - qref / (qmod - eps))) zeta = torch.exp(ga * (1.0 - scale)) return torch.where( qmod > 0.0, -ga * gam * scale * zeta * storch.divide(qref, qmod**2), torch.tensor(0.0, **self.dd), ) def _get_alpha(self) -> Tensor: """ Calculate reference polarizabilities. Returns ------- Tensor Reference polarizabilities of shape `(..., nat, ref, 23)`. """ # pylint: disable=import-outside-toplevel from ..reference import d4 as reference zero = torch.tensor(0.0, **self.dd) refsys = reference.refsys.to(self.device)[self.numbers] refascale = reference.refascale.to(**self.dd)[self.numbers] refalpha = reference.refalpha.to(**self.dd)[self.numbers] refscount = reference.refscount.to(**self.dd)[self.numbers] secscale = reference.secscale.to(**self.dd) secalpha = reference.secalpha.to(**self.dd) if self.ref_charges == "eeq": # pylint: disable=import-outside-toplevel from ..reference.d4.charge_eeq import clsh as _refsq refsq = _refsq.to(**self.dd)[self.numbers] elif self.ref_charges == "gfn2": # pylint: disable=import-outside-toplevel from ..reference.d4.charge_gfn2 import refh as _refsq refsq = _refsq.to(**self.dd)[self.numbers] else: raise ValueError(f"Unknown reference charges: {self.ref_charges}") mask = refsys > 0 zeff = data.ZEFF(self.device)[refsys] gam = data.GAM(**self.dd)[refsys] * self.gc # charge scaling zeta = torch.where( mask, self._zeta(gam, zeff, refsq + zeff), zero, ) aiw = secscale[refsys] * secalpha[refsys] * zeta.unsqueeze(-1) h = refalpha - refscount.unsqueeze(-1) * aiw alpha = refascale.unsqueeze(-1) * h # (..., n, r, 23) return torch.where(alpha > 0.0, alpha, zero) def _get_refc6(self) -> Tensor: """ Calculate reference C6 dispersion coefficients. The reference C6 coefficients are not weighted by the Gaussian weights yet. Returns ------- Tensor Reference C6 coefficients of shape ``(..., nat, nat, nref, nref)``. """ # (..., n, r, 23) -> (..., n, n, r, r) return trapzd(self._get_alpha()) ############ # Printing # ############ def __str__(self) -> str: # pragma: no cover """Return a string representation of the model.""" return ( f"{self.__class__.__name__}(\n" f" ga={self.ga},\n" f" gc={self.gc},\n" f" wf={self.wf},\n" f" ref_charges={self.ref_charges},\n" f" rc6={self.rc6.shape},\n" f")" ) def __repr__(self) -> str: # pragma: no cover """Return a string representation of the model.""" return str(self)