提取网络特征(extract features)

本文仅为记录代码

def main():

python 复制代码
if __name__ == '__main__':
    # setup random seed
    setup(seed=42)
    # Avoid the pylint warning.
    a = MolVocab
    # supress rdkit logger
    lg = RDLogger.logger()
    lg.setLevel(RDLogger.CRITICAL)

    # Initialize MolVocab
    mol_vocab = MolVocab

    args = parse_args()
    if args.parser_name == 'finetune':
        logger = create_logger(name='train', save_dir=args.save_dir, quiet=False)
        cross_validate(args, logger)
    elif args.parser_name == 'pretrain':
        logger = create_logger(name='pretrain', save_dir=args.save_dir)
        pretrain_model(args, logger)
    elif args.parser_name == "eval":
        logger = create_logger(name='eval', save_dir=args.save_dir, quiet=False)
        cross_validate(args, logger)
    elif args.parser_name == 'fingerprint':
        train_args = get_newest_train_args()
        logger = create_logger(name='fingerprint', save_dir=None, quiet=False)
        feas = generate_fingerprints(args, logger)
        np.savez_compressed(args.output_path, fps=feas)
    elif args.parser_name == 'predict':
        train_args = get_newest_train_args()
        avg_preds, test_smiles = make_predictions(args, train_args)
        write_prediction(avg_preds, test_smiles, args)

def generate_fingerprints:

python 复制代码
def generate_fingerprints(args: Namespace, logger: Logger = None) -> List[List[float]]:
    """
    Generate the fingerprints.

    :param logger:
    :param args: Arguments.
    :return: A list of lists of target fingerprints.
    """

    checkpoint_path = args.checkpoint_paths[0]
    if logger is None:
        logger = create_logger('fingerprints', quiet=False)
    print('Loading data')
    test_data = get_data(path=args.data_path,
                         args=args,
                         use_compound_names=False,
                         max_data_size=float("inf"),
                         skip_invalid_smiles=False)
    test_data = MoleculeDataset(test_data)

    logger.info(f'Total size = {len(test_data):,}')
    logger.info(f'Generating...')
    # Load model
    model = load_checkpoint(checkpoint_path, cuda=args.cuda, current_args=args, logger=logger)
    model_preds = do_generate(
        model=model,
        data=test_data,
        args=args
    )

    return model_preds

do_generate:

python 复制代码
def do_generate(model: nn.Module,
                data: MoleculeDataset,
                args: Namespace,
                ) -> List[List[float]]:
    """
    Do the fingerprint generation on a dataset using the pre-trained models.

    :param model: A model.
    :param data: A MoleculeDataset.
    :param args: A StandardScaler object fit on the training targets.
    :return: A list of fingerprints.
    """
    model.eval()
    args.bond_drop_rate = 0
    preds = []

    mol_collator = MolCollator(args=args, shared_dict={})

    num_workers = 4
    mol_loader = DataLoader(data,
                            batch_size=32,
                            shuffle=False,
                            num_workers=num_workers,
                            collate_fn=mol_collator)
    for item in mol_loader:
        _, batch, features_batch, _, _ = item
        with torch.no_grad():
            batch_preds = model(batch, features_batch)
            preds.extend(batch_preds.data.cpu().numpy())
    return preds
相关推荐
星浩AI3 分钟前
现在最需要被 PUA 的,其实是 AI
人工智能·后端·github
superior tigre8 分钟前
CUDA算子开发(LLM方向)常见的一些术语
人工智能·加速推理
weixin_4639234211 分钟前
知网更新后,这4种降AI方法已失效!
人工智能
WenGyyyL35 分钟前
ColBERT论文研读——NLP(IR)里程碑之作
人工智能·python·语言模型·自然语言处理
彩旗工作室38 分钟前
Cursor 全面深度指南:从诞生到实战,AI 编程时代的终极武器
人工智能·ai编程
新新学长搞科研43 分钟前
第五届电子、集成电路与通信技术国际学术会议(EICCT 2026)
运维·人工智能·自动化·集成测试·信号处理·集成学习·电气自动化
华奥系科技1 小时前
智慧经济新格局:解码社区、园区与城市一体化建设逻辑
大数据·人工智能·科技·物联网·安全
大模型真好玩1 小时前
大模型训练全流程实战指南工具篇(九)——LLamaFactory大模型训练工具使用指南
人工智能·agent·deepseek
大傻^1 小时前
SpringAI2.0 Tool Calling 进阶:动态模式、ToolContext 与隐式解析
人工智能·springai
阿达_优阅达1 小时前
告别手工对账:xSuite 如何帮助 SAP 企业实现财务全流程自动化?
服务器·数据库·人工智能·自动化·sap·企业数字化转型·xsuite