This is an automated archive made by the Lemmit Bot.

The original was posted on /r/machinelearning by /u/RepresentativeWay0 on 2024-04-06 17:42:30.


Looking at the code for current mixture of experts models, they seem to use argmax, with k=1 (picking only the top expert) to select the router choice. Since argmax is non differentiable, the gradient cannot flow to the other experts. Thus it seems to me that only the weights of the selected expert will be updated if it performs poorly. However, it could be the case that a different expert was in fact a better choice for the given input, but the router cannot know this because the gradient does not flow to the other experts.

How can the router learn that it has made a wrong choice and use a different expert next time?