Java AI 开发完全教程

一、AI 开发环境搭建

1.1 基础开发环境

复制代码
<!-- Maven 依赖 -->
<dependencies>
    <!-- 1. 深度学习框架选择 -->
    <!-- Deeplearning4j (Java 原生) -->
    <dependency>
        <groupId>org.deeplearning4j</groupId>
        <artifactId>deeplearning4j-core</artifactId>
        <version>1.0.0-M2.1</version>
    </dependency>
    
    <!-- DJL (Deep Java Library) - 支持多引擎 -->
    <dependency>
        <groupId>ai.djl</groupId>
        <artifactId>api</artifactId>
        <version>0.23.0</version>
    </dependency>
    
    <!-- 2. 数据处理 -->
    <dependency>
        <groupId>org.datavec</groupId>
        <artifactId>datavec-api</artifactId>
        <version>1.0.0-M2.1</version>
    </dependency>
    
    <!-- 3. 数学计算 -->
    <dependency>
        <groupId>org.nd4j</groupId>
        <artifactId>nd4j-native-platform</artifactId>
        <version>1.0.0-M2.1</version>
    </dependency>
    
    <!-- 4. 机器学习库 -->
    <dependency>
        <groupId>org.apache.commons</groupId>
        <artifactId>commons-math3</artifactId>
        <version>3.6.1</version>
    </dependency>
</dependencies>

1.2 GPU 加速配置

复制代码
// 检查 GPU 支持
import org.nd4j.linalg.factory.Nd4j;

public class GPUCheck {
    public static void main(String[] args) {
        System.out.println("后端类型: " + Nd4j.getBackend().getClass().getName());
        System.out.println("CUDA 可用: " + Nd4j.getBackend().isCPU());
        
        // 设置使用 GPU
        System.setProperty("org.bytedeco.javacpp.logger.debug", "true");
        System.setProperty("org.bytedeco.cuda.cudart.version", "11.0");
    }
}

二、机器学习入门

2.1 线性回归示例

复制代码
import org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression;
import org.apache.commons.math3.stat.regression.SimpleRegression;

public class LinearRegressionExample {
    
    public static void main(String[] args) {
        // 1. 准备数据
        double[] x = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
        double[] y = {1.2, 1.8, 3.1, 3.8, 5.2, 5.9, 7.1, 7.8, 9.0, 9.7};
        
        // 2. 创建并训练模型
        SimpleRegression regression = new SimpleRegression();
        for (int i = 0; i < x.length; i++) {
            regression.addData(x[i], y[i]);
        }
        
        // 3. 获取结果
        double slope = regression.getSlope();
        double intercept = regression.getIntercept();
        double rSquare = regression.getRSquare();
        
        System.out.println("回归方程: y = " + slope + "x + " + intercept);
        System.out.println("R² = " + rSquare);
        
        // 4. 预测
        double prediction = regression.predict(11);
        System.out.println("x=11 时的预测值: " + prediction);
    }
}

2.2 决策树分类

复制代码
import weka.classifiers.trees.J48;
import weka.core.Instances;
import weka.core.converters.ConverterUtils.DataSource;

public class DecisionTreeExample {
    
    public static void main(String[] args) throws Exception {
        // 1. 加载数据集(使用 Weka 自带的 iris 数据集)
        DataSource source = new DataSource("src/main/resources/iris.arff");
        Instances data = source.getDataSet();
        
        if (data.classIndex() == -1) {
            data.setClassIndex(data.numAttributes() - 1);
        }
        
        // 2. 创建决策树模型
        String[] options = {"-C", "0.25", "-M", "2"}; // 剪枝参数
        J48 tree = new J48();
        tree.setOptions(options);
        
        // 3. 训练模型
        tree.buildClassifier(data);
        
        // 4. 输出模型
        System.out.println("决策树模型:");
        System.out.println(tree);
        
        // 5. 交叉验证评估
        weka.classifiers.Evaluation eval = new weka.classifiers.Evaluation(data);
        eval.crossValidateModel(tree, data, 10, new java.util.Random(1));
        
        System.out.println("\n评估结果:");
        System.out.println("准确率: " + eval.pctCorrect() + "%");
        System.out.println("混淆矩阵:\n" + eval.toMatrixString());
    }
}

三、深度学习入门

3.1 使用 Deeplearning4j 构建神经网络

复制代码
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.*;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.lossfunctions.LossFunctions;

public class NeuralNetworkExample {
    
    public static MultiLayerNetwork createSimpleNN(int numInput, int numOutput) {
        // 1. 网络配置
        MultiLayerConfiguration configuration = new NeuralNetConfiguration.Builder()
            .seed(123)  // 随机种子
            .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
            .updater(new Adam(0.001))  // 优化器
            .list()
            
            // 输入层
            .layer(0, new DenseLayer.Builder()
                .nIn(numInput)
                .nOut(128)
                .activation(Activation.RELU)
                .weightInit(WeightInit.XAVIER)
                .build())
            
            // 隐藏层
            .layer(1, new DenseLayer.Builder()
                .nIn(128)
                .nOut(64)
                .activation(Activation.RELU)
                .weightInit(WeightInit.XAVIER)
                .build())
            
            // 输出层
            .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
                .nIn(64)
                .nOut(numOutput)
                .activation(Activation.SOFTMAX)
                .weightInit(WeightInit.XAVIER)
                .build())
            
            .build();
        
        // 2. 创建网络
        MultiLayerNetwork model = new MultiLayerNetwork(configuration);
        model.init();
        
        // 3. 设置监听器
        model.setListeners(new ScoreIterationListener(100));
        
        return model;
    }
    
    public static void trainModel(MultiLayerNetwork model, DataSetIterator trainData, int epochs) {
        System.out.println("开始训练...");
        for (int i = 0; i < epochs; i++) {
            model.fit(trainData);
            System.out.println("Epoch " + (i + 1) + " 完成");
        }
    }
}

3.2 手写数字识别(MNIST)

复制代码
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;

public class MNISTExample {
    
    public static void main(String[] args) throws Exception {
        // 1. 加载 MNIST 数据集
        int batchSize = 64;
        int numClasses = 10;
        
        DataSetIterator mnistTrain = new MnistDataSetIterator(
            batchSize, true, 12345);
        DataSetIterator mnistTest = new MnistDataSetIterator(
            batchSize, false, 12345);
        
        // 2. 创建卷积神经网络
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
            .seed(12345)
            .updater(new Adam(0.001))
            .list()
            .layer(0, new ConvolutionLayer.Builder(5, 5)
                .nIn(1)  // 灰度图,1个通道
                .stride(1, 1)
                .nOut(20)
                .activation(Activation.IDENTITY)
                .build())
            .layer(1, new SubsamplingLayer.Builder(
                SubsamplingLayer.PoolingType.MAX)
                .kernelSize(2, 2)
                .stride(2, 2)
                .build())
            .layer(2, new ConvolutionLayer.Builder(5, 5)
                .stride(1, 1)
                .nOut(50)
                .activation(Activation.IDENTITY)
                .build())
            .layer(3, new SubsamplingLayer.Builder(
                SubsamplingLayer.PoolingType.MAX)
                .kernelSize(2, 2)
                .stride(2, 2)
                .build())
            .layer(4, new DenseLayer.Builder()
                .activation(Activation.RELU)
                .nOut(500)
                .build())
            .layer(5, new OutputLayer.Builder(
                LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
                .nOut(numClasses)
                .activation(Activation.SOFTMAX)
                .build())
            .setInputType(InputType.convolutionalFlat(28, 28, 1))
            .build();
        
        // 3. 训练模型
        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();
        model.setListeners(new ScoreIterationListener(100));
        
        int numEpochs = 5;
        for (int i = 0; i < numEpochs; i++) {
            model.fit(mnistTrain);
            System.out.println("Epoch " + (i + 1) + " 完成");
        }
        
        // 4. 评估模型
        Evaluation eval = model.evaluate(mnistTest);
        System.out.println(eval.stats());
        
        // 5. 保存模型
        model.save(new File("mnist-model.zip"), true);
    }
}

四、自然语言处理(NLP)

4.1 情感分析

复制代码
import opennlp.tools.doccat.*;
import opennlp.tools.util.*;
import opennlp.tools.util.model.ModelUtil;

import java.io.*;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;

public class SentimentAnalysis {
    
    public static DoccatModel trainSentimentModel() throws IOException {
        // 1. 准备训练数据
        InputStreamFactory dataIn = new InputStreamFactory() {
            @Override
            public InputStream createInputStream() throws IOException {
                String trainingData = 
                    "positive 我喜欢这个电影\n" +
                    "positive 这个产品很棒\n" +
                    "positive 服务很好\n" +
                    "negative 我不喜欢这个\n" +
                    "negative 质量很差\n" +
                    "negative 太糟糕了";
                return new ByteArrayInputStream(
                    trainingData.getBytes(StandardCharsets.UTF_8));
            }
        };
        
        // 2. 创建训练样本
        ObjectStream<String> lineStream = 
            new PlainTextByLineStream(dataIn, StandardCharsets.UTF_8);
        ObjectStream<DocumentSample> sampleStream = 
            new DocumentSampleStream(lineStream);
        
        // 3. 训练模型
        TrainingParameters params = ModelUtil.createDefaultTrainingParameters();
        params.put(TrainingParameters.CUTOFF_PARAM, 0);
        
        DoccatFactory factory = new DoccatFactory();
        DoccatModel model = DocumentCategorizerME.train(
            "zh", sampleStream, params, factory);
        
        return model;
    }
    
    public static void analyzeSentiment(String text, DoccatModel model) {
        DocumentCategorizerME categorizer = new DocumentCategorizerME(model);
        double[] outcomes = categorizer.categorize(text);
        String category = categorizer.getBestCategory(outcomes);
        
        System.out.println("文本: " + text);
        System.out.println("情感: " + category);
        System.out.println("置信度: " + outcomes[categorizer.getIndex(category)]);
        
        for (int i = 0; i < categorizer.getNumberOfCategories(); i++) {
            System.out.println(categorizer.getCategory(i) + ": " + outcomes[i]);
        }
    }
    
    public static void main(String[] args) throws IOException {
        DoccatModel model = trainSentimentModel();
        
        // 测试
        analyzeSentiment("这个电影太精彩了", model);
        analyzeSentiment("质量很差,不推荐", model);
    }
}

4.2 中文分词

复制代码
import com.hankcs.hanlp.HanLP;
import com.hankcs.hanlp.seg.common.Term;
import com.hankcs.hanlp.tokenizer.NLPTokenizer;
import com.hankcs.hanlp.tokenizer.StandardTokenizer;

import java.util.List;

public class ChineseNLPExample {
    
    public static void main(String[] args) {
        String text = "自然语言处理是人工智能的重要方向";
        
        // 1. 标准分词
        System.out.println("=== 标准分词 ===");
        List<Term> termList = StandardTokenizer.segment(text);
        for (Term term : termList) {
            System.out.printf("%s/%s ", term.word, term.nature);
        }
        System.out.println();
        
        // 2. NLP 分词
        System.out.println("\n=== NLP 分词 ===");
        List<Term> nlpList = NLPTokenizer.segment(text);
        for (Term term : nlpList) {
            System.out.printf("%s/%s ", term.word, term.nature);
        }
        System.out.println();
        
        // 3. 关键词提取
        System.out.println("\n=== 关键词提取 ===");
        List<String> keywords = HanLP.extractKeyword(text, 5);
        System.out.println("关键词: " + keywords);
        
        // 4. 文本摘要
        System.out.println("\n=== 文本摘要 ===");
        String document = "自然语言处理是人工智能的一个重要分支。" +
                         "它研究人与计算机之间用自然语言进行有效通信的理论和方法。" +
                         "目标是让计算机能够理解和生成人类语言。";
        List<String> summary = HanLP.extractSummary(document, 2);
        System.out.println("摘要: " + summary);
        
        // 5. 命名实体识别
        System.out.println("\n=== 命名实体识别 ===");
        String nerText = "张三在北京的清华大学读书";
        List<Term> nerResult = StandardTokenizer.segment(nerText);
        for (Term term : nerResult) {
            if (term.nature.toString().startsWith("nr") ||  // 人名
                term.nature.toString().startsWith("ns") ||  // 地名
                term.nature.toString().startsWith("nt")) {  // 机构名
                System.out.printf("[%s: %s] ", term.word, term.nature);
            }
        }
    }
}

五、计算机视觉

5.1 使用 DJL 进行图像分类

复制代码
import ai.djl.*;
import ai.djl.inference.*;
import ai.djl.modality.cv.*;
import ai.djl.modality.cv.util.*;
import ai.djl.modality.cv.transform.*;
import ai.djl.modality.cv.translator.*;
import ai.djl.repository.zoo.*;
import ai.djl.translate.*;
import ai.djl.training.util.*;

import java.nio.file.*;
import java.util.*;

public class ImageClassification {
    
    public static void main(String[] args) throws Exception {
        // 1. 加载预训练模型(ResNet50)
        String modelUrl = "https://resources.djl.ai/mlrepo/model/cv/" +
                         "image_classification/ai/djl/zoo/resnet/0.0.1/" +
                         "resnet50_v1.zip";
        
        Criteria<Image, Classifications> criteria = Criteria.builder()
            .setTypes(Image.class, Classifications.class)
            .optModelUrls(modelUrl)
            .optTranslator(ImageClassificationTranslator.builder()
                .addTransform(new Resize(224, 224))
                .addTransform(new ToTensor())
                .optApplySoftmax(true)
                .build())
            .optProgress(new ProgressBar())
            .build();
        
        // 2. 加载图片
        Path imagePath = Paths.get("src/main/resources/cat.jpg");
        Image img = ImageFactory.getInstance().fromFile(imagePath);
        
        // 3. 进行预测
        try (ZooModel<Image, Classifications> model = criteria.loadModel();
             Predictor<Image, Classifications> predictor = model.newPredictor()) {
            
            Classifications classifications = predictor.predict(img);
            
            // 4. 输出结果
            List<Classifications.Classification> items = classifications.topK(5);
            System.out.println("图像分类结果:");
            for (Classifications.Classification item : items) {
                System.out.printf("%s: %.2f%%\n", 
                    item.getClassName(), item.getProbability() * 100);
            }
        }
    }
}

5.2 目标检测

复制代码
import org.bytedeco.opencv.opencv_core.*;
import org.bytedeco.opencv.opencv_objdetect.*;
import org.bytedeco.javacv.*;
import org.bytedeco.javacpp.Loader;

import javax.swing.*;
import java.awt.image.BufferedImage;

import static org.bytedeco.opencv.global.opencv_core.*;
import static org.bytedeco.opencv.global.opencv_imgproc.*;
import static org.bytedeco.opencv.global.opencv_objdetect.*;
import static org.bytedeco.opencv.global.opencv_imgcodecs.*;

public class FaceDetection {
    
    public static void detectFaces(String imagePath) {
        // 1. 加载级联分类器
        CascadeClassifier faceDetector = new CascadeClassifier();
        faceDetector.load("haarcascade_frontalface_default.xml");
        
        // 2. 读取图片
        Mat image = imread(imagePath);
        Mat grayImage = new Mat();
        cvtColor(image, grayImage, COLOR_BGR2GRAY);
        
        // 3. 检测人脸
        RectVector faceDetections = new RectVector();
        faceDetector.detectMultiScale(grayImage, faceDetections);
        
        System.out.println("检测到 " + faceDetections.size() + " 张人脸");
        
        // 4. 标记人脸
        for (int i = 0; i < faceDetections.size(); i++) {
            Rect rect = faceDetections.get(i);
            rectangle(image, rect, new Scalar(0, 255, 0, 0), 2, 0, 0);
            
            // 添加标签
            putText(image, "Face " + (i + 1), 
                   new Point(rect.x(), rect.y() - 10), 
                   FONT_HERSHEY_SIMPLEX, 0.5, 
                   new Scalar(0, 255, 0, 0), 2);
        }
        
        // 5. 保存结果
        imwrite("output_with_faces.jpg", image);
        
        // 显示图片
        OpenCVFrameConverter.ToMat converter = new OpenCVFrameConverter.ToMat();
        CanvasFrame canvas = new CanvasFrame("人脸检测结果");
        canvas.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
        canvas.showImage(converter.convert(image));
    }
    
    public static void main(String[] args) {
        // 下载预训练模型
        // https://github.com/opencv/opencv/blob/master/data/haarcascades/haarcascade_frontalface_default.xml
        detectFaces("test_photo.jpg");
    }
}

六、强化学习

6.1 Q-Learning 示例

复制代码
import java.util.*;

public class QLearningExample {
    
    // 定义环境
    static class Environment {
        int[][] grid = {
            {0, 0, 0, 1},
            {0, -1, 0, -1},
            {0, 0, 0, 0}
        };
        
        int rows = grid.length;
        int cols = grid[0].length;
        int state = 0;  // 0-11
        int goal = 3;   // 目标位置
        
        public int reset() {
            state = 0;
            return state;
        }
        
        public StepResult step(int action) {
            int row = state / cols;
            int col = state % cols;
            
            switch (action) {
                case 0: // 上
                    row = Math.max(0, row - 1);
                    break;
                case 1: // 下
                    row = Math.min(rows - 1, row + 1);
                    break;
                case 2: // 左
                    col = Math.max(0, col - 1);
                    break;
                case 3: // 右
                    col = Math.min(cols - 1, col + 1);
                    break;
            }
            
            int newState = row * cols + col;
            int reward = grid[row][col];
            boolean done = (newState == goal);
            
            if (reward == -1) {  // 碰到障碍,回到起点
                newState = 0;
                reward = -10;
            } else if (done) {
                reward = 100;
            } else {
                reward = -1;  // 每走一步的代价
            }
            
            state = newState;
            return new StepResult(newState, reward, done);
        }
        
        static class StepResult {
            int state;
            int reward;
            boolean done;
            
            StepResult(int state, int reward, boolean done) {
                this.state = state;
                this.reward = reward;
                this.done = done;
            }
        }
    }
    
    public static void main(String[] args) {
        Environment env = new Environment();
        int numStates = env.rows * env.cols;
        int numActions = 4;  // 上下左右
        
        // 初始化 Q 表
        double[][] Q = new double[numStates][numActions];
        for (int i = 0; i < numStates; i++) {
            Arrays.fill(Q[i], 0.0);
        }
        
        // 超参数
        double alpha = 0.1;     // 学习率
        double gamma = 0.9;     // 折扣因子
        double epsilon = 0.1;   // 探索率
        int episodes = 1000;
        
        Random random = new Random(42);
        
        // Q-Learning 训练
        for (int episode = 0; episode < episodes; episode++) {
            int state = env.reset();
            boolean done = false;
            
            while (!done) {
                // 选择动作(ε-greedy)
                int action;
                if (random.nextDouble() < epsilon) {
                    action = random.nextInt(numActions);  // 探索
                } else {
                    // 利用:选择 Q 值最大的动作
                    action = 0;
                    for (int a = 1; a < numActions; a++) {
                        if (Q[state][a] > Q[state][action]) {
                            action = a;
                        }
                    }
                }
                
                // 执行动作
                Environment.StepResult result = env.step(action);
                
                // 更新 Q 值
                int nextState = result.state;
                double reward = result.reward;
                
                // 找到下一个状态的最大 Q 值
                double maxNextQ = 0;
                for (int a = 0; a < numActions; a++) {
                    maxNextQ = Math.max(maxNextQ, Q[nextState][a]);
                }
                
                // Q-Learning 更新公式
                Q[state][action] = Q[state][action] + 
                                   alpha * (reward + gamma * maxNextQ - Q[state][action]);
                
                state = nextState;
                done = result.done;
            }
        }
        
        // 测试训练结果
        System.out.println("训练后的 Q 表:");
        for (int s = 0; s < numStates; s++) {
            System.out.printf("状态 %2d: ", s);
            for (int a = 0; a < numActions; a++) {
                System.out.printf("%6.2f ", Q[s][a]);
            }
            System.out.println();
        }
        
        // 找到最优路径
        System.out.println("\n最优路径:");
        int currentState = env.reset();
        List<Integer> path = new ArrayList<>();
        
        for (int step = 0; step < 20; step++) {
            path.add(currentState);
            
            // 选择最优动作
            int bestAction = 0;
            for (int a = 1; a < numActions; a++) {
                if (Q[currentState][a] > Q[currentState][bestAction]) {
                    bestAction = a;
                }
            }
            
            Environment.StepResult result = env.step(bestAction);
            currentState = result.state;
            
            if (result.done) {
                path.add(currentState);
                break;
            }
        }
        
        System.out.println("路径: " + path);
    }
}

七、模型部署

7.1 Spring Boot 部署 AI 服务

复制代码
// 1. 模型服务类
@Service
public class AIService {
    
    private MultiLayerNetwork model;
    private ZooModel<Image, Classifications> imageModel;
    
    @PostConstruct
    public void init() throws Exception {
        // 加载模型
        loadModels();
    }
    
    private void loadModels() throws Exception {
        // 加载文本分类模型
        model = MultiLayerNetwork.load(
            new File("model.bin"), true);
        
        // 加载图像分类模型
        Criteria<Image, Classifications> criteria = Criteria.builder()
            .setTypes(Image.class, Classifications.class)
            .optModelPath(Paths.get("resnet50"))
            .build();
        imageModel = criteria.loadModel();
    }
    
    public PredictionResult predict(InputData input) {
        // 使用模型进行预测
        INDArray features = convertToINDArray(input.getFeatures());
        INDArray output = model.output(features);
        
        return new PredictionResult(output);
    }
    
    public Map<String, Double> classifyImage(MultipartFile file) throws Exception {
        Image img = ImageFactory.getInstance()
            .fromInputStream(file.getInputStream());
        
        try (Predictor<Image, Classifications> predictor = 
             imageModel.newPredictor()) {
            
            Classifications classifications = predictor.predict(img);
            Map<String, Double> result = new HashMap<>();
            
            for (Classifications.Classification c : classifications.topK(3)) {
                result.put(c.getClassName(), c.getProbability());
            }
            
            return result;
        }
    }
}

// 2. REST 控制器
@RestController
@RequestMapping("/api/ai")
@CrossOrigin(origins = "*")
public class AIController {
    
    @Autowired
    private AIService aiService;
    
    @PostMapping("/predict")
    public ResponseEntity<PredictionResult> predict(
            @RequestBody PredictionRequest request) {
        PredictionResult result = aiService.predict(request.getData());
        return ResponseEntity.ok(result);
    }
    
    @PostMapping("/classify")
    public ResponseEntity<Map<String, Double>> classifyImage(
            @RequestParam("file") MultipartFile file) {
        try {
            Map<String, Double> result = aiService.classifyImage(file);
            return ResponseEntity.ok(result);
        } catch (Exception e) {
            return ResponseEntity.status(500).build();
        }
    }
    
    @PostMapping("/sentiment")
    public ResponseEntity<SentimentResult> analyzeSentiment(
            @RequestBody String text) {
        // 情感分析接口
        return ResponseEntity.ok(aiService.analyzeSentiment(text));
    }
}

// 3. 配置类
@Configuration
public class AIConfig {
    
    @Bean
    public ThreadPoolTaskExecutor aiTaskExecutor() {
        ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor();
        executor.setCorePoolSize(4);
        executor.setMaxPoolSize(8);
        executor.setQueueCapacity(100);
        executor.setThreadNamePrefix("ai-thread-");
        executor.initialize();
        return executor;
    }
}

7.2 模型性能优化

复制代码
public class ModelOptimizer {
    
    // 1. 模型量化
    public static void quantizeModel(MultiLayerNetwork model) {
        // 将浮点模型转换为整数模型
        ComputationGraph graph = model.toComputationGraph();
        
        // 设置量化配置
        GraphOptimizationConfig config = new GraphOptimizationConfig();
        config.setQuantizationConfig(QuantizationConfig.builder()
            .quantize(true)
            .targetDataType(DataType.INT8)
            .build());
        
        graph.setOptimizationConfig(config);
    }
    
    // 2. 模型剪枝
    public static MultiLayerNetwork pruneModel(
            MultiLayerNetwork model, 
            double pruningRate) {
        
        INDArray params = model.params();
        INDArray absParams = params.dup().abs();
        
        // 计算阈值
        double threshold = absParams.percentile(pruningRate * 100);
        
        // 将小于阈值的权重设为0
        for (int i = 0; i < params.length(); i++) {
            if (Math.abs(params.getDouble(i)) < threshold) {
                params.putScalar(i, 0.0);
            }
        }
        
        return model;
    }
    
    // 3. 模型蒸馏
    public static MultiLayerNetwork knowledgeDistillation(
            MultiLayerNetwork teacherModel,
            MultiLayerNetwork studentModel,
            DataSetIterator trainData,
            double temperature,
            double alpha) {
        
        int epochs = 10;
        for (int epoch = 0; epoch < epochs; epoch++) {
            while (trainData.hasNext()) {
                DataSet batch = trainData.next();
                
                // 教师模型预测
                INDArray teacherOutput = teacherModel.output(batch.getFeatures());
                
                // 学生模型预测
                INDArray studentOutput = studentModel.output(batch.getFeatures());
                
                // 计算蒸馏损失
                INDArray distillationLoss = computeDistillationLoss(
                    teacherOutput, studentOutput, temperature);
                INDArray studentLoss = computeStudentLoss(
                    studentOutput, batch.getLabels());
                
                // 组合损失
                INDArray totalLoss = distillationLoss.mul(alpha)
                    .add(studentLoss.mul(1 - alpha));
                
                // 反向传播
                studentModel.setInput(batch.getFeatures());
                studentModel.setLabels(batch.getLabels());
                studentModel.fit();
            }
            trainData.reset();
        }
        
        return studentModel;
    }
}

八、实用工具和技巧

8.1 数据预处理工具类

复制代码
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.datavec.api.transform.TransformProcess;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.condition.ConditionOp;
import org.datavec.api.transform.condition.column.DoubleColumnCondition;
import org.datavec.api.transform.filter.Filter;
import org.datavec.api.transform.filter.ConditionFilter;
import org.datavec.local.transforms.LocalTransformExecutor;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;

import java.io.File;
import java.util.Arrays;
import java.util.List;

public class DataPreprocessor {
    
    // 1. 数据标准化
    public static INDArray normalize(INDArray data) {
        INDArray mean = data.mean(0);
        INDArray std = data.std(0).add(1e-8);  // 避免除零
        
        return data.subRowVector(mean).divRowVector(std);
    }
    
    // 2. 独热编码
    public static INDArray oneHotEncode(int[] labels, int numClasses) {
        INDArray encoded = Nd4j.zeros(labels.length, numClasses);
        
        for (int i = 0; i < labels.length; i++) {
            encoded.putScalar(new int[]{i, labels[i]}, 1.0);
        }
        
        return encoded;
    }
    
    // 3. 数据增强
    public static INDArray augmentImage(INDArray image, double rotation, 
                                        double zoom, boolean flip) {
        int height = (int) image.size(1);
        int width = (int) image.size(2);
        int channels = (int) image.size(3);
        
        INDArray augmented = image.dup();
        
        // 旋转
        if (rotation != 0) {
            // 实现旋转逻辑
        }
        
        // 缩放
        if (zoom != 0) {
            // 实现缩放逻辑
        }
        
        // 翻转
        if (flip) {
            augmented = Nd4j.reverse(augmented, 2);  // 水平翻转
        }
        
        return augmented;
    }
    
    // 4. 处理缺失值
    public static INDArray handleMissingValues(INDArray data) {
        INDArray result = data.dup();
        long[] shape = data.shape();
        
        for (int i = 0; i < shape[0]; i++) {
            for (int j = 0; j < shape[1]; j++) {
                if (Double.isNaN(data.getDouble(i, j))) {
                    // 用列均值填充缺失值
                    double mean = data.getColumn(j).meanNumber().doubleValue();
                    result.putScalar(i, j, mean);
                }
            }
        }
        
        return result;
    }
}

九、学习资源推荐

9.1 学习路径

  1. 基础阶段(1-2个月)

    • Java 基础

    • 线性代数和概率论

    • 机器学习基础

  2. 中级阶段(2-3个月)

    • 深度学习理论

    • Deeplearning4j/DJL 框架

    • 完成2-3个项目实践

  3. 高级阶段(3-6个月)

    • 模型优化和部署

    • 大规模数据处理

    • 参与开源项目

9.2 推荐资源

  • 书籍:《Java深度学习》、《Deep Learning with Java》

  • 在线课程:Coursera 的吴恩达机器学习课程

  • GitHub 项目

    • Deeplearning4j 官方示例

    • DJL 示例代码库

    • Java-ML 项目

9.3 项目实践建议

  1. 从简单的线性回归开始

  2. 尝试 MNIST 手写数字识别

  3. 实现一个情感分析系统

  4. 构建图像分类应用

  5. 部署为 Web 服务

十、常见问题

10.1 性能问题

复制代码
// 内存优化
public class MemoryOptimizer {
    public static void optimize() {
        // 1. 使用混合精度训练
        Nd4j.setDataType(DataType.HALF);
        
        // 2. 垃圾回收优化
        System.gc();
        
        // 3. 批处理
        int batchSize = 32;  // 根据 GPU 内存调整
        
        // 4. 使用内存映射文件处理大数据
        // DataSetIterator iterator = new AsyncDataSetIterator(
        //     new MnistDataSetIterator(batchSize, true, 12345), 2);
    }
}

10.2 调试技巧

复制代码
public class DebugUtils {
    
    // 1. 检查梯度
    public static void checkGradients(MultiLayerNetwork model, 
                                      DataSet dataSet) {
        Gradient gradient = model.gradient();
        System.out.println("梯度范数: " + gradient.gradientForVariable().norm2Number());
    }
    
    // 2. 监控训练过程
    public static void addTrainingListeners(MultiLayerNetwork model) {
        model.setListeners(
            new ScoreIterationListener(100),  // 每100次迭代输出损失
            new EvaluativeListener(
                new MnistDataSetIterator(100, false, 12345), 
                1,  // 每1个epoch评估一次
                InvocationType.EPOCH_END
            ),
            new HistogramIterationListener(100)  // 直方图
        );
    }
    
    // 3. 保存和加载检查点
    public static void saveCheckpoint(MultiLayerNetwork model, 
                                      String path, 
                                      int iteration) {
        File file = new File(path + "/checkpoint_iter" + iteration + ".zip");
        model.save(file, true);
    }
}
相关推荐
Ztopcloud极拓云视角2 小时前
实战:GPT-6 + Gemma 4 端云混合 AI 调用架构设计
大数据·人工智能·gpt
测绘第一深情2 小时前
MapQR:自动驾驶在线矢量化高精地图构建的端到端 SOTA 方法
数据结构·人工智能·python·神经网络·算法·机器学习·自动驾驶
Magic--2 小时前
C++ 智能指针
开发语言·c++·算法
墨雪遗痕2 小时前
工程架构认知(三):从传统Web系统到AI大模型驱动系统
前端·人工智能·架构
高洁012 小时前
AI算法实战:逻辑回归在风控场景中的应用
人工智能·python·深度学习·transformer
Timer@2 小时前
LangChain 教程 05|模型配置:AI 的大脑与推理引擎
人工智能·算法·langchain
sali-tec2 小时前
C# 基于OpenCv的视觉工作流-章50-霍夫找圆
图像处理·人工智能·opencv·算法·计算机视觉
_童年的回忆_2 小时前
【Java】宝塔下安装Adoptium Temurin (免费JDK)
java·开发语言
想带你从多云到转晴2 小时前
04、数据结构与算法---双向链表
java·数据结构·算法·链表