Gaussian Error Linear Units {GELUs}
- [1. Gaussian Error Linear Units (GELUs)](#1. Gaussian Error Linear Units (GELUs))
- [2. PyTorch GELU](#2. PyTorch GELU)
- [3. TensorFlow GELU](#3. TensorFlow GELU)
-
- [3.1. `tf.nn.gelu`](#3.1.
tf.nn.gelu
) - [3.2. `tf.keras.activations.gelu`](#3.2.
tf.keras.activations.gelu
) - [3.3. `tf.keras.ops.gelu`](#3.3.
tf.keras.ops.gelu
)
- [3.1. `tf.nn.gelu`](#3.1.
- [4. Compute a polynomial approximation of the error function](#4. Compute a polynomial approximation of the error function)
-
- [4.1. `tensorflow/compiler/xla/client/lib/math.cc`](#4.1.
tensorflow/compiler/xla/client/lib/math.cc
) - [4.2. `tensorflow/compiler/mlir/lite/tests/optimize.mlir`](#4.2.
tensorflow/compiler/mlir/lite/tests/optimize.mlir
)
- [4.1. `tensorflow/compiler/xla/client/lib/math.cc`](#4.1.
- [5. Compute a rational approximation of the error function](#5. Compute a rational approximation of the error function)
-
- [5.1. `third_party/xla/xla/client/lib/math.cc`](#5.1.
third_party/xla/xla/client/lib/math.cc
) - [5.2. `tensorflow/compiler/mlir/lite/tests/optimize.mlir`](#5.2.
tensorflow/compiler/mlir/lite/tests/optimize.mlir
)
- [5.1. `third_party/xla/xla/client/lib/math.cc`](#5.1.
- [6. TensorFlow Lite GELU](#6. TensorFlow Lite GELU)
-
- [6.1. `tensorflow/lite/kernels/activations.cc`](#6.1.
tensorflow/lite/kernels/activations.cc
) - [6.2. `tensorflow/lite/kernels/internal/reference/gelu.h`](#6.2.
tensorflow/lite/kernels/internal/reference/gelu.h
) - [6.3. `tensorflow/lite/delegates/gpu/common/tasks/elementwise.cc`](#6.3.
tensorflow/lite/delegates/gpu/common/tasks/elementwise.cc
)
- [6.1. `tensorflow/lite/kernels/activations.cc`](#6.1.
- References
1. Gaussian Error Linear Units (GELUs)
https://arxiv.org/abs/1606.08415
The Gaussian Error Linear Unit (GELU) is a high-performing neural network activation function.
2. PyTorch GELU
https://pytorch.org/docs/stable/generated/torch.nn.GELU.html
classtorch.nn.GELU(approximate='none')
Applies the Gaussian Error Linear Units function.
GELU ( x ) = x ∗ Φ ( x ) \text{GELU}(x) = x * \Phi(x) GELU(x)=x∗Φ(x)
where Φ ( x ) \Phi(x) Φ(x) is the Cumulative Distribution Function for Gaussian Distribution.
When the approximate argument is tanh
, Gelu is estimated with:
GELU ( x ) = 0.5 ∗ x ∗ ( 1 + Tanh ( 2 / π ∗ ( x + 0.044715 ∗ x 3 ) ) ) \text{GELU}(x) = 0.5 * x * (1 + \text{Tanh}(\sqrt{2 / \pi} * (x + 0.044715 * x^3))) GELU(x)=0.5∗x∗(1+Tanh(2/π ∗(x+0.044715∗x3)))
Args:
approximate (str, optional) : the gelu approximation algorithm to use: 'none'
| 'tanh'
. Default: 'none'
Shape:
- Input:
(*)
,*
means any number of dimensions. - Output:
(*)
, same shape as the input.
Examples:
m = nn.GELU()
input = torch.randn(2)
output = m(input)
3. TensorFlow GELU
3.1. tf.nn.gelu
https://www.tensorflow.org/api_docs/python/tf/nn/gelu
3.2. tf.keras.activations.gelu
https://www.tensorflow.org/api_docs/python/tf/keras/activations/gelu
3.3. tf.keras.ops.gelu
https://www.tensorflow.org/api_docs/python/tf/keras/ops/gelu
4. Compute a polynomial approximation of the error function
4.1. tensorflow/compiler/xla/client/lib/math.cc
v2.12.1 - tensorflow/compiler/xla/client/lib/math.cc
https://github.com/tensorflow/tensorflow/blob/v2.12.1/tensorflow/compiler/xla/client/lib/math.cc
// Compute a polynomial approximation of the error function.
// This is the same approximation used by Eigen.
static XlaOp ErfImpl32(XlaOp x) {
static const std::array<float, 7> kAlpha{
-2.72614225801306e-10f, 2.77068142495902e-08f, -2.10102402082508e-06f,
-5.69250639462346e-05f, -7.34990630326855e-04f, -2.95459980854025e-03f,
-1.60960333262415e-02f,
};
static const std::array<float, 5> kBeta{
-1.45660718464996e-05f, -2.13374055278905e-04f, -1.68282697438203e-03f,
-7.37332916720468e-03f, -1.42647390514189e-02f,
};
x = Clamp(ScalarLike(x, -4.f), x, ScalarLike(x, 4.f));
auto x2 = x * x;
return x * EvaluatePolynomial<float>(x2, kAlpha) /
EvaluatePolynomial<float>(x2, kBeta);
}
// Evaluate the polynomial given `x` and coefficients in decreasing order.
template <typename FP>
XlaOp EvaluatePolynomial(XlaOp x, absl::Span<const FP> coefficients) {
static_assert(std::is_floating_point<FP>::value,
"Template-argument 'FP' must be a floating-point type");
XlaOp poly = ScalarLike(x, 0.0);
for (FP c : coefficients) {
poly = poly * x + ScalarLike(x, c);
}
return poly;
}
4.2. tensorflow/compiler/mlir/lite/tests/optimize.mlir
v2.12.1 - tensorflow/compiler/mlir/lite/tests/optimize.mlir
https://github.com/tensorflow/tensorflow/blob/v2.12.1/tensorflow/compiler/mlir/lite/tests/optimize.mlir
func.func @gelu(%arg0: tensor<3xf32>) -> tensor<3xf32> {
%cst = arith.constant dense<0.707106769> : tensor<f32>
%cst_0 = arith.constant dense<5.000000e-01> : tensor<f32>
%cst_1 = arith.constant dense<1.000000e+00> : tensor<f32>
%0 = "tfl.mul"(%arg0, %cst_0) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<f32>) -> tensor<3xf32>
%1 = "tfl.mul"(%arg0, %cst) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<f32>) -> tensor<3xf32>
%2 = "tf.Erf"(%1) : (tensor<3xf32>) -> tensor<3xf32>
%3 = "tfl.add"(%2, %cst_1) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<f32>) -> tensor<3xf32>
%4 = "tfl.mul"(%0, %3) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32>
func.return %4 : tensor<3xf32>
// CHECK-LABEL:gelu
// CHECK: "tfl.gelu"(%arg0) {approximate = false} : (tensor<3xf32>) -> tensor<3xf32>
}
5. Compute a rational approximation of the error function
5.1. third_party/xla/xla/client/lib/math.cc
v2.17.0 - third_party/xla/xla/client/lib/math.cc
https://github.com/tensorflow/tensorflow/blob/v2.17.0/third_party/xla/xla/client/lib/math.cc
// Compute a rational approximation of the error function.
static XlaOp ErfImpl32(XlaOp x) {
static const std::array<float, 5> kAlpha{
0.00022905065861350646f, 0.0034082910107109506f, 0.050955695062380861f,
0.18520832239976145f, 1.128379143519084f};
static const std::array<float, 7> kBeta{-1.1791602954361697e-7,
0.000023547966471313185f,
0.0010179625278914885f,
0.014070470171167667f,
0.11098505178285362f,
0.49746925110067538f,
1.0f};
// We clamp x to be within [-c;c] where c = erfinv(1-2^-23), outside of
// which x should be +/-1.
constexpr float kErfInvOneMinusHalfULP = 3.7439211627767994f;
x = Clamp(ScalarLike(x, -kErfInvOneMinusHalfULP), x,
ScalarLike(x, kErfInvOneMinusHalfULP));
auto x2 = x * x;
return (x * EvaluatePolynomial<float>(x2, kAlpha)) /
EvaluatePolynomial<float>(x2, kBeta);
}
// Evaluate the polynomial given `x` and coefficients in decreasing order.
template <typename FP>
XlaOp EvaluatePolynomial(XlaOp x, absl::Span<const FP> coefficients) {
static_assert(std::is_floating_point<FP>::value,
"Template-argument 'FP' must be a floating-point type");
if (coefficients.empty()) {
return ScalarLike(x, FP(0.0));
}
XlaOp poly = ScalarLike(x, coefficients[0]);
for (int i = 1; i < coefficients.size(); ++i) {
FP c = coefficients[i];
poly = poly * x + ScalarLike(x, c);
}
return poly;
}
5.2. tensorflow/compiler/mlir/lite/tests/optimize.mlir
v2.17.0 - tensorflow/compiler/mlir/lite/tests/optimize.mlir
https://github.com/tensorflow/tensorflow/blob/v2.17.0/tensorflow/compiler/mlir/lite/tests/optimize.mlir
func.func @gelu(%arg0: tensor<3xf32>) -> tensor<3xf32> {
%cst = arith.constant dense<0.707106769> : tensor<f32>
%cst_0 = arith.constant dense<5.000000e-01> : tensor<f32>
%cst_1 = arith.constant dense<1.000000e+00> : tensor<f32>
%0 = "tfl.mul"(%arg0, %cst_0) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<f32>) -> tensor<3xf32>
%1 = "tfl.mul"(%arg0, %cst) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<f32>) -> tensor<3xf32>
%2 = "tf.Erf"(%1) : (tensor<3xf32>) -> tensor<3xf32>
%3 = "tfl.add"(%2, %cst_1) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<f32>) -> tensor<3xf32>
%4 = "tfl.mul"(%0, %3) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32>
func.return %4 : tensor<3xf32>
// CHECK-LABEL:gelu
// CHECK: "tfl.gelu"(%arg0) <{approximate = false}> : (tensor<3xf32>) -> tensor<3xf32>
}
6. TensorFlow Lite GELU
6.1. tensorflow/lite/kernels/activations.cc
https://github.com/tensorflow/tensorflow/blob/v2.17.0/tensorflow/lite/kernels/activations.cc
TfLiteStatus GeluPrepare(TfLiteContext* context, TfLiteNode* node);
TfLiteStatus GeluEval(TfLiteContext* context, TfLiteNode* node);
6.2. tensorflow/lite/kernels/internal/reference/gelu.h
...
namespace gelu_internal {
constexpr float kSqrt2dPi = M_2_SQRTPI * M_SQRT1_2; // sqrt( 2 / pi )
} // namespace gelu_internal
// Plain implementations for GELU. Used for populating lookup table.
inline float GeluTransform(float in) {
// Note: 0.5 * x * ( 1 + erf( x / sqrt( 2 ) ) ) is commonly used, but cause
// catastropic cancellation for large negative inputs. Rewriting the
// expression via erfc avoids the numerical stability issues.
return 0.5f * in * std::erfc(in * static_cast<float>(-M_SQRT1_2));
}
inline float GeluTransformApproximate(float in) {
// 0.5 * x * ( 1 + tanh( sqrt( 2 / pi ) * ( x + 0.044715 * x^3 ) ) )
return 0.5f * in *
(1.f + std::tanh(gelu_internal::kSqrt2dPi *
// Note: Avoid std::pow for integer exponents
// as it leads to much slower performance.
(in + 0.044715f * in * in * in)));
}
...
6.3. tensorflow/lite/delegates/gpu/common/tasks/elementwise.cc
...
case OperationType::GELU:
// OpenCL has erfc and so it can use the more accurate gelu calculation
// as compared to the OpenGL and Vulkan implementations.
// gelu(x) = 0.5 * x * erfc(x * -sqrt(0.5))
result =
"$0 = INIT_FLT4(0.5f) * $1 * erfc($1 * "
"INIT_FLT4(-0.70710678118654752440f));";
break;
...
References
[1] Yongqiang Cheng, https://yongqiang.blog.csdn.net/