MLIP Base Classes

class metatrain.utils.mlip.MLIPModel(hypers: Dict[str, Any], dataset_info: DatasetInfo)[source]

Bases: ModelInterface

Base model class for MLIP-only architectures.

This class is a base class for MLIP-only models that predict only energies and forces. It provides:

  • Common forward pass logic with neighbor list processing

  • Automatic integration of CompositionModel for composition-based energy corrections

  • Automatic integration of Scaler for target scaling

  • Checkpoint saving/loading (get_checkpoint, load_checkpoint)

  • Model export to metatomic format (export)

  • Support for restarting training (restart)

Derived classes only need to implement the compute_energy() method.

The base class automatically handles additive models and scaling at evaluation time, so the derived class only needs to compute the “raw” energy predictions.

Example:

from metatrain.utils.mlip import MLIPModel


class MyMLIPModel(MLIPModel):
    def compute_energy(
        self,
        edge_vectors: torch.Tensor,
        species: torch.Tensor,
        centers: torch.Tensor,
        neighbors: torch.Tensor,
        system_indices: torch.Tensor,
    ) -> torch.Tensor:
        # Implement your energy computation here
        ...
        return energies  # shape: (N_systems,)
Parameters:
  • hypers (Dict[str, Any]) – Model hyperparameters.

  • dataset_info (DatasetInfo) – Information about the dataset, including atomic types and targets.

forward(systems: List[System], outputs: Dict[str, ModelOutput], selected_atoms: Labels | None = None) Dict[str, TensorMap][source]

Execute the model for the given systems, computing the requested outputs.

Parameters:
  • systems (List[System]) – List of systems to evaluate the model on.

  • outputs (Dict[str, ModelOutput]) – Dictionary of outputs that the model should compute.

  • selected_atoms (Labels | None) – Optional Labels specifying a subset of atoms to compute the outputs for. If None, the outputs are computed for all atoms in each system.

Returns:

A dictionary mapping each requested output name to the corresponding TensorMap containing the computed values.

Return type:

Dict[str, TensorMap]

See also

metatomic.torch.ModelInterface for more explanation about the different arguments.

request_neighbor_list(cutoff: float) None[source]
Parameters:

cutoff (float)

Return type:

None

abstractmethod compute_energy(edge_vectors: Tensor, species: Tensor, centers: Tensor, neighbors: Tensor, system_indices: Tensor) Tensor[source]

Compute the total energy given the edge vectors and other information.

Parameters:
  • edge_vectors (Tensor) – Tensor of shape (N_edges, 3) containing the vectors between neighboring atoms.

  • species (Tensor) – Tensor of shape (N_atoms,) containing the atomic species indices.

  • centers (Tensor) – Tensor of shape (N_edges,) containing the indices of the center atoms for each edge.

  • neighbors (Tensor) – Tensor of shape (N_edges,) containing the indices of the neighbor atoms for each edge.

  • system_indices (Tensor) – Tensor of shape (N_atoms,) containing the indices of the systems each atom belongs to.

Returns:

Tensor of shape (N_systems,) containing the total energy for each system.

Return type:

Tensor

supported_outputs() Dict[str, ModelOutput][source]

Get the outputs currently supported by this model.

Returns:

Dictionary mapping output names to their ModelOutput definitions.

Return type:

Dict[str, ModelOutput]

restart(dataset_info: DatasetInfo) MLIPModel[source]

Restart training with a new dataset, potentially with new targets.

Parameters:

dataset_info (DatasetInfo) – New dataset information.

Returns:

Updated model instance.

Return type:

MLIPModel

classmethod load_checkpoint(checkpoint: Dict[str, Any], context: Literal['restart', 'finetune', 'export']) MLIPModel[source]

Load a model from a checkpoint.

Parameters:
  • checkpoint (Dict[str, Any]) – Checkpoint dictionary.

  • context (Literal['restart', 'finetune', 'export']) – Context for loading (restart, finetune, or export).

Returns:

Loaded model instance.

Return type:

MLIPModel

get_checkpoint() Dict[source]

Get a checkpoint dictionary for saving the model.

Returns:

Checkpoint dictionary.

Return type:

Dict

export(metadata: ModelMetadata | None = None) AtomisticModel[source]

Export the model to a metatomic AtomisticModel.

Parameters:

metadata (ModelMetadata | None) – Optional metadata to merge with the model’s metadata.

Returns:

Exported AtomisticModel.

Return type:

AtomisticModel

classmethod upgrade_checkpoint(checkpoint: Dict[str, Any]) Dict[str, Any][source]

Upgrade the checkpoint to the current version of the model.

This method should be implemented by derived classes if they need to upgrade checkpoints between versions. The base MLIPModel implementation is version 1 and does not require any upgrades yet.

Parameters:

checkpoint (Dict[str, Any]) – Checkpoint’s state dictionary.

Raises:

RuntimeError – if the checkpoint cannot be upgraded to the current version of the model.

Returns:

The upgraded checkpoint.

Return type:

Dict[str, Any]

metatrain.utils.mlip.get_mlip_scheduler(optimizer: Optimizer, train_hypers: Dict[str, Any], steps_per_epoch: int) LambdaLR[source]

Get a CosineAnnealing learning-rate scheduler with warmup for MLIP trainers.

Parameters:
  • optimizer (Optimizer) – The optimizer for which to create the scheduler.

  • train_hypers (Dict[str, Any]) – The training hyperparameters.

  • steps_per_epoch (int) – The number of steps per epoch.

Returns:

The learning rate scheduler.

Return type:

LambdaLR

class metatrain.utils.mlip.MLIPTrainer(hypers: Dict[str, Any])[source]

Bases: TrainerInterface

Base trainer class for MLIP-only architectures.

This class is a base trainer for MLIP-only models. It implements the complete training loop and handles:

  • Distributed training

  • Data loading with optional rotational augmentation

  • Loss computation

  • Checkpointing

Derived classes only need to implement the use_rotational_augmentation() method to specify whether rotational data augmentation should be used during training.

Note on rotational augmentation: You don’t need rotational augmentation if rotational invariance is enforced in the neural network architecture itself (e.g., through equivariant message passing). However, if your architecture does not enforce rotational invariance, you should use rotational augmentation to ensure the model learns rotationally invariant representations.

Example:

from metatrain.utils.mlip import MLIPTrainer


class MyMLIPTrainer(MLIPTrainer):
    def use_rotational_augmentation(self) -> bool:
        # Return True to use rotational augmentation, False otherwise
        return False
Parameters:

hypers (Dict[str, Any]) – Training hyperparameters.

abstractmethod use_rotational_augmentation() bool[source]

Specify whether the trainer should use rotational augmentation.

Returns:

True if rotational augmentation should be used, False otherwise.

Return type:

bool

train(model: MLIPModel, dtype: dtype, devices: List[device], train_datasets: List[Dataset | Subset], val_datasets: List[Dataset | Subset], checkpoint_dir: str) None[source]

Train the MLIP model.

Parameters:
  • model (MLIPModel) – The MLIP model to train.

  • dtype (dtype) – The dtype to use for training.

  • devices (List[device]) – The devices to use for training.

  • train_datasets (List[Dataset | Subset]) – The training datasets.

  • val_datasets (List[Dataset | Subset]) – The validation datasets.

  • checkpoint_dir (str) – The directory to save checkpoints.

Return type:

None

save_checkpoint(model: ModelInterface, path: str | Path) None[source]

Save a checkpoint of the model and trainer state.

Parameters:
  • model (ModelInterface) – The model to save.

  • path (str | Path) – Path where the checkpoint will be saved.

Return type:

None

classmethod load_checkpoint(checkpoint: Dict[str, Any], hypers: Dict[str, Any], context: Literal['restart', 'finetune']) MLIPTrainer[source]

Load trainer state from a checkpoint.

Parameters:
  • checkpoint (Dict[str, Any]) – Dictionary containing the checkpoint data.

  • hypers (Dict[str, Any]) – Training hyperparameters.

  • context (Literal['restart', 'finetune']) – Loading context (‘restart’ or ‘finetune’).

Returns:

Initialized trainer with loaded state.

Return type:

MLIPTrainer

classmethod upgrade_checkpoint(checkpoint: Dict) Dict[source]

Upgrade the checkpoint to the current version of the trainer.

This method should be implemented by derived classes if they need to upgrade checkpoints between versions. The base MLIPTrainer implementation is version 1 and does not require any upgrades yet.

Parameters:

checkpoint (Dict) – Checkpoint’s state dictionary.

Raises:

RuntimeError – if the checkpoint cannot be upgraded to the current version of the trainer.

Returns:

The upgraded checkpoint.

Return type:

Dict