# -*- coding: utf-8 -*-
"""Multistage work chain."""
from copy import deepcopy
from aiida.common import AttributeDict
from aiida.engine import append_, while_, WorkChain, ToContext
from aiida.engine import calcfunction, workfunction
from aiida.orm import Dict, Int, Float, SinglefileData, Str, RemoteData, StructureData
from aiida.plugins import WorkflowFactory
from aiida_lsmo.utils import dict_merge, HARTREE2EV
from aiida_lsmo.utils.multiply_unitcell import check_resize_unit_cell_legacy, resize_unit_cell
from aiida_lsmo.utils.cp2k_utils import ot_has_small_bandgap, get_kinds_section, get_multiplicity_section
from .cp2k_multistage_protocols import load_isotherm_protocol, set_initial_conditions
Cp2kBaseWorkChain = WorkflowFactory('cp2k.base') # pylint: disable=invalid-name
[docs]@workfunction
def get_initial_magnetization(structure, protocol, with_ghost_atoms=None):
"""Prepare structure with correct initial magnetization.
Returns modified structuredata (possibly with specific atomic kinds for different inital magnetizations)
as well as corresponding cp2k parameters dict.
:param structure: AiiDA StructureData
:param protocol: AiiDA Dict with appropriate cp2k parameters (kinds and multiplicity)
:param with_ghost_atoms: if true, add ghost atoms for BSSE counterpoise correction (optional)
:returns: {'structure': StructureData, 'cp2k_param': Dict }
"""
protocol_dict = protocol.get_dict()
if protocol_dict['initial_magnetization'] == 'oxidation_state':
from aiida_lsmo.calcfunctions.oxidation_state import compute_oxidation_states
oxidation_states = compute_oxidation_states(structure)
return apply_initial_magnetization(structure, protocol, oxidation_states, with_ghost_atoms=with_ghost_atoms)
return apply_initial_magnetization(structure, protocol, with_ghost_atoms=with_ghost_atoms)
[docs]@calcfunction
def apply_initial_magnetization(structure, protocol, oxidation_states=None, with_ghost_atoms=None):
"""Prepare structure with correct initial magnetization.
Returns modified structuredata (possibly with specific atomic kinds for different inital magnetizations)
as well as corresponding cp2k parameters dict.
Note: AiiDA does not allow one calcfunction to call another, which forces this split between workfunction
and calcfunction.
:param structure: AiiDA StructureData
:param protocol: AiiDA Dict with appropriate cp2k parameters (kinds and multiplicity)
:param oxidation_states: Oxidation state computed with oximachine (optional)
:param with_ghost_atoms: if true, add ghost atoms for BSSE counterpoise correction (optional)
:returns: {'structure': StructureData, 'cp2k_param': Dict }
"""
atoms = structure.get_ase()
protocol_dict = protocol.get_dict()
if oxidation_states is not None:
atoms = set_initial_conditions(atoms=atoms,
initial_magnetization=protocol_dict['initial_magnetization'],
oxidation_states=oxidation_states)
else:
atoms = set_initial_conditions(atoms=atoms, initial_magnetization=protocol_dict['initial_magnetization'])
cp2k_param = get_kinds_section(atoms=atoms, protocol=protocol_dict, with_ghost_atoms=bool(with_ghost_atoms))
dict_merge(cp2k_param, get_multiplicity_section(atoms=atoms, protocol=protocol_dict))
return {'structure': StructureData(ase=atoms), 'cp2k_param': Dict(dict=cp2k_param)}
[docs]class Cp2kMultistageWorkChain(WorkChain):
"""Submits Cp2kBase workchains for ENERGY, GEO_OPT, CELL_OPT and MD jobs iteratively
The protocol_yaml file contains a series of settings_x and stage_x:
the workchains starts running the settings_0/stage_0 calculation, and, in case of a failure, changes the settings
untill the SCF of stage_0 converges. Then it uses the same settings to run the next stages (i.e., stage_1, etc.).
"""
[docs] @classmethod
def define(cls, spec):
super().define(spec)
# Inputs
spec.expose_inputs(Cp2kBaseWorkChain,
namespace='cp2k_base',
exclude=['cp2k.structure', 'cp2k.parameters', 'cp2k.metadata.options.parser_name'])
spec.input('structure', valid_type=StructureData, required=False, help='Input structure')
spec.input('protocol_tag',
valid_type=Str,
default=lambda: Str('standard'),
required=False,
help='The tag of the protocol to be read from {tag}.yaml unless protocol_yaml input is specified')
spec.input('protocol_yaml',
valid_type=SinglefileData,
required=False,
help='Specify a custom yaml file with the multistage settings (and ignore protocol_tag)')
spec.input('protocol_modify',
valid_type=Dict,
default=lambda: Dict(dict={}),
required=False,
help='Specify custom settings that overvrite the yaml settings')
spec.input('starting_settings_idx',
valid_type=Int,
default=lambda: Int(0),
required=False,
help='If idx>0 is chosen, jumps directly to overwrite settings_0 with settings_{idx}')
spec.input('min_cell_size',
valid_type=Float,
default=lambda: Float(0.0),
required=False,
help='To avoid using k-points, extend the cell so that min(perp_width)>min_cell_size')
spec.input('parent_calc_folder',
valid_type=RemoteData,
required=False,
help='Provide an initial parent folder that contains the wavefunction for restart')
spec.input(
'cp2k_base.cp2k.parameters',
valid_type=Dict,
required=False,
help='Specify custom CP2K settings to overwrite the input dictionary just before submitting the CalcJob')
spec.input('cp2k_base.cp2k.metadata.options.parser_name',
valid_type=str,
default='lsmo.cp2k_advanced_parser',
non_db=True,
help='Parser of the calculation: the default is cp2k_advanced_parser to get the necessary info')
# Workchain outline
spec.outline(
cls.setup_multistage,
while_(cls.should_run_stage0)(
cls.run_stage,
cls.inspect_and_update_settings_stage0,
),
cls.inspect_and_update_stage,
while_(cls.should_run_stage)(
cls.run_stage,
cls.inspect_and_update_stage,
),
cls.results,
)
# Exit codes
spec.exit_code(901, 'ERROR_MISSING_INITIAL_SETTINGS',
'Specified starting_settings_idx that is not existing, or any in between 0 and idx is missing')
spec.exit_code(902, 'ERROR_NO_MORE_SETTINGS',
'Settings for Stage0 are not ok but there are no more robust settings to try')
spec.exit_code(903, 'ERROR_PARSING_OUTPUT',
'Something important was not printed correctly and the parsing of the first calculation failed')
# Outputs
spec.expose_outputs(Cp2kBaseWorkChain, include=['remote_folder'])
spec.output('output_structure',
valid_type=StructureData,
required=False,
help='Processed structure (missing if only ENERGY calculation is performed)')
spec.output('last_input_parameters',
valid_type=Dict,
required=False,
help='CP2K input parameters used (and possibly working) used in the last stage')
spec.output('output_parameters',
valid_type=Dict,
required=False,
help='Output CP2K parameters of all the stages, merged together')
[docs] def setup_multistage(self):
"""Setup initial parameters."""
# Store the workchain inputs in context (to be modified later)
self.ctx.base_inp = AttributeDict(self.exposed_inputs(Cp2kBaseWorkChain, 'cp2k_base'))
# Check if an input parent_calc_folder is provided
if 'parent_calc_folder' in self.inputs:
self.ctx.parent_calc_folder = self.inputs.parent_calc_folder
else:
self.ctx.parent_calc_folder = None
# Read yaml file selected as SinglefileData or chosen with the tag, and overwrite with custom modifications
if 'protocol_yaml' in self.inputs:
self.ctx.protocol = load_isotherm_protocol(singlefiledata=self.inputs.protocol_yaml)
else:
self.ctx.protocol = load_isotherm_protocol(tag=self.inputs.protocol_tag.value)
dict_merge(self.ctx.protocol, self.inputs.protocol_modify.get_dict())
# Initialize
self.ctx.settings_ok = False
self.ctx.stage_idx = 0
self.ctx.stage_tag = 'stage_{}'.format(self.ctx.stage_idx)
self.ctx.settings_idx = 0
self.ctx.settings_tag = 'settings_{}'.format(self.ctx.settings_idx)
self.ctx.structure = self.inputs.structure
# Resize the unit cell if min(perp_with) < inputs.min_cell_size
self.ctx.resize = check_resize_unit_cell_legacy(self.ctx.structure, self.inputs.min_cell_size) # Dict
if self.ctx.resize['nx'] > 1 or self.ctx.resize['ny'] > 1 or self.ctx.resize['nz'] > 1:
resized_struct = resize_unit_cell(self.ctx.structure, self.ctx.resize)
self.ctx.structure = resized_struct
self.report('Unit cell resized by {}x{}x{} (StructureData<{}>)'.format(self.ctx.resize['nx'],
self.ctx.resize['ny'],
self.ctx.resize['nz'],
resized_struct.pk))
else:
self.report('Unit cell was NOT resized')
# Generate input parameters and store them
self.ctx.cp2k_param = deepcopy(self.ctx.protocol['settings_0'])
while self.inputs.starting_settings_idx > self.ctx.settings_idx:
# overwrite untill the desired starting setting are obtained
self.ctx.settings_idx += 1
self.ctx.settings_tag = 'settings_{}'.format(self.ctx.settings_idx)
if self.ctx.settings_tag in self.ctx.protocol:
dict_merge(self.ctx.cp2k_param, self.ctx.protocol[self.ctx.settings_tag])
else:
return self.exit_codes.ERROR_MISSING_INITIAL_SETTINGS # pylint: disable=no-member
# handle starting magnetization
results = get_initial_magnetization(self.ctx.structure, Dict(dict=self.ctx.protocol))
self.ctx.structure = results['structure']
dict_merge(self.ctx.cp2k_param, results['cp2k_param'].get_dict())
dict_merge(self.ctx.cp2k_param, self.ctx.protocol['stage_0'])
[docs] def should_run_stage0(self):
"""Returns True if it is the first iteration or the settings are not ok."""
return not self.ctx.settings_ok
[docs] def run_stage(self):
"""Check for restart, prepare input, submit and direct output to context."""
# Update structure
self.ctx.base_inp['cp2k']['structure'] = self.ctx.structure
# Check if it is needed to restart the calculation and provide the parent folder and new structure
if self.ctx.parent_calc_folder:
self.ctx.base_inp['cp2k']['parent_calc_folder'] = self.ctx.parent_calc_folder
self.ctx.cp2k_param['FORCE_EVAL']['DFT']['SCF']['SCF_GUESS'] = 'RESTART'
self.ctx.cp2k_param['FORCE_EVAL']['DFT']['WFN_RESTART_FILE_NAME'] = './parent_calc/aiida-RESTART.wfn'
else:
self.ctx.cp2k_param['FORCE_EVAL']['DFT']['SCF']['SCF_GUESS'] = 'ATOMIC'
# Overwrite the generated input with the custom cp2k/parameters
if 'parameters' in self.exposed_inputs(Cp2kBaseWorkChain, 'cp2k_base')['cp2k']:
dict_merge(
self.ctx.cp2k_param,
AttributeDict(self.exposed_inputs(Cp2kBaseWorkChain, 'cp2k_base')['cp2k']['parameters'].get_dict()))
self.ctx.base_inp['cp2k']['parameters'] = Dict(dict=self.ctx.cp2k_param).store()
# Update labels
self.ctx.base_inp['metadata'].update({
'label': '{}_{}'.format(self.ctx.stage_tag, self.ctx.settings_tag),
'call_link_label': 'run_{}_{}'.format(self.ctx.stage_tag, self.ctx.settings_tag),
})
self.ctx.base_inp['cp2k']['metadata'].update(
{'label': self.ctx.base_inp['cp2k']['parameters'].get_dict()['GLOBAL']['RUN_TYPE']})
running_base = self.submit(Cp2kBaseWorkChain, **self.ctx.base_inp)
self.report('submitted Cp2kBaseWorkChain for {}/{}'.format(self.ctx.stage_tag, self.ctx.settings_tag))
return ToContext(stages=append_(running_base))
[docs] def inspect_and_update_settings_stage0(self): # pylint: disable=inconsistent-return-statements
"""Inspect the stage0/settings_{idx} calculation and check if it is
needed to update the settings and resubmint the calculation."""
self.ctx.settings_ok = True
# Settings/structure are bad: there are problems in parsing the output file
# and, most probably, the calculation didn't even start the scf cycles
if 'output_parameters' in self.ctx.stages[-1].outputs:
cp2k_out = self.ctx.stages[-1].outputs.output_parameters
else:
self.report('ERROR_PARSING_OUTPUT')
return self.exit_codes.ERROR_PARSING_OUTPUT # pylint: disable=no-member
# Settings are bad: the SCF did not converge in the final step
if not cp2k_out['motion_step_info']['scf_converged'][-1]:
self.report('BAD SETTINGS: the SCF did not converge')
self.ctx.settings_ok = False
self.ctx.settings_idx += 1
else:
# SCF converged, but the computed bandgap needs to be checked
self.report('Bandgaps spin1/spin2: {:.3f} and {:.3f} ev'.format(cp2k_out['bandgap_spin1_au'] * HARTREE2EV,
cp2k_out['bandgap_spin2_au'] * HARTREE2EV))
bandgap_thr_ev = self.ctx.protocol['bandgap_thr_ev']
if ot_has_small_bandgap(self.ctx.cp2k_param, cp2k_out, bandgap_thr_ev):
self.report('BAD SETTINGS: band gap is < {:.3f} eV'.format(bandgap_thr_ev))
self.ctx.settings_ok = False
self.ctx.settings_idx += 1
# Update the settings tag, check if it is available and overwrite
if not self.ctx.settings_ok:
cp2k_out.label = '{}_{}_discard'.format(self.ctx.stage_tag, self.ctx.settings_tag)
next_settings_tag = 'settings_{}'.format(self.ctx.settings_idx)
if next_settings_tag in self.ctx.protocol:
self.ctx.settings_tag = next_settings_tag
dict_merge(self.ctx.cp2k_param, self.ctx.protocol[self.ctx.settings_tag])
else:
return self.exit_codes.ERROR_NO_MORE_SETTINGS # pylint: disable=no-member
[docs] def inspect_and_update_stage(self):
"""Update geometry, parent folder and the new &MOTION settings."""
last_stage = self.ctx.stages[-1]
if 'output_structure' in last_stage.outputs:
self.ctx.structure = last_stage.outputs.output_structure
self.report('Structure updated for next stage')
else:
self.report('New structure NOT found and NOT updated for next stage')
self.ctx.parent_calc_folder = last_stage.outputs.remote_folder
last_stage.outputs.output_parameters.label = '{}_{}_valid'.format(self.ctx.stage_tag, self.ctx.settings_tag)
self.ctx.stage_idx += 1
next_stage_tag = 'stage_{}'.format(self.ctx.stage_idx)
if next_stage_tag in self.ctx.protocol:
self.ctx.stage_tag = next_stage_tag
self.ctx.next_stage_exists = True
dict_merge(self.ctx.cp2k_param, self.ctx.protocol[self.ctx.stage_tag])
else:
self.ctx.next_stage_exists = False
self.report('All stages computed, finishing...')
[docs] def should_run_stage(self):
"""Return True if it exists a new stage to compute."""
return self.ctx.next_stage_exists
[docs] def results(self):
"""Gather final outputs of the workchain."""
# Gather all the ouput_parameters in a final Dict
all_output_parameters = {}
for i, stage in enumerate(self.ctx.stages):
all_output_parameters['out_{}'.format(i)] = stage.outputs.output_parameters
self.out('output_parameters', extract_results(resize=self.ctx.resize, **all_output_parameters))
# Output the final parameters that worked as a Dict
self.out('last_input_parameters', self.ctx.base_inp['cp2k']['parameters'])
# Output the final remote folder
self.out_many(self.exposed_outputs(self.ctx.stages[-1], Cp2kBaseWorkChain))
# Output the final structure only if it was modified (there is any MD or OPT stage)
if 'output_structure' in self.ctx.stages[-1].outputs:
self.out('output_structure', self.ctx.stages[-1].outputs.output_structure)
self.report('Outputs: Dict<{}> and StructureData<{}>'.format(self.outputs['output_parameters'].pk,
self.outputs['output_structure'].pk))
else:
self.report('Outputs: Dict<{}> and NO StructureData'.format(self.outputs['output_parameters'].pk))