from torchtree.cli import PLUGIN_MANAGER
from torchtree.cli.evolution import (
create_alignment,
create_evolution_joint,
create_evolution_parser,
create_site_model_srd06_mus,
create_taxa,
)
from torchtree.cli.utils import make_unconstrained
[docs]
def create_map_parser(subprasers):
parser = subprasers.add_parser(
'map', help='build a JSON file for maximum a posteriori inference'
)
create_evolution_parser(parser)
parser.add_argument(
'--lr',
default=1.0,
type=float,
help="""learning rate""",
)
parser.add_argument(
'--max_iter',
type=int,
default=20,
help="""maximal number of iterations per optimization step (default: 20)""",
)
parser.add_argument(
'--max_eval',
type=int,
help="""maximal number of function evaluations per optimization step
(default: max_iter * 1.25)""",
)
parser.add_argument(
'--tolerance_grad',
type=float,
default=1e-5,
help="""termination tolerance on first order optimality (default: 1e-5)""",
)
parser.add_argument(
'--tolerance_change',
type=float,
default=1e-9,
help="""termination tolerance on function value/parameter changes
(default: 1e-9)""",
)
parser.add_argument(
'--history_size',
type=int,
default=100,
help="""update history size (default: 100)""",
)
parser.add_argument(
'--line_search_fn',
type=str,
help="""either 'strong_wolfe' or None (default: None)""",
)
parser.add_argument(
'--stem',
required=True,
help="""stem for output files""",
)
parser.set_defaults(func=build_optimizer)
return parser
[docs]
def create_optimizer(joint, parameters, arg):
return {
"id": "bfgs",
"type": "Optimizer",
"algorithm": "torch.optim.LBFGS",
"options": {"lr": arg.lr},
"maximize": True,
"iterations": 10,
"max_iter": arg.max_iter,
"loss": joint,
"parameters": [parameter["id"] for parameter in parameters],
}
[docs]
def create_logger(id_, parameters, arg):
return {
"id": id_,
"type": "Logger",
"parameters": parameters,
"file_name": arg.stem + '.csv',
}
[docs]
def create_tree_logger(id_, tree_id, arg):
return {
"id": id_,
"type": "TreeLogger",
"tree_model": tree_id,
"file_name": arg.stem + '.tree',
}
[docs]
def build_optimizer(arg):
json_list = []
taxa = create_taxa('taxa', arg)
json_list.append(taxa)
alignment = create_alignment('alignment', 'taxa', arg)
json_list.append(alignment)
if arg.model == 'SRD06':
json_list.append(create_site_model_srd06_mus('srd06.mus'))
joint_dic = create_evolution_joint(taxa, 'alignment', arg)
json_list.append(joint_dic)
parameters_unres, parameters = make_unconstrained(json_list)
opt_dict = create_optimizer('joint', parameters_unres, arg)
json_list.append(opt_dict)
logger_dict = create_logger('logger', parameters, arg)
tree_logger_dict = create_tree_logger('tree.logger', 'tree', arg)
json_list.extend((logger_dict, tree_logger_dict))
for plugin in PLUGIN_MANAGER.plugins():
plugin.process_all(arg, json_list)
return json_list