Source code for metatrain.utils.mlip

import copy
import logging
import math
from abc import abstractmethod
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Union

import torch
from metatensor.torch import Labels, TensorBlock, TensorMap
from metatensor.torch.operations._add import _add_block_block
from metatomic.torch import (
    AtomisticModel,
    ModelCapabilities,
    ModelMetadata,
    ModelOutput,
    NeighborListOptions,
    System,
)
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader, DistributedSampler

from .abc import ModelInterface, TrainerInterface
from .additive import CompositionModel, get_remove_additive_transform
from .augmentation import RotationalAugmenter
from .data import (
    CollateFn,
    CombinedDataLoader,
    Dataset,
    get_num_workers,
    unpack_batch,
    validate_num_workers,
)
from .data.dataset import DatasetInfo
from .distributed.distributed_data_parallel import DistributedDataParallel
from .distributed.slurm import DistributedEnvironment
from .dtype import dtype_to_str
from .evaluate_model import evaluate_model
from .io import check_file_extension
from .logging import ROOT_LOGGER, MetricLogger
from .loss import LossAggregator
from .metadata import merge_metadata
from .metrics import MAEAccumulator, RMSEAccumulator, get_selected_metric
from .neighbor_lists import (
    get_requested_neighbor_lists,
    get_system_with_neighbor_lists_transform,
)
from .per_atom import average_by_num_atoms
from .scaler import Scaler, get_remove_scale_transform
from .transfer import batch_to


[docs] class MLIPModel(ModelInterface): """ Base model class for MLIP-only architectures. This class is a base class for MLIP-only models that predict only energies and forces. It provides: - Common forward pass logic with neighbor list processing - Automatic integration of :py:class:`~metatrain.utils.additive.CompositionModel` for composition-based energy corrections - Automatic integration of :py:class:`~metatrain.utils.scaler.Scaler` for target scaling - Checkpoint saving/loading (``get_checkpoint``, ``load_checkpoint``) - Model export to metatomic format (``export``) - Support for restarting training (``restart``) Derived classes only need to implement the :py:meth:`~metatrain.utils.mlip.MLIPModel.compute_energy` method. The base class automatically handles additive models and scaling at evaluation time, so the derived class only needs to compute the "raw" energy predictions. Example: .. code-block:: python from metatrain.utils.mlip import MLIPModel class MyMLIPModel(MLIPModel): def compute_energy( self, edge_vectors: torch.Tensor, species: torch.Tensor, centers: torch.Tensor, neighbors: torch.Tensor, system_indices: torch.Tensor, ) -> torch.Tensor: # Implement your energy computation here ... return energies # shape: (N_systems,) :param hypers: Model hyperparameters. :param dataset_info: Information about the dataset, including atomic types and targets. """ __checkpoint_version__ = 1 __supported_devices__ = ["cuda", "cpu"] __supported_dtypes__ = [torch.float64, torch.float32] __default_metadata__ = ModelMetadata() def __init__(self, hypers: Dict[str, Any], dataset_info: DatasetInfo) -> None: super().__init__(hypers, dataset_info, self.__default_metadata__) # Infer architecture name from module name # e.g., "metatrain.mlip_example.model" -> "mlip_example" module_parts = self.__class__.__module__.split(".") if len(module_parts) >= 2 and module_parts[0] == "metatrain": self._architecture_name = module_parts[1] else: self._architecture_name = "mlip" if len(dataset_info.targets) > 1: raise ValueError( "MLIPModel only supports datasets with a single target. " f"Found {len(dataset_info.targets)} targets." ) self.target_name = list(dataset_info.targets.keys())[0] if dataset_info.targets[self.target_name].quantity != "energy": raise ValueError( "MLIPModel only supports datasets with an energy as target quantity. " f"Found '{dataset_info.targets[self.target_name].quantity}'." ) if not dataset_info.targets[self.target_name].is_scalar: raise ValueError( "MLIPModel only supports datasets with a scalar target. " "Found a non-scalar target." ) if dataset_info.targets[self.target_name].per_atom: raise ValueError( "MLIPModel only supports datasets with a total energy target. " "Found a per-atom target." ) num_properties = len( dataset_info.targets[self.target_name].layout.block().properties ) if num_properties > 1: raise ValueError( "MLIPModel only supports datasets with a single sub-target. " f"Found {num_properties} sub-targets." ) self.atomic_types = dataset_info.atomic_types # Create outputs dictionary from targets self.outputs = {} for target_name, target_info in dataset_info.targets.items(): self.outputs[target_name] = ModelOutput( quantity=target_info.quantity, unit=target_info.unit, per_atom=target_info.per_atom, ) # Add position gradients (forces) as a supported output if "positions" in target_info.gradients: self.outputs[f"{target_name}_positions_gradients"] = ModelOutput( unit=f"{target_info.unit}/{dataset_info.length_unit}", per_atom=True, ) # Additive models: these are handled by the trainer at training time # and they are added to the output at evaluation time composition_model = CompositionModel( hypers={}, dataset_info=DatasetInfo( length_unit=dataset_info.length_unit, atomic_types=self.atomic_types, targets={ target_name: target_info for target_name, target_info in dataset_info.targets.items() if CompositionModel.is_valid_target(target_name, target_info) }, ), ) self.additive_models = torch.nn.ModuleList([composition_model]) # Scaler: this is also handled by the trainer at training time self.scaler = Scaler(hypers={}, dataset_info=dataset_info) # Track whether new targets have been added (for restart/finetuning) self.has_new_targets = False
[docs] def forward( self, systems: List[System], outputs: Dict[str, ModelOutput], selected_atoms: Optional[Labels] = None, ) -> Dict[str, TensorMap]: # Check that we're only being asked for supported outputs for output_name in outputs: if output_name not in self.outputs: raise ValueError( f"Output '{output_name}' is not supported by this model. " f"Supported outputs: {list(self.outputs.keys())}" ) if selected_atoms is not None: raise ValueError( "MLIPModel does not support the 'selected_atoms' argument." ) positions = [] centers = [] neighbors = [] species = [] cell_shifts = [] cells = [] node_counter = 0 for system in systems: positions.append(system.positions) species.append(system.types) assert len(system.known_neighbor_lists()) == 1, "no neighbor list found" neighbor_list = system.get_neighbor_list(self.nl_options) nl_values = neighbor_list.samples.values centers.append(nl_values[:, 0] + node_counter) neighbors.append(nl_values[:, 1] + node_counter) cell_shifts.append(nl_values[:, 2:]) cells.append(system.cell) node_counter += len(system.positions) positions = torch.cat(positions) centers = torch.cat(centers) neighbors = torch.cat(neighbors) species = torch.cat(species) cells = torch.stack(cells) cell_shifts = torch.cat(cell_shifts) system_indices = torch.concatenate( [ torch.full( (len(system),), i_system, device=positions.device, ) for i_system, system in enumerate(systems) ], ) # somehow the backward of this operation is very slow at evaluation, # where there is only one cell, therefore we simplify the calculation # for that case if len(cells) == 1: cell_contributions = cell_shifts.to(cells.dtype) @ cells[0] else: cell_contributions = torch.einsum( "ab, abc -> ac", cell_shifts.to(cells.dtype), cells[system_indices[centers]], ) edge_vectors = positions[neighbors] - positions[centers] + cell_contributions energy_as_tensor = self.compute_energy( edge_vectors, species, centers, neighbors, system_indices ) energy_as_tensor_map = TensorMap( keys=Labels( ["_"], torch.tensor([[0]], dtype=torch.int64, device=energy_as_tensor.device), ), blocks=[ TensorBlock( values=energy_as_tensor.unsqueeze(-1), samples=Labels( names=["structure"], values=torch.arange( len(energy_as_tensor), device=energy_as_tensor.device, ).unsqueeze(-1), ), components=[], properties=Labels( names=["energy"], values=torch.tensor( [[0]], dtype=torch.int64, device=energy_as_tensor.device ), ), ) ], ) return_dict = {self.target_name: energy_as_tensor_map} # At evaluation time, add the scaler and additive contributions if not self.training: return_dict = self.scaler(systems, return_dict) for additive_model in self.additive_models: outputs_for_additive_model: Dict[str, ModelOutput] = {} for name, output in outputs.items(): if name in additive_model.outputs: outputs_for_additive_model[name] = output additive_contributions = additive_model( systems, outputs_for_additive_model, selected_atoms ) for name in additive_contributions: # TODO: "manual" sparse sum: update to metatensor.torch.add after # sparse sum is implemented in metatensor.operations if name in return_dict: output_blocks: List[TensorBlock] = [] for k, b in return_dict[name].items(): if k in additive_contributions[name].keys: output_blocks.append( _add_block_block( b, additive_contributions[name] .block(k) .to(device=b.device, dtype=b.dtype), ) ) else: output_blocks.append(b) return_dict[name] = TensorMap( return_dict[name].keys, output_blocks ) return return_dict
[docs] def request_neighbor_list(self, cutoff: float) -> None: self.nl_options = NeighborListOptions( cutoff=cutoff, full_list=True, strict=True, ) def requested_neighbor_lists() -> List[NeighborListOptions]: return [self.nl_options] self.requested_neighbor_lists = requested_neighbor_lists
[docs] @abstractmethod def compute_energy( self, edge_vectors: torch.Tensor, species: torch.Tensor, centers: torch.Tensor, neighbors: torch.Tensor, system_indices: torch.Tensor, ) -> torch.Tensor: """ Compute the total energy given the edge vectors and other information. :param edge_vectors: Tensor of shape (N_edges, 3) containing the vectors between neighboring atoms. :param species: Tensor of shape (N_atoms,) containing the atomic species indices. :param centers: Tensor of shape (N_edges,) containing the indices of the center atoms for each edge. :param neighbors: Tensor of shape (N_edges,) containing the indices of the neighbor atoms for each edge. :param system_indices: Tensor of shape (N_atoms,) containing the indices of the systems each atom belongs to. :return: Tensor of shape (N_systems,) containing the total energy for each system. """
[docs] def supported_outputs(self) -> Dict[str, ModelOutput]: """ Get the outputs currently supported by this model. :return: Dictionary mapping output names to their ModelOutput definitions. """ return self.outputs
[docs] def restart(self, dataset_info: DatasetInfo) -> "MLIPModel": """ Restart training with a new dataset, potentially with new targets. :param dataset_info: New dataset information. :return: Updated model instance. """ # Merge old and new dataset info merged_info = self.dataset_info.union(dataset_info) new_atomic_types = [ at for at in merged_info.atomic_types if at not in self.atomic_types ] new_targets = { key: value for key, value in merged_info.targets.items() if key not in self.dataset_info.targets } self.has_new_targets = len(new_targets) > 0 if len(new_atomic_types) > 0: raise ValueError( f"New atomic types found in the dataset: {new_atomic_types}. " "The MLIPModel does not support adding new atomic types." ) if self.has_new_targets: raise ValueError( "New targets found in the dataset. " "The MLIPModel does not support adding new targets." ) self.dataset_info = merged_info # Restart the composition model and scaler self.additive_models[0] = self.additive_models[0].restart( dataset_info=DatasetInfo( length_unit=dataset_info.length_unit, atomic_types=self.atomic_types, targets={ target_name: target_info for target_name, target_info in dataset_info.targets.items() if CompositionModel.is_valid_target(target_name, target_info) }, ), ) self.scaler = self.scaler.restart(dataset_info) return self
[docs] @classmethod def load_checkpoint( cls, checkpoint: Dict[str, Any], context: Literal["restart", "finetune", "export"], ) -> "MLIPModel": """ Load a model from a checkpoint. :param checkpoint: Checkpoint dictionary. :param context: Context for loading (restart, finetune, or export). :return: Loaded model instance. """ if context == "restart": logging.info(f"Using latest model from epoch {checkpoint['epoch']}") model_state_dict = checkpoint["model_state_dict"] elif context in {"finetune", "export"}: logging.info(f"Using best model from epoch {checkpoint['best_epoch']}") model_state_dict = checkpoint["best_model_state_dict"] else: raise ValueError("Unknown context tag for checkpoint loading!") # Create the model model_data = checkpoint["model_data"] model = cls( hypers=model_data["model_hypers"], dataset_info=model_data["dataset_info"], ) dtype = next(iter(model_state_dict.values())).dtype model.to(dtype).load_state_dict(model_state_dict) model.additive_models[0].sync_tensor_maps() model.scaler.sync_tensor_maps() # Loading the metadata from the checkpoint model.metadata = merge_metadata(model.metadata, checkpoint.get("metadata")) return model
[docs] def get_checkpoint(self) -> Dict: """ Get a checkpoint dictionary for saving the model. :return: Checkpoint dictionary. """ checkpoint = { "architecture_name": self._architecture_name, "model_ckpt_version": self.__checkpoint_version__, "metadata": self.metadata, "model_data": { "model_hypers": self.hypers, "dataset_info": self.dataset_info, }, "epoch": None, "best_epoch": None, "model_state_dict": self.state_dict(), "best_model_state_dict": self.state_dict(), } return checkpoint
[docs] def export(self, metadata: Optional[ModelMetadata] = None) -> AtomisticModel: """ Export the model to a metatomic AtomisticModel. :param metadata: Optional metadata to merge with the model's metadata. :return: Exported AtomisticModel. """ dtype = next(self.parameters()).dtype if dtype not in self.__supported_dtypes__: raise ValueError(f"unsupported dtype {dtype} for MLIPModel") # Make sure the model is all in the same dtype self.to(dtype) # The composition model contains some TensorMaps that need to be moved self.additive_models[0].weights_to(torch.device("cpu"), torch.float64) # Get interaction range from neighbor list cutoff interaction_range = self.nl_options.cutoff capabilities = ModelCapabilities( outputs=self.outputs, atomic_types=self.atomic_types, interaction_range=interaction_range, length_unit=self.dataset_info.length_unit, supported_devices=self.__supported_devices__, dtype=dtype_to_str(dtype), ) metadata = merge_metadata(self.metadata, metadata) return AtomisticModel(self.eval(), metadata, capabilities)
[docs] @classmethod def upgrade_checkpoint(cls, checkpoint: Dict[str, Any]) -> Dict[str, Any]: """ Upgrade the checkpoint to the current version of the model. This method should be implemented by derived classes if they need to upgrade checkpoints between versions. The base MLIPModel implementation is version 1 and does not require any upgrades yet. :param checkpoint: Checkpoint's state dictionary. :raises RuntimeError: if the checkpoint cannot be upgraded to the current version of the model. :return: The upgraded checkpoint. """ if checkpoint["model_ckpt_version"] != cls.__checkpoint_version__: raise RuntimeError( f"Unable to upgrade the checkpoint: the checkpoint is using model " f"version {checkpoint['model_ckpt_version']}, while the current " f"model version is {cls.__checkpoint_version__}." ) return checkpoint
[docs] def get_mlip_scheduler( optimizer: torch.optim.Optimizer, train_hypers: Dict[str, Any], steps_per_epoch: int ) -> LambdaLR: """ Get a CosineAnnealing learning-rate scheduler with warmup for MLIP trainers. :param optimizer: The optimizer for which to create the scheduler. :param train_hypers: The training hyperparameters. :param steps_per_epoch: The number of steps per epoch. :return: The learning rate scheduler. """ total_steps = train_hypers["num_epochs"] * steps_per_epoch warmup_steps = int(train_hypers["warmup_fraction"] * total_steps) min_lr_ratio = 0.0 # hardcoded for now, could be made configurable in the future def lr_lambda(current_step: int) -> float: if current_step < warmup_steps: # Linear warmup return float(current_step) / float(max(1, warmup_steps)) else: # Cosine decay progress = (current_step - warmup_steps) / float( max(1, total_steps - warmup_steps) ) cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress)) return min_lr_ratio + (1.0 - min_lr_ratio) * cosine_decay scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda) return scheduler
[docs] class MLIPTrainer(TrainerInterface): """ Base trainer class for MLIP-only architectures. This class is a base trainer for MLIP-only models. It implements the complete training loop and handles: - Distributed training - Data loading with optional rotational augmentation - Loss computation - Checkpointing Derived classes only need to implement the :py:meth:`~metatrain.utils.mlip.MLIPTrainer.use_rotational_augmentation` method to specify whether rotational data augmentation should be used during training. Note on rotational augmentation: You don't need rotational augmentation if rotational invariance is enforced in the neural network architecture itself (e.g., through equivariant message passing). However, if your architecture does not enforce rotational invariance, you should use rotational augmentation to ensure the model learns rotationally invariant representations. Example: .. code-block:: python from metatrain.utils.mlip import MLIPTrainer class MyMLIPTrainer(MLIPTrainer): def use_rotational_augmentation(self) -> bool: # Return True to use rotational augmentation, False otherwise return False :param hypers: Training hyperparameters. """ __checkpoint_version__ = 1 def __init__(self, hypers: Dict[str, Any]) -> None: super().__init__(hypers) self.optimizer_state_dict: Optional[Dict[str, Any]] = None self.scheduler_state_dict: Optional[Dict[str, Any]] = None self.epoch: Optional[int] = None self.best_epoch: Optional[int] = None self.best_metric: Optional[float] = None self.best_model_state_dict: Optional[Dict[str, Any]] = None self.best_optimizer_state_dict: Optional[Dict[str, Any]] = None
[docs] @abstractmethod def use_rotational_augmentation(self) -> bool: """ Specify whether the trainer should use rotational augmentation. :return: True if rotational augmentation should be used, False otherwise. """
[docs] def train( self, model: MLIPModel, dtype: torch.dtype, devices: List[torch.device], train_datasets: List[Union[Dataset, torch.utils.data.Subset]], val_datasets: List[Union[Dataset, torch.utils.data.Subset]], checkpoint_dir: str, ) -> None: """ Train the MLIP model. :param model: The MLIP model to train. :param dtype: The dtype to use for training. :param devices: The devices to use for training. :param train_datasets: The training datasets. :param val_datasets: The validation datasets. :param checkpoint_dir: The directory to save checkpoints. """ assert dtype in model.__supported_dtypes__ is_distributed = self.hypers["distributed"] if is_distributed: if len(devices) > 1: raise ValueError( "Requested distributed training with the `multi-gpu` device. " "If you want to run distributed training with this MLIP model, " "please set `device` to cuda." ) # the calculation of the device number works both when GPUs on different # processes are not visible to each other and when they are distr_env = DistributedEnvironment(self.hypers["distributed_port"]) device_number = distr_env.local_rank % torch.cuda.device_count() device = torch.device("cuda", device_number) torch.distributed.init_process_group(backend="nccl", device_id=device) world_size = torch.distributed.get_world_size() rank = torch.distributed.get_rank() else: rank = 0 device = devices[0] # only one device, as we don't support non-distributed multi-gpu for now if is_distributed: logging.info(f"Training on {world_size} devices with dtype {dtype}") else: logging.info(f"Training on device {device} with dtype {dtype}") # Move the model to the device and dtype: model.to(device=device, dtype=dtype) # The additive models are always in float64 (to avoid numerical errors in # the composition weights, which can be very large). for additive_model in model.additive_models: additive_model.to(dtype=torch.float64) model.scaler.to(dtype=torch.float64) logging.info("Calculating composition weights") model.additive_models[0].train_model( # this is the composition model train_datasets, model.additive_models[1:], self.hypers["batch_size"], is_distributed, self.hypers["fixed_composition_weights"], ) if self.hypers["scale_targets"]: logging.info("Calculating scaling weights") model.scaler.train_model( train_datasets, model.additive_models, self.hypers["batch_size"], is_distributed, self.hypers["fixed_scaling_weights"], ) logging.info("Setting up data loaders") if is_distributed: train_samplers = [ DistributedSampler( train_dataset, num_replicas=world_size, rank=rank, shuffle=True, drop_last=True, ) for train_dataset in train_datasets ] val_samplers = [ DistributedSampler( val_dataset, num_replicas=world_size, rank=rank, shuffle=False, drop_last=False, ) for val_dataset in val_datasets ] else: train_samplers = [None] * len(train_datasets) val_samplers = [None] * len(val_datasets) # Extract additive models and scaler and move them to CPU/float64 so they # can be used in the collate function model.additive_models[0].weights_to(device="cpu", dtype=torch.float64) additive_models = copy.deepcopy( model.additive_models.to(dtype=torch.float64, device="cpu") ) model.additive_models.to(device) model.additive_models[0].weights_to(device=device, dtype=torch.float64) model.scaler.scales_to(device="cpu", dtype=torch.float64) scaler = copy.deepcopy(model.scaler.to(dtype=torch.float64, device="cpu")) model.scaler.to(device) model.scaler.scales_to(device=device, dtype=torch.float64) # Create collate function(s): dataset_info = model.dataset_info train_targets = dataset_info.targets extra_data_info = dataset_info.extra_data requested_neighbor_lists = get_requested_neighbor_lists(model) # Check if rotational augmentation should be used use_augmentation = self.use_rotational_augmentation() if use_augmentation: # Create separate collate functions for train and validation rotational_augmenter = RotationalAugmenter( target_info_dict=train_targets, extra_data_info_dict=extra_data_info ) collate_fn_train = CollateFn( target_keys=list(train_targets.keys()), callables=[ rotational_augmenter.apply_random_augmentations, get_system_with_neighbor_lists_transform(requested_neighbor_lists), get_remove_additive_transform(additive_models, train_targets), get_remove_scale_transform(scaler), ], ) collate_fn_val = CollateFn( target_keys=list(train_targets.keys()), callables=[ # no augmentation for validation get_system_with_neighbor_lists_transform(requested_neighbor_lists), get_remove_additive_transform(additive_models, train_targets), get_remove_scale_transform(scaler), ], ) else: # Use same collate function for both train and validation (no augmentation) collate_fn_train = CollateFn( target_keys=list(train_targets.keys()), callables=[ get_system_with_neighbor_lists_transform(requested_neighbor_lists), get_remove_additive_transform(additive_models, train_targets), get_remove_scale_transform(scaler), ], ) collate_fn_val = collate_fn_train # Create dataloader for the training datasets: if self.hypers["num_workers"] is None: num_workers = get_num_workers() logging.info( "Number of workers for data-loading not provided and chosen " f"automatically. Using {num_workers} workers." ) else: num_workers = self.hypers["num_workers"] validate_num_workers(num_workers) train_dataloaders = [] for train_dataset, train_sampler in zip( train_datasets, train_samplers, strict=True ): if len(train_dataset) < self.hypers["batch_size"]: raise ValueError( f"A training dataset has fewer samples " f"({len(train_dataset)}) than the batch size " f"({self.hypers['batch_size']}). " "Please reduce the batch size." ) train_dataloaders.append( DataLoader( dataset=train_dataset, batch_size=self.hypers["batch_size"], sampler=train_sampler, shuffle=( # the sampler takes care of this (if present) train_sampler is None ), drop_last=( # the sampler takes care of this (if present) train_sampler is None ), collate_fn=collate_fn_train, num_workers=num_workers, ) ) train_dataloader = CombinedDataLoader(train_dataloaders, shuffle=True) # Create dataloader for the validation datasets: val_dataloaders = [] for val_dataset, val_sampler in zip(val_datasets, val_samplers, strict=True): if len(val_dataset) < self.hypers["batch_size"]: raise ValueError( f"A validation dataset has fewer samples " f"({len(val_dataset)}) than the batch size " f"({self.hypers['batch_size']}). " "Please reduce the batch size." ) val_dataloaders.append( DataLoader( dataset=val_dataset, batch_size=self.hypers["batch_size"], sampler=val_sampler, shuffle=False, drop_last=False, collate_fn=collate_fn_val, num_workers=num_workers, ) ) val_dataloader = CombinedDataLoader(val_dataloaders, shuffle=False) if is_distributed: model = DistributedDataParallel(model, device_ids=[device]) # Extract all the possible outputs and their gradients: train_targets = (model.module if is_distributed else model).dataset_info.targets outputs_list = [] for target_name, target_info in train_targets.items(): outputs_list.append(target_name) for gradient_name in target_info.gradients: outputs_list.append(f"{target_name}_{gradient_name}_gradients") # Create a loss function: loss_hypers = self.hypers["loss"] loss_fn = LossAggregator( targets=train_targets, config=loss_hypers, ) logging.info("Using the following loss functions:") for name, info in loss_fn.metadata.items(): logging.info(f"{name}:") main = {k: v for k, v in info.items() if k != "gradients"} logging.info(main) if "gradients" not in info or len(info["gradients"]) == 0: continue logging.info("With gradients:") for grad, ginfo in info["gradients"].items(): logging.info(f"\t{name}::{grad}: {ginfo}") # Create an optimizer: optimizer = torch.optim.Adam( model.parameters(), lr=self.hypers["learning_rate"] ) if self.optimizer_state_dict is not None: # try to load the optimizer state dict, but this is only possible # if there are no new targets in the model (new parameters) if not (model.module if is_distributed else model).has_new_targets: optimizer.load_state_dict(self.optimizer_state_dict) # Create a learning rate scheduler lr_scheduler = get_mlip_scheduler(optimizer, self.hypers, len(train_dataloader)) if self.scheduler_state_dict is not None: # same as the optimizer, try to load the scheduler state dict if not (model.module if is_distributed else model).has_new_targets: lr_scheduler.load_state_dict(self.scheduler_state_dict) # per-atom targets: per_structure_targets = self.hypers["per_structure_targets"] # Log the initial learning rate: old_lr = optimizer.param_groups[0]["lr"] logging.info(f"Initial learning rate: {old_lr}") start_epoch = 0 if self.epoch is None else self.epoch + 1 # Train the model: if self.best_metric is None: self.best_metric = float("inf") logging.info("Starting training") epoch = start_epoch for epoch in range(start_epoch, start_epoch + self.hypers["num_epochs"]): if is_distributed: for train_sampler in train_samplers: train_sampler.set_epoch(epoch) train_rmse_calculator = RMSEAccumulator(self.hypers["log_separate_blocks"]) val_rmse_calculator = RMSEAccumulator(self.hypers["log_separate_blocks"]) if self.hypers["log_mae"]: train_mae_calculator = MAEAccumulator( self.hypers["log_separate_blocks"] ) val_mae_calculator = MAEAccumulator(self.hypers["log_separate_blocks"]) train_loss = 0.0 for batch in train_dataloader: optimizer.zero_grad() systems, targets, extra_data = unpack_batch(batch) systems, targets, extra_data = batch_to( systems, targets, extra_data, dtype=dtype, device=device ) predictions = evaluate_model( model, systems, {key: train_targets[key] for key in targets.keys()}, is_training=True, ) # average by the number of atoms predictions = average_by_num_atoms( predictions, systems, per_structure_targets ) targets = average_by_num_atoms(targets, systems, per_structure_targets) train_loss_batch = loss_fn(predictions, targets, extra_data) if is_distributed: # make sure all parameters contribute to the gradient calculation # to make torch DDP happy for param in model.parameters(): train_loss_batch += 0.0 * param.sum() train_loss_batch.backward() optimizer.step() lr_scheduler.step() if is_distributed: # sum the loss over all processes torch.distributed.all_reduce(train_loss_batch) train_loss += train_loss_batch.item() scaled_predictions = (model.module if is_distributed else model).scaler( systems, predictions ) scaled_targets = (model.module if is_distributed else model).scaler( systems, targets ) train_rmse_calculator.update( scaled_predictions, scaled_targets, extra_data ) if self.hypers["log_mae"]: train_mae_calculator.update( scaled_predictions, scaled_targets, extra_data ) finalized_train_info = train_rmse_calculator.finalize( not_per_atom=["positions_gradients"] + per_structure_targets, is_distributed=is_distributed, device=device, ) if self.hypers["log_mae"]: finalized_train_info.update( train_mae_calculator.finalize( not_per_atom=["positions_gradients"] + per_structure_targets, is_distributed=is_distributed, device=device, ) ) val_loss = 0.0 for batch in val_dataloader: systems, targets, extra_data = unpack_batch(batch) systems, targets, extra_data = batch_to( systems, targets, extra_data, dtype=dtype, device=device ) predictions = evaluate_model( model, systems, {key: train_targets[key] for key in targets.keys()}, is_training=False, ) # average by the number of atoms predictions = average_by_num_atoms( predictions, systems, per_structure_targets ) targets = average_by_num_atoms(targets, systems, per_structure_targets) val_loss_batch = loss_fn(predictions, targets, extra_data) if is_distributed: # sum the loss over all processes torch.distributed.all_reduce(val_loss_batch) val_loss += val_loss_batch.item() scaled_predictions = (model.module if is_distributed else model).scaler( systems, predictions ) scaled_targets = (model.module if is_distributed else model).scaler( systems, targets ) val_rmse_calculator.update( scaled_predictions, scaled_targets, extra_data ) if self.hypers["log_mae"]: val_mae_calculator.update( scaled_predictions, scaled_targets, extra_data ) finalized_val_info = val_rmse_calculator.finalize( not_per_atom=["positions_gradients"] + per_structure_targets, is_distributed=is_distributed, device=device, ) if self.hypers["log_mae"]: finalized_val_info.update( val_mae_calculator.finalize( not_per_atom=["positions_gradients"] + per_structure_targets, is_distributed=is_distributed, device=device, ) ) # Now we log the information: finalized_train_info = {"loss": train_loss, **finalized_train_info} finalized_val_info = {"loss": val_loss, **finalized_val_info} if epoch == start_epoch: metric_logger = MetricLogger( log_obj=ROOT_LOGGER, dataset_info=( model.module if is_distributed else model ).dataset_info, initial_metrics=[finalized_train_info, finalized_val_info], names=["training", "validation"], ) if epoch % self.hypers["log_interval"] == 0: metric_logger.log( metrics=[finalized_train_info, finalized_val_info], epoch=epoch, rank=rank, learning_rate=optimizer.param_groups[0]["lr"], ) val_metric = get_selected_metric( finalized_val_info, self.hypers["best_model_metric"] ) if val_metric < self.best_metric: self.best_metric = val_metric self.best_model_state_dict = copy.deepcopy( (model.module if is_distributed else model).state_dict() ) self.best_epoch = epoch self.best_optimizer_state_dict = copy.deepcopy(optimizer.state_dict()) if epoch % self.hypers["checkpoint_interval"] == 0: if is_distributed: torch.distributed.barrier() self.optimizer_state_dict = optimizer.state_dict() self.scheduler_state_dict = lr_scheduler.state_dict() self.epoch = epoch if rank == 0: self.save_checkpoint( (model.module if is_distributed else model), Path(checkpoint_dir) / f"model_{epoch}.ckpt", ) # prepare for the checkpoint that will be saved outside the function self.epoch = epoch self.optimizer_state_dict = optimizer.state_dict() self.scheduler_state_dict = lr_scheduler.state_dict() if is_distributed: torch.distributed.destroy_process_group()
[docs] def save_checkpoint(self, model: ModelInterface, path: Union[str, Path]) -> None: """ Save a checkpoint of the model and trainer state. :param model: The model to save. :param path: Path where the checkpoint will be saved. """ checkpoint = model.get_checkpoint() checkpoint.update( { "train_hypers": self.hypers, "trainer_ckpt_version": self.__checkpoint_version__, "epoch": self.epoch, "optimizer_state_dict": self.optimizer_state_dict, "scheduler_state_dict": self.scheduler_state_dict, "best_epoch": self.best_epoch, "best_metric": self.best_metric, "best_model_state_dict": self.best_model_state_dict, "best_optimizer_state_dict": self.best_optimizer_state_dict, } ) torch.save( checkpoint, check_file_extension(path, ".ckpt"), )
[docs] @classmethod def load_checkpoint( cls, checkpoint: Dict[str, Any], hypers: Dict[str, Any], context: Literal["restart", "finetune"], # not used at the moment ) -> "MLIPTrainer": """ Load trainer state from a checkpoint. :param checkpoint: Dictionary containing the checkpoint data. :param hypers: Training hyperparameters. :param context: Loading context ('restart' or 'finetune'). :return: Initialized trainer with loaded state. """ trainer = cls(hypers) trainer.optimizer_state_dict = checkpoint["optimizer_state_dict"] trainer.scheduler_state_dict = checkpoint["scheduler_state_dict"] trainer.epoch = checkpoint["epoch"] trainer.best_epoch = checkpoint["best_epoch"] trainer.best_metric = checkpoint["best_metric"] trainer.best_model_state_dict = checkpoint["best_model_state_dict"] trainer.best_optimizer_state_dict = checkpoint["best_optimizer_state_dict"] return trainer
[docs] @classmethod def upgrade_checkpoint(cls, checkpoint: Dict) -> Dict: """ Upgrade the checkpoint to the current version of the trainer. This method should be implemented by derived classes if they need to upgrade checkpoints between versions. The base MLIPTrainer implementation is version 1 and does not require any upgrades yet. :param checkpoint: Checkpoint's state dictionary. :raises RuntimeError: if the checkpoint cannot be upgraded to the current version of the trainer. :return: The upgraded checkpoint. """ if checkpoint["trainer_ckpt_version"] != cls.__checkpoint_version__: raise RuntimeError( f"Unable to upgrade the checkpoint: the checkpoint is using trainer " f"version {checkpoint['trainer_ckpt_version']}, while the current " f"trainer version is {cls.__checkpoint_version__}." ) return checkpoint