Source code for metatrain.utils.testing.torchscript

import copy
from typing import Any

import torch
from metatomic.torch import System

from metatrain.utils.abc import ModelInterface
from metatrain.utils.data import DatasetInfo
from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists

from .architectures import ArchitectureTests


[docs] class TorchscriptTests(ArchitectureTests): """Test suite to check that architectures can be jit compiled with TorchScript.""" float_hypers: list[str] = [] """List of hyperparameter keys (dot-separated for nested keys) that are floats. A test will set these to integers to test that TorchScript compilation works in that case."""
[docs] def jit_compile(self, model: ModelInterface) -> torch.jit.ScriptModule: """JIT compiles the given model. The default is to simply torch.jit.script the model, but architectures can override this method if some special compilation procedure is needed. :param model: Model to compile. :return: JIT compiled model. """ return torch.jit.script(model)
[docs] def test_torchscript( self, model_hypers: dict, dataset_info: DatasetInfo, dtype: torch.dtype ) -> None: """Tests that the model can be jitted. If this test fails it probably means that there is some code in the model that is not compatible with TorchScript. The exception raised by the test should indicate where the problem is. :param model_hypers: Hyperparameters to initialize the model. :param dataset_info: Dataset to initialize the model. :param dtype: Dtype to use for the model and inputs. """ model = self.model_cls(model_hypers, dataset_info) system = System( types=torch.tensor([6, 1, 8, 7]), positions=torch.tensor( [[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, 2.0], [0.0, 0.0, 3.0]], dtype=dtype, ), cell=torch.zeros(3, 3, dtype=dtype), pbc=torch.tensor([False, False, False]), ) system = get_system_with_neighbor_lists( system, model.requested_neighbor_lists() ) model = model.to(dtype) model = self.jit_compile(model) model( [system], model.outputs, )
[docs] def test_torchscript_dtypechange( self, model_hypers: dict, dataset_info: DatasetInfo, dtype: torch.dtype ) -> None: """Tests that the model can be changed to a different dtype after jitting. If this test fails and ``test_torchscript`` passes, it probably means that your model is overwriting the `to()` method, which does not work in TorchScript. If ``test_torchscript`` also fails, then one should fix that one first. :param model_hypers: Hyperparameters to initialize the model. :param dataset_info: Dataset to initialize the model. :param dtype: Dtype to change the model to. """ model = self.model_cls(model_hypers, dataset_info) system = System( types=torch.tensor([6, 1, 8, 7]), positions=torch.tensor( [[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, 2.0], [0.0, 0.0, 3.0]], dtype=dtype, ), cell=torch.zeros(3, 3, dtype=dtype), pbc=torch.tensor([False, False, False]), ) system = get_system_with_neighbor_lists( system, model.requested_neighbor_lists() ) model = self.jit_compile(model) model = model.to(dtype) model( [system], model.outputs, )
[docs] def test_torchscript_spherical( self, model_hypers: dict, dataset_info_spherical: DatasetInfo ) -> None: """Tests that there is no problem with jitting with spherical targets. :param model_hypers: Hyperparameters to initialize the model. :param dataset_info_spherical: Dataset to initialize the model (containing spherical targets). """ self.test_torchscript( model_hypers=model_hypers, dataset_info=dataset_info_spherical, dtype=torch.float32, )
[docs] def test_torchscript_save_load( self, tmpdir: Any, model_hypers: dict, dataset_info: DatasetInfo ) -> None: """Tests that the model can be jitted, saved and loaded. :param tmpdir: Temporary directory where to save the model. :param model_hypers: Hyperparameters to initialize the model. :param dataset_info: Dataset to initialize the model. """ model = self.model_cls(model_hypers, dataset_info) with tmpdir.as_cwd(): torch.jit.save(self.jit_compile(model), "model.pt") torch.jit.load("model.pt")
[docs] def test_torchscript_integers( self, model_hypers: dict, dataset_info: DatasetInfo, ) -> None: """Tests that the model can be jitted when some float parameters are instead supplied as integers. :param model_hypers: Hyperparameters to initialize the model. :param dataset_info: Dataset to initialize the model. """ new_hypers = copy.deepcopy(model_hypers) for hyper in self.float_hypers: nested_key = hyper.split(".") sub_dict = new_hypers for key in nested_key[:-1]: sub_dict = sub_dict[key] sub_dict[nested_key[-1]] = int(sub_dict[nested_key[-1]]) self.test_torchscript( model_hypers=new_hypers, dataset_info=dataset_info, dtype=torch.float32 )