Phoenix ML モデル仕様書
X (旧 Twitter) のレコメンデーションシステムの中核となる機械学習モデルの技術仕様
目次
1.概要
Phoenix は X (旧 Twitter) のレコメンデーションシステムの中核となる機械学習モデルです。Grok ベースの Transformer アーキテクチャを採用し、ユーザーの過去の行動履歴と候補投稿から、最適なコンテンツをランキング・検索します。
1.1システム構成
Phoenix は以下の 2 つの主要コンポーネントで構成されています:
| コンポーネント | 役割 | 主要ファイル |
|---|---|---|
| Two-Tower Retrieval Model | 大規模候補プールからの高速検索 | recsys_retrieval_model.py |
| Ranking Transformer | 候補投稿の精密なランキング | recsys_model.py |
1.2処理フロー
2.Two-Tower Retrieval Model
Two-Tower アーキテクチャは、ユーザーと候補アイテムを別々のタワー(エンコーダー)で埋め込み、内積類似度で高速検索を実現します。
2.1アーキテクチャ概要
User Tower
- - Phoenix Transformer
- - Mean Pooling
- - L2 Normalization
Output: [B, D] normalized
Candidate Tower
- - MLP Projection (2-layer)
- - L2 Normalization
Output: [N, D] normalized
Dot Product Similarity
scores = user @ corpus.T→ Top-K Selection
2.2User Tower
User Tower は Phoenix Transformer を使用してユーザー表現を生成します。
入力シーケンス構成
[User Embedding] + [History Embeddings (S tokens)]
1 token S tokens (default: 128)処理フロー
# 1. ユーザー埋め込みの生成
user_embeddings, user_padding_mask = block_user_reduce(
user_hashes, # [B, num_user_hashes]
user_embeddings, # [B, num_user_hashes, D]
num_user_hashes, # 2
emb_size, # 128
)
# Output: [B, 1, D]
# 2. 履歴埋め込みの生成
history_embeddings, history_padding_mask = block_history_reduce(
history_post_hashes, # [B, S, num_item_hashes]
history_post_embeddings, # [B, S, num_item_hashes, D]
history_author_embeddings, # [B, S, num_author_hashes, D]
history_product_surface_embeddings, # [B, S, D]
history_actions_embeddings, # [B, S, D]
num_item_hashes, # 2
num_author_hashes, # 2
)
# Output: [B, S, D]
# 3. シーケンス連結
embeddings = concat([user_embeddings, history_embeddings], axis=1)
# Shape: [B, 1+S, D]
# 4. Transformer エンコーディング(因果的アテンション)
model_output = transformer(embeddings, padding_mask)
# 5. Mean Pooling + L2 正規化
user_representation = mean_pool(model_output, padding_mask)
user_representation = l2_normalize(user_representation)
# Output: [B, D]2.3Candidate Tower
Candidate Tower は軽量な MLP で候補投稿を埋め込み空間にマッピングします。
Linear (in → 2*D)
+ SiLU
Linear (2*D → D)
2.4Similarity Search
正規化された埋め込みベクトル間の内積により、コサイン類似度を計算します。
def _retrieve_top_k(
self,
user_representation: jax.Array, # [B, D]
corpus_embeddings: jax.Array, # [N, D]
top_k: int,
corpus_mask: Optional[jax.Array] = None,
) -> Tuple[jax.Array, jax.Array]:
"""Top-K候補の検索
Returns:
top_k_indices: [B, K] 上位K件のインデックス
top_k_scores: [B, K] 類似度スコア
"""
# 内積による類似度計算
scores = jnp.matmul(user_representation, corpus_embeddings.T)
# マスク適用(無効な候補を除外)
if corpus_mask is not None:
scores = jnp.where(corpus_mask[None, :], scores, -INF)
# Top-K選択
top_k_scores, top_k_indices = jax.lax.top_k(scores, top_k)
return top_k_indices, top_k_scoresNote
3.Ranking Transformer
Ranking Transformer は、検索された候補投稿を精密にスコアリングし、最終的なランキングを決定します。
3.1アーキテクチャ概要
Input Sequence:
[User (1)] + [History (S)] + [Candidates (C)]Attention Mask
- User+History: Causal Attention (下三角行列)
- Candidates: User+History への attend + Self-attention
- (他の候補には attend しない)
Phoenix Transformer Layers
h = h + LayerNorm(MHA(LayerNorm(h)))h = h + LayerNorm(FFN(LayerNorm(h)))Output Projection
logits = candidate_embeddings @ unembeddingShape: [B, C, num_actions]
3.2Candidate Isolation(候補分離アテンションマスク)
各候補投稿が独立してスコアリングされることを保証するため、特別なアテンションマスクを使用します。
シーケンス: [user, h1, h2, c1, c2, c3]
位置: 0 1 2 3 4 5
アテンションマスク:
Keys: u h1 h2 c1 c2 c3
Query u : 1 0 0 0 0 0 <- 因果的
Query h1 : 1 1 0 0 0 0 <- 因果的
Query h2 : 1 1 1 0 0 0 <- 因果的
Query c1 : 1 1 1 1 0 0 <- user+history + 自己
Query c2 : 1 1 1 0 1 0 <- user+history + 自己
Query c3 : 1 1 1 0 0 1 <- user+history + 自己3.3Multi-Action Prediction(マルチアクション予測)
Phoenix は 19 種類のユーザーアクションを同時に予測します。
| インデックス | アクション名 | 説明 |
|---|---|---|
| 0 | favorite_score | いいね確率 |
| 1 | reply_score | 返信確率 |
| 2 | repost_score | リポスト確率 |
| 3 | photo_expand_score | 画像展開確率 |
| 4 | click_score | クリック確率 |
| 5 | profile_click_score | プロフィールクリック確率 |
| 6 | vqv_score | 動画視聴完了確率 |
| 7 | share_score | 共有確率 |
| 8 | share_via_dm_score | DM共有確率 |
| 9 | share_via_copy_link_score | リンクコピー確率 |
| 10 | dwell_score | 滞在確率 |
| 11 | quote_score | 引用確率 |
| 12 | quoted_click_score | 引用クリック確率 |
| 13 | follow_author_score | 著者フォロー確率 |
| 14 | not_interested_score | 興味なし確率 |
| 15 | block_author_score | 著者ブロック確率 |
| 16 | mute_author_score | 著者ミュート確率 |
| 17 | report_score | 通報確率 |
| 18 | dwell_time | 滞在時間予測 |
4.入力特徴量
4.1特徴量の種類
1. ユーザーハッシュ埋め込み
user_hashes: [B, num_user_hashes]user_embeddings: [B, num_user_hashes, D]2. 履歴埋め込み
history_post_hashes: [B, S, num_item_hashes]history_post_embeddings: [B, S, num_item_hashes, D]history_author_embeddings: [B, S, num_author_hashes, D]history_actions: [B, S, num_actions]history_product_surface: [B, S]3. 候補埋め込み
candidate_post_hashes: [B, C, num_item_hashes]candidate_post_embeddings: [B, C, num_item_hashes, D]candidate_author_embeddings: [B, C, num_author_hashes, D]candidate_product_surface: [B, C]B: バッチサイズ | S: 履歴シーケンス長 (default: 128) | C: 候補数 (default: 32) | D: 埋め込み次元 (default: 128)
4.2Product Surface
Product Surface は、ユーザーがコンテンツとインタラクションした場所を示すカテゴリ特徴量です。
product_surface_vocab_size: int = 16 # サーフェスの種類数5.Hash-Based Embeddings
Phoenix は、メモリ効率とスケーラビリティのためにハッシュベースの埋め込みを採用しています。
従来の埋め込み:
entity_id → embedding_table[entity_id]問題: 数十億のユニークIDに対して巨大なテーブルが必要
ハッシュベース埋め込み:
entity_id → [hash_1(id), hash_2(id), ...] → lookup → combine利点: 固定サイズのテーブルで任意のIDを処理可能
5.1HashConfig
@dataclass
class HashConfig:
"""ハッシュベース埋め込みの設定"""
num_user_hashes: int = 2 # ユーザーIDに使用するハッシュ関数の数
num_item_hashes: int = 2 # 投稿IDに使用するハッシュ関数の数
num_author_hashes: int = 2 # 著者IDに使用するハッシュ関数の数6.モデル設定
6.1TransformerConfig
@dataclass
class TransformerConfig:
emb_size: int # 埋め込み次元
key_size: int # アテンションキーの次元
num_q_heads: int # クエリヘッド数
num_kv_heads: int # キー/バリューヘッド数(GQA対応)
num_layers: int # Transformer レイヤー数
widening_factor: float = 4.0 # FFN拡張係数
attn_output_multiplier: float = 1.0 # アテンション出力スケール
name: Optional[str] = None
def make(self) -> "Transformer":
return Transformer(
num_q_heads=self.num_q_heads,
num_kv_heads=self.num_kv_heads,
widening_factor=self.widening_factor,
key_size=self.key_size,
attn_output_multiplier=self.attn_output_multiplier,
num_layers=self.num_layers,
)6.2デフォルト設定例
emb_size = 128
num_actions = 19
history_seq_len = 32
candidate_seq_len = 8
hash_config = HashConfig(
num_user_hashes=2,
num_item_hashes=2,
num_author_hashes=2,
)
recsys_model = PhoenixModelConfig(
emb_size=emb_size,
num_actions=num_actions,
history_seq_len=history_seq_len,
candidate_seq_len=candidate_seq_len,
hash_config=hash_config,
product_surface_vocab_size=16,
model=TransformerConfig(
emb_size=emb_size,
widening_factor=2,
key_size=64,
num_q_heads=2,
num_kv_heads=2,
num_layers=2,
attn_output_multiplier=0.125,
),
)7.Transformer アーキテクチャ詳細
7.1全体構造
Input: embeddings [B, T, D], padding_mask [B, T]
for layer_idx in range(num_layers):
h = inputs
# Self-Attention Block
h_attn = MHABlock(RMSNorm(h), mask)
h_attn = RMSNorm(h_attn)
h = h + h_attn
# Feed-Forward Block
h_dense = DenseBlock(RMSNorm(h))
h_dense = RMSNorm(h_dense)
h = h + h_dense
Output: embeddings [B, T, D]
7.2Multi-Head Attention
Grouped Query Attention (GQA) 対応のマルチヘッドアテンションを使用しています。
# Q/K/V 射影
query_heads = self._linear_projection(query, self.key_size, self.num_q_heads)
key_heads = self._linear_projection(key, self.key_size, self.num_kv_heads)
value_heads = self._linear_projection(value, self.value_size, self.num_kv_heads)
# Rotary Position Embedding (RoPE) 適用
rotate = RotaryEmbedding(dim=self.key_size)
key_heads = rotate(key_heads, seq_dim=1, offset=0)
query_heads = rotate(query_heads, seq_dim=1, offset=0)
# GQA: query heads をグループ化
query_heads = query_heads.reshape((b, t, kv_h, h // kv_h, d))
# アテンションスコア計算
attn_logits = jnp.einsum("...thHd,...Thd->...hHtT", query_heads, key_heads)
attn_logits *= self.attn_output_multiplier
# Soft-capping(数値安定性のため)
max_attn_val = 30.0
attn_logits = max_attn_val * jnp.tanh(attn_logits / max_attn_val)
# マスク適用 + Softmax
attn_logits = jnp.where(mask, attn_logits, -1e30)
attn_weights = jax.nn.softmax(attn_logits)7.3Feed-Forward Network (DenseBlock)
@dataclass
class DenseBlock(hk.Module):
"""SwiGLU 活性化を使用した FFN"""
widening_factor: float = 4.0
def __call__(self, inputs: jax.Array) -> jax.Array:
_, _, model_size = inputs.shape
# SwiGLU: gate * GELU(x)
h_v = Linear(ffn_size)(inputs) # Value branch
h_w1 = jax.nn.gelu(Linear(ffn_size)(inputs)) # Gate branch
h_dense = Linear(model_size)(h_w1 * h_v) # Output projection
return h_dense7.4Rotary Position Embedding (RoPE)
位置情報を回転行列として埋め込み、相対位置関係を効率的にエンコードします。
class RotaryEmbedding(hk.Module):
"""回転位置埋め込み(RoPE)
参考: https://arxiv.org/abs/2104.09864
"""
def __init__(self, dim: int, base_exponent: int = 10000):
self.dim = dim
self.base_exponent = base_exponent
def __call__(self, x: jax.Array, seq_dim: int, offset: jax.Array) -> jax.Array:
# 周波数の計算
exponents = jnp.arange(0, self.dim, 2, dtype=jnp.float32)
inv_freq = 1.0 / (self.base_exponent ** (exponents / self.dim))
# 位置インデックス
t = jnp.arange(x.shape[seq_dim]) + offset
# 位相角の計算
phase = jnp.einsum("bi,j->bij", t, inv_freq)
phase = jnp.tile(phase, reps=(1, 2))
# 回転の適用
x = x * jnp.cos(phase) + rotate_half(x) * jnp.sin(phase)
return x8.推論パイプライン
8.1ランキング推論
# 1. モデル設定
recsys_model = PhoenixModelConfig(...)
# 2. 推論ランナーの作成と初期化
inference_runner = RecsysInferenceRunner(
runner=ModelRunner(model=recsys_model, bs_per_device=0.125),
name="recsys_local",
)
inference_runner.initialize()
# 3. バッチの作成
batch, embeddings = create_example_batch(
batch_size=1,
emb_size=emb_size,
history_len=history_seq_len,
num_candidates=candidate_seq_len,
num_actions=num_actions,
...
)
# 4. ランキング実行
ranking_output = inference_runner.rank(batch, embeddings)
# 5. 結果の取得
scores = ranking_output.scores # [B, C, num_actions]
ranked_indices = ranking_output.ranked_indices # [B, C]
p_favorite = ranking_output.p_favorite_score # [B, C]8.2検索推論
# 1. モデル設定
retrieval_model_config = PhoenixRetrievalModelConfig(...)
# 2. 推論ランナーの作成と初期化
inference_runner = RecsysRetrievalInferenceRunner(
runner=RetrievalModelRunner(model=retrieval_model_config, bs_per_device=0.125),
name="retrieval_local",
)
inference_runner.initialize()
# 3. コーパスの設定
corpus_embeddings, corpus_post_ids = create_example_corpus(
corpus_size=1000,
emb_size=emb_size,
)
inference_runner.set_corpus(corpus_embeddings, corpus_post_ids)
# 4. バッチの作成
batch, embeddings = create_example_batch(...)
# 5. 検索実行
retrieval_output = inference_runner.retrieve(
batch,
embeddings,
top_k=10,
)
# 6. 結果の取得
top_k_indices = retrieval_output.top_k_indices # [B, K]
top_k_scores = retrieval_output.top_k_scores # [B, K]
user_representation = retrieval_output.user_representation # [B, D]9.パラメータサマリー
| パラメータ | デフォルト値 | 説明 |
|---|---|---|
emb_size | 128 | 埋め込み次元 |
num_layers | 2 | Transformer レイヤー数 |
num_q_heads | 2 | クエリヘッド数 |
num_kv_heads | 2 | キー/バリューヘッド数 |
key_size | 64 | アテンションキー次元 |
widening_factor | 2.0 | FFN拡張係数 |
attn_output_multiplier | 0.125 | アテンション出力スケール |
history_seq_len | 128 | 最大履歴シーケンス長 |
candidate_seq_len | 32 | 最大候補数 |
num_actions | 19 | 予測アクション数 |
num_user_hashes | 2 | ユーザーハッシュ数 |
num_item_hashes | 2 | アイテムハッシュ数 |
num_author_hashes | 2 | 著者ハッシュ数 |
product_surface_vocab_size | 16 | プロダクトサーフェス種類数 |
fprop_dtype | bfloat16 | 推論時データ型 |
10.関連ファイル
| ファイル | 説明 |
|---|---|
/phoenix/grok.py | Transformer コアアーキテクチャ |
/phoenix/recsys_model.py | Ranking モデル実装 |
/phoenix/recsys_retrieval_model.py | Retrieval モデル実装 |
/phoenix/runners.py | 推論ランナーとユーティリティ |
/phoenix/run_ranker.py | ランキングデモスクリプト |
/phoenix/run_retrieval.py | 検索デモスクリプト |
/phoenix/test_recsys_model.py | アテンションマスクのテスト |