深度学习中显性特征组合的网络结构crossNet

crossNet

论文获取

1、原理

这个网络结构最开始提出来是在推荐领域,其核心作用就是在不增加太多参数与内存的下显式实现特征交叉组合 .其公式如下: x l x_{l} xl是第l层的输出, x 0 x_{0} x0是原始输入,那么第 x l + 1 x_{l+1} xl+1层就等于原始输入乘以上一层输出实现交叉后加上上一层输出;这样一层就是 x 2 x^2 x2的交叉,两层就有 x 3 x^3 x3的交叉,依次类推,这样就确保实现高阶的特征交叉组合.且其参数是线性增长的.

2、原理分析

  • 首先从公式看其维度是不发生变化的,所以其高阶信息是压缩过的
  • 注意到公式加上一层信息,因此不同层组合,展开其最后一层输出有保留低阶到高阶的的组合信息,但可能存在部分丢失问题
  • 该结构适合在原始输入端使用,代替特征交叉组合的特征工程
  • 该结构不建议过多层,容易导致过拟合,一般1~3层
  • 该结构适合特征本身有交叉关系的场景使用
  • 该结构使用稀疏与稠密的输入

3、实现

以torch 为例,当然实际应用通常希望加一点正则防止过拟合,此时可以通过优化器指定增加该部分结构的权重正则.

python 复制代码
crossnet = CrossNet(...)
crossnet_params = []
othernet = OtherNet(...)
#获取crossnet参数
for name, param in model.named_parameters():
    if "crossnet" in name:
        crossnet_params.append(param)
optimizer = torch.optim.Adam([
    {"params": crossnet.weights, "weight_decay":1e-4}, 
    {"params": crossnet.biases, "weight_decay":0}
], lr=0.001)
python 复制代码
class CrossNet(nn.Module):
    def __init__(self, in_features, num_layers):
        super(CrossNet, self).__init__()
        self.num_layers = num_layers
        self.in_features = in_features
        self.weights = nn.ParameterList([
            nn.Parameter(torch.randn(in_features, 1)) for _ in range(num_layers)
        ])
        self.biases = nn.ParameterList([
            nn.Parameter(torch.randn(in_features)) for _ in range(num_layers)
        ])
        self.reset_parameters()
	def forward(self, x0):
        x = x0
        for i in range(self.num_layers):
            # x: (batch, in_features)
            # x @ w: (batch, 1)
            xw = torch.matmul(x, self.weights[i])  # (batch, 1)
            # x0 * (x @ w) : broadcasting (batch, in_features)
            cross = x0 * xw  # broadcasting
            x = cross + self.biases[i] + x  # (batch, in_features)
        return x
    def reset_parameters(self):
        # 用合理的初始化方法初始化 weights 和 biases
        for w in self.weights:
            nn.init.xavier_uniform_(w)
        for b in self.biases:
            nn.init.zeros_(b)
相关推荐
陈橘又青2 分钟前
100% AI 写的开源项目三周多已获得 800 star 了
人工智能·后端·ai·restful·数据
松岛雾奈.23010 分钟前
深度学习--TensorFlow框架使用
深度学习·tensorflow·neo4j
中杯可乐多加冰20 分钟前
逻辑控制案例详解|基于smardaten实现OA一体化办公系统逻辑交互
人工智能·深度学习·低代码·oa办公·无代码·一体化平台·逻辑控制
IT_陈寒1 小时前
Redis实战:5个高频应用场景下的性能优化技巧,让你的QPS提升50%
前端·人工智能·后端
龙智DevSecOps解决方案1 小时前
Perforce《2025游戏技术现状报告》Part 1:游戏引擎技术的广泛影响以及生成式AI的成熟之路
人工智能·unity·游戏引擎·游戏开发·perforce
大佬,救命!!!1 小时前
更换适配python版本直接进行机器学习深度学习等相关环境配置(非仿真环境)
人工智能·python·深度学习·机器学习·学习笔记·详细配置
星空的资源小屋1 小时前
VNote:程序员必备Markdown笔记神器
javascript·人工智能·笔记·django
梵得儿SHI1 小时前
(第七篇)Spring AI 基础入门总结:四层技术栈全景图 + 三大坑根治方案 + RAG 进阶预告
java·人工智能·spring·springai的四大核心能力·向量维度·prompt模板化·向量存储检索
亚马逊云开发者1 小时前
Amazon Bedrock助力飞书深诺电商广告分类
人工智能
2301_823438021 小时前
解析论文《复杂海上救援环境中无人机群的双阶段协作路径规划与任务分配》
人工智能·算法·无人机