MLIP Base Classes¶
- class metatrain.utils.mlip.MLIPModel(hypers: Dict[str, Any], dataset_info: DatasetInfo)[source]¶
Bases:
ModelInterfaceBase model class for MLIP-only architectures.
This class is a base class for MLIP-only models that predict only energies and forces. It provides:
Common forward pass logic with neighbor list processing
Automatic integration of
CompositionModelfor composition-based energy correctionsAutomatic integration of
Scalerfor target scalingCheckpoint saving/loading (
get_checkpoint,load_checkpoint)Model export to metatomic format (
export)Support for restarting training (
restart)
Derived classes only need to implement the
compute_energy()method.The base class automatically handles additive models and scaling at evaluation time, so the derived class only needs to compute the “raw” energy predictions.
Example:
from metatrain.utils.mlip import MLIPModel class MyMLIPModel(MLIPModel): def compute_energy( self, edge_vectors: torch.Tensor, species: torch.Tensor, centers: torch.Tensor, neighbors: torch.Tensor, system_indices: torch.Tensor, ) -> torch.Tensor: # Implement your energy computation here ... return energies # shape: (N_systems,)
- Parameters:
dataset_info (DatasetInfo) – Information about the dataset, including atomic types and targets.
- forward(systems: List[System], outputs: Dict[str, ModelOutput], selected_atoms: Labels | None = None) Dict[str, TensorMap][source]¶
Execute the model for the given
systems, computing the requestedoutputs.- Parameters:
systems (List[System]) – List of systems to evaluate the model on.
outputs (Dict[str, ModelOutput]) – Dictionary of outputs that the model should compute.
selected_atoms (Labels | None) – Optional
Labelsspecifying a subset of atoms to compute the outputs for. IfNone, the outputs are computed for all atoms in each system.
- Returns:
A dictionary mapping each requested output name to the corresponding
TensorMapcontaining the computed values.- Return type:
See also
metatomic.torch.ModelInterfacefor more explanation about the different arguments.
- abstractmethod compute_energy(edge_vectors: Tensor, species: Tensor, centers: Tensor, neighbors: Tensor, system_indices: Tensor) Tensor[source]¶
Compute the total energy given the edge vectors and other information.
- Parameters:
edge_vectors (Tensor) – Tensor of shape (N_edges, 3) containing the vectors between neighboring atoms.
species (Tensor) – Tensor of shape (N_atoms,) containing the atomic species indices.
centers (Tensor) – Tensor of shape (N_edges,) containing the indices of the center atoms for each edge.
neighbors (Tensor) – Tensor of shape (N_edges,) containing the indices of the neighbor atoms for each edge.
system_indices (Tensor) – Tensor of shape (N_atoms,) containing the indices of the systems each atom belongs to.
- Returns:
Tensor of shape (N_systems,) containing the total energy for each system.
- Return type:
- supported_outputs() Dict[str, ModelOutput][source]¶
Get the outputs currently supported by this model.
- Returns:
Dictionary mapping output names to their ModelOutput definitions.
- Return type:
- restart(dataset_info: DatasetInfo) MLIPModel[source]¶
Restart training with a new dataset, potentially with new targets.
- Parameters:
dataset_info (DatasetInfo) – New dataset information.
- Returns:
Updated model instance.
- Return type:
- classmethod load_checkpoint(checkpoint: Dict[str, Any], context: Literal['restart', 'finetune', 'export']) MLIPModel[source]¶
Load a model from a checkpoint.
- get_checkpoint() Dict[source]¶
Get a checkpoint dictionary for saving the model.
- Returns:
Checkpoint dictionary.
- Return type:
- export(metadata: ModelMetadata | None = None) AtomisticModel[source]¶
Export the model to a metatomic AtomisticModel.
- Parameters:
metadata (ModelMetadata | None) – Optional metadata to merge with the model’s metadata.
- Returns:
Exported AtomisticModel.
- Return type:
- classmethod upgrade_checkpoint(checkpoint: Dict[str, Any]) Dict[str, Any][source]¶
Upgrade the checkpoint to the current version of the model.
This method should be implemented by derived classes if they need to upgrade checkpoints between versions. The base MLIPModel implementation is version 1 and does not require any upgrades yet.
- metatrain.utils.mlip.get_mlip_scheduler(optimizer: Optimizer, train_hypers: Dict[str, Any], steps_per_epoch: int) LambdaLR[source]¶
Get a CosineAnnealing learning-rate scheduler with warmup for MLIP trainers.
- class metatrain.utils.mlip.MLIPTrainer(hypers: Dict[str, Any])[source]¶
Bases:
TrainerInterfaceBase trainer class for MLIP-only architectures.
This class is a base trainer for MLIP-only models. It implements the complete training loop and handles:
Distributed training
Data loading with optional rotational augmentation
Loss computation
Checkpointing
Derived classes only need to implement the
use_rotational_augmentation()method to specify whether rotational data augmentation should be used during training.Note on rotational augmentation: You don’t need rotational augmentation if rotational invariance is enforced in the neural network architecture itself (e.g., through equivariant message passing). However, if your architecture does not enforce rotational invariance, you should use rotational augmentation to ensure the model learns rotationally invariant representations.
Example:
from metatrain.utils.mlip import MLIPTrainer class MyMLIPTrainer(MLIPTrainer): def use_rotational_augmentation(self) -> bool: # Return True to use rotational augmentation, False otherwise return False
- abstractmethod use_rotational_augmentation() bool[source]¶
Specify whether the trainer should use rotational augmentation.
- Returns:
True if rotational augmentation should be used, False otherwise.
- Return type:
- train(model: MLIPModel, dtype: dtype, devices: List[device], train_datasets: List[Dataset | Subset], val_datasets: List[Dataset | Subset], checkpoint_dir: str) None[source]¶
Train the MLIP model.
- Parameters:
- Return type:
None
- save_checkpoint(model: ModelInterface, path: str | Path) None[source]¶
Save a checkpoint of the model and trainer state.
- Parameters:
model (ModelInterface) – The model to save.
path (str | Path) – Path where the checkpoint will be saved.
- Return type:
None
- classmethod load_checkpoint(checkpoint: Dict[str, Any], hypers: Dict[str, Any], context: Literal['restart', 'finetune']) MLIPTrainer[source]¶
Load trainer state from a checkpoint.
- classmethod upgrade_checkpoint(checkpoint: Dict) Dict[source]¶
Upgrade the checkpoint to the current version of the trainer.
This method should be implemented by derived classes if they need to upgrade checkpoints between versions. The base MLIPTrainer implementation is version 1 and does not require any upgrades yet.
- Parameters:
checkpoint (Dict) – Checkpoint’s state dictionary.
- Raises:
RuntimeError – if the checkpoint cannot be upgraded to the current version of the trainer.
- Returns:
The upgraded checkpoint.
- Return type: