ガイドに戻る技術仕様書
SPEC DOCUMENT

Phoenix ML モデル仕様書

X (旧 Twitter) のレコメンデーションシステムの中核となる機械学習モデルの技術仕様

1.概要

Phoenix は X (旧 Twitter) のレコメンデーションシステムの中核となる機械学習モデルです。Grok ベースの Transformer アーキテクチャを採用し、ユーザーの過去の行動履歴と候補投稿から、最適なコンテンツをランキング・検索します。

1.1システム構成

Phoenix は以下の 2 つの主要コンポーネントで構成されています:

コンポーネント役割主要ファイル
Two-Tower Retrieval Model大規模候補プールからの高速検索recsys_retrieval_model.py
Ranking Transformer候補投稿の精密なランキングrecsys_model.py

1.2処理フロー

Phoenix 処理フロー
候補プール (数百万件)
|
v
Two-Tower Retrieval--- ANN検索 --->候補 (数百〜数千件)
|
v
Ranking Transformer--- スコアリング --->ランキング結果
|
v
ユーザーフィード表示

2.Two-Tower Retrieval Model

Two-Tower アーキテクチャは、ユーザーと候補アイテムを別々のタワー(エンコーダー)で埋め込み、内積類似度で高速検索を実現します。

2.1アーキテクチャ概要

Two-Tower Retrieval Model アーキテクチャ

User Tower

  • - Phoenix Transformer
  • - Mean Pooling
  • - L2 Normalization

Output: [B, D] normalized

Candidate Tower

  • - MLP Projection (2-layer)
  • - L2 Normalization

Output: [N, D] normalized

Dot Product Similarity

scores = user @ corpus.T

→ Top-K Selection

2.2User Tower

User Tower は Phoenix Transformer を使用してユーザー表現を生成します。

入力シーケンス構成

入力シーケンス
[User Embedding] + [History Embeddings (S tokens)]
     1 token              S tokens (default: 128)

処理フロー

User Tower 処理フローpython
# 1. ユーザー埋め込みの生成
user_embeddings, user_padding_mask = block_user_reduce(
    user_hashes,           # [B, num_user_hashes]
    user_embeddings,       # [B, num_user_hashes, D]
    num_user_hashes,       # 2
    emb_size,             # 128
)
# Output: [B, 1, D]

# 2. 履歴埋め込みの生成
history_embeddings, history_padding_mask = block_history_reduce(
    history_post_hashes,        # [B, S, num_item_hashes]
    history_post_embeddings,    # [B, S, num_item_hashes, D]
    history_author_embeddings,  # [B, S, num_author_hashes, D]
    history_product_surface_embeddings,  # [B, S, D]
    history_actions_embeddings,          # [B, S, D]
    num_item_hashes,   # 2
    num_author_hashes, # 2
)
# Output: [B, S, D]

# 3. シーケンス連結
embeddings = concat([user_embeddings, history_embeddings], axis=1)
# Shape: [B, 1+S, D]

# 4. Transformer エンコーディング(因果的アテンション)
model_output = transformer(embeddings, padding_mask)

# 5. Mean Pooling + L2 正規化
user_representation = mean_pool(model_output, padding_mask)
user_representation = l2_normalize(user_representation)
# Output: [B, D]

2.3Candidate Tower

Candidate Tower は軽量な MLP で候補投稿を埋め込み空間にマッピングします。

Candidate Tower アーキテクチャ
Input: [post_embeddings, author_embeddings]
concat → [B, C, (num_item_hashes + num_author_hashes) * D]

Linear (in → 2*D)

+ SiLU

Linear (2*D → D)

L2 Normalization
Output: [B, C, D]

正規化された埋め込みベクトル間の内積により、コサイン類似度を計算します。

Top-K 検索python
def _retrieve_top_k(
    self,
    user_representation: jax.Array,   # [B, D]
    corpus_embeddings: jax.Array,     # [N, D]
    top_k: int,
    corpus_mask: Optional[jax.Array] = None,
) -> Tuple[jax.Array, jax.Array]:
    """Top-K候補の検索

    Returns:
        top_k_indices: [B, K] 上位K件のインデックス
        top_k_scores: [B, K] 類似度スコア
    """
    # 内積による類似度計算
    scores = jnp.matmul(user_representation, corpus_embeddings.T)

    # マスク適用(無効な候補を除外)
    if corpus_mask is not None:
        scores = jnp.where(corpus_mask[None, :], scores, -INF)

    # Top-K選択
    top_k_scores, top_k_indices = jax.lax.top_k(scores, top_k)

    return top_k_indices, top_k_scores

Note

本番環境では、FAISS や ScaNN などの近似最近傍探索(ANN)ライブラリを使用して、数十億件規模のコーパスから高速に検索を行います。

3.Ranking Transformer

Ranking Transformer は、検索された候補投稿を精密にスコアリングし、最終的なランキングを決定します。

3.1アーキテクチャ概要

Ranking Transformer アーキテクチャ

Input Sequence:

[User (1)] + [History (S)] + [Candidates (C)]

Attention Mask

  • User+History: Causal Attention (下三角行列)
  • Candidates: User+History への attend + Self-attention
  • (他の候補には attend しない)

Phoenix Transformer Layers

h = h + LayerNorm(MHA(LayerNorm(h)))
h = h + LayerNorm(FFN(LayerNorm(h)))

Output Projection

logits = candidate_embeddings @ unembedding

Shape: [B, C, num_actions]

3.2Candidate Isolation(候補分離アテンションマスク)

各候補投稿が独立してスコアリングされることを保証するため、特別なアテンションマスクを使用します。

アテンションマスク構造
シーケンス: [user, h1, h2, c1, c2, c3]
位置:        0     1   2   3   4   5

アテンションマスク:
            Keys:  u   h1  h2  c1  c2  c3
Query u   :        1   0   0   0   0   0   <- 因果的
Query h1  :        1   1   0   0   0   0   <- 因果的
Query h2  :        1   1   1   0   0   0   <- 因果的
Query c1  :        1   1   1   1   0   0   <- user+history + 自己
Query c2  :        1   1   1   0   1   0   <- user+history + 自己
Query c3  :        1   1   1   0   0   1   <- user+history + 自己

3.3Multi-Action Prediction(マルチアクション予測)

Phoenix は 19 種類のユーザーアクションを同時に予測します。

インデックスアクション名説明
0favorite_scoreいいね確率
1reply_score返信確率
2repost_scoreリポスト確率
3photo_expand_score画像展開確率
4click_scoreクリック確率
5profile_click_scoreプロフィールクリック確率
6vqv_score動画視聴完了確率
7share_score共有確率
8share_via_dm_scoreDM共有確率
9share_via_copy_link_scoreリンクコピー確率
10dwell_score滞在確率
11quote_score引用確率
12quoted_click_score引用クリック確率
13follow_author_score著者フォロー確率
14not_interested_score興味なし確率
15block_author_score著者ブロック確率
16mute_author_score著者ミュート確率
17report_score通報確率
18dwell_time滞在時間予測

4.入力特徴量

4.1特徴量の種類

入力特徴量の構造

1. ユーザーハッシュ埋め込み

user_hashes: [B, num_user_hashes]user_embeddings: [B, num_user_hashes, D]

2. 履歴埋め込み

history_post_hashes: [B, S, num_item_hashes]history_post_embeddings: [B, S, num_item_hashes, D]history_author_embeddings: [B, S, num_author_hashes, D]history_actions: [B, S, num_actions]history_product_surface: [B, S]

3. 候補埋め込み

candidate_post_hashes: [B, C, num_item_hashes]candidate_post_embeddings: [B, C, num_item_hashes, D]candidate_author_embeddings: [B, C, num_author_hashes, D]candidate_product_surface: [B, C]

B: バッチサイズ | S: 履歴シーケンス長 (default: 128) | C: 候補数 (default: 32) | D: 埋め込み次元 (default: 128)

4.2Product Surface

Product Surface は、ユーザーがコンテンツとインタラクションした場所を示すカテゴリ特徴量です。

ホームタイムライン検索結果通知プロフィールページ引用リポストなど
python
product_surface_vocab_size: int = 16  # サーフェスの種類数

5.Hash-Based Embeddings

Phoenix は、メモリ効率とスケーラビリティのためにハッシュベースの埋め込みを採用しています。

Hash-Based Embedding System

従来の埋め込み:

entity_id → embedding_table[entity_id]

問題: 数十億のユニークIDに対して巨大なテーブルが必要

ハッシュベース埋め込み:

entity_id → [hash_1(id), hash_2(id), ...] → lookup → combine

利点: 固定サイズのテーブルで任意のIDを処理可能

5.1HashConfig

HashConfigpython
@dataclass
class HashConfig:
    """ハッシュベース埋め込みの設定"""

    num_user_hashes: int = 2     # ユーザーIDに使用するハッシュ関数の数
    num_item_hashes: int = 2     # 投稿IDに使用するハッシュ関数の数
    num_author_hashes: int = 2   # 著者IDに使用するハッシュ関数の数

6.モデル設定

6.1TransformerConfig

TransformerConfig (grok.py)python
@dataclass
class TransformerConfig:
    emb_size: int           # 埋め込み次元
    key_size: int           # アテンションキーの次元
    num_q_heads: int        # クエリヘッド数
    num_kv_heads: int       # キー/バリューヘッド数(GQA対応)
    num_layers: int         # Transformer レイヤー数
    widening_factor: float = 4.0          # FFN拡張係数
    attn_output_multiplier: float = 1.0   # アテンション出力スケール

    name: Optional[str] = None

    def make(self) -> "Transformer":
        return Transformer(
            num_q_heads=self.num_q_heads,
            num_kv_heads=self.num_kv_heads,
            widening_factor=self.widening_factor,
            key_size=self.key_size,
            attn_output_multiplier=self.attn_output_multiplier,
            num_layers=self.num_layers,
        )

6.2デフォルト設定例

run_ranker.py よりpython
emb_size = 128
num_actions = 19
history_seq_len = 32
candidate_seq_len = 8

hash_config = HashConfig(
    num_user_hashes=2,
    num_item_hashes=2,
    num_author_hashes=2,
)

recsys_model = PhoenixModelConfig(
    emb_size=emb_size,
    num_actions=num_actions,
    history_seq_len=history_seq_len,
    candidate_seq_len=candidate_seq_len,
    hash_config=hash_config,
    product_surface_vocab_size=16,
    model=TransformerConfig(
        emb_size=emb_size,
        widening_factor=2,
        key_size=64,
        num_q_heads=2,
        num_kv_heads=2,
        num_layers=2,
        attn_output_multiplier=0.125,
    ),
)

7.Transformer アーキテクチャ詳細

7.1全体構造

Transformer 全体構造

Input: embeddings [B, T, D], padding_mask [B, T]

for layer_idx in range(num_layers):

h = inputs

# Self-Attention Block

h_attn = MHABlock(RMSNorm(h), mask)

h_attn = RMSNorm(h_attn)

h = h + h_attn

# Feed-Forward Block

h_dense = DenseBlock(RMSNorm(h))

h_dense = RMSNorm(h_dense)

h = h + h_dense

Output: embeddings [B, T, D]

7.2Multi-Head Attention

Grouped Query Attention (GQA) 対応のマルチヘッドアテンションを使用しています。

Multi-Head Attention の主要処理python
# Q/K/V 射影
query_heads = self._linear_projection(query, self.key_size, self.num_q_heads)
key_heads = self._linear_projection(key, self.key_size, self.num_kv_heads)
value_heads = self._linear_projection(value, self.value_size, self.num_kv_heads)

# Rotary Position Embedding (RoPE) 適用
rotate = RotaryEmbedding(dim=self.key_size)
key_heads = rotate(key_heads, seq_dim=1, offset=0)
query_heads = rotate(query_heads, seq_dim=1, offset=0)

# GQA: query heads をグループ化
query_heads = query_heads.reshape((b, t, kv_h, h // kv_h, d))

# アテンションスコア計算
attn_logits = jnp.einsum("...thHd,...Thd->...hHtT", query_heads, key_heads)
attn_logits *= self.attn_output_multiplier

# Soft-capping(数値安定性のため)
max_attn_val = 30.0
attn_logits = max_attn_val * jnp.tanh(attn_logits / max_attn_val)

# マスク適用 + Softmax
attn_logits = jnp.where(mask, attn_logits, -1e30)
attn_weights = jax.nn.softmax(attn_logits)

7.3Feed-Forward Network (DenseBlock)

SwiGLU 活性化を使用した FFNpython
@dataclass
class DenseBlock(hk.Module):
    """SwiGLU 活性化を使用した FFN"""

    widening_factor: float = 4.0

    def __call__(self, inputs: jax.Array) -> jax.Array:
        _, _, model_size = inputs.shape

        # SwiGLU: gate * GELU(x)
        h_v = Linear(ffn_size)(inputs)           # Value branch
        h_w1 = jax.nn.gelu(Linear(ffn_size)(inputs))  # Gate branch

        h_dense = Linear(model_size)(h_w1 * h_v)  # Output projection

        return h_dense

7.4Rotary Position Embedding (RoPE)

位置情報を回転行列として埋め込み、相対位置関係を効率的にエンコードします。

RoPE 実装python
class RotaryEmbedding(hk.Module):
    """回転位置埋め込み(RoPE)
    参考: https://arxiv.org/abs/2104.09864
    """

    def __init__(self, dim: int, base_exponent: int = 10000):
        self.dim = dim
        self.base_exponent = base_exponent

    def __call__(self, x: jax.Array, seq_dim: int, offset: jax.Array) -> jax.Array:
        # 周波数の計算
        exponents = jnp.arange(0, self.dim, 2, dtype=jnp.float32)
        inv_freq = 1.0 / (self.base_exponent ** (exponents / self.dim))

        # 位置インデックス
        t = jnp.arange(x.shape[seq_dim]) + offset

        # 位相角の計算
        phase = jnp.einsum("bi,j->bij", t, inv_freq)
        phase = jnp.tile(phase, reps=(1, 2))

        # 回転の適用
        x = x * jnp.cos(phase) + rotate_half(x) * jnp.sin(phase)

        return x

8.推論パイプライン

8.1ランキング推論

run_ranker.py の使用例python
# 1. モデル設定
recsys_model = PhoenixModelConfig(...)

# 2. 推論ランナーの作成と初期化
inference_runner = RecsysInferenceRunner(
    runner=ModelRunner(model=recsys_model, bs_per_device=0.125),
    name="recsys_local",
)
inference_runner.initialize()

# 3. バッチの作成
batch, embeddings = create_example_batch(
    batch_size=1,
    emb_size=emb_size,
    history_len=history_seq_len,
    num_candidates=candidate_seq_len,
    num_actions=num_actions,
    ...
)

# 4. ランキング実行
ranking_output = inference_runner.rank(batch, embeddings)

# 5. 結果の取得
scores = ranking_output.scores        # [B, C, num_actions]
ranked_indices = ranking_output.ranked_indices  # [B, C]
p_favorite = ranking_output.p_favorite_score    # [B, C]

8.2検索推論

run_retrieval.py の使用例python
# 1. モデル設定
retrieval_model_config = PhoenixRetrievalModelConfig(...)

# 2. 推論ランナーの作成と初期化
inference_runner = RecsysRetrievalInferenceRunner(
    runner=RetrievalModelRunner(model=retrieval_model_config, bs_per_device=0.125),
    name="retrieval_local",
)
inference_runner.initialize()

# 3. コーパスの設定
corpus_embeddings, corpus_post_ids = create_example_corpus(
    corpus_size=1000,
    emb_size=emb_size,
)
inference_runner.set_corpus(corpus_embeddings, corpus_post_ids)

# 4. バッチの作成
batch, embeddings = create_example_batch(...)

# 5. 検索実行
retrieval_output = inference_runner.retrieve(
    batch,
    embeddings,
    top_k=10,
)

# 6. 結果の取得
top_k_indices = retrieval_output.top_k_indices  # [B, K]
top_k_scores = retrieval_output.top_k_scores    # [B, K]
user_representation = retrieval_output.user_representation  # [B, D]

9.パラメータサマリー

パラメータデフォルト値説明
emb_size128埋め込み次元
num_layers2Transformer レイヤー数
num_q_heads2クエリヘッド数
num_kv_heads2キー/バリューヘッド数
key_size64アテンションキー次元
widening_factor2.0FFN拡張係数
attn_output_multiplier0.125アテンション出力スケール
history_seq_len128最大履歴シーケンス長
candidate_seq_len32最大候補数
num_actions19予測アクション数
num_user_hashes2ユーザーハッシュ数
num_item_hashes2アイテムハッシュ数
num_author_hashes2著者ハッシュ数
product_surface_vocab_size16プロダクトサーフェス種類数
fprop_dtypebfloat16推論時データ型

10.関連ファイル

ファイル説明
/phoenix/grok.pyTransformer コアアーキテクチャ
/phoenix/recsys_model.pyRanking モデル実装
/phoenix/recsys_retrieval_model.pyRetrieval モデル実装
/phoenix/runners.py推論ランナーとユーティリティ
/phoenix/run_ranker.pyランキングデモスクリプト
/phoenix/run_retrieval.py検索デモスクリプト
/phoenix/test_recsys_model.pyアテンションマスクのテスト
X レコメンドアルゴリズム - オープンソース技術仕様書