deepxde 源码阅读笔记(长期更新)

2023.11.23

阅读的deepxde version: 1.9.0

1. train_aux_vars ,即pde当中的第三个参数

这个变量的含义困惑很久。最后发现就是operator learning的PDEs方程parameters。

脉络:def pde aux_vars->

deepxde目前支持tf1最多,但是对其他框架也有支持,仓库的主要维护者Lu Lu跟百度应该有合作,目前对paddlepaddle的支持正在提升。因此,实际上对于很多函数,作者写了4遍,每个框架都写一遍,基本同样的功能,但里面的细节不同。

pdeoperator.py 中的PDEoperator坐标类的train_next_batch方法

(注意:我这里命名是按着记忆直接写的,不是源码中的真实名称)

python 复制代码
    def train_next_batch(self, batch_size=None):
        if self.train_x is None:
            func_feats = self.func_space.random(self.num_func)
            func_vals = self.func_space.eval_batch(func_feats, self.eval_pts)
            vx = self.func_space.eval_batch(
                func_feats, self.pde.train_x[:, self.func_vars]
            )
            self.train_x = (func_vals, self.pde.train_x)
            self.train_aux_vars = vx

        if self.batch_size is None:
            return self.train_x, self.train_y, self.train_aux_vars

        indices = self.train_sampler.get_next(self.batch_size)
        traix_x = (self.train_x[0][indices], self.train_x[1])
        return traix_x, self.train_y, self.train_aux_vars[indices]

model.py

python 复制代码
def set_data_train(self, X_train, y_train, train_aux_vars=None):
        self.X_train = X_train
        self.y_train = y_train
        self.train_aux_vars = train_aux_vars
python 复制代码
            self.train_state.set_data_train(
                *self.data.train_next_batch(self.batch_size)
            )

根据这些代码段顺藤摸瓜确定了train_aux_vars的定位:Parameters of PDEs.

python 复制代码
        def outputs_losses(training, inputs, targets, auxiliary_vars, losses_fn):
            self.net.auxiliary_vars = auxiliary_vars
            # Don't call outputs() decorated by @tf.function above, otherwise the
            # gradient of outputs wrt inputs will be lost here.
            outputs_ = self.net(inputs, training=training)
            # Data losses
            losses = losses_fn(targets, outputs_, loss_fn, inputs, self)
            if not isinstance(losses, list):
                losses = [losses]
            # Regularization loss
            if self.net.regularizer is not None:
                losses += [tf.math.reduce_sum(self.net.losses)]
            losses = tf.convert_to_tensor(losses)
            # Weighted losses
            if loss_weights is not None:
                losses *= loss_weights
            return outputs_, losses

值得一提的是作者从头到尾都在使用 TrainState 来储存训练当中的各种变量。

python 复制代码
class TrainState:
    def __init__(self):
        self.epoch = 0
        self.step = 0

        # Current data
        self.X_train = None
        self.y_train = None
        self.train_aux_vars = None
        self.X_test = None
        self.y_test = None
        self.test_aux_vars = None

        # Results of current step
        # Train results
        self.loss_train = None
        self.y_pred_train = None
        # Test results
        self.loss_test = None
        self.y_pred_test = None
        self.y_std_test = None
        self.metrics_test = None

        # The best results correspond to the min train loss
        self.best_step = 0
        self.best_loss_train = np.inf
        self.best_loss_test = np.inf
        self.best_y = None
        self.best_ystd = None
        self.best_metrics = None

    def set_data_train(self, X_train, y_train, train_aux_vars=None):
        self.X_train = X_train
        self.y_train = y_train
        self.train_aux_vars = train_aux_vars

    def set_data_test(self, X_test, y_test, test_aux_vars=None):
        self.X_test = X_test
        self.y_test = y_test
        self.test_aux_vars = test_aux_vars

    def update_best(self):
        if self.best_loss_train > np.sum(self.loss_train):
            self.best_step = self.step
            self.best_loss_train = np.sum(self.loss_train)
            self.best_loss_test = np.sum(self.loss_test)
            self.best_y = self.y_pred_test
            self.best_ystd = self.y_std_test
            self.best_metrics = self.metrics_test

    def disregard_best(self):
        self.best_loss_train = np.inf

由train state可见pde前两个变量的含义就是偏微分方程组的坐标输入和解输出

deepxde采用LGPL,是比较激进的一类开源许可证。激进的开源是一种商业策略,参见我写的另一篇文章

2023-11-27

在物理启发时,compile不要加metrics. 因为物理启发不会用标签。

cpp 复制代码
model.compile(
        'adam',
        lr=1e-4,
        decay=("inverse time", 1, 1e-4),
        # metrics=["mean l2 relative error"],
    )

到源码中会发现

Model.py

cpp 复制代码
    @utils.timing
    def train(
        self,
        iterations=None,
        batch_size=None,
        display_every=1000,
        disregard_previous_best=False,
        callbacks=None,
        model_restore_path=None,
        model_save_path=None,
        epochs=None,
    ):
        """Trains the model.

        Args:
            iterations (Integer): Number of iterations to train the model, i.e., number
                of times the network weights are updated.
            batch_size: Integer, tuple, or ``None``.

                - If you solve PDEs via ``dde.data.PDE`` or ``dde.data.TimePDE``, do not use `batch_size`, and instead use
                  `dde.callbacks.PDEPointResampler
                  <https://deepxde.readthedocs.io/en/latest/modules/deepxde.html#deepxde.callbacks.PDEPointResampler>`_,
                  see an `example <https://github.com/lululxvi/deepxde/blob/master/examples/diffusion_1d_resample.py>`_.
                - For DeepONet in the format of Cartesian product, if `batch_size` is an Integer,
                  then it is the batch size for the branch input; if you want to also use mini-batch for the trunk net input,
                  set `batch_size` as a tuple, where the fist number is the batch size for the branch net input
                  and the second number is the batch size for the trunk net input.
            display_every (Integer): Print the loss and metrics every this steps.
            disregard_previous_best: If ``True``, disregard the previous saved best
                model.
            callbacks: List of ``dde.callbacks.Callback`` instances. List of callbacks
                to apply during training.
            model_restore_path (String): Path where parameters were previously saved.
            model_save_path (String): Prefix of filenames created for the checkpoint.
            epochs (Integer): Deprecated alias to `iterations`. This will be removed in
                a future version.
        """
        if iterations is None and epochs is not None:
            print(
                "Warning: epochs is deprecated and will be removed in a future version."
                " Use iterations instead."
            )
            iterations = epochs
        self.batch_size = batch_size
        self.callbacks = CallbackList(callbacks=callbacks)
        self.callbacks.set_model(self)
        if disregard_previous_best:
            self.train_state.disregard_best()

        if backend_name == "tensorflow.compat.v1":
            if self.train_state.step == 0:
                self.sess.run(tf.global_variables_initializer())
                if config.hvd is not None:
                    bcast = config.hvd.broadcast_global_variables(0)
                    self.sess.run(bcast)
            else:
                utils.guarantee_initialized_variables(self.sess)

        if model_restore_path is not None:
            self.restore(model_restore_path, verbose=1)

        if config.rank == 0:
            print("Training model...\n")
        self.stop_training = False
        self.train_state.set_data_train(*self.data.train_next_batch(self.batch_size))
        self.train_state.set_data_test(*self.data.test())
        self._test()
        self.callbacks.on_train_begin()
        if optimizers.is_external_optimizer(self.opt_name):
            if backend_name == "tensorflow.compat.v1":
                self._train_tensorflow_compat_v1_scipy(display_every)
            elif backend_name == "tensorflow":
                self._train_tensorflow_tfp()
            elif backend_name == "pytorch":
                self._train_pytorch_lbfgs()
            elif backend_name == "paddle":
                self._train_paddle_lbfgs()
        else:
            if iterations is None:
                raise ValueError("No iterations for {}.".format(self.opt_name))
            self._train_sgd(iterations, display_every)
        self.callbacks.on_train_end()

        if config.rank == 0:
            print("")
            display.training_display.summary(self.train_state)
        if model_save_path is not None:
            self.save(model_save_path, verbose=1)
        return self.losshistory, self.train_state

    def _train_sgd(self, iterations, display_every):
        for i in range(iterations):
            self.callbacks.on_epoch_begin()
            self.callbacks.on_batch_begin()

            self.train_state.set_data_train(
                *self.data.train_next_batch(self.batch_size)
            )
            self._train_step(
                self.train_state.X_train,
                self.train_state.y_train,
                self.train_state.train_aux_vars,
            )

            self.train_state.epoch += 1
            self.train_state.step += 1
            if self.train_state.step % display_every == 0 or i + 1 == iterations:
                self._test()

            self.callbacks.on_batch_end()
            self.callbacks.on_epoch_end()

            if self.stop_training:
                break

里面经常出现测试代码:

self._test()

相关推荐
懒惰的bit9 天前
STM32F103C8T6 学习笔记摘要(四)
笔记·stm32·学习
zkyqss9 天前
OVS Faucet练习(下)
linux·笔记·openstack
浦东新村轱天乐9 天前
【麻省理工】《how to speaking》笔记
笔记
奔跑的蜗牛AZ9 天前
TiDB 字符串行转列与 JSON 数据查询优化知识笔记
笔记·json·tidb
cwtlw9 天前
Excel学习03
笔记·学习·其他·excel
杭州杭州杭州9 天前
计算机网络笔记
笔记·计算机网络
cyborg9 天前
终于再也不用在notion中写公式了
笔记
循环过三天10 天前
1.2、CAN总线帧格式
笔记·stm32·单片机·嵌入式硬件·学习
循环过三天10 天前
1.1、CAN总线简介
笔记·stm32·单片机·嵌入式硬件·学习
mooyuan天天10 天前
pikachu靶场通关笔记43 SSRF关卡01-CURL(三种方法渗透)
笔记·安全·web安全·ssrf漏洞