文章出處

NDArray可以很方便的求解導數,比如下面的例子:(代碼主要參考自https://zh.gluon.ai/chapter_crashcourse/autograd.html

 用代碼實現如下:

 1 import mxnet.ndarray as nd
 2 import mxnet.autograd as ag
 3 x = nd.array([[1,2],[3,4]])
 4 print(x)
 5 x.attach_grad() #附加導數存放的空間
 6 with ag.record():
 7     y = 2*x**2
 8 y.backward() #求導
 9 z = x.grad #將導數結果(也是一個矩陣)賦值給z
10 print(z) #打印結果
[[ 1.  2.]
 [ 3.  4.]]
<NDArray 2x2 @cpu(0)>

[[  4.   8.]
 [ 12.  16.]]
<NDArray 2x2 @cpu(0)>

 

對控制流求導

NDArray還能對諸如if的控制分支進行求導,比如下面這段代碼:

1 def f(a):
2     if nd.sum(a).asscalar()<15: #如果矩陣a的元數和<15
3         b = a*2 #則所有元素*2
4     else:
5         b = a 
6     return b

數學公式等價于:

這樣就轉換成本文最開頭示例一樣,變成單一函數求導,顯然導數值就是x前的常數項,驗證一下:

import mxnet.ndarray as nd
import mxnet.autograd as ag

def f(a):
    if nd.sum(a).asscalar()<15: #如果矩陣a的元數和<15
        b = a*2 #則所有元素平方
    else:
        b = a 
    return b

#注:1+2+3+4<15,所以進入b=a*2的分支
x = nd.array([[1,2],[3,4]])
print("x1=")
print(x)
x.attach_grad()
with ag.record():
    y = f(x)
print("y1=")
print(y)
y.backward() #dy/dx = y/x 即:2
print("x1.grad=")
print(x.grad)


x = x*2
print("x2=")
print(x)
x.attach_grad()
with ag.record():
    y = f(x)
print("y2=")
print(y)
y.backward()
print("x2.grad=")
print(x.grad)
x1=
[[ 1.  2.]
 [ 3.  4.]]
<NDArray 2x2 @cpu(0)>

y1= [[ 2. 4.] [ 6. 8.]] <NDArray 2x2 @cpu(0)>
x1.grad= [[ 2. 2.] [ 2. 2.]] <NDArray 2x2 @cpu(0)>
x2= [[ 2. 4.] [ 6. 8.]] <NDArray 2x2 @cpu(0)>
y2= [[ 2. 4.] [ 6. 8.]] <NDArray 2x2 @cpu(0)>
x2.grad= [[ 1. 1.] [ 1. 1.]] <NDArray 2x2 @cpu(0)>

 

頭梯度

原文上講得很含糊,其實所謂頭梯度,就是一個求導結果前的乘法系數,見下面代碼:

 1 import mxnet.ndarray as nd
 2 import mxnet.autograd as ag
 3 
 4 x = nd.array([[1,2],[3,4]])
 5 print("x=")
 6 print(x)
 7 
 8 x.attach_grad()
 9 with ag.record():
10     y = 2*x*x
11 
12 head = nd.array([[10, 1.], [.1, .01]]) #所謂的"頭梯度"
13 print("head=")
14 print(head)
15 y.backward(head_gradient) #用頭梯度求導
16 
17 print("x.grad=")
18 print(x.grad) #打印結果
x=
[[ 1.  2.]
 [ 3.  4.]]
<NDArray 2x2 @cpu(0)>

head= [[ 10. 1. ] [ 0.1 0.01]] <NDArray 2x2 @cpu(0)>
x.grad= [[ 40. 8. ] [ 1.20000005 0.16 ]] <NDArray 2x2 @cpu(0)>

對比本文最開頭的求導結果,上面的代碼僅僅多了一個head矩陣,最終的結果,其實就是在常規求導結果的基礎上,再乘上head矩陣(指:數乘而非叉乘)

 

鏈式法則

先復習下數學

注:最后一行中所有變量x,y,z都是向量(即:矩形),為了不讓公式看上去很凌亂,就統一省掉了變量上的箭頭。NDArray對復合函數求導時,已經自動應用了鏈式法則,見下面的示例代碼:

 1 import mxnet.ndarray as nd
 2 import mxnet.autograd as ag
 3 
 4 x = nd.array([[1,2],[3,4]])
 5 print("x=")
 6 print(x)
 7 
 8 x.attach_grad()
 9 with ag.record():
10     y = x**2
11     z = y**2 + y
12 
13 z.backward()
14 
15 print("x.grad=")
16 print(x.grad) #打印結果
17 
18 print("w=")
19 w = 4*x**3 + 2*x
20 print(w) # 驗證結果
x=
[[ 1.  2.]
 [ 3.  4.]]
<NDArray 2x2 @cpu(0)>

x.grad= [[ 6. 36.] [ 114. 264.]] <NDArray 2x2 @cpu(0)>
w= [[ 6. 36.] [ 114. 264.]] <NDArray 2x2 @cpu(0)>

 


文章列表


不含病毒。www.avast.com
arrow
arrow
    全站熱搜
    創作者介紹
    創作者 大師兄 的頭像
    大師兄

    IT工程師數位筆記本

    大師兄 發表在 痞客邦 留言(0) 人氣()