Additive vs Vanilla

Created Diff never expires
7 removals
Lines
Total
Removed
Words
Total
Removed
To continue using this feature, upgrade to
Diffchecker logo
Diffchecker Pro
20 lines
24 additions
Lines
Total
Added
Words
Total
Added
To continue using this feature, upgrade to
Diffchecker logo
Diffchecker Pro
28 lines

def _calculate_scores(self, query, key):
def _calculate_scores(self, query, key):
"""Calculates attention scores as a nonlinear sum of query and key.
"""Calculates attention scores as a query-key dot product.
Args:
Args:
query: Query tensor of shape `[batch_size, Tq, dim]`.
query: Query tensor of shape `[batch_size, Tq, dim]`.
key: Key tensor of shape `[batch_size, Tv, dim]`.
key: Key tensor of shape `[batch_size, Tv, dim]`.
Returns:
Returns:
Tensor of shape `[batch_size, Tq, Tv]`.
Tensor of shape `[batch_size, Tq, Tv]`.
"""
"""
# Reshape tensors to enable broadcasting.
if self.score_mode == "dot":
# Reshape into [batch_size, Tq, 1, dim].
scores = tf.matmul(query, key, transpose_b=True)
q_reshaped = tf.expand_dims(query, axis=-2)
if self.scale is not None:
# Reshape into [batch_size, 1, Tv, dim].
scores *= self.scale
k_reshaped = tf.expand_dims(key, axis=-3)
elif self.score_mode == "concat":
if self.use_scale:
# Reshape tensors to enable broadcasting.
scale = self.scale
# Reshape into [batch_size, Tq, 1, dim].
else:
q_reshaped = tf.expand_dims(query, axis=-2)
scale = 1.0
# Reshape into [batch_size, 1, Tv, dim].
return tf.reduce_sum(scale * tf.tanh(q_reshaped + k_reshaped), axis=-1)
k_reshaped = tf.expand_dims(key, axis=-3)
if self.scale is not None:
scores = self.concat_score_weight * tf.reduce_sum(
tf.tanh(self.scale * (q_reshaped + k_reshaped)), axis=-1
)
else:
scores = self.concat_score_weight * tf.reduce_sum(
tf.tanh(q_reshaped + k_reshaped), axis=-1
)


return scores