deepxde 源码阅读笔记(长期更新)

2023.11.23

阅读的deepxde version: 1.9.0

1. train_aux_vars ,即pde当中的第三个参数

这个变量的含义困惑很久。最后发现就是operator learning的PDEs方程parameters。

脉络:def pde aux_vars->

deepxde目前支持tf1最多,但是对其他框架也有支持,仓库的主要维护者Lu Lu跟百度应该有合作,目前对paddlepaddle的支持正在提升。因此,实际上对于很多函数,作者写了4遍,每个框架都写一遍,基本同样的功能,但里面的细节不同。

pdeoperator.py 中的PDEoperator坐标类的train_next_batch方法

(注意:我这里命名是按着记忆直接写的,不是源码中的真实名称)

python 复制代码
    def train_next_batch(self, batch_size=None):
        if self.train_x is None:
            func_feats = self.func_space.random(self.num_func)
            func_vals = self.func_space.eval_batch(func_feats, self.eval_pts)
            vx = self.func_space.eval_batch(
                func_feats, self.pde.train_x[:, self.func_vars]
            )
            self.train_x = (func_vals, self.pde.train_x)
            self.train_aux_vars = vx

        if self.batch_size is None:
            return self.train_x, self.train_y, self.train_aux_vars

        indices = self.train_sampler.get_next(self.batch_size)
        traix_x = (self.train_x[0][indices], self.train_x[1])
        return traix_x, self.train_y, self.train_aux_vars[indices]

model.py

python 复制代码
def set_data_train(self, X_train, y_train, train_aux_vars=None):
        self.X_train = X_train
        self.y_train = y_train
        self.train_aux_vars = train_aux_vars
python 复制代码
            self.train_state.set_data_train(
                *self.data.train_next_batch(self.batch_size)
            )

根据这些代码段顺藤摸瓜确定了train_aux_vars的定位:Parameters of PDEs.

python 复制代码
        def outputs_losses(training, inputs, targets, auxiliary_vars, losses_fn):
            self.net.auxiliary_vars = auxiliary_vars
            # Don't call outputs() decorated by @tf.function above, otherwise the
            # gradient of outputs wrt inputs will be lost here.
            outputs_ = self.net(inputs, training=training)
            # Data losses
            losses = losses_fn(targets, outputs_, loss_fn, inputs, self)
            if not isinstance(losses, list):
                losses = [losses]
            # Regularization loss
            if self.net.regularizer is not None:
                losses += [tf.math.reduce_sum(self.net.losses)]
            losses = tf.convert_to_tensor(losses)
            # Weighted losses
            if loss_weights is not None:
                losses *= loss_weights
            return outputs_, losses

值得一提的是作者从头到尾都在使用 TrainState 来储存训练当中的各种变量。

python 复制代码
class TrainState:
    def __init__(self):
        self.epoch = 0
        self.step = 0

        # Current data
        self.X_train = None
        self.y_train = None
        self.train_aux_vars = None
        self.X_test = None
        self.y_test = None
        self.test_aux_vars = None

        # Results of current step
        # Train results
        self.loss_train = None
        self.y_pred_train = None
        # Test results
        self.loss_test = None
        self.y_pred_test = None
        self.y_std_test = None
        self.metrics_test = None

        # The best results correspond to the min train loss
        self.best_step = 0
        self.best_loss_train = np.inf
        self.best_loss_test = np.inf
        self.best_y = None
        self.best_ystd = None
        self.best_metrics = None

    def set_data_train(self, X_train, y_train, train_aux_vars=None):
        self.X_train = X_train
        self.y_train = y_train
        self.train_aux_vars = train_aux_vars

    def set_data_test(self, X_test, y_test, test_aux_vars=None):
        self.X_test = X_test
        self.y_test = y_test
        self.test_aux_vars = test_aux_vars

    def update_best(self):
        if self.best_loss_train > np.sum(self.loss_train):
            self.best_step = self.step
            self.best_loss_train = np.sum(self.loss_train)
            self.best_loss_test = np.sum(self.loss_test)
            self.best_y = self.y_pred_test
            self.best_ystd = self.y_std_test
            self.best_metrics = self.metrics_test

    def disregard_best(self):
        self.best_loss_train = np.inf

由train state可见pde前两个变量的含义就是偏微分方程组的坐标输入和解输出

deepxde采用LGPL,是比较激进的一类开源许可证。激进的开源是一种商业策略,参见我写的另一篇文章

2023-11-27

在物理启发时,compile不要加metrics. 因为物理启发不会用标签。

cpp 复制代码
model.compile(
        'adam',
        lr=1e-4,
        decay=("inverse time", 1, 1e-4),
        # metrics=["mean l2 relative error"],
    )

到源码中会发现

Model.py

cpp 复制代码
    @utils.timing
    def train(
        self,
        iterations=None,
        batch_size=None,
        display_every=1000,
        disregard_previous_best=False,
        callbacks=None,
        model_restore_path=None,
        model_save_path=None,
        epochs=None,
    ):
        """Trains the model.

        Args:
            iterations (Integer): Number of iterations to train the model, i.e., number
                of times the network weights are updated.
            batch_size: Integer, tuple, or ``None``.

                - If you solve PDEs via ``dde.data.PDE`` or ``dde.data.TimePDE``, do not use `batch_size`, and instead use
                  `dde.callbacks.PDEPointResampler
                  <https://deepxde.readthedocs.io/en/latest/modules/deepxde.html#deepxde.callbacks.PDEPointResampler>`_,
                  see an `example <https://github.com/lululxvi/deepxde/blob/master/examples/diffusion_1d_resample.py>`_.
                - For DeepONet in the format of Cartesian product, if `batch_size` is an Integer,
                  then it is the batch size for the branch input; if you want to also use mini-batch for the trunk net input,
                  set `batch_size` as a tuple, where the fist number is the batch size for the branch net input
                  and the second number is the batch size for the trunk net input.
            display_every (Integer): Print the loss and metrics every this steps.
            disregard_previous_best: If ``True``, disregard the previous saved best
                model.
            callbacks: List of ``dde.callbacks.Callback`` instances. List of callbacks
                to apply during training.
            model_restore_path (String): Path where parameters were previously saved.
            model_save_path (String): Prefix of filenames created for the checkpoint.
            epochs (Integer): Deprecated alias to `iterations`. This will be removed in
                a future version.
        """
        if iterations is None and epochs is not None:
            print(
                "Warning: epochs is deprecated and will be removed in a future version."
                " Use iterations instead."
            )
            iterations = epochs
        self.batch_size = batch_size
        self.callbacks = CallbackList(callbacks=callbacks)
        self.callbacks.set_model(self)
        if disregard_previous_best:
            self.train_state.disregard_best()

        if backend_name == "tensorflow.compat.v1":
            if self.train_state.step == 0:
                self.sess.run(tf.global_variables_initializer())
                if config.hvd is not None:
                    bcast = config.hvd.broadcast_global_variables(0)
                    self.sess.run(bcast)
            else:
                utils.guarantee_initialized_variables(self.sess)

        if model_restore_path is not None:
            self.restore(model_restore_path, verbose=1)

        if config.rank == 0:
            print("Training model...\n")
        self.stop_training = False
        self.train_state.set_data_train(*self.data.train_next_batch(self.batch_size))
        self.train_state.set_data_test(*self.data.test())
        self._test()
        self.callbacks.on_train_begin()
        if optimizers.is_external_optimizer(self.opt_name):
            if backend_name == "tensorflow.compat.v1":
                self._train_tensorflow_compat_v1_scipy(display_every)
            elif backend_name == "tensorflow":
                self._train_tensorflow_tfp()
            elif backend_name == "pytorch":
                self._train_pytorch_lbfgs()
            elif backend_name == "paddle":
                self._train_paddle_lbfgs()
        else:
            if iterations is None:
                raise ValueError("No iterations for {}.".format(self.opt_name))
            self._train_sgd(iterations, display_every)
        self.callbacks.on_train_end()

        if config.rank == 0:
            print("")
            display.training_display.summary(self.train_state)
        if model_save_path is not None:
            self.save(model_save_path, verbose=1)
        return self.losshistory, self.train_state

    def _train_sgd(self, iterations, display_every):
        for i in range(iterations):
            self.callbacks.on_epoch_begin()
            self.callbacks.on_batch_begin()

            self.train_state.set_data_train(
                *self.data.train_next_batch(self.batch_size)
            )
            self._train_step(
                self.train_state.X_train,
                self.train_state.y_train,
                self.train_state.train_aux_vars,
            )

            self.train_state.epoch += 1
            self.train_state.step += 1
            if self.train_state.step % display_every == 0 or i + 1 == iterations:
                self._test()

            self.callbacks.on_batch_end()
            self.callbacks.on_epoch_end()

            if self.stop_training:
                break

里面经常出现测试代码:

self._test()

相关推荐
新晓·故知27 分钟前
<基于递归实现线索二叉树的构造及遍历算法探讨>
数据结构·经验分享·笔记·算法·链表
魔理沙偷走了BUG39 分钟前
【数学分析笔记】第4章第4节 复合函数求导法则及其应用(3)
笔记·数学分析
NuyoahC2 小时前
算法笔记(十一)——优先级队列(堆)
c++·笔记·算法·优先级队列
这可就有点麻烦了3 小时前
强化学习笔记之【TD3算法】
linux·笔记·算法·机器学习
Ljubim.te5 小时前
软件设计师——数据结构
数据结构·笔记
speop7 小时前
【笔记】I/O总结王道强化视频笔记
笔记·音视频
yngsqq7 小时前
031集——文本文件按空格分行——C#学习笔记
笔记·学习·c#
sealaugh328 小时前
aws(学习笔记第一课) AWS CLI,创建ec2 server以及drawio进行aws画图
笔记·学习·aws
CXDNW8 小时前
【网络篇】计算机网络——应用层详述(笔记)
服务器·笔记·计算机网络·http·web·cdn·dns
向上的车轮9 小时前
Django学习笔记五:templates使用详解
笔记·学习·django