The conditional probability distribution induced by a given network and loss in supervised learning

Any neural network and a loss function in a supervised learning settings induces a conditional probability density. Explicitly working with the induced conditional probability is beneficial.

An Artificial neural network can be described as a function f(x)f(x) over some input set X\mathcal{X} with values in some output set Y\mathcal{Y}. In supervised learning, X\mathcal{X} is different from Y\mathcal{Y} while in unsupervised learning they are the same. The activity of the network is carried out by many layers of activation units linked together using weights (Goodfellow et al., 2016). Denote the collection of all the weights of the network by the symbol ww. ww is an element in W\mathcal{W}, the set of all possible weight values the network can take at any one time. In large models, ww is a very large tuple. To reflect that for different wWw\in \mathcal{W} we have different network realizations f(x)f(x), we will adopt the notation f(xw)f(x|w).

In this post, I show in detail how a neural network f(w)f(\cdot|w) (and an associated loss function \ell) induces a conditional probability distribution and briefly discuss the benefits of this view. In future posts, I extend this probabilistic view to other elements of supervised learning and show in details how this view is important to understanding deep learning.

1. The induced conditional probability distribution

To see how any neural network f(w)f(\cdot|w) also defines a conditional probability density or mass function p(yx,w)p(y|x,w) suppose a loss function (y,f(xw))\ell(y,f(x|w)) was given. We can define the following probability density (or mass) function

p(yx,w)=e(y,f(xw))Z(1) \tag{1} p(y|x,w) = \frac{e^{-\ell(y,f(x|w))}}{Z}

where ZZ is a normalizing constant such that e(y,f(xw))e^{-\ell(y,f(x|w))} integrates to 1 for any network realization f(xw)f(x|w).

Equation (1) above is the well known Gibbs density where \ell is playing the role of the energy of the configuration yy. For this to be a well defined probability density (or mass) function, we need the normalizing constant ZZ to be finite for all wWw\in \mathcal{W}. This is indeed the case by the usual assumptions

  1. (y,f(xw))0\ell(y, f(x|w)) \geq 0 for all y,f(xw)Yy, f(x|w)\in\mathcal{Y}, with equality if y=f(xw)y=f(x|w).
  2. \ell is an integrable function (w.r.t the unknown true distribution q(yx)q(y|x) that generated the data). In other words, we require that the average loss equation

L(w)=X×Y(y,f(xw))q(yx)dydx(2) \tag{2} L(w) = \int_{\mathcal{X}\times\mathcal{Y}} \ell(y,f(x|w))q(y|x)\,dy\,dx is finite for all ww. LL is sometimes called the expected loss in machine learning, or the risk function in decision theory.

While the integrability of \ell cannot be verified in practice since q(yx)q(y|x) is unknown, typical losses like squared error loss, absolute error loss, and Huber loss satisfy these properties under reasonable distributional assumptions. Let’s work out what happens in these important examples.

1.1 Examples

In supervised learning, one is given a data set {(x1,y1),,(xn,yn)}\{(x_1, y_1),\cdots, (x_n, y_n)\} and our objective is to construct a neural network one can use to predict future unseen value yn+1y_{n+1} given future seen or unseen value xn+1x_{n+1} (Friedman, 1994). Based on the nature of the observables x,yx, y one constructs an appropriate neural network and chooses a loss function \ell that is deemed appropriate. In the following examples we drive the conditional probability densities (or mass) functions associated with a given network and loss function.

Example 1.1.1: Squared error loss

When yy is on a continuous scale (i.e. stock price, air temperature …etc) modelled as a subset of Rk\mathbb{R}^k, we could use the squared euclidean norm 22|\cdot|^2_2 as a loss function.

SE(y,f(xw))yf(xw)22=i=1k(yif(xw)i)2 \begin{aligned} \ell_{\text{SE}}(y, f(x|w)) &\triangleq |y-f(x|w)|_2^2\\ &= \sum_{i=1}^k(y_i-f(x|w)_i)^2 \end{aligned} with y=(y1,,yk)y=(y_1, \cdots, y_k), and f(xw)if(x|w)_i is the ii-th entry of f(xw)f(x|w). The neural network f(w)f(\cdot|w) and the loss SE(y,f(xw))\ell_{\text{SE}}(y,f(x|w)) induces the parameteric conditional probability density 1Zeyf(xw)22\frac{1}{Z}e^{-|y-f(x|w)|_2^2} which one can immediately recognize as the kk dimensional normal distribution p(yx,w)=12πΣk/2e12(yf(xw))TΣ1(yf(xw)) p(y|x,w) = \frac{1}{\sqrt{2\pi}|\Sigma|^{k/2}}e^{\frac{-1}{2}(y-f(x|w))^T\Sigma^{-1} (y-f(x|w))} with mean f(xw)f(x|w) and variance Σ=2Ik×k\Sigma=2\mathbb{I}_{k\times k} where (yf(xw))T(y-f(x|w))^T is the transpose of the column vector (yf(xw))(y-f(x|w)), and E|\Epsilon| is the determinant of E\Epsilon.

Example 1.1.2: Absolute error loss

Another loss function used in practice is the L1L^1 norm.

AE(y,f(xw))L1(y,f(xw))=i=1kyif(xw)i \begin{aligned} \ell_{\text{AE}}(y, f(x|w)) &\triangleq L^1(y, f(x|w))\\ &=\sum_{i=1}^k|y_i-f(x|w)_i| \end{aligned}

The induced conditional probability density is

p(yx,w)=1Zei=1kyif(xw)i=1Zi=1keyif(xw)i \begin{aligned} p(y|x,w) &= \frac{1}{Z}e^{-\sum_{i=1}^k|y_i-f(x|w)_i|}\\ &= \frac{1}{Z}\prod_{i=1}^k e^{-|y_i-f(x|w)_i|} \end{aligned} which is the product of kk independent laplace probability densities with parameter 11.

Example 1.1.3: Cross entropy loss

When yy is on a categorical scale (i.e. dog vs cat vs bird, happy vs sad, a number in the set {0,,9}\{0, \cdots, 9\}), one typically uses a network with number of output units matching the cardinality of Y\mathcal{Y} and the cross entropy loss

CE(y,f(xw))=iY{δ[y=i]logsoftmax[f(xw)]i} \ell_{\text{CE}}(y, f(x|w)) = -\sum_{i\in \mathcal{Y}} \big\{\delta[y=i]\log{\text{softmax}[f(x|w)]}_i\big\} where δ[y=i]\delta[y=i] is the dirac delta, and softmax(f(xw))i\text{softmax}(f(x|w))_i is the ii-th component of softmax(f(xw))\text{softmax}(f(x|w)). The associated conditional probability mass function is in fact explicit

P(y=ix,w)=softmax[f(xw)]iiY P(y=i|x,w) = \text{softmax}[f(x|w)]_i \quad i\in\mathcal{Y} When the cardinality of Y\mathcal{Y} is 22, the cross-entropy reduces to the binary cross-entropy.

2. Why should we care?

Now that we have a good handle on the conditional distribution induced by our choice of the loss function and the nature of the output layer of the network, one can apply the tools of frequentist statistics such as maximum likelihood, hypothesis testing, and asymptotic theory for analyzing supervised learning methods.

For instance, under the assumption that the probability of any pair (x,y)i(x,y)_i is independent, we can write the log likelihood of the dataset (under our model) for different parameters ww as

logi=1np(yixi,w)=i=1nlogp(yixi,w)(3)\tag{3} \log{\prod_{i=1}^n p(y_i|x_i,w)} = \sum_{i=1}^n \log{p(y_i|x_i,w)}

yielding the following maximum likelihood parameter everyone is familiar with from 2nd year undergraduate statistics.

w^mle=argmaxwW{i=1nlogp(yixi,w)} \hat{w}_{\text{mle}} = \underset{w\in W}{\text{argmax}}\bigg\{\sum_{i=1}^n \log{p(y_i|x_i,w)}\bigg\} Maximizing the likelihood equation (3) above is equivalent to minimizing the empirical loss

Ln(w)=1ni=1n(yi,f(xiw))(4)\tag{4} L_n(w) = \frac{1}{n}\sum_{i=1}^n \ell(y_i, f(x_i|w))

since logp(yx,w)=(y,f(xw))+log(Z)\log{p(y|x,w)}=-\ell(y,f(x|w))+\text{log(Z)}, and ZZ doesn’t depend on our parameter ww by construction.

One might argue that we did not gain much by characterizing the conditional probability density (or mass) associated with a given loss and network. This is not exactly true. First, making the nature of the assumed noise in the model explicit provide us with more information about the nature of our model and ways to change it. For instance, in the probabilistic characterization one could choose a non-diagnoal matrix to represent known correlations in our noise. Even more, one could compute the observed errors and check if they indeed conform with the assumed noise distribution (a statistical technique for measuring model fit).

Second, one can apply tools of information theory to formally characterize what it means to be surprised when using our model to stand in for the unknown distribution q(yx)q(y|x) that generated the data. Many notions of surprise are available in practice. If we subscribe to the infamous notion of Shannon surprise, then the average surprise when using the model p(yx,wmle)p(y|x, w_{\text{mle}}) instead of the true unknown distribution that generates the data is defined by the cross entropy

H(q,p)X×Yq(yx)logp(yx,wmle)dydx(4)\tag{4} H(q,p) \triangleq -\int_{\mathcal{X}\times\mathcal{Y}} q(y|x)\log{p(y|x,w_{\text{mle}})}\,dy\,dx

The minimum average (Shannon) surprise when observing samples from q(yx)q(y|x) is H(q)X×Yq(yx)logq(yx)dxdy, H(q) \triangleq -\int_{\mathcal{X}\times\mathcal{Y}}q(y|x)\log{q(y|x)}\,dx\,dy, the entropy of q(yx)q(y|x). As a result, the average excess surprise when using our model p(yx,w)p(y|x,w) instead of the true distribution q(yx)q(y|x) is

DKL(q,p)X×Yq(yx)logq(yx)logp(yx,wmle)dydx=H(q,p)H(q)(3) \begin{aligned}\tag{3} D_{\text{KL}}(q,p) &\triangleq \int_{\textstyle\mathcal{X}\times\mathcal{Y}} q(y|x)\frac{\log{q(y|x)}}{\log{p(y|x,w_{\text{mle}})}}\,dy\,dx\\ &= H(q,p) - H(q) \end{aligned} which is the Kullback-Leibler divergence (Kullback & Leibler, 1951). This is one important reason why machine learning minimizes the cross entropy H(q,p)H(q,p) (which is equivalent to minimizing the KL divergence from qq to pp) since if our model is any good it should stand in for qq when making decisions related to our observables xx and yy.

Last but not least, by understanding the conditional probabilistic distribution induced by a network and loss, one can use laws that are unique to probability theory such as conditional expectation, and bayes’ rule to study complex models and building powerful learning machines.

  1. Goodfellow, I., Bengio, Y., & Courville, A. (2016). Deep Learning. MIT Press.
  2. Friedman, J. H. (1994). An Overview of Predictive Learning and Function Approximation. In V. Cherkassky, J. H. Friedman, & H. Wechsler (Eds.), From Statistics to Neural Networks (pp. 1–61). Springer Berlin Heidelberg.
  3. Kullback, S., & Leibler, R. A. (1951). On Information and Sufficiency. The Annals of Mathematical Statistics, 22(1), 79–86. https://doi.org/10.1214/aoms/1177729694