# %% MACEFittingParameters
model_parameters = MACEModelParameters(
foundation_model_path='/home/MHeadFineTune/AB_PAW/Model2/mace-mp-0b3-medium.model'
)
replay_finetuning_settings = MACEReplayFinetuningSettings(
replay_data_filepath='/home/MHeadFineTune/AB_PAW/Model2/mp_traj_combined.xyz',
number_of_samples=10000,
replay_subselect_method=MLParameterOptions.REPLAY_SUBSELECT.RANDOM,
replay_filtering_type=MLParameterOptions.REPLAY_FILTERING.COMBINATIONS,
)
dataset_parameters = ForceFieldDatasetParameters(
dataset_name=None,
validation_fraction=0.2,
isolated_atom_energies=None,
energy_key='REF_energy',
forces_key='REF_forces',
stress_key='REF_stress',
energy_weight=1,
forces_weight=100.0,
stress_weight=1,
compute_stress=True,
forces_cap=None,
replay_finetuning_settings=replay_finetuning_settings,
)
training_parameters = TrainingParameters(
experiment_name='AB_replay_finetuning',
batch_size=5,
max_number_of_epochs=200,
patience=50,
device=Automatic,
random_seed=123,
number_of_workers=0,
default_dtype=MLParameterOptions.DTYPE.FLOAT64,
learning_rate=0.005,
weight_decay=5e-07,
restart_from_last_checkpoint=True,
scheduler_patience=5,
gradient_clipping_threshold=100,
save_all_available_model_formats=True,
additional_parameters=None,
)
mace_fitting_parameters = MACEFittingParameters(
model_parameters=model_parameters,
dataset_parameters=dataset_parameters,
training_parameters=training_parameters,
)
nlsave('GSiO2hBN_Train_model_with_MultiHFine.hdf5', mace_fitting_parameters)
# %% MachineLearnedForceFieldTrainer
machine_learned_force_field_trainer = MachineLearnedForceFieldTrainer(
fitting_parameters=mace_fitting_parameters,
training_sets=combined_training_set_training_set_0,
calculator=cam_AB_training_set_lcao_calculator_0,
train_test_split=0.9,
random_seed=None,
save_model_evaluator=True,
)
machine_learned_force_field_trainer.train()
Epoch | Train loss | dataset val. loss | foundation_replay val. loss |
Combined val. loss
1 | 18.9857 | 9.0502 | 9.6832 | 9.5235
2 | 16.2392 | 9.5699 | 9.8213 | 9.7579
3 | 15.9793 | 8.4948 | 14.1380 | 12.7140
4 | 15.7344 | 7.0205 | 9.3637 | 8.7724
5 | 15.2174 | 7.4290 | 9.1267 | 8.6983
6 | 15.1447 | 7.6615 | 9.5313 | 9.0595
7 | 13.9269 | 7.7902 | 9.3499 | 8.9563
8 | 14.3858 | 6.9831 | 8.4594 | 8.0869
9 | 14.6542 | 6.1258 | 9.1843 | 8.4125
10 | 67.0137 | 7.2914 | 9.0968 | 8.6412
11 | 14.6468 | 7.3211 | 9.6342 | 9.0505
12 | 13.7790 | 6.7579 | 72.1496 | 55.6489
13 | 3348.1013 | 6.4292 | 9.5679 | 8.7759
14 | 13.6505 | 6.6999 | 13.4243 | 11.7275
15 | 18.6541 | 4.8157 | 8.1373 | 7.2992
16 | 10.8260 | 4.5590 | 9.3064 | 8.1084
17 | 10.2169 | 5.3469 | 8.0607 | 7.3759
18 | 10.8128 | 4.8280 | 7.9056 | 7.1290
19 | 44.6345 | 4.8770 | 8.8867 | 7.8749
20 | 13.1652 | 6.0382 | 8.0510 | 7.5431
21 | 9.9637 | 5.0135 | 8.2736 | 7.4510
22 | 9.9831 | 5.1497 | 7.8262 | 7.1508
23 | 9.9076 | 5.5356 | 8.3363 | 7.6296
24 | 9.4363 | 5.0831 | 8.0683 | 7.3150
25 | 9.2744 | 4.0919 | 8.0174 | 7.0269
26 | 8.5802 | 4.0783 | 7.9753 | 6.9919
27 | 8.5956 | 3.7595 | 7.9519 | 6.8940
28 | 7.9154 | 4.1679 | 8.0973 | 7.1057
29 | 7.9012 | 4.1561 | 8.0263 | 7.0497
30 | 7.8318 | 3.7461 | 7.9410 | 6.8825
31 | 7.6374 | 3.7943 | 8.1701 | 7.0659
32 | 8.1747 | 3.6043 | 8.1010 | 6.9663
33 | 8.0958 | 3.6114 | 8.4863 | 7.2561
34 | 7.4335 | 3.9294 | 8.1361 | 7.0746
35 | 7.4023 | 4.0456 | 8.1316 | 7.1006
36 | 7.2933 | 3.7167 | 8.4425 | 7.2500
37 | 7.0873 | 3.3984 | 8.2840 | 7.0512
38 | 6.3931 | 3.4524 | 8.4852 | 7.2152
39 | 6.6867 | 3.2221 | 8.5646 | 7.2165
40 | 6.4043 | 3.3317 | 8.5464 | 7.2305
41 | 6.6129 | 3.1848 | 8.6768 | 7.2910
42 | 6.4041 | 3.3490 | 8.6733 | 7.3298
43 | 6.3622 | 3.0963 | 8.7256 | 7.3051
44 | 6.0477 | 3.0227 | 8.8481 | 7.3782
45 | 5.9473 | 3.0272 | 8.8475 | 7.3789
46 | 5.9198 | 3.0129 | 8.8019 | 7.3411