連鎖律(多変数関数の合成関数の微分)

連鎖律(チェインルール)とは,高校数学で習う合成関数の微分公式を多変数関数に拡張した公式です。例えば,2変数関数の場合,以下のようになります。

連鎖律(チェインルール)

(x,y)(x,y) から (u,v)(u,v) が定まり,(u,v)(u,v) から ff が定まるとき,

fx=fuux+fvvx\dfrac{\partial f}{\partial x}=\dfrac{\partial f}{\partial u}\dfrac{\partial u}{\partial x}+\dfrac{\partial f}{\partial v}\dfrac{\partial v}{\partial x}

fy=fuuy+fvvy\dfrac{\partial f}{\partial y}=\dfrac{\partial f}{\partial u}\dfrac{\partial u}{\partial y}+\dfrac{\partial f}{\partial v}\dfrac{\partial v}{\partial y}

この記事では,連鎖律の具体例,行列を使った表現,導出について解説します。

連鎖律について

高校数学で習う合成関数の微分(→合成関数の微分公式と例題7問)を多変数関数に拡張したのが連鎖律です。

連鎖律は数学ではもちろん,物理でも頻繁に登場します。また,機械学習におけるニューラルネットワークの逆誤差伝搬法を理解するためにも必要な公式です。

偏微分が大量に登場します。偏微分については偏微分の意味と計算例・応用をどうぞ。

例題

連鎖律を使って偏微分を計算してみます。この記事では,全ての偏微分係数が存在するとき,という条件はいちいち書かないことにします。

例題

f(x,y)=(x2+y2)sinxyf(x,y)=(x^2+y^2)\sin xy に対して,偏導関数 fx\dfrac{\partial f}{\partial x} を求めよ。

解答

u(x,y)=x2+y2u(x,y)=x^2+y^2v(x,y)=sinxyv(x,y)=\sin xy とおくと,f=uvf=uv であり,

fx=fuux+fvvx=v(2x)+u(ycosxy)=2xsinxy+(x2y+y3)cosxy\dfrac{\partial f}{\partial x}=\dfrac{\partial f}{\partial u}\dfrac{\partial u}{\partial x}+\dfrac{\partial f}{\partial v}\dfrac{\partial v}{\partial x}\\ =v(2x)+u(y\cos xy)\\ =2x\sin xy+(x^2y+y^3)\cos xy

連鎖律と行列

連鎖律を行列で表現してみます。

(x,y)(u,v)(x,y)\to (u,v) のヤコビ行列(偏導関数を並べたもの)を JAJ_A(u,v)f(u,v)\to f のヤコビ行列を JBJ_B とします。→ヤコビ行列,ヤコビアンの定義

つまり,

JA=(uxuyvxvy)J_A=\begin{pmatrix}\dfrac{\partial u}{\partial x}&\dfrac{\partial u}{\partial y}\\\dfrac{\partial v}{\partial x}&\dfrac{\partial v}{\partial y}\end{pmatrix}JB=(fufv)J_B=\begin{pmatrix}\dfrac{\partial f}{\partial u}&\dfrac{\partial f}{\partial v}\end{pmatrix}

です。このとき連鎖律は

(x,y)f(x,y)\to f のヤコビ行列 J=(fxfy)J=\begin{pmatrix}\dfrac{\partial f}{\partial x}&\dfrac{\partial f}{\partial y}\end{pmatrix}

がヤコビ行列の積 JBJAJ_BJ_A となることを表しています。

より一般に,以下が成立します。

連鎖律(一般形)

(x1,,xl)(x_1,\cdots,x_l) から (u1,,um)(u_1,\cdots,u_m) が定まり,(u1,,um)(u_1,\cdots,u_m) から (f1,,fn)(f_1,\cdots,f_n) が定まるとする。それぞれの変換のヤコビ行列を JA,JBJ_A,J_B とする。

このとき,(x1,,xl)(f1,,fn)(x_1,\cdots,x_l)\to (f_1,\cdots,f_n) のヤコビ行列は JBJAJ_BJ_A

例えば l=2,m=3,n=2l=2,m=3,n=2 のとき, (f1x1f1x2f2x1f2x2)=(f1u1f1u2f1u3f2u1f2u2f2u3)(u1x1u1x2u2x1u2x2u3x1u3x2)\begin{pmatrix}\dfrac{\partial f_1}{\partial x_1}&\dfrac{\partial f_1}{\partial x_2}\\\dfrac{\partial f_2}{\partial x_1}&\dfrac{\partial f_2}{\partial x_2}\end{pmatrix}=\begin{pmatrix}\dfrac{\partial f_1}{\partial u_1}&\dfrac{\partial f_1}{\partial u_2}&\dfrac{\partial f_1}{\partial u_3}\\\dfrac{\partial f_2}{\partial u_1}&\dfrac{\partial f_2}{\partial u_2}&\dfrac{\partial f_2}{\partial u_3}\end{pmatrix}\begin{pmatrix}\dfrac{\partial u_1}{\partial x_1}&\dfrac{\partial u_1}{\partial x_2}\\\dfrac{\partial u_2}{\partial x_1}&\dfrac{\partial u_2}{\partial x_2}\\\dfrac{\partial u_3}{\partial x_1}&\dfrac{\partial u_3}{\partial x_2}\end{pmatrix} という感じです。美しいですね!

連鎖律の導出

厳密な証明ではありませんが,イメージはつかみやすいと思います。

導出

(x1,,xl)(x_1,\cdots ,x_l)(x1+Δx1,,xl+Δxl)(x_1+\Delta x_1,\cdots ,x_l+\Delta x_l) に微小変化させたときの (u1,,um)(u_1,\cdots,u_m) の変化量 JA(Δx1,,Δxl)\fallingdotseq J_A(\Delta x_1,\cdots, \Delta x_l)

(u1,,um)(u_1,\cdots ,u_m)(u1+Δu1,,um+Δum)(u_1+\Delta u_1,\cdots ,u_m+\Delta u_m) に微小変化させたときの (f1,,fn)(f_1,\cdots,f_n) の変化量 JB(Δu1,,Δum)\fallingdotseq J_B(\Delta u_1,\cdots, \Delta u_m)

以上2式より,(x1,,xl)(x_1,\cdots ,x_l)(x1+Δx1,,xl+Δxl)(x_1+\Delta x_1,\cdots ,x_l+\Delta x_l) に微小変化させたときの (f1,,fn)(f_1,\cdots,f_n) の変化量 JBJA(Δx1,,Δxl)\fallingdotseq J_BJ_A(\Delta x_1,\cdots, \Delta x_l)

これは xfx\to f のヤコビ行列が JBJAJ_BJ_A であることを示している。

連鎖律のことを英語では chain rule(チェインルール)と言います。けっこうかっこいいですね。