import copy
from typing import Any
import torch
from metatomic.torch import System
from metatrain.utils.abc import ModelInterface
from metatrain.utils.data import DatasetInfo
from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists
from .architectures import ArchitectureTests
[docs]
class TorchscriptTests(ArchitectureTests):
"""Test suite to check that architectures can be jit compiled with
TorchScript."""
float_hypers: list[str] = []
"""List of hyperparameter keys (dot-separated for nested keys)
that are floats. A test will set these to integers to test that
TorchScript compilation works in that case."""
[docs]
def jit_compile(self, model: ModelInterface) -> torch.jit.ScriptModule:
"""JIT compiles the given model.
The default is to simply torch.jit.script the model, but
architectures can override this method if some special
compilation procedure is needed.
:param model: Model to compile.
:return: JIT compiled model.
"""
return torch.jit.script(model)
[docs]
def test_torchscript(
self, model_hypers: dict, dataset_info: DatasetInfo, dtype: torch.dtype
) -> None:
"""Tests that the model can be jitted.
If this test fails it probably means that there is some
code in the model that is not compatible with TorchScript.
The exception raised by the test should indicate where
the problem is.
:param model_hypers: Hyperparameters to initialize the model.
:param dataset_info: Dataset to initialize the model.
:param dtype: Dtype to use for the model and inputs.
"""
model = self.model_cls(model_hypers, dataset_info)
system = System(
types=torch.tensor([6, 1, 8, 7]),
positions=torch.tensor(
[[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, 2.0], [0.0, 0.0, 3.0]],
dtype=dtype,
),
cell=torch.zeros(3, 3, dtype=dtype),
pbc=torch.tensor([False, False, False]),
)
system = get_system_with_neighbor_lists(
system, model.requested_neighbor_lists()
)
model = model.to(dtype)
model = self.jit_compile(model)
model(
[system],
model.outputs,
)
[docs]
def test_torchscript_dtypechange(
self, model_hypers: dict, dataset_info: DatasetInfo, dtype: torch.dtype
) -> None:
"""Tests that the model can be changed to a different dtype after jitting.
If this test fails and ``test_torchscript`` passes, it probably means that
your model is overwriting the `to()` method, which does not work in
TorchScript. If ``test_torchscript`` also fails, then one should fix
that one first.
:param model_hypers: Hyperparameters to initialize the model.
:param dataset_info: Dataset to initialize the model.
:param dtype: Dtype to change the model to.
"""
model = self.model_cls(model_hypers, dataset_info)
system = System(
types=torch.tensor([6, 1, 8, 7]),
positions=torch.tensor(
[[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, 2.0], [0.0, 0.0, 3.0]],
dtype=dtype,
),
cell=torch.zeros(3, 3, dtype=dtype),
pbc=torch.tensor([False, False, False]),
)
system = get_system_with_neighbor_lists(
system, model.requested_neighbor_lists()
)
model = self.jit_compile(model)
model = model.to(dtype)
model(
[system],
model.outputs,
)
[docs]
def test_torchscript_spherical(
self, model_hypers: dict, dataset_info_spherical: DatasetInfo
) -> None:
"""Tests that there is no problem with jitting with spherical targets.
:param model_hypers: Hyperparameters to initialize the model.
:param dataset_info_spherical: Dataset to initialize the model
(containing spherical targets).
"""
self.test_torchscript(
model_hypers=model_hypers,
dataset_info=dataset_info_spherical,
dtype=torch.float32,
)
[docs]
def test_torchscript_save_load(
self, tmpdir: Any, model_hypers: dict, dataset_info: DatasetInfo
) -> None:
"""Tests that the model can be jitted, saved and loaded.
:param tmpdir: Temporary directory where to save the
model.
:param model_hypers: Hyperparameters to initialize the model.
:param dataset_info: Dataset to initialize the model.
"""
model = self.model_cls(model_hypers, dataset_info)
with tmpdir.as_cwd():
torch.jit.save(self.jit_compile(model), "model.pt")
torch.jit.load("model.pt")
[docs]
def test_torchscript_integers(
self,
model_hypers: dict,
dataset_info: DatasetInfo,
) -> None:
"""Tests that the model can be jitted when some float
parameters are instead supplied as integers.
:param model_hypers: Hyperparameters to initialize the model.
:param dataset_info: Dataset to initialize the model.
"""
new_hypers = copy.deepcopy(model_hypers)
for hyper in self.float_hypers:
nested_key = hyper.split(".")
sub_dict = new_hypers
for key in nested_key[:-1]:
sub_dict = sub_dict[key]
sub_dict[nested_key[-1]] = int(sub_dict[nested_key[-1]])
self.test_torchscript(
model_hypers=new_hypers, dataset_info=dataset_info, dtype=torch.float32
)