# This file is part of tad-dftd4.
#
# SPDX-Identifier: Apache-2.0
# Copyright (C) 2024 Grimme Group
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Dispersion Methods: Base
========================
Base classes and interfaces for dispersion terms.
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import ClassVar
import torch
from tad_mctc.convert import any_to_tensor
from tad_mctc.typing import DD, Any, CNFunc, Tensor, TensorLike
from ..cutoff import Cutoff
from ..damping import Damping, Param
from ..model import ModelInst, ModelKey
[docs]
class DispTerm(TensorLike, ABC):
"""
Base class for all dispersion terms.
Parameters
----------
damping_fn : Damping
Damping function to be used for the dispersion term.
charge_dependent : bool
Whether the term is charge-dependent, i.e., requires atomic charges
for the calculation.
device : torch.device, optional
Device on which the term is calculated.
dtype : torch.dtype, optional
Data type of the term's tensors.
"""
damping_fn: Damping
"""Damping function to be used for the dispersion term."""
charge_dependent: bool
"""
Whether the term is charge-dependent, i.e., requires atomic charges for
the calculation.
"""
__slots__ = ("damping_fn", "charge_dependent")
def __init__(
self,
damping_fn: Damping,
charge_dependent: bool,
*,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
):
super().__init__(device=device, dtype=dtype)
self.damping_fn = damping_fn
self.charge_dependent = charge_dependent
def __eq__(self, other: Any):
if self.__class__ is not other.__class__:
return False
return (
self.damping_fn == other.damping_fn
and self.charge_dependent == other.charge_dependent
)
[docs]
@abstractmethod
def calculate(
self,
numbers: Tensor,
positions: Tensor,
param: Param,
cn: Tensor,
model: ModelInst,
q: Tensor | None,
r4r2: Tensor,
rvdw: Tensor,
cutoff: Cutoff,
) -> Tensor:
"""Evaluate the energy for the dispersion term."""
[docs]
class Disp(TensorLike):
"""Base class for DFT-D dispersion calculations."""
terms: list[DispTerm]
"""List of dispersion terms for which the calculation is performed."""
cn_fn: CNFunc
"""Coordination number."""
cn_fn_kwargs: dict[str, Any]
"""Keyword arguments for the coordination number function."""
_model_key: ModelKey
"""Key for the DFT-D model, e.g., 'd4'."""
_model_kwargs: dict[str, Any]
"""Keyword arguments for the DFT-D model."""
_model_instance: ModelInst | None
"""Instance of the DFT-D model, if provided."""
__slots__ = (
"terms",
"cn_fn",
"cn_fn_kwargs",
"_model_key",
"_model_instance",
"_model_kwargs",
)
_ALLOWED_MODELS = ("d3", "d4", "d4s", "d5")
"""Allowed DFT-D models for the calculation."""
TERMS: ClassVar[list[tuple[type[DispTerm], dict[str, Any] | None]]] = []
"""List of dispersion terms to be registered in the constructor."""
def __init__(
self,
model: ModelKey | ModelInst = "d4",
model_kwargs: dict[str, Any] | None = None,
cn_fn: CNFunc | None = None,
cn_fn_kwargs: dict[str, Any] | None = None,
*,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
):
super().__init__(device=device, dtype=dtype)
if isinstance(model, str):
key = model.casefold()
if key not in self._ALLOWED_MODELS:
raise ValueError(
f"Unknown model '{key}'. "
f"Please use {', '.join(self._ALLOWED_MODELS)}."
)
self._model_key = key
self._model_instance = None
self._model_kwargs = model_kwargs or {}
else:
# sentinel, never queried
self._model_key = "instance" # type: ignore[assignment]
self._model_instance = model
self._model_kwargs = {}
if model_kwargs:
from warnings import warn
warn(
"`model` is an instance - `model_kwargs` were ignored.",
RuntimeWarning,
stacklevel=2,
)
if cn_fn is None:
name = (
self._model_instance.__class__.__name__
if self._model_instance is not None
else self._model_key
)
if "3" in name:
# pylint: disable=import-outside-toplevel
from tad_mctc.ncoord import cn_d3
cn_fn = cn_d3
elif "4" in name:
# pylint: disable=import-outside-toplevel
from tad_mctc.ncoord import cn_d4
cn_fn = cn_d4
elif "5" in name:
# pylint: disable=import-outside-toplevel
from tad_mctc.ncoord import cn_d3
cn_fn = cn_d3
else: # pragma: no cover
raise ValueError(
f"Unknown model '{self._model_key}'. "
"Please use 'd3', 'd4', 'd4s', or 'd5'."
)
self.cn_fn = cn_fn
self.cn_fn_kwargs = cn_fn_kwargs if cn_fn_kwargs is not None else {}
self.terms: list[DispTerm] = []
for term_cls, kw in self.TERMS:
kw = {} if kw is None else kw.copy()
self.register(term_cls(**kw))
def register(self, term: DispTerm) -> None:
self.terms.append(term)
def deregister(self, term: DispTerm) -> None:
self.terms.remove(term)
[docs]
def get_model(self, numbers: Tensor) -> ModelInst:
"""
Get the DFT-D4 model for the given atomic numbers.
Parameters
----------
numbers : Tensor
Atomic numbers of the atoms in the system.
Returns
-------
ModelInst
The DFT-D4 model initialized with the atomic numbers.
"""
if self._model_instance is not None:
return self._model_instance
if self._model_key.casefold() == "d3":
# pylint: disable=import-outside-toplevel
from tad_dftd4.model.d3 import D3Model
return D3Model(numbers=numbers, **self._model_kwargs, **self.dd)
if self._model_key.casefold() == "d4":
# pylint: disable=import-outside-toplevel
from tad_dftd4.model.d4 import D4Model
return D4Model(numbers=numbers, **self._model_kwargs, **self.dd)
if self._model_key.casefold() == "d4s":
# pylint: disable=import-outside-toplevel
from tad_dftd4.model.d4s import D4SModel
return D4SModel(numbers=numbers, **self._model_kwargs, **self.dd)
raise ValueError(
f"Unknown model '{self._model_key}'. "
"Please use 'd3', 'd4', 'd4s', or 'd5'."
)
# Radii
def get_rcov(self, numbers: Tensor) -> Tensor:
# pylint: disable=import-outside-toplevel
from tad_mctc.data import COV_D3
return COV_D3(**self.dd)[numbers]
def get_r4r2(self, numbers: Tensor) -> Tensor:
# pylint: disable=import-outside-toplevel
from tad_dftd4.data import R4R2
return R4R2(**self.dd)[numbers]
def get_rvdw(self, numbers: Tensor) -> Tensor:
# pylint: disable=import-outside-toplevel
from tad_mctc.data import VDW_PAIRWISE
return VDW_PAIRWISE(**self.dd)[
numbers.unsqueeze(-1), numbers.unsqueeze(-2)
]
# Calculation
[docs]
def calculate(
self,
numbers: Tensor,
positions: Tensor,
charge: Tensor | float | int,
param: Param,
*,
cutoff: Cutoff | None = None,
q: Tensor | None = None,
rcov: Tensor | None = None,
r4r2: Tensor | None = None,
rvdw: Tensor | None = None,
):
"""
Evaluate DFT-D4 dispersion energy for a (batch of) molecule(s).
Parameters
----------
numbers : Tensor
Atomic numbers for all atoms in the system of shape ``(..., nat)``.
positions : Tensor
Cartesian coordinates of all atoms (shape: ``(..., nat, 3)``).
charge : Tensor
Total charge of the system.
param : Param
DFT-D4 damping parameters.
model : D4Model | D4SModel | None, optional
The DFT-D4 dispersion model for the evaluation of the C6
coefficients. Defaults to ``None``, which creates
:class:`tad_dftd4.model.d4.D4Model`.
rcov : Tensor | None, optional
Covalent radii of the atoms in the system. Accepted for backwards
compatibility, but the coordination number function now handles
radii internally.
r4r2 : Tensor | None, optional
r⁴ over r² expectation values of the atoms in the system. Defaults
to ``None``, i.e., default values are used.
q : Tensor | None, optional
Atomic partial charges. Defaults to ``None``, i.e., EEQ charges are
calculated using the total ``charge``.
cutoff : Cutoff | None, optional
Collection of real-space cutoffs. Defaults to ``None``, i.e.,
:class:`tad_dftd4.cutoff.Cutoff` is initialized with its defaults.
counting_function : CountingFunction, optional
Counting function used for the DFT-D4 coordination number. Defaults
to the error function counting function
:func:`tad_mctc.ncoord.count.erf_count`.
damping_function : DampingFunction, optional
Damping function to evaluate distance dependent contributions.
Defaults to the Becke-Johnson rational damping function
:func:`tad_dftd4.damping.rational.rational_damping`.
Returns
-------
Tensor
Atom-resolved DFT-D4 dispersion energy.
Raises
------
ValueError
Shape inconsistencies between ``numbers``, ``positions``, ``r4r2``.
RuntimeError
If atomic charges are explicitly provided, but no term requires
them.
"""
dd: DD = {"device": positions.device, "dtype": positions.dtype}
charge = any_to_tensor(charge, **dd)
if numbers.shape != positions.shape[:-1]:
raise ValueError(
f"Shape of positions ({positions.shape}) is not consistent "
f"with atomic numbers ({numbers.shape}).",
)
if cutoff is None:
cutoff = Cutoff(**dd)
model = self.get_model(numbers=numbers)
# 2) radii defaults
if rcov is not None:
if numbers.shape != rcov.shape:
raise ValueError(
f"Shape of covalent radii ({rcov.shape}) is not consistent "
f"with atomic numbers ({numbers.shape}).",
)
if r4r2 is None:
r4r2 = self.get_r4r2(numbers)
if numbers.shape != r4r2.shape:
raise ValueError(
f"Shape of expectation values r4r2 ({r4r2.shape}) is not "
f"consistent with atomic numbers ({numbers.shape}).",
)
if rvdw is None:
rvdw = self.get_rvdw(numbers)
if numbers.shape != rvdw.shape[:-1]:
raise ValueError(
f"Shape of van der Waals radii ({rvdw.shape}) is not "
f"consistent with atomic numbers ({numbers.shape}).",
)
# 3) Coordination numbers
cn = self.cn_fn(numbers, positions, **self.cn_fn_kwargs)
# 4) charges if any term demands them
is_c_dep = any(t.charge_dependent for t in self.terms)
if q is not None and is_c_dep is False:
raise RuntimeError(
"Atomic charges are explicitly provided, but no term "
"requires them. Please remove the `q` argument or "
"provide a term that requires atomic charges.",
)
if q is None and is_c_dep is True:
# pylint: disable=import-outside-toplevel
from tad_multicharge import get_eeq_charges
q = get_eeq_charges(
numbers, positions, charge, cutoff=cutoff.cn_eeq
)
if q is not None:
if numbers.shape != q.shape:
raise ValueError(
f"Shape of atomic charges ({q.shape}) is not consistent "
f"with atomic numbers ({numbers.shape}).",
)
# 5) delegate
energy = torch.zeros_like(numbers, **dd)
for term in self.terms:
energy = energy + term.calculate(
numbers=numbers,
positions=positions,
param=param,
cn=cn,
model=model,
q=q,
r4r2=r4r2,
rvdw=rvdw,
cutoff=cutoff,
)
return energy