CONSTRAINED LOW-RANK MATRIX (AND ... - Laurent Risser

2 with rank(ˆY)=1. xPCA (leading eigen-vector of Y) estimates x* (up to a sign). true values of cards: Y ij. = 1. N x* i x* j. + Z ij. BBP phase transition: Zij ⇠ N(0,∆) x*.
7MB taille 4 téléchargements 295 vues
CONSTRAINED LOW-RANK MATRIX (AND TENSOR) ESTIMATION Lenka Zdeborová (IPhT, CEA Saclay, France)

with T. Lesieur, F. Krzakala; Proofs with J. Xu, J. Barbier, N. Macris, M. Dia, M. Lelarge, L. Miolane.

LET’S PLAY A GAME

+1 -1

N=15 people

LET’S PLAY A GAME +1 -1

+1 -1

-1

+1

-1

+1

+1

-1

+1 -1 +1

+1 -1

LET’S PLAY A GAME +1 -1

+1 -1

-1

+1

-1

+1

+1

-1

+1 -1 +1

+1 -1

LET’S PLAY A GAME

+1

• Generate a random Gaussian variable Z (zero mean and variance Δ)

• Report: ‣ Y=Z+1/⎷N if the cards were the same. -1

‣ Y=Z-1/⎷N if the cards were different.

LET’S PLAY A GAME • Each pair reports: ‣ Yij=Zij+1/⎷N if cards the same. ‣ Yij=Zij-1/⎷N if cards different. Zij ⇠ N (0,

)

Collect Yij for every pair (ij). Goal: Recover cards (up to symmetry) purely from the knowledge of Y = {Yij }i ) (M ) = Ex,w log Z , x+ w 4 = replica symmetric free energy Why is this useful?

x ⇠ PX (x) x ∈ ℝ

r

w ⇠ N (0, 1r )

When N>>1, the rN-dimensional problem reduces to a r-dimensional one.

THEOREMS: Theorem 1:

M 2 Rr⇥r

1 (M ) log Z(Y ) concentrates around maximum of N " !# r M M M Tr(M M > ) (M ) = Ex,w log Z , x+ w 4 x ⇠ PX (x)

Theorem 2:

>

MMSE = Tr[Ex (xx )

argmax (M )]

w ⇠ N (0, 1r )

Proofs: Korada, Macris’10, Krzakala, Xu, LZ, ITW’16, Barbier, Dia, Macris, Krzakala, Lesieur, LZ, NIPS’16; more elegant Lelarge, Miolane’16; El-Alaoui, Krzakala’17

FREE ENERGY FOR THE ASYMMETRIC CASE N Y

M Y

Y p 1 P (u, v|Y ) = PU (ui ) PV (vj ) Pout (Yij |uTi vj / N ) Z(Y ) i=1 j=1 i,j

"

(Mu , Mv ) = Eu,w log Zu "

+↵Ev,w log Zv

↵Mv ↵Mv , u+

Mu Mu , v+

Conjectured: Lesieur, Krzakala, LZ’15 Proof: Miolane’17

r

Mu

w

!#

r

↵Mv

w

!#

> ↵Tr(Mv Mu )

2

FREE ENERGY FOR THE TENSOR CASE N Y

p

Y (p 1 P (x|Y ) = PX (xi ) Pout (Yi1 ....ip | (p Z(Y ) i=1 N i P(x; A, B) = PX (x) exp B x Z(A, B) f (A, B) ⌘ EP (x)

>

x Ax 2



STATE EVOLUTION N X 1 t t ⇤ > r⇥r M ⌘ ai (xi ) 2 R N i=1

Characterisation of the AMP via matrix-order parameter M.

M

t+1

"

= Ex,w f

Mt Mt , x+

MSEAMP = Tr[Ex (xx> )

r

Mt

MAMP ]

!

w x

>

#

x ⇠ PX (x) w ⇠ N (0, 1r )

Proof: Rangan, Fletcher’12, Javanmard, Montanari’12, Deshpande, Montanari’14.

Observation: Stationary points of (M ) are fixed points of the state evolution.

BOTTOM LINE "

(M ) = Ex,w log Z

M M , x+

r

M

w

!#

Tr(M M > ) 4

AMP-MSE given by the local maximum of the free energy reached gradient descent starting from small M/large MSE.

free energy

MMSE is given by the global maximum of the free energy. MMSE = Tr[Ex (xx> )

argmax (M )]

MSEAMP = Tr[Ex (xx> ) MAMP

argmax (M ) M

MAMP ]

ZOOLOGY OF FIXED POINTS (FOR MATRIX ESTIMATION)

Zero mean prior: EX (x) = 0 SE has always a “trivial” fixed point M=0. Stability of the trivial fixed point: t 2 2 ⌃M ⌃ [E (x )] X t+1 t+1 M(r=1) = Mt M = This is the same as the spectral phase transition of the Fisher score matrix (Edwards’68, known as the BBP’05 transition) Non-zero mean priors: EX (x) 6= 0

MMSE always better than random guessing (spectral methods still have a phase transition). Multiple fixed points may still exist.

accuracy

accuracy

From fixed points to phase transitions:

⇢ PX (xi ) = [ (xi 2

1) + (xi + 1)] + (1

⇢) (xi )

ALGORITHMIC INTERPRETATION • Easy by approximate message passing. • Impossible information theoretically. • Hard phase: in presence of a first order phase transition.

Conjecture: No polynomial algorithm works. - Physically sensible. - Mathematically wide open.

⇢ [ (xi 2

1) + (xi + 1)] + (1

⇢) (xi )

accuracy

PX (xi ) =

noise, Δ

Phase Diagram: ⇢ PX (xi ) = [ (xi 2

impossible

1) + (xi + 1)] + (1

impossible

hard rd a h

easy

⇢) (xi )

easy

HARD PHASE IN NATURE

Metastable diamond = high error. Equilibrium graphite = low error. Algorithms are stuck at high error for exponential time.

MAIN QUESTIONS



What is the minimal achievable estimation error on x*? (Is it possible to do better than PCA?)

✅ What is the minimal efficiently achievable estimation error on x*?

accuracy

accuracy

From fixed points to phase transitions:

⇢ PX (xi ) = [ (xi 2

1) + (xi + 1)] + (1

⇢) (xi )

OPTIMAL SPECTRAL ALGORITHMS

For zero-mean priors, spectral method that has the same phase transition as AMP. AMP has better error. For noise that is not Gaussian additive, to have the optimal phase transition, spectral algorithm need to be done on the Fisher score matrix.

@ log Pout (yij |w) Sij ⌘ @w

yij ,0

OPTIMAL PRE-PROCESSING Exponential additive noise

Pout (y|w) = e

|y w|

/2

Cauchy additive noise Pout (y|w) = [1 + (y

w)2 ]

1

/⇡

AAACD3icbVDLSsNAFJ3UV62vqks3g0VskdakCOpCKLpxWcHYQpqGyXTSDp08mJlYQuwnuPFX3LhQcevWnX/j9LHQ1gMXDufcy733uBGjQur6t5ZZWFxaXsmu5tbWNza38ts7dyKMOSYmDlnImy4ShNGAmJJKRpoRJ8h3GWm4/auR37gnXNAwuJVJRGwfdQPqUYykkpz8Yd1JW9yHYSyHxeRhUIIXlnFUTMqDUrtqt9OyMTxuRRQ6+YJe0ceA88SYkgKYou7kv1qdEMc+CSRmSAjL0CNpp4hLihkZ5lqxIBHCfdQllqIB8omw0/FDQ3iglA70Qq4qkHCs/p5IkS9E4ruq00eyJ2a9kfifZ8XSO7NTGkSxJAGeLPJiBmUIR+nADuUES5YogjCn6laIe4gjLFWGORWCMfvyPDGrlfOKfnNSqF1O08iCPbAPisAAp6AGrkEdmACDR/AMXsGb9qS9aO/ax6Q1o01ndsEfaJ8/25Wawg==

AAACB3icbVDLSsNAFJ3UV62vqEsXDhahLlqTIqgLoejGZQVjC20Mk+m0HTp5MDOxhDRLN/6KGxcqbv0Fd/6N0zYLbT1w4XDOvdx7jxsyKqRhfGu5hcWl5ZX8amFtfWNzS9/euRNBxDGxcMAC3nSRIIz6xJJUMtIMOUGey0jDHVyN/cYD4YIG/q2MQ2J7qOfTLsVIKsnR9+tO0uYeDCKZluLR8AheQHKflEdxeThKj6uOXjQqxgRwnpgZKYIMdUf/ancCHHnEl5ghIVqmEUo7QVxSzEhaaEeChAgPUI+0FPWRR4SdTB5J4aFSOrAbcFW+hBP190SCPCFiz1WdHpJ9MeuNxf+8ViS7Z3ZC/TCSxMfTRd2IQRnAcSqwQznBksWKIMypuhXiPuIIS5VdQYVgzr48T6xq5bxi3JwUa5dZGnmwBw5ACZjgFNTANagDC2DwCJ7BK3jTnrQX7V37mLbmtGxmF/yB9vkDpIeYpg==

Fisher score: Sij = sign(Yij )

Fisher score:

Yij Sij = 1 + Yij2

OTHER EXAMPLES OF PHASE DIAGRAMS

Non-zero mean prior

accuracy

PX (xi ) = (1

⇢) (xi ) + ⇢ (xi

1)

Alg

⇢ = 0.2

⇢ = 0.01

Non-zero mean prior ⇢) (xi ) + ⇢ (xi

hard

accuracy

PX (xi ) = (1

1)

Stochastic block model, r groups xTi xj µ T Pout (Yij = 1| p ) = pout + p xi xj N N

pout (1 pout ) = µ2

1.0

r=15

MSE

0.6

0.4

hard

0.8

AMP from solution SE stable branch SE unstable branch c = Alg IT Dyn

0.2

0.0 0.0

0.5

1.0

1.5

2.0

r2

r>4 hard phase exists. r