前言
使用神经网络训练后的模型可以实现图像风格的迁移,比如前一阵非常火爆的吉卜力风格。本文尝试在 Android 实现 PyTorch 风格迁移模型的部署, 实现基于固定风格的迁移。
什么是图像风格迁移
图像风格迁移(Neural Style Transfer, NST) 是一种基于深度学习的图像处理技术,通过结合内容图像的结构和风格图像的艺术特征(如笔触、色彩),生成兼具两者特点的新图像。其核心思想是利用卷积神经网络(如VGG19)提取图像的高层语义特征:
- 内容保留:通过深层网络响应匹配内容图像的布局和物体轮廓。
- 风格融合:通过Gram矩阵统计风格图像的纹理、色彩分布等低层特征。
- 优化生成:以内容图像为初始输入,通过梯度下降最小化与目标内容和风格的损失函数,逐步合成新图像。
典型应用包括将照片转化为梵高、毕加索等艺术风格,或生成吉卜力动画风格的画面。技术变体包括实时风格迁移(如AdaIN)和基于扩散模型的生成方法。
基于 PyTorchLite 的风格迁移
在上一篇 PyTorch对抗生成网络模型及Android端的实现 中,我们通过使用 PyTorch Android 端的 SDK 实现了 GAN 在手机端的部署。但是由于算力和存储的限制,在手机上部署图像生成模型压力还是比较大的,存储和推理速度都是瓶颈。然而, 风格迁移的实现相对来说就比较简单了,算力和存储占用大幅降低的情况之下,也能收获相对来说不错的效果。


可以看到,以上两个示例的效果还是不错的。
PyTorch Lite 处理输入输出
我们先回顾一下 PyTorch Lite 的用法,上一节中对于基于 GAN 框架的生成模型来说,按照训练期间模型的定义,我们的输入是 1x100 的随机数,输出是 1x64x64x3 (64 像素大小彩色图片)。
kotlin
private fun genImage(): Bitmap {
val zDim = intArrayOf(1, 100)
val outDims = intArrayOf(64, 64, 3)
val z = FloatArray(zDim[0] * zDim[1])
val rand = Random()
// 生成高斯随机数
for (c in 0 until zDim[0] * zDim[1]) {
z[c] = rand.nextGaussian().toFloat()
}
val tensor = Tensor.fromBlob(z, longArrayOf(1, 100))
val resultArray = module.forward(IValue.from(tensor)).toTensor().dataAsFloatArray
val resultImg = Array(outDims[0]) { Array(outDims[1]) { FloatArray(outDims[2]) { 0.0f } } }
var index = 0
// 根据输出的一维数组,解析生成的卡通图像
....
val bitmap = Utils.getBitmap(resultImg, outDims)
return bitmap
}
因此,我们做的工作就是构建适用于模型输入结构的数据 inputTensor
,同时根据模型返回的数据 resultTensor
,将数据转换为我们需要的格式,比如对于生成模型来说,转化为对应平台可以渲染的图像数据即可,对于Android 来说就是常用的 Bitmap。
而对于风格迁移模型来说,一般情况下输入有两个,内容图片和风格图片。这里简单起见,风格类型由模型固化,我们只处理内容图片的输入,生成固定风格的图片。
模型初始化
这里我们直接使用 PyTorch 官方示例中 fast_neural_style 风格迁移所用到的模型。具体模型可以从 这里 下载,转换为 PyTorch Lite 可用的模型即可,我们以风格比较鲜明的 mosaic
为例。
kotlin
module = LiteModuleLoader.load(AndroidAssetsFileUtil.assetFilePath(this, "mosaic.pt"))
对 Bitmap 进行风格迁移
kotlin
fun transferStyleAsync(
contentImage: Bitmap, scale: Float = 1.0f, cb: ((Bitmap) -> Unit)? = null
): Bitmap {
// 1. Preprocess the content image (simple ToTensor + multiply by 255)
val (transformedImage, width, height) = preprocessImage(contentImage, scale)
// 2. Create input tensor
val inputTensor = Tensor.fromBlob(
transformedImage, longArrayOf(1, 3, height.toLong(), width.toLong())
)
// 3. Run the model
val outputTensor = module.forward(IValue.from(inputTensor)).toTensor()
// 4. Postprocessor the output
return postprocessingImage(outputTensor, width, height, cb)
}
HWC 和 CHW
这里需要注意的是,Android 中标准的 Bitmap 其数据是 RGB 的格式进行存储的,而 PyTorch 中为了方便内存优化进行计算,是按照 CHW 的格式排列数据的,因此将传统的 Bitmap 传入 PyTorch 进行处理之前需要进行数据转换。同时,由于数据数据的宽高涉及到后续推理模型的计算,因此还需要返回内容图片的宽高值。
返回结果的处理
有了原始 Bitmap 数据和宽高,我们就可以调用模型进行推理了。
kotlin
val inputTensor = Tensor.fromBlob(
transformedImage, longArrayOf(1, 3, height.toLong(), width.toLong())
)
val outputTensor = module.forward(IValue.from(inputTensor)).toTensor()
返回结果依然是 Tensor 类型,我们需要从中获取所需要的 Bitmap 图像数据
kotlin
private fun postprocessingImage(
outputTensor: Tensor, width: Int, height: Int): Bitmap {
val outputData = outputTensor.dataAsFloatArray
// Create output bitmap
val outputBitmap = createBitmap(width, height)
val pixels = IntArray(width * height)
// Convert from CHW to ARGB format
val channelSize = width * height
for (i in 0 until channelSize) {
// Get RGB values (scaled back from 0-255)
val r = outputData[i].toInt().coerceIn(0, 255)
val g = outputData[i + channelSize].toInt().coerceIn(0, 255)
val b = outputData[i + 2 * channelSize].toInt().coerceIn(0, 255)
// Combine into ARGB pixel
pixels[i] = 0xFF shl 24 or (r shl 16) or (g shl 8) or b
}
// Set pixels to bitmap
outputBitmap.setPixels(pixels, 0, width, 0, 0, width, height)
return outputBitmap
}
由于返回的 dataAsFloatArray
依然是 CHW 格式的数据,我们需要再执行一次逆向操作,将 CHW 格式的数据转换为 ARGB 格式的 Bitmap 数据。我们可以看一下效果(示例中最后一个就是 mosaic
风格)。

可以顺便看一眼耗时
java
16:49:39.471 StyleTransferProcessor D transferStyle() called with: contentImage = android.graphics.Bitmap@3e8f24, scale = 0.5
16:49:39.471 StyleTransferProcessor D preprocessImage() called with: bitmap = android.graphics.Bitmap@3e8f24, scale = 0.5
16:49:39.477 StyleTransferProcessor I scale 468,832
16:49:39.511 StyleTransferProcessor D transferStyle() step2 done
16:49:41.332 StyleTransferProcessor D transferStyle() step3 done
16:49:41.332 StyleTransferProcessor D postprocessingImage() called with: outputTensor = Tensor([1, 3, 832, 468], dtype=torch.float32), width = 468, height = 832
16:49:41.354 StyleTransferProcessor D channelSize = 389376
16:49:41.466 StyleTransferProcessor D bitmap is ok
16:50:21.199 StyleTransferProcessor D transferStyle() called with: contentImage = android.graphics.Bitmap@3e8f24, scale = 0.5
16:50:21.199 StyleTransferProcessor D preprocessImage() called with: bitmap = android.graphics.Bitmap@3e8f24, scale = 0.5
16:50:21.202 StyleTransferProcessor I scale 468,832
16:50:21.228 StyleTransferProcessor D transferStyle() step2 done
16:50:23.159 StyleTransferProcessor D transferStyle() step3 done
16:50:23.160 StyleTransferProcessor D postprocessingImage() called with: outputTensor = Tensor([1, 3, 832, 468], dtype=torch.float32), width = 468, height = 832
16:50:23.185 StyleTransferProcessor D channelSize = 389376
16:50:23.307 StyleTransferProcessor D bitmap is ok
16:50:31.999 StyleTransferProcessor D transferStyle() called with: contentImage = android.graphics.Bitmap@3e8f24, scale = 0.5
16:50:31.999 StyleTransferProcessor D preprocessImage() called with: bitmap = android.graphics.Bitmap@3e8f24, scale = 0.5
16:50:32.002 StyleTransferProcessor I scale 468,832
16:50:32.028 StyleTransferProcessor D transferStyle() step2 done
16:50:33.716 StyleTransferProcessor D transferStyle() step3 done
16:50:33.716 StyleTransferProcessor D postprocessingImage() called with: outputTensor = Tensor([1, 3, 832, 468], dtype=torch.float32), width = 468, height = 832
16:50:33.743 StyleTransferProcessor D channelSize = 389376
16:50:33.856 StyleTransferProcessor D bitmap is ok
可以看到,对原始图片宽高按照 50% 的比例压缩之后,转换时间还是挺快的,2 秒基本上就可以完成一副图片的转换(测试手机为一加 8, 骁龙 865 ,8GB RAM)
总结
有兴趣的话,可以运行官方示例 fast_neural_style,对比一下各自的效果。还可以基于官方示例,训练自己的风格模型,尝试实现不同的风格。
在手机端除了由于内存的限制,以及转换后模型精度的变化,多多少少还是有一些损失。甚至对于一些像素比较高的图片,不进行压缩的话会出现内存不足无法运行的情况。当然,风格迁移的实现方式有很多种,现在很多手机自带的相册和相机都可以实时进行固定风格的迁移,技术总是在不断变化的进步,相信随着这一波生成式人工智能的发展,风格迁移模型可以变得更加强大。