Pytorch风格迁移的Android实现

前言

使用神经网络训练后的模型可以实现图像风格的迁移,比如前一阵非常火爆的吉卜力风格。本文尝试在 Android 实现 PyTorch 风格迁移模型的部署, 实现基于固定风格的迁移。

什么是图像风格迁移

图像风格迁移(Neural Style Transfer, NST) 是一种基于深度学习的图像处理技术,通过结合内容图像的结构和风格图像的艺术特征(如笔触、色彩),生成兼具两者特点的新图像。其核心思想是利用卷积神经网络(如VGG19)提取图像的高层语义特征:

  • 内容保留:通过深层网络响应匹配内容图像的布局和物体轮廓。
  • 风格融合:通过Gram矩阵统计风格图像的纹理、色彩分布等低层特征。
  • 优化生成:以内容图像为初始输入,通过梯度下降最小化与目标内容和风格的损失函数,逐步合成新图像。

典型应用包括将照片转化为梵高、毕加索等艺术风格,或生成吉卜力动画风格的画面。技术变体包括实时风格迁移(如AdaIN)和基于扩散模型的生成方法。

基于 PyTorchLite 的风格迁移

在上一篇 PyTorch对抗生成网络模型及Android端的实现 中,我们通过使用 PyTorch Android 端的 SDK 实现了 GAN 在手机端的部署。但是由于算力和存储的限制,在手机上部署图像生成模型压力还是比较大的,存储和推理速度都是瓶颈。然而, 风格迁移的实现相对来说就比较简单了,算力和存储占用大幅降低的情况之下,也能收获相对来说不错的效果。

可以看到,以上两个示例的效果还是不错的。

PyTorch Lite 处理输入输出

我们先回顾一下 PyTorch Lite 的用法,上一节中对于基于 GAN 框架的生成模型来说,按照训练期间模型的定义,我们的输入是 1x100 的随机数,输出是 1x64x64x3 (64 像素大小彩色图片)。

kotlin 复制代码
    private fun genImage(): Bitmap {
        val zDim = intArrayOf(1, 100)
        val outDims = intArrayOf(64, 64, 3)
    
        val z = FloatArray(zDim[0] * zDim[1])
      
        val rand = Random()
        // 生成高斯随机数
        for (c in 0 until zDim[0] * zDim[1]) {
            z[c] = rand.nextGaussian().toFloat()
        }
    
        val tensor = Tensor.fromBlob(z, longArrayOf(1, 100))
        
        val resultArray = module.forward(IValue.from(tensor)).toTensor().dataAsFloatArray
        val resultImg = Array(outDims[0]) { Array(outDims[1]) { FloatArray(outDims[2]) { 0.0f } } }
        var index = 0
        // 根据输出的一维数组,解析生成的卡通图像
        ....
        val bitmap = Utils.getBitmap(resultImg, outDims)
        return bitmap
    }

因此,我们做的工作就是构建适用于模型输入结构的数据 inputTensor ,同时根据模型返回的数据 resultTensor,将数据转换为我们需要的格式,比如对于生成模型来说,转化为对应平台可以渲染的图像数据即可,对于Android 来说就是常用的 Bitmap。

而对于风格迁移模型来说,一般情况下输入有两个,内容图片和风格图片。这里简单起见,风格类型由模型固化,我们只处理内容图片的输入,生成固定风格的图片。

模型初始化

这里我们直接使用 PyTorch 官方示例中 fast_neural_style 风格迁移所用到的模型。具体模型可以从 这里 下载,转换为 PyTorch Lite 可用的模型即可,我们以风格比较鲜明的 mosaic 为例。

kotlin 复制代码
module = LiteModuleLoader.load(AndroidAssetsFileUtil.assetFilePath(this, "mosaic.pt"))

对 Bitmap 进行风格迁移

kotlin 复制代码
    fun transferStyleAsync(
        contentImage: Bitmap, scale: Float = 1.0f, cb: ((Bitmap) -> Unit)? = null
    ): Bitmap {
        
        // 1. Preprocess the content image (simple ToTensor + multiply by 255)
        val (transformedImage, width, height) = preprocessImage(contentImage, scale)

        // 2. Create input tensor
        val inputTensor = Tensor.fromBlob(
            transformedImage, longArrayOf(1, 3, height.toLong(), width.toLong())
        )
        

        // 3. Run the model
        val outputTensor = module.forward(IValue.from(inputTensor)).toTensor()


        // 4. Postprocessor the output
        return postprocessingImage(outputTensor, width, height, cb)
    }

HWC 和 CHW

这里需要注意的是,Android 中标准的 Bitmap 其数据是 RGB 的格式进行存储的,而 PyTorch 中为了方便内存优化进行计算,是按照 CHW 的格式排列数据的,因此将传统的 Bitmap 传入 PyTorch 进行处理之前需要进行数据转换。同时,由于数据数据的宽高涉及到后续推理模型的计算,因此还需要返回内容图片的宽高值。

返回结果的处理

有了原始 Bitmap 数据和宽高,我们就可以调用模型进行推理了。

kotlin 复制代码
        val inputTensor = Tensor.fromBlob(
            transformedImage, longArrayOf(1, 3, height.toLong(), width.toLong())
        )
        val outputTensor = module.forward(IValue.from(inputTensor)).toTensor()

返回结果依然是 Tensor 类型,我们需要从中获取所需要的 Bitmap 图像数据

kotlin 复制代码
    private fun postprocessingImage(
        outputTensor: Tensor, width: Int, height: Int): Bitmap {
        
        val outputData = outputTensor.dataAsFloatArray

        // Create output bitmap
        val outputBitmap = createBitmap(width, height)

        val pixels = IntArray(width * height)

        // Convert from CHW to ARGB format
        val channelSize = width * height
       
        for (i in 0 until channelSize) {
            // Get RGB values (scaled back from 0-255)
            val r = outputData[i].toInt().coerceIn(0, 255)
            val g = outputData[i + channelSize].toInt().coerceIn(0, 255)
            val b = outputData[i + 2 * channelSize].toInt().coerceIn(0, 255)

            // Combine into ARGB pixel
            pixels[i] = 0xFF shl 24 or (r shl 16) or (g shl 8) or b
        }

        // Set pixels to bitmap
        outputBitmap.setPixels(pixels, 0, width, 0, 0, width, height)
        return outputBitmap
    }

由于返回的 dataAsFloatArray 依然是 CHW 格式的数据,我们需要再执行一次逆向操作,将 CHW 格式的数据转换为 ARGB 格式的 Bitmap 数据。我们可以看一下效果(示例中最后一个就是 mosaic 风格)。

可以顺便看一眼耗时

java 复制代码
16:49:39.471 StyleTransferProcessor   D  transferStyle() called with: contentImage = android.graphics.Bitmap@3e8f24, scale = 0.5
16:49:39.471 StyleTransferProcessor   D  preprocessImage() called with: bitmap = android.graphics.Bitmap@3e8f24, scale = 0.5
16:49:39.477 StyleTransferProcessor   I  scale 468,832
16:49:39.511 StyleTransferProcessor   D  transferStyle() step2 done
16:49:41.332 StyleTransferProcessor   D  transferStyle() step3 done
16:49:41.332 StyleTransferProcessor   D  postprocessingImage() called with: outputTensor = Tensor([1, 3, 832, 468], dtype=torch.float32), width = 468, height = 832
16:49:41.354 StyleTransferProcessor   D  channelSize = 389376
16:49:41.466 StyleTransferProcessor   D  bitmap is ok


16:50:21.199 StyleTransferProcessor   D  transferStyle() called with: contentImage = android.graphics.Bitmap@3e8f24, scale = 0.5
16:50:21.199 StyleTransferProcessor   D  preprocessImage() called with: bitmap = android.graphics.Bitmap@3e8f24, scale = 0.5
16:50:21.202 StyleTransferProcessor   I  scale 468,832
16:50:21.228 StyleTransferProcessor   D  transferStyle() step2 done
16:50:23.159 StyleTransferProcessor   D  transferStyle() step3 done
16:50:23.160 StyleTransferProcessor   D  postprocessingImage() called with: outputTensor = Tensor([1, 3, 832, 468], dtype=torch.float32), width = 468, height = 832
16:50:23.185 StyleTransferProcessor   D  channelSize = 389376
16:50:23.307 StyleTransferProcessor   D  bitmap is ok


16:50:31.999 StyleTransferProcessor   D  transferStyle() called with: contentImage = android.graphics.Bitmap@3e8f24, scale = 0.5
16:50:31.999 StyleTransferProcessor   D  preprocessImage() called with: bitmap = android.graphics.Bitmap@3e8f24, scale = 0.5
16:50:32.002 StyleTransferProcessor   I  scale 468,832
16:50:32.028 StyleTransferProcessor   D  transferStyle() step2 done
16:50:33.716 StyleTransferProcessor   D  transferStyle() step3 done
16:50:33.716 StyleTransferProcessor   D  postprocessingImage() called with: outputTensor = Tensor([1, 3, 832, 468], dtype=torch.float32), width = 468, height = 832
16:50:33.743 StyleTransferProcessor   D  channelSize = 389376
16:50:33.856 StyleTransferProcessor   D  bitmap is ok

可以看到,对原始图片宽高按照 50% 的比例压缩之后,转换时间还是挺快的,2 秒基本上就可以完成一副图片的转换(测试手机为一加 8, 骁龙 865 ,8GB RAM)

总结

有兴趣的话,可以运行官方示例 fast_neural_style,对比一下各自的效果。还可以基于官方示例,训练自己的风格模型,尝试实现不同的风格。

在手机端除了由于内存的限制,以及转换后模型精度的变化,多多少少还是有一些损失。甚至对于一些像素比较高的图片,不进行压缩的话会出现内存不足无法运行的情况。当然,风格迁移的实现方式有很多种,现在很多手机自带的相册和相机都可以实时进行固定风格的迁移,技术总是在不断变化的进步,相信随着这一波生成式人工智能的发展,风格迁移模型可以变得更加强大。

相关推荐
moonless022214 分钟前
🌈Transformer说人话版(二)位置编码 【持续更新ing】
人工智能·llm
小爷毛毛_卓寿杰14 分钟前
基于大模型与知识图谱的对话引导意图澄清系统技术解析
人工智能·llm
聚客AI25 分钟前
解构高效提示工程:分层模型、文本扩展引擎与可视化调试全链路指南
人工智能·llm·掘金·日新计划
喝过期的拉菲32 分钟前
使用 Pytorch Lightning 时追踪指标和可视化指标
pytorch·可视化·lightning·指标追踪
摆烂工程师38 分钟前
Claude Code 落地实践的工作简易流程
人工智能·claude·敏捷开发
亚马逊云开发者40 分钟前
得心应手:探索 MCP 与数据库结合的应用场景
人工智能
移动开发者1号42 分钟前
ReLinker优化So库加载指南
android·kotlin
大明哥_1 小时前
100 个 Coze 精品案例 - 小红书爆款图文,单篇点赞 20000+,用 Coze 智能体一键生成有声儿童绘本!
人工智能
聚客AI1 小时前
🚀拒绝试错成本!企业接入MCP协议的避坑清单
人工智能·掘金·日新计划·mcp
山野万里__1 小时前
C++与Java内存共享技术:跨平台与跨语言实现指南
android·java·c++·笔记