ViTOutput

Created Diff never expires
3 removals
15 lines
4 additions
16 lines
class TFViTMAEOutput(tf.keras.layers.Layer):
class TFViTOutput(tf.keras.layers.Layer):
def __init__(self, config: ViTMAEConfig, **kwargs):
def __init__(self, config: ViTConfig, **kwargs):
super().__init__(**kwargs)
super().__init__(**kwargs)


self.dense = tf.keras.layers.Dense(
self.dense = tf.keras.layers.Dense(
units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
)
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)


def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.dropout(inputs=hidden_states, training=training)
hidden_states = self.dropout(inputs=hidden_states, training=training)
hidden_states = hidden_states + input_tensor
hidden_states = hidden_states + input_tensor


return hidden_states
return hidden_states