CD-GPT:通过生物中心法则建模的基础模型

生物学基础模型

CD-GPT 是由腾讯 AI Lab 科学智能团队(Tencent AI4S)开发的一个创新性基础模型。其核心思想是借鉴生物学中的"中心法则"(Central Dogma),将生物大分子(如蛋白质、RNA、DNA)的复杂相互作用和功能生成过程,统一建模在一个连贯的、可学习的框架内。

该模型旨在解决生命科学领域的一个关键挑战:如何理解和预测从基因序列(DNA)到功能分子(如蛋白质)这一多层次、跨模态的生物信息流。传统方法通常孤立地处理不同分子类型或生物过程。CD-GPT 的创新之处在于,它尝试构建一个能够理解和模拟"DNA → RNA → 蛋白质"这一核心信息流及其调控关系的统一模型。

通过这种基于生物中心法则的建模方式,CD-GPT 有望在多个前沿生物计算任务上展现出强大潜力,例如:

  • 蛋白质设计:根据特定功能需求,逆向生成或优化蛋白质序列。
  • 非编码RNA功能预测:理解RNA分子的结构与其调控功能之间的关系。
  • 基因调控网络解析:更准确地模拟基因表达如何受序列和表观遗传因素调控。

总而言之,CD-GPT 代表了一种将深刻生物学原理与前沿人工智能技术相结合的新范式,为解码生命复杂系统、加速药物发现和合成生物学等领域的研究提供了有力的计算工具。

模型基座

python 复制代码
class CDGPT(nn.Module):
    CONFIG = {
        "cdgpt-1b": dict(num_layers=12, num_heads=24, embedding_dim=2304),
        "cdgpt-7b": dict(num_layers=32, num_heads=32, embedding_dim=4096)
    }

    @classmethod
    def from_config(cls, cfg):
        model_type = cfg.model.type
        if model_type:
            mcfg = cls.CONFIG[model_type]
            num_layers, num_heads, embedding_dim = mcfg['num_layers'], mcfg['num_heads'], mcfg['embedding_dim']
        else:
            num_layers = cfg.model.num_layers
            num_heads = cfg.model.num_heads
            embedding_dim = cfg.model.num_hiddens
        pad_id = SentencePieceTokenizer(cfg).pad_id
        return {
            "vocab_size": cfg.tokenizer.vocab_size,
            "max_len": cfg.model.max_len,
            "embedding_dim": embedding_dim,
            "num_layers": num_layers,
            "num_heads": num_heads,
            "pad_id": pad_id
        }

    @configurable
    def __init__(self,
                 vocab_size: int,
                 max_len: int = 1024,
                 embedding_dim=2304,
                 num_layers: int = 12,
                 num_heads: int = 24,
                 bias=False,
                 eps=1e-5,
                 pad_id=None,
                 include_head=True):
        super().__init__()
        self.vocab_size = vocab_size
        self.max_len = max_len
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.embedding_dim = embedding_dim
        self.eps = eps
        self.transformer = nn.ModuleDict(
            dict(
                wte=nn.Embedding(self.vocab_size, self.embedding_dim),
                h=nn.ModuleList([
                    Block(self.embedding_dim, self.num_heads, self.max_len, eps=self.eps) for _ in
                    range(self.num_layers)
                ]),
                ln_f=RMSNorm(self.embedding_dim, eps=self.eps),
            )
        )
        self.Block = Block
        self.lm_head = nn.Linear(embedding_dim, vocab_size, bias=bias) if include_head else None
        self.rope_cache = None
        self.kv_caches = []
        self.pad_id = pad_id
        self.apply(self._init_weights)
        self.activation_checkpoint = False
        self.activation_checkpoint_func = checkpoint
        n_params = sum(p.numel() for p in self.parameters())
        print("number of parameters: %.2fM" % (n_params / 1e6,))

    def enable_activation_checkpoint(self, enabled=True):
        self.activation_checkpoint = enabled

    def finetune_vocab(self):
        h = self.transformer.h
        for moudle in h:
            moudle.requires_grad_(False)

    def _init_weights(self, module: nn.Module) -> None:
        if isinstance(module, nn.Linear):
            nn.init.normal_(module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.num_layers))
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.num_layers))

    def _make_casual_mask(self, device):
        """
        Args:
            input_ids: [bs, seq_len]
        """
        ones = torch.ones((self.max_len, self.max_len), dtype=torch.bool, device=device)
        return torch.tril(ones)[None, None]

    def _make_rope_mask(self, device, dtype=torch.int64):
        return precompute_freqs_cis(
            seq_len=self.max_len,
            n_elem=self.embedding_dim // self.num_heads,
            dtype=dtype,
            device=device
        )

    def _forward_embedding_impl(self, input_ids):
        x = self.transformer.wte(input_ids)  # [bs, seq_len, hidden_dim]
        return x

    def _forward_head_impl(self, x):
        if self.lm_head is not None:
            x = self.lm_head(x)  # (b, t, vocab_size)
        return x

    def forward(self,
                input_ids: torch.Tensor,
                attention_mask: Optional[torch.Tensor] = None,
                pos_ids: Optional[torch.Tensor] = None):
        """
        Args:
            input_ids: [bs, seq_len], input token indics
            attention_mask: [bs, 1, seq_len, seq_len], attention mask, when it's none,
                default casual mask
            pos_ids: [seq_len or 1], use it when inference generating new token id or
                keep it none when training.
        """
        bs, seq_len = input_ids.shape
        device = input_ids.device
        dtype = input_ids.dtype
        assert (
                seq_len <= self.max_len
        ), f"Cannot forward sequence of length {seq_len}, max length is only {self.max_len}"
        if self.rope_cache is None:
            self.rope_cache = self._make_rope_mask(device, dtype)  # [max_len, ...]

        if pos_ids is not None:
            rope = self.rope_cache.index_select(0, pos_ids)
            if attention_mask is None:
                attention_mask = self._make_casual_mask(device)
            attention_mask = attention_mask.index_select(2, pos_ids)
            attention_mask = attention_mask[:, :, :, :self.max_len]
        else:
            rope = self.rope_cache[:seq_len]
            if attention_mask is not None:
                attention_mask = attention_mask[:, :, :seq_len, :seq_len]

        x = self._forward_embedding_impl(input_ids)
        if pos_ids is None:
            for block in self.transformer.h:
                if self.activation_checkpoint:
                    x, _, _ = self.activation_checkpoint_func(block, x, rope, attention_mask)
                else:
                    x, _, _ = block(x, rope, attn_mask=attention_mask)
        else:
            if not self.kv_caches:
                head_dim = self.embedding_dim // self.num_heads
                cache_shape = (bs, self.num_heads, self.max_len, head_dim)
                # prelocate memory
                self.kv_caches = [
                    (torch.zeros(cache_shape, device=x.device, dtype=x.dtype),
                     torch.zeros(cache_shape, device=x.device, dtype=x.dtype))
                    for _ in range(self.num_layers)
                ]
            for i, block in enumerate(self.transformer.h):
                x, self.kv_caches[i], _ = block(x, rope,
                                                attn_mask=attention_mask,
                                                pos_ids=pos_ids,
                                                kv_cache=self.kv_caches[i])
        x = self.transformer.ln_f(x)
        x = self._forward_head_impl(x)
        return x

    def get_embedding_pooling(self, input_ids):
        x = self._forward_embedding_impl(input_ids)
        x = x.mean(dim=0)
        return x

    def reset_cache(self):
        self.kv_caches.clear()

    @torch.no_grad()
    def generate(self,
                 token_ids,
                 max_new_tokens,
                 *,
                 top_k: int = 0,
                 top_p: float = 0.,
                 temperature: float = 1.0,
                 output_score: bool = True,
                 stop_ids: Any = None):
        if token_ids.dim() == 2 or isinstance(token_ids, list):
            return [self.generate(t,
                                  max_new_tokens,
                                  top_k=top_k,
                                  top_p=top_p,
                                  temperature=temperature,
                                  output_score=output_score,
                                  stop_ids=stop_ids) for t in token_ids]
        seq_len = token_ids.size(0)
        assert seq_len < self.max_len, f"input token is too long"
        device, dtype = token_ids.device, token_ids.dtype
        max_len = min(self.max_len, seq_len + max_new_tokens)
        # create an empty tensor of the expected final shape and fill in the current tokens
        empty = torch.empty(max_len, dtype=dtype, device=device)
        empty[:seq_len] = token_ids
        token_ids = empty
        scores = [] if output_score else None
        input_pos = torch.arange(0, seq_len, device=device)
        for cur_pos in range(seq_len, max_len):
            x = token_ids.index_select(0, input_pos)[None]
            logits = self(x, pos_ids=input_pos)[:, -1]
            idx_next = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature)[0]
            input_pos = input_pos[-1:] + 1
            # concatenate the new generation
            token_ids = token_ids.index_copy(0, input_pos, idx_next)

            if output_score:
                scores.append(logits.softmax(dim=-1)[0, idx_next])

            if stop_ids is not None and idx_next.item() in stop_ids:
                break

        self.reset_cache()
        return GenerationOutput(sequences=token_ids[:input_pos + 1], scores=scores)

Sequence预测任务

python 复制代码
class CDGPTSequencePrediction(CDGPT):

    @classmethod
    def from_config(cls, cfg):
        pad_id = cfg.tokenizer.pad_id
        num_classes = cfg.model.num_classes
        return {
            "num_classes": num_classes,
            "pad_id": pad_id,
            **super().from_config(cfg)
        }

    @configurable
    def __init__(self,
                 num_classes: int,
                 vocab_size: int,
                 max_len: int = 2048,
                 embedding_dim=2304,
                 num_layers: int = 12,
                 num_heads: int = 24,
                 bias=False,
                 eps=1e-5,
                 pad_id=2,
                 dropout=0.0):
        super().__init__(vocab_size, max_len, embedding_dim, num_layers, num_heads, bias, eps, include_head=False)
        self.num_classes = num_classes
        self.pad_id = pad_id
        self.dropout = dropout
        self.cls_head = SequencePredictionHead(self.embedding_dim, self.num_classes, self.dropout)

    def forward(self,
                input_ids: torch.Tensor,
                attention_mask: Optional[torch.Tensor] = None,
                pos_ids: Optional[torch.Tensor] = None):
        hiddens = super().forward(input_ids, attention_mask, pos_ids)
        result = {}
        if self.pad_id is None:
            sequence_lengths = -1  # last token for classification or regression
        else:
            sequence_lengths = torch.ne(input_ids, self.pad_id).sum(-1) - 1
        batch_size = hiddens.shape[0]
        hiddens = hiddens[torch.arange(batch_size, device=hiddens.device), sequence_lengths]
        res = self.cls_head(hiddens)
        result["output"] = res
        return result

Token预测任务

python 复制代码
class CDGPTTokenPrediction(CDGPT):

    @classmethod
    def from_config(cls, cfg):
        pad_id = cfg.tokenizer.pad_id
        num_classes = cfg.model.num_classes
        return {
            "num_classes": num_classes,
            "pad_id": pad_id,
            **super().from_config(cfg)
        }

    @configurable
    def __init__(self,
                 num_classes,
                 vocab_size: int,
                 max_len: int = 2048,
                 embedding_dim=2304,
                 num_layers: int = 12,
                 num_heads: int = 24,
                 bias=False,
                 eps=1e-5,
                 pad_id=2,
                 dropout=0.0):
        super().__init__(vocab_size=vocab_size,
                         max_len=max_len,
                         embedding_dim=embedding_dim,
                         num_layers=num_layers,
                         num_heads=num_heads,
                         bias=bias,
                         eps=eps,
                         include_head=True)
        self.num_classes = num_classes
        self.pad_id = pad_id
        self.cls_head = TokenPredictionHead(self.embedding_dim, self.num_classes, dropout, num_heads, max_len, eps)

    def forward(self, token_ids, pos_ids=None, attention_mask=None):
        bs, seq_len = token_ids.shape
        device = token_ids.device
        dtype = token_ids.dtype
        assert (
                seq_len <= self.max_len
        ), f"Cannot forward sequence of length {seq_len}, max length is only {self.max_len}"

        if self.rope_cache is None:
            self.rope_cache = self._make_rope_mask(device, dtype)  # [max_len, ...]

        rope = self.rope_cache[:seq_len]
        if attention_mask is not None:
            attention_mask = self.attention_mask[:, :, :seq_len, :seq_len]

        x = self._forward_embedding_impl(token_ids)

        for block in self.transformer.h:
            if self.activation_checkpoint:
                x, _, _ = self.activation_checkpoint_func(block, x, rope, attention_mask, None, None, True)
            else:
                x, _, _ = block(x, rope, attn_mask=attention_mask, need_attn=True)

        x = self.transformer.ln_f(x)
        result = {}
        result["output"] = self.cls_head(x)
        return result

Residue-Pair预测任务

python 复制代码
class CDGPTResiduePairPrediction(CDGPT):

    @classmethod
    def from_config(cls, cfg):
        pad_id = cfg.tokenizer.pad_id
        num_classes = cfg.model.num_classes
        return {
            "num_classes": num_classes,
            "pad_id": pad_id,
            **super().from_config(cfg)
        }

    @configurable
    def __init__(self,
                 num_classes,
                 vocab_size: int,
                 max_len: int = 2048,
                 embedding_dim=2304,
                 num_layers: int = 12,
                 num_heads: int = 24,
                 bias=True,
                 eps=1e-5,
                 pad_id=2,
                 ):
        super().__init__(vocab_size=vocab_size,
                         max_len=max_len,
                         embedding_dim=embedding_dim,
                         num_layers=num_layers,
                         num_heads=num_heads,
                         bias=bias,
                         eps=eps,
                         include_head=True,
                         pad_id=pad_id)
        self.num_classes = num_classes
        self.contact_head = ResiduePairPredictionHead(num_heads * num_layers, self.num_classes, bias)

    def forward(self, token_ids, pos_ids=None, attention_mask=None):
        bs, seq_len = token_ids.shape
        device = token_ids.device
        dtype = token_ids.dtype
        assert (
                seq_len <= self.max_len
        ), f"Cannot forward sequence of length {seq_len}, max length is only {self.max_len}"

        if self.rope_cache is None:
            self.rope_cache = self._make_rope_mask(device, dtype)  # [max_len, ...]

        rope = self.rope_cache[:seq_len]
        if attention_mask is not None:
            attention_mask = self.attention_mask[:, :, :seq_len, :seq_len]

        x = self._forward_embedding_impl(token_ids)
        attn_weights = []
        for block in self.transformer.h:
            if self.activation_checkpoint:
                x, _, attn = self.activation_checkpoint_func(block, x, rope, attention_mask, None, None, True)
            else:
                x, _, attn = block(x, rope, attn_mask=attention_mask, need_attn=True)
            attn_weights.append(attn)
        x = self.transformer.ln_f(x)
        logits = self.lm_head(x)
        result = {}
        # stack attentions 
        attentions = torch.stack(attn_weights, 1)
        contact = self.contact_head(attentions)
        result["output"] = contact
        result["logits"] = logits

        return result

完整流程

请参考TencentAI4S/CD-GPT,CD-GPT: Biological Foundation Model at Full-molecular Level