জেলু অ্যাক্টিভেশন কি?


18

আমি বিইআরটি পেপার দিয়ে যাচ্ছিলাম যা জেলু (গাউসীয় ত্রুটি লিনিয়ার ইউনিট) ব্যবহার করে যা সমীকরণকে যা ঘুরে ফিরে

GELU(x)=xP(Xx)=xΦ(x).
0.5x(1+tanh[2/π(x+0.044715x3)])

আপনি কি সমীকরণটি সরল করতে এবং এটি কীভাবে আনুমানিক করা হয়েছে তা ব্যাখ্যা করতে পারেন।

উত্তর:


19

GELU ফাংশন

আমরা নীচের N ( 0 , 1 ) , অর্থাৎ Φ ( x ) এর ক্রমবর্ধমান বন্টনকেN(0,1) প্রসারিত করতে পারি : GELU ( x ) : = x P ( X x ) = x Φ ( x ) = 0.5 x ( 1 + erf ( এক্সΦ(x)

GELU(x):=xP(Xx)=xΦ(x)=0.5x(1+erf(x2))

দ্রষ্টব্য যে এটি একটি সংজ্ঞা , সমীকরণ (বা কোনও সম্পর্ক) নয়। লেখকরা এই প্রস্তাবটির জন্য কিছু ন্যায্যতা সরবরাহ করেছেন, যেমন একটি স্টোকাস্টিক উপমা , তবে গাণিতিকভাবে, এটি কেবল একটি সংজ্ঞা।

এখানে জেলু এর চক্রান্ত রয়েছে:

তানহ আনুমানিক

এই ধরণের সংখ্যাসূচক আনুমানিকতার জন্য মূল ধারণাটি হ'ল একটি অনুরূপ ফাংশন (প্রাথমিকভাবে অভিজ্ঞতার ভিত্তিতে) সন্ধান করা, এটি প্যারামিটারাইজাইজ করা এবং তারপরে এটি মূল ফাংশন থেকে পয়েন্টের একটি সেটে ফিট করে।

জেনে erf(x) খুব ঘনিষ্ঠ হয় tanh(x)

এবং এরফের প্রথম ডেরাইভেটিভ ( xerf(x2)তানহ() এরসাথে মিলে যায়tanh(2πx)x=0, যা2π , আমরাতান(fit) ফিট করতে এগিয়ে

tanh(2π(x+ax2+bx3+cx4+dx5))
(বা আরও শর্তাবলী সহ) পয়েন্টের একটি সেট(xi,erf(xi2))

আমি (1.5,1.5) ( এই সাইটটি ব্যবহার করে ) এর মধ্যে 20 টি নমুনায় এই ফাংশনটি ফিট করেছি এবং এখানে সহগ রয়েছে:

সেটিং দ্বারা a=c=d=0 , b হতে অনুমান করা হয় 0.04495641 । বিস্তৃত পরিসীমা থেকে আরও নমুনা সহ (সেই সাইটটি কেবলমাত্র 20 টি অনুমতি দিয়েছে), সহগ b কাগজের 0.044715 এর নিকটবর্তী হবে । শেষ পর্যন্ত আমরা পেতে

GELU(x)=xΦ(x)=0.5x(1+erf(x2))0.5x(1+tanh(2π(x+0.044715x3)))

x [ - - 10 , 10 ] এর জন্য গড় স্কোয়ার ত্রুটি 108x[10,10]

মনে রাখবেন যে আমরা যদি প্রথম ডেরাইভেটিভস, পদ √ এর মধ্যে সম্পর্কটি ব্যবহার না করি 2π

0.5x(1+tanh(0.797885x+0.035677x3))

সমতা কাজে লাগানো

erff(x)=f(x)tanhpol(x)tanhx

erf(x)tanh(pol(x))=tanh(pol(x))=tanh(pol(x))erf(x)

Previously, we were fortunate to end up with (almost) zero coefficients for even powers x2 and x4, however in general, this might lead to low quality approximations that, for example, have a term like 0.23x2 that is being cancelled out by extra terms (even or odd) instead of simply opting for 0x2.

Sigmoid approximation

A similar relationship holds between erf(x) and 2(σ(x)12) (sigmoid), which is proposed in the paper as another approximation, with mean squared error 104 for x[10,10].

Here is a Python code for generating data points, fitting the functions, and calculating the mean squared errors:

import math
import numpy as np
import scipy.optimize as optimize


def tahn(xs, a):
    return [math.tanh(math.sqrt(2 / math.pi) * (x + a * x**3)) for x in xs]


def sigmoid(xs, a):
    return [2 * (1 / (1 + math.exp(-a * x)) - 0.5) for x in xs]


print_points = 0
np.random.seed(123)
# xs = [-2, -1, -.9, -.7, 0.6, -.5, -.4, -.3, -0.2, -.1, 0,
#       .1, 0.2, .3, .4, .5, 0.6, .7, .9, 2]
# xs = np.concatenate((np.arange(-1, 1, 0.2), np.arange(-4, 4, 0.8)))
# xs = np.concatenate((np.arange(-2, 2, 0.5), np.arange(-8, 8, 1.6)))
xs = np.arange(-10, 10, 0.001)
erfs = np.array([math.erf(x/math.sqrt(2)) for x in xs])
ys = np.array([0.5 * x * (1 + math.erf(x/math.sqrt(2))) for x in xs])

# Fit tanh and sigmoid curves to erf points
tanh_popt, _ = optimize.curve_fit(tahn, xs, erfs)
print('Tanh fit: a=%5.5f' % tuple(tanh_popt))

sig_popt, _ = optimize.curve_fit(sigmoid, xs, erfs)
print('Sigmoid fit: a=%5.5f' % tuple(sig_popt))

# curves used in https://mycurvefit.com:
# 1. sinh(sqrt(2/3.141593)*(x+a*x^2+b*x^3+c*x^4+d*x^5))/cosh(sqrt(2/3.141593)*(x+a*x^2+b*x^3+c*x^4+d*x^5))
# 2. sinh(sqrt(2/3.141593)*(x+b*x^3))/cosh(sqrt(2/3.141593)*(x+b*x^3))
y_paper_tanh = np.array([0.5 * x * (1 + math.tanh(math.sqrt(2/math.pi)*(x + 0.044715 * x**3))) for x in xs])
tanh_error_paper = (np.square(ys - y_paper_tanh)).mean()
y_alt_tanh = np.array([0.5 * x * (1 + math.tanh(math.sqrt(2/math.pi)*(x + tanh_popt[0] * x**3))) for x in xs])
tanh_error_alt = (np.square(ys - y_alt_tanh)).mean()

# curve used in https://mycurvefit.com:
# 1. 2*(1/(1+2.718281828459^(-(a*x))) - 0.5)
y_paper_sigmoid = np.array([x * (1 / (1 + math.exp(-1.702 * x))) for x in xs])
sigmoid_error_paper = (np.square(ys - y_paper_sigmoid)).mean()
y_alt_sigmoid = np.array([x * (1 / (1 + math.exp(-sig_popt[0] * x))) for x in xs])
sigmoid_error_alt = (np.square(ys - y_alt_sigmoid)).mean()

print('Paper tanh error:', tanh_error_paper)
print('Alternative tanh error:', tanh_error_alt)
print('Paper sigmoid error:', sigmoid_error_paper)
print('Alternative sigmoid error:', sigmoid_error_alt)

if print_points == 1:
    print(len(xs))
    for x, erf in zip(xs, erfs):
        print(x, erf)

Output:

Tanh fit: a=0.04485
Sigmoid fit: a=1.70099
Paper tanh error: 2.4329173471294176e-08
Alternative tanh error: 2.698034519269613e-08
Paper sigmoid error: 5.6479106346814546e-05
Alternative sigmoid error: 5.704246564663601e-05

2
Why is the approximation needed? Couldn't they just use erf function?
SebiSebi

8

First note that

Φ(x)=12erfc(x2)=12(1+erf(x2))
by parity of erf. We need to show that
erf(x2)tanh(2π(x+ax3))
for a0.044715.

For large values of x, both functions are bounded in [1,1]. For small x, the respective Taylor series read

tanh(x)=xx33+o(x3)
and
erf(x)=2π(xx33)+o(x3).
Substituting, we get that
tanh(2π(x+ax3))=2π(x+(a23π)x3)+o(x3)
and
erf(x2)=2π(xx36)+o(x3).
Equating coefficient for x3, we find
a0.04553992412
close to the paper's 0.044715.

আমাদের সাইট ব্যবহার করে, আপনি স্বীকার করেছেন যে আপনি আমাদের কুকি নীতি এবং গোপনীয়তা নীতিটি পড়েছেন এবং বুঝতে পেরেছেন ।
Licensed under cc by-sa 3.0 with attribution required.