Author Topic: Finetuning MACE  (Read 179 times)

0 Members and 1 Guest are viewing this topic.

Offline AsifShah

  • QuantumATK Guru
  • ****
  • Posts: 219
  • Country: in
  • Reputation: 4
    • View Profile
Finetuning MACE
« on: Today at 02:37 »
Dear Admin, I am fine-tuning (multi-head) a MACE model "mace-mp-0b3-medium.model" with interface  ~3000 configurations obtained from optimization of various interface structures between two materials. The caculator used to generate the data was LCAO-PAW. However, I observe that the training loss is more than validation loss and it does not converge easily below 2. I am using latest version of QuantumATK Y-2026.03
Code
# %% MACEFittingParameters

model_parameters = MACEModelParameters(
    foundation_model_path='/home/MHeadFineTune/AB_PAW/Model2/mace-mp-0b3-medium.model'
)

replay_finetuning_settings = MACEReplayFinetuningSettings(
    replay_data_filepath='/home/MHeadFineTune/AB_PAW/Model2/mp_traj_combined.xyz',
    number_of_samples=10000,
    replay_subselect_method=MLParameterOptions.REPLAY_SUBSELECT.RANDOM,
    replay_filtering_type=MLParameterOptions.REPLAY_FILTERING.COMBINATIONS,
)

dataset_parameters = ForceFieldDatasetParameters(
    dataset_name=None,
    validation_fraction=0.2,
    isolated_atom_energies=None,
    energy_key='REF_energy',
    forces_key='REF_forces',
    stress_key='REF_stress',
    energy_weight=1,
    forces_weight=100.0,
    stress_weight=1,
    compute_stress=True,
    forces_cap=None,
    replay_finetuning_settings=replay_finetuning_settings,
)

training_parameters = TrainingParameters(
    experiment_name='AB_replay_finetuning',
    batch_size=5,
    max_number_of_epochs=200,
    patience=50,
    device=Automatic,
    random_seed=123,
    number_of_workers=0,
    default_dtype=MLParameterOptions.DTYPE.FLOAT64,
    learning_rate=0.005,
    weight_decay=5e-07,
    restart_from_last_checkpoint=True,
    scheduler_patience=5,
    gradient_clipping_threshold=100,
    save_all_available_model_formats=True,
    additional_parameters=None,
)

mace_fitting_parameters = MACEFittingParameters(
    model_parameters=model_parameters,
    dataset_parameters=dataset_parameters,
    training_parameters=training_parameters,
)
nlsave('GSiO2hBN_Train_model_with_MultiHFine.hdf5', mace_fitting_parameters)


# %% MachineLearnedForceFieldTrainer

machine_learned_force_field_trainer = MachineLearnedForceFieldTrainer(
    fitting_parameters=mace_fitting_parameters,
    training_sets=combined_training_set_training_set_0,
    calculator=cam_AB_training_set_lcao_calculator_0,
    train_test_split=0.9,
    random_seed=None,
    save_model_evaluator=True,
)
machine_learned_force_field_trainer.train()
Log file
Code
Epoch |   Train loss | dataset val. loss | foundation_replay val. loss |
Combined val. loss
    1 |      18.9857 |       9.0502 |       9.6832 |       9.5235
    2 |      16.2392 |       9.5699 |       9.8213 |       9.7579
    3 |      15.9793 |       8.4948 |      14.1380 |      12.7140
    4 |      15.7344 |       7.0205 |       9.3637 |       8.7724
    5 |      15.2174 |       7.4290 |       9.1267 |       8.6983
    6 |      15.1447 |       7.6615 |       9.5313 |       9.0595
    7 |      13.9269 |       7.7902 |       9.3499 |       8.9563
    8 |      14.3858 |       6.9831 |       8.4594 |       8.0869
    9 |      14.6542 |       6.1258 |       9.1843 |       8.4125
   10 |      67.0137 |       7.2914 |       9.0968 |       8.6412
   11 |      14.6468 |       7.3211 |       9.6342 |       9.0505
   12 |      13.7790 |       6.7579 |      72.1496 |      55.6489
   13 |    3348.1013 |       6.4292 |       9.5679 |       8.7759
   14 |      13.6505 |       6.6999 |      13.4243 |      11.7275
   15 |      18.6541 |       4.8157 |       8.1373 |       7.2992
   16 |      10.8260 |       4.5590 |       9.3064 |       8.1084
   17 |      10.2169 |       5.3469 |       8.0607 |       7.3759
   18 |      10.8128 |       4.8280 |       7.9056 |       7.1290
   19 |      44.6345 |       4.8770 |       8.8867 |       7.8749
   20 |      13.1652 |       6.0382 |       8.0510 |       7.5431
   21 |       9.9637 |       5.0135 |       8.2736 |       7.4510
   22 |       9.9831 |       5.1497 |       7.8262 |       7.1508
   23 |       9.9076 |       5.5356 |       8.3363 |       7.6296
   24 |       9.4363 |       5.0831 |       8.0683 |       7.3150
   25 |       9.2744 |       4.0919 |       8.0174 |       7.0269
   26 |       8.5802 |       4.0783 |       7.9753 |       6.9919
   27 |       8.5956 |       3.7595 |       7.9519 |       6.8940
   28 |       7.9154 |       4.1679 |       8.0973 |       7.1057
   29 |       7.9012 |       4.1561 |       8.0263 |       7.0497
   30 |       7.8318 |       3.7461 |       7.9410 |       6.8825
   31 |       7.6374 |       3.7943 |       8.1701 |       7.0659
   32 |       8.1747 |       3.6043 |       8.1010 |       6.9663
   33 |       8.0958 |       3.6114 |       8.4863 |       7.2561
   34 |       7.4335 |       3.9294 |       8.1361 |       7.0746
   35 |       7.4023 |       4.0456 |       8.1316 |       7.1006
   36 |       7.2933 |       3.7167 |       8.4425 |       7.2500
   37 |       7.0873 |       3.3984 |       8.2840 |       7.0512
   38 |       6.3931 |       3.4524 |       8.4852 |       7.2152
   39 |       6.6867 |       3.2221 |       8.5646 |       7.2165
   40 |       6.4043 |       3.3317 |       8.5464 |       7.2305
   41 |       6.6129 |       3.1848 |       8.6768 |       7.2910
   42 |       6.4041 |       3.3490 |       8.6733 |       7.3298
   43 |       6.3622 |       3.0963 |       8.7256 |       7.3051
   44 |       6.0477 |       3.0227 |       8.8481 |       7.3782
   45 |       5.9473 |       3.0272 |       8.8475 |       7.3789
   46 |       5.9198 |       3.0129 |       8.8019 |       7.3411