deepxde 源码阅读笔记(长期更新)

2023.11.23

阅读的deepxde version: 1.9.0

1. train_aux_vars ,即pde当中的第三个参数

这个变量的含义困惑很久。最后发现就是operator learning的PDEs方程parameters。

脉络:def pde aux_vars->

deepxde目前支持tf1最多,但是对其他框架也有支持,仓库的主要维护者Lu Lu跟百度应该有合作,目前对paddlepaddle的支持正在提升。因此,实际上对于很多函数,作者写了4遍,每个框架都写一遍,基本同样的功能,但里面的细节不同。

pdeoperator.py 中的PDEoperator坐标类的train_next_batch方法

(注意:我这里命名是按着记忆直接写的,不是源码中的真实名称)

python 复制代码
    def train_next_batch(self, batch_size=None):
        if self.train_x is None:
            func_feats = self.func_space.random(self.num_func)
            func_vals = self.func_space.eval_batch(func_feats, self.eval_pts)
            vx = self.func_space.eval_batch(
                func_feats, self.pde.train_x[:, self.func_vars]
            )
            self.train_x = (func_vals, self.pde.train_x)
            self.train_aux_vars = vx

        if self.batch_size is None:
            return self.train_x, self.train_y, self.train_aux_vars

        indices = self.train_sampler.get_next(self.batch_size)
        traix_x = (self.train_x[0][indices], self.train_x[1])
        return traix_x, self.train_y, self.train_aux_vars[indices]

model.py

python 复制代码
def set_data_train(self, X_train, y_train, train_aux_vars=None):
        self.X_train = X_train
        self.y_train = y_train
        self.train_aux_vars = train_aux_vars
python 复制代码
            self.train_state.set_data_train(
                *self.data.train_next_batch(self.batch_size)
            )

根据这些代码段顺藤摸瓜确定了train_aux_vars的定位:Parameters of PDEs.

python 复制代码
        def outputs_losses(training, inputs, targets, auxiliary_vars, losses_fn):
            self.net.auxiliary_vars = auxiliary_vars
            # Don't call outputs() decorated by @tf.function above, otherwise the
            # gradient of outputs wrt inputs will be lost here.
            outputs_ = self.net(inputs, training=training)
            # Data losses
            losses = losses_fn(targets, outputs_, loss_fn, inputs, self)
            if not isinstance(losses, list):
                losses = [losses]
            # Regularization loss
            if self.net.regularizer is not None:
                losses += [tf.math.reduce_sum(self.net.losses)]
            losses = tf.convert_to_tensor(losses)
            # Weighted losses
            if loss_weights is not None:
                losses *= loss_weights
            return outputs_, losses

值得一提的是作者从头到尾都在使用 TrainState 来储存训练当中的各种变量。

python 复制代码
class TrainState:
    def __init__(self):
        self.epoch = 0
        self.step = 0

        # Current data
        self.X_train = None
        self.y_train = None
        self.train_aux_vars = None
        self.X_test = None
        self.y_test = None
        self.test_aux_vars = None

        # Results of current step
        # Train results
        self.loss_train = None
        self.y_pred_train = None
        # Test results
        self.loss_test = None
        self.y_pred_test = None
        self.y_std_test = None
        self.metrics_test = None

        # The best results correspond to the min train loss
        self.best_step = 0
        self.best_loss_train = np.inf
        self.best_loss_test = np.inf
        self.best_y = None
        self.best_ystd = None
        self.best_metrics = None

    def set_data_train(self, X_train, y_train, train_aux_vars=None):
        self.X_train = X_train
        self.y_train = y_train
        self.train_aux_vars = train_aux_vars

    def set_data_test(self, X_test, y_test, test_aux_vars=None):
        self.X_test = X_test
        self.y_test = y_test
        self.test_aux_vars = test_aux_vars

    def update_best(self):
        if self.best_loss_train > np.sum(self.loss_train):
            self.best_step = self.step
            self.best_loss_train = np.sum(self.loss_train)
            self.best_loss_test = np.sum(self.loss_test)
            self.best_y = self.y_pred_test
            self.best_ystd = self.y_std_test
            self.best_metrics = self.metrics_test

    def disregard_best(self):
        self.best_loss_train = np.inf

由train state可见pde前两个变量的含义就是偏微分方程组的坐标输入和解输出

deepxde采用LGPL,是比较激进的一类开源许可证。激进的开源是一种商业策略,参见我写的另一篇文章

2023-11-27

在物理启发时,compile不要加metrics. 因为物理启发不会用标签。

cpp 复制代码
model.compile(
        'adam',
        lr=1e-4,
        decay=("inverse time", 1, 1e-4),
        # metrics=["mean l2 relative error"],
    )

到源码中会发现

Model.py

cpp 复制代码
    @utils.timing
    def train(
        self,
        iterations=None,
        batch_size=None,
        display_every=1000,
        disregard_previous_best=False,
        callbacks=None,
        model_restore_path=None,
        model_save_path=None,
        epochs=None,
    ):
        """Trains the model.

        Args:
            iterations (Integer): Number of iterations to train the model, i.e., number
                of times the network weights are updated.
            batch_size: Integer, tuple, or ``None``.

                - If you solve PDEs via ``dde.data.PDE`` or ``dde.data.TimePDE``, do not use `batch_size`, and instead use
                  `dde.callbacks.PDEPointResampler
                  <https://deepxde.readthedocs.io/en/latest/modules/deepxde.html#deepxde.callbacks.PDEPointResampler>`_,
                  see an `example <https://github.com/lululxvi/deepxde/blob/master/examples/diffusion_1d_resample.py>`_.
                - For DeepONet in the format of Cartesian product, if `batch_size` is an Integer,
                  then it is the batch size for the branch input; if you want to also use mini-batch for the trunk net input,
                  set `batch_size` as a tuple, where the fist number is the batch size for the branch net input
                  and the second number is the batch size for the trunk net input.
            display_every (Integer): Print the loss and metrics every this steps.
            disregard_previous_best: If ``True``, disregard the previous saved best
                model.
            callbacks: List of ``dde.callbacks.Callback`` instances. List of callbacks
                to apply during training.
            model_restore_path (String): Path where parameters were previously saved.
            model_save_path (String): Prefix of filenames created for the checkpoint.
            epochs (Integer): Deprecated alias to `iterations`. This will be removed in
                a future version.
        """
        if iterations is None and epochs is not None:
            print(
                "Warning: epochs is deprecated and will be removed in a future version."
                " Use iterations instead."
            )
            iterations = epochs
        self.batch_size = batch_size
        self.callbacks = CallbackList(callbacks=callbacks)
        self.callbacks.set_model(self)
        if disregard_previous_best:
            self.train_state.disregard_best()

        if backend_name == "tensorflow.compat.v1":
            if self.train_state.step == 0:
                self.sess.run(tf.global_variables_initializer())
                if config.hvd is not None:
                    bcast = config.hvd.broadcast_global_variables(0)
                    self.sess.run(bcast)
            else:
                utils.guarantee_initialized_variables(self.sess)

        if model_restore_path is not None:
            self.restore(model_restore_path, verbose=1)

        if config.rank == 0:
            print("Training model...\n")
        self.stop_training = False
        self.train_state.set_data_train(*self.data.train_next_batch(self.batch_size))
        self.train_state.set_data_test(*self.data.test())
        self._test()
        self.callbacks.on_train_begin()
        if optimizers.is_external_optimizer(self.opt_name):
            if backend_name == "tensorflow.compat.v1":
                self._train_tensorflow_compat_v1_scipy(display_every)
            elif backend_name == "tensorflow":
                self._train_tensorflow_tfp()
            elif backend_name == "pytorch":
                self._train_pytorch_lbfgs()
            elif backend_name == "paddle":
                self._train_paddle_lbfgs()
        else:
            if iterations is None:
                raise ValueError("No iterations for {}.".format(self.opt_name))
            self._train_sgd(iterations, display_every)
        self.callbacks.on_train_end()

        if config.rank == 0:
            print("")
            display.training_display.summary(self.train_state)
        if model_save_path is not None:
            self.save(model_save_path, verbose=1)
        return self.losshistory, self.train_state

    def _train_sgd(self, iterations, display_every):
        for i in range(iterations):
            self.callbacks.on_epoch_begin()
            self.callbacks.on_batch_begin()

            self.train_state.set_data_train(
                *self.data.train_next_batch(self.batch_size)
            )
            self._train_step(
                self.train_state.X_train,
                self.train_state.y_train,
                self.train_state.train_aux_vars,
            )

            self.train_state.epoch += 1
            self.train_state.step += 1
            if self.train_state.step % display_every == 0 or i + 1 == iterations:
                self._test()

            self.callbacks.on_batch_end()
            self.callbacks.on_epoch_end()

            if self.stop_training:
                break

里面经常出现测试代码:

self._test()

相关推荐
Aileen_0v05 小时前
【AI驱动的数据结构:包装类的艺术与科学】
linux·数据结构·人工智能·笔记·网络协议·tcp/ip·whisper
Rinai_R7 小时前
计算机组成原理的学习笔记(7)-- 存储器·其二 容量扩展/多模块存储系统/外存/Cache/虚拟存储器
笔记·物联网·学习
吃着火锅x唱着歌7 小时前
PHP7内核剖析 学习笔记 第四章 内存管理(1)
android·笔记·学习
ragnwang7 小时前
C++ Eigen常见的高级用法 [学习笔记]
c++·笔记·学习
胡西风_foxww8 小时前
【es6复习笔记】rest参数(7)
前端·笔记·es6·参数·rest
胡西风_foxww11 小时前
【es6复习笔记】函数参数的默认值(6)
javascript·笔记·es6·参数·函数·默认值
胡西风_foxww11 小时前
【es6复习笔记】生成器(11)
javascript·笔记·es6·实例·生成器·函数·gen
waterme1onY11 小时前
Spring AOP 中记录日志
java·开发语言·笔记·后端
2401_8791036811 小时前
24.12.25 AOP
java·开发语言·笔记
索然无味io12 小时前
跨站请求伪造之基本介绍
前端·笔记·学习·web安全·网络安全·php