JEPA as a Neural Tokenizer: Learning Robust Speech Representations with Density Adaptive Attention
Georgios Ioannides111Work does not relate to position at Amazon., Christos Constantinou222Work does not relate to position at Amazon., Aman Chadha333Work does not relate to position at Amazon., Aaron Elkins, Linsey Pang, Ravid Shwartz-Ziv, Yann LeCun
Abstract
We introduce a two-stage self-supervised framework that combines the Joint-Embedding Predictive Architecture (JEPA) with a Density Adaptive Attention Mechanism (DAAM) for learning robust speech representations. Stage
1 uses JEPA with DAAM to learn semantic audio features via masked prediction in latent space, fully decoupled from waveform reconstruction. Stage2 leverages these representations for efficient tokenization using Finite Scalar Quantization (FSQ) and a mixed-radix packing scheme, followed by high-fidelity waveform reconstruction with a HiFi-GAN decoder. By integrating Gaussian mixture-based density-adaptive gating into the JEPA encoder, the model performs adaptive temporal feature selection and discovers hierarchical speech structure at a low frame rate of 2.5~Hz. The resulting tokens (47.5 tokens/sec) provide a reversible, highly compressed, and language-model-friendly representation that is competitive with, and often more efficient than, existing neural audio codecs.
Hybrid Discrete-Continuous Speech Representations via JEPA with Density Adaptive Attention
Georgios Ioannides ∗ 1 , Christos Constantinou † 2 , Aman Chadha ‡ 3 , Aaron Elkins 4 , Linsey Pang 5 , Ravid Shwartz-Ziv 6 , and Yann LeCun 6
1 Carnegie Mellon University, Amazon GenAI, James Silberrad Brown Center for Artificial Intelligence
October 25, 2025
We introduce a two-stage self-supervised framework that combines the Joint-Embedding Predictive Architecture (JEPA) with a Density Adaptive Attention Mechanism (DAAM) for learning robust speech representations. Stage 1 uses JEPA with DAAM to learn semantic audio features via masked prediction in latent space, fully decoupled from waveform reconstruction. Stage 2 leverages these representations for efficient tokenization using Finite Scalar Quantization (FSQ) and a mixed-radix packing scheme, followed by high-fidelity waveform reconstruction with a HiFi-GAN decoder. By integrating Gaussian mixturebased density-adaptive gating into the JEPA encoder, the model performs adaptive temporal feature selection and discovers hierarchical speech structure at a low frame rate of 2.5 Hz. The resulting tokens (47.5 tokens/sec) provide a reversible, highly compressed, and language-model-friendly representation that is competitive with, and often more efficient than, existing neural audio codecs.
Configuration.
∗ Work does not relate to position at Amazon.
Hybrid Discrete-Continuous Speech Representations via JEPA with Density Adaptive Attention
Overview
We introduce a two-stage self-supervised learning framework that combines the Joint-Embedding Predictive Architecture (JEPA) [Assran et al., 2023] with Density Adaptive Attention Mechanisms (DAAM) for learning robust speech representations. This approach decouples representation learning from reconstruction: Stage 1 employs JEPA with DAAM to learn semantic audio features through masked prediction, while Stage 2 leverages these representations for efficient tokenization via Finite Scalar Quantization (FSQ) [Mentzer et al., 2023] and high-quality reconstruction through HiFi-GAN [Kong et al., 2020].
Key innovation. By integrating Density Adaptive Attention-based gating (Gaussian Mixture gating) [Ioannides et al., 2024] into the JEPA encoder, we achieve adaptive feature selection during self-supervised learning. Combined with a mixed-radix packing scheme, the learned representations capture hierarchical speech structure-due to progressive downsampling from layer to layer-at a low frame rate of 2.5 Hz, enabling efficient speech modeling without labeled data.
Motivation: Why JEPA for Speech?
Traditional speech codec training couples representation learning with reconstruction objectives, forcing the encoder to prioritize features that minimize waveform-level losses. This conflates two distinct goals:
- Learning semantically meaningful representations that capture linguistic and acoustic structure.
- Preserving perceptual quality for high-fidelity reconstruction.
JEPA addresses this by separating concerns: the encoder learns to predict masked representations in latent space (Stage 1), then a separate decoder learns to map these representations to audio (Stage 2). This architectural separation enables:
· Better representations: the encoder optimizes for semantic content rather than low-level waveform details. · Efficiency: fine-tuning the encoder reduces Stage 2 training cost. · Flexibility: the same encoder can support multiple downstream tasks (text-to-speech, voice conversion, automatic speech recognition, etc.). · Scalability: Stage 1 can leverage large unlabeled datasets.
The integration of DAAM enhances this framework by introducing adaptive attention that learns which temporal regions and features are most informative for prediction, naturally discovering speech-relevant patterns.
Stage 1: Self-Supervised JEPA Encoder with DAAM
JEPA Masking Strategy
The JEPA framework employs block-based temporal masking to create a self-supervised learning objective. For a batch of audio sequences with temporal length T , binary masks m ∈ { 0 , 1 } B × T are generated, where 1 indicates visible (context) regions and 0 indicates masked (target) regions.
Block Masking Algorithm. Given mask ratio ρ ∈ [0 , 1], minimum span length s min , and maximum span length s max , we construct masks as follows:
- Initialize: m ← 1 B × T (all positions visible).
- For each sample b ∈ { 1 , . . . , B } :
- (a) Compute target: n mask = ⌊ ρ · T ⌋ .
- (b) Initialize counter: n masked ← 0.
- While n masked < n mask :
- (a) Sample span length: ℓ ∼ Uniform( s min , s max ).
- (b) Sample start position: t start ∼ Uniform(0 , T -ℓ ).
- (c) Compute end position: t end ← min( t start + ℓ, T ).
- (d) Set mask: m [ b, t ] ← 0 for all t ∈ [ t start , t end ).

Stage 1: JEPA Pretraining
Figure 1: The input waveform is processed by three parallel pathways: (1) an online encoder (trainable, green) that processes the full audio and feeds into a predictor network (yellow) after feature-space masking with a learned mask token, (2) a target encoder (purple) updated via EMA that also processes the full audio to generate z target , and (3) a masking strategy module (blue) that generates binary masks. The MSE loss is computed only on masked regions between z predicted and z target (stop-gradient), with gradients backpropagating only through the online encoder and predictor. The target encoder provides stable representations without receiving gradients directly [Grill et al., 2020].
- (e) Update counter: n masked ← n masked +( t end -t start ).
Block Masking Algorithm.
· Mask ratio: ρ = 0 . 5 (50% of timesteps masked). · Minimum span: s min = 2 frames. · Maximum span: s max = T/ 4 frames (adaptive to sequence length).
At 2.5 Hz frame rate, this corresponds to variable spans adapted to the sequence length.
Masking hyperparameters.
· Mask ratio: ρ = 0 . 5 (50% of timesteps masked). · Minimum span: s min = 2 frames. · Maximum span: s max = T/ 4 frames (adaptive to sequence length).
At 2.5 Hz frame rate, this corresponds to variable spans adapted to the sequence length.
Density Adaptive Attention for Temporal Feature Modulation
The core innovation integrating a stabilized version of the original DAAM into JEPA is the DensityAdaptiveAttention module, which computes adaptive attention gates based on learned Gaussian mixture distributions. Unlike standard self-attention that computes pairwise dot-products between positions, DAAM learns to identify statistically salient temporal regions based on their distributional characteristics.
Mathematical Formulation
For input features x ∈ R B × C × T (batch size, channels, time), the DAAM module operates along the temporal axis.
Step 1: Temporal statistics. For each batch and channel, compute the mean and variance across time:
$$
$$
$$
$$
Step 2: Learnable Gaussian parameters. For K Gaussian components, we maintain learnable parameters:
· Mean offsets: δ = [ δ 1 , . . . , δ K ] ∈ R K , initialized to δ k = 0. · Log-scale parameters: ν = [ ν 1 , . . . , ν K ] ∈ R K , initialized to ν k = log(0 . 5).
The positive scales are computed via softplus:
$$
$$
with ϵ = 10 -3 for numerical stability.
$$
$$
Step 4: Log-density under each Gaussian. The log-probability density at each timestep is:
$$
$$
Step 5: Mixture aggregation via log-sum-exp. To form a mixture of Gaussians:
$$
$$
Step 6: Attention gate and feature modulation. The final attention gate is
$$
$$
$$
$$
where ⊙ denotes element-wise multiplication.
DAAM operates on a learned 1-channel attention projection over time: features are first projected to a single channel, the Gaussian mixture gate is computed on that 1D temporal signal, and the resulting gate scales the full feature tensor.
Step 1: Temporal statistics.
Step 2: Learnable Gaussian parameters.
• Optimizer: AdamW, β 1 = 0 . 8, β 2 = 0 .
• Learning rate: 1 . 5 × 10 - 4 (decoder), 0 . 75 × 10 - 4
(discriminators).
• Weight decay: 10 - 3 . • Batch size: 8. • Training steps: 29 000.
Step 3: Standardized deviations.
· FFT sizes: [2048, 1024, 512, 256, 128]. · Hop sizes: [512, 256, 128, 64, 32]. · Window: Hann. 3. GAN loss. We use multi-period and multi-scale discriminators [Kumar et al., 2019]. Generator loss:
$$
$$
Feature matching:
GAN total:
Discriminator loss:
$$
$$
$$
$$
$$
$$
Step 4: Log-density under each Gaussian.
Step 5: Mixture aggregation via log-sum-exp.
Step 6: Attention gate and feature modulation.
The core innovation integrating a stabilized version of the original DAAM into JEPA is the DensityAdaptiveAttention module, which computes adaptive attention gates based on learned Gaussian mixture distributions. Unlike standard self-attention that computes pairwise dot-products between positions, DAAM learns to identify statistically salient temporal regions based on their distributional characteristics.
Implementation details.
· All computations in FP32 for numerical stability. · Variance clamped: var ≥ 10 -6 . · Softplus ensures positive scales: ˜ σ k > 0. · Number of Gaussians: K = 4 across all layers.
and the output features are
JEPA Encoder Architecture
The JEPA encoder consists of two parallel pathways that share weights but serve different roles.
Context encoder (online network). Processes the full audio input. Masking is applied later in feature space by replacing hidden timesteps with a learned mask token before the predictor. Parameters are updated via gradient descent.
Context encoder (online network).
Target encoder (EMA network).
· Upsample kernels: [3 , 7 , 11 , 15 , 23 , 32]. · Residual blocks: 8 per stage.
Convolutional--Transformer Hybrid Design
Downsampling path. The input raw waveform [ B, 1 , T wav ] passes through Conv1D blocks with stride, progressing through channel dimensions
$$
$$
The total stride is 8 × 8 × 5 × 5 × 6 = 9600 samples/hop at 24 kHz, resulting in a latent representation [ B, 512 , T z ], where T z corresponds to approximately 2.5 Hz frame rate.
Conformer blocks [Gulati et al., 2020]. We use 8 Conformer layers with 16 attention heads. Each layer comprises self-attention, feedforward, convolution, and layer normalization. DAAM gating is applied in the encoder blocks (after the strided convolutions and residual stacks); there is no DAAM after the Conformer blocks in the current implementation.
Integration with DAAM. After each Conformer block, features pass through GAttnGateG modules that:
- Project features to a single channel via 1 × 1 convolution.
- Compute a DAAM gate from projected features.
- Apply learned scaling
$$
$$
where α (initialized to 0 . 05) controls modulation strength.
Downsampling path.
Conformer blocks citep{Gulati2020Conformer
· Levels: L = [4 , 4 , 4 , 4]. · Code dimension: C = 128. · Temperature: τ = 1 . 0.

Figure 4: Stage 1 JEPA masked prediction loss (MSE) over training steps. JEPA+DAAM (blue) converges faster and to a lower final loss ( ∼ 0 . 09) compared to JEPA without DAAM (orange, ∼ 0 . 17), demonstrating that Density Adaptive Attention enables more efficient representation learning. Both models use identical architectures except for DAAM gating.
Straight-through estimator. During backpropagation,
$$
$$
Integration with DAAM.
The compact tokens enable direct training of decoder-only Transformers for speech generation:
· Input: discrete token sequence at 47.5 tokens/sec. · Output: next-token prediction over a 16 384-way vocabulary. · Decoding: tokens → FSQ indices → dequantized features → waveform via HiFi-GAN.
JEPA Predictor Network
The predictor takes context representations and predicts masked regions. It uses two Conformer blocks with 16 attention heads, processing masked context features and outputting predictions for all temporal positions. The predictor only receives context (visible) regions but must predict features at all positions; the mask is applied to the loss.
Stage 1 Training Objective
The JEPA training objective is pure self-supervised prediction in latent space.
Loss Function
where:
· M = { t : m t = 0 } is the set of masked positions,
$$
$$

Figure 2: JEPA online encoder architecture. Input waveform passes through an initial Conv1D layer followed by 5 encoder blocks, each containing Conv1D with stride, SnakeBeta activation, residual blocks, and Gaussian Adaptive Attention gating. Features are projected through a bottleneck Conv1D layer and processed by 8 Conformer blocks (each with FNN, multi-head attention with 16 heads, depthwise convolution, and a second FNN) to produce the final representation z . The target encoder shares this architecture but is updated via exponential moving average rather than backpropagation.
· N mask = |M| , · C is the channel dimension, · sg( · ) denotes the stop-gradient operation.
The loss is computed only on masked regions by weighting squared differences and normalized by the number of masked tokens times channels.
EMA Target Update
After each training step, the target encoder parameters are updated via EMA:
$$
$$
with momentum coefficient τ = 0 . 996.
Stage 1 hyperparameters.
· Optimizer: AdamW with β 1 = 0 . 8, β 2 = 0 . 99. · Learning rate: 1 . 5 × 10 -4 . · Weight decay: 10 -3 . · Batch size: 32. · Max audio length: 15 s @ 24 kHz. · Training steps: 24 000.
Collapse monitoring. We monitor (without backpropagation) the standard deviation of predictor outputs across batch and temporal dimensions. If the mean standard deviation falls below 0 . 01, a warning is logged. This does not contribute to the loss.

Figure 3: JEPA predictor network architecture. The predictor takes masked context features z masked and processes them through: (1) an expansion Conv1D layer that doubles the channel dimension, (2) two Conformer blocks separated by an intermediate Conv1D for feature refinement, and (3) a projection Conv1D that reduces back to the original dimensionality, producing predicted features z pred at all positions including masked regions.
Collapse monitoring.
Stage 2: Fine-Tuning Encoder + FSQ Quantization + HiFi-GAN Decoder
After Stage 1 completes, the JEPA encoder weights are fine-tuned and used as a feature extractor for Stage 2. Stage 2 introduces quantization and waveform reconstruction.
Finite Scalar Quantization (FSQ)
FSQ provides efficient discrete tokenization without codebook learning [Mentzer et al., 2023]. Unlike VQVAE, which maintains learnable codebooks, FSQ uses fixed scalar quantization per dimension. Let z e ∈ R B × C × T be encoder features.
FSQ Formulation
Projection.
$$
$$
Quantization. For dimension d with level L d , define boundaries
$$
$$
$$
$$
The quantized value is z q [ d ] = q d ( z ′ e [ d ]).
Projection.
where:
· M = { t : m t = 0 } is the set of masked positions,
$$
$$

Figure 2: JEPA online encoder architecture. Input waveform passes through an initial Conv1D layer followed by 5 encoder blocks, each containing Conv1D with stride, SnakeBeta activation, residual blocks, and Gaussian Adaptive Attention gating. Features are projected through a bottleneck Conv1D layer and processed by 8 Conformer blocks (each with FNN, multi-head attention with 16 heads, depthwise convolution, and a second FNN) to produce the final representation z . The target encoder shares this architecture but is updated via exponential moving average rather than backpropagation.
· N mask = |M| , · C is the channel dimension, · sg( · ) denotes the stop-gradient operation.
The loss is computed only on masked regions by weighting squared differences and normalized by the number of masked tokens times channels.
Quantization.
FSQ provides efficient discrete tokenization without codebook learning [Mentzer et al., 2023]. Unlike VQVAE, which maintains learnable codebooks, FSQ uses fixed scalar quantization per dimension. Let z e ∈ R B × C × T be encoder features.
Configuration.
· Levels: L = [4 , 4 , 4 , 4]. · Code dimension: C = 128. · Temperature: τ = 1 . 0.

Figure 4: Stage 1 JEPA masked prediction loss (MSE) over training steps. JEPA+DAAM (blue) converges faster and to a lower final loss ( ∼ 0 . 09) compared to JEPA without DAAM (orange, ∼ 0 . 17), demonstrating that Density Adaptive Attention enables more efficient representation learning. Both models use identical architectures except for DAAM gating.
Straight-through estimator. During backpropagation,
$$
$$
Straight-through estimator.
Mixed-Radix Token Packing
To maximize compression efficiency, we implement a mixed-radix packing algorithm that converts FSQ indices into compact integer tokens [Simon, 2024].
Let i ∈ Z B × T × D denote FSQ indices, with dimension-specific radices r = [ r 1 , . . . , r G ] for a group of G dimensions.
Mixed-Radix Encoding
Any combination [ i 1 , . . . , i G ] is encoded as
$$
$$
Example. For G = 7 and r = [4 , 4 , 4 , 4 , 4 , 4 , 4] with i = [2 , 1 , 3 , 0 , 2 , 1 , 3]:
$$
$$
$$
$$
with maximum value 4 7 -1 = 16383.
Example.
Efficient Iterative Computation
Using Horner's method [Knuth, 1997]:
$$
$$
implemented right-to-left:
Table 1: Comparison of tokenization approaches.
- Initialize token = i G .
- For k = G -1 down to 1:
Padding and Grouping
Our FSQ implementation yields D = 128 quantized dimensions. We choose group size G = 7:
· Number of groups: ⌈ 128 / 7 ⌉ = 19. · Padding: 19 × 7 -128 = 5 dimensions with radix 1.
Token rate.
Groups per frame: 19. Tokens/sec:
$$
$$
$$
$$
Token rate.
· Levels: L = [4 , 4 , 4 , 4]. · Code dimension: C = 128. · Temperature: τ = 1 . 0.

Figure 4: Stage 1 JEPA masked prediction loss (MSE) over training steps. JEPA+DAAM (blue) converges faster and to a lower final loss ( ∼ 0 . 09) compared to JEPA without DAAM (orange, ∼ 0 . 17), demonstrating that Density Adaptive Attention enables more efficient representation learning. Both models use identical architectures except for DAAM gating.
Straight-through estimator. During backpropagation,
$$
$$
Decoding
The reverse operation extracts indices:
- Initialize rem = token.
- For k = 1 to G : · prod = ∏ G j = k +1 r j . · i k = ⌊ rem / prod ⌋ . · rem = rem mod prod.
Comparison to Alternatives
Advantages:
· Perfect reversibility via modular arithmetic. · Near-optimal compression for given radices. · No learned codebook (unlike VQ-VAE). · Flexible grouping G trading vocabulary size versus token rate. · Integer-only operations, hardware-friendly.
With G = 7 and radix 4, the per-token vocabulary is 4 7 = 16384, comparable to subword vocabularies used in NLP.
$$
$$
Table 2: Frame rate comparison with state-of-the-art neural codecs.
Integration with Language Models
The compact tokens enable direct training of decoder-only Transformers for speech generation:
· Input: discrete token sequence at 47.5 tokens/sec. · Output: next-token prediction over a 16 384-way vocabulary. · Decoding: tokens → FSQ indices → dequantized features → waveform via HiFi-GAN.
Frame Rate Comparison with Neural Codecs
HiFi-GAN Decoder
The decoder upsamples quantized representations back to waveform using HiFi-GAN with DAAM gating in residual blocks [Kong et al., 2020].
Decoder Architecture
Quantized features [ B, 512 , T z ] are upsampled via ConvTranspose1D blocks through channel dimensions
$$
$$
with strides 6 , 5 , 5 , 8 , 8 (total stride 9600), yielding output waveform [ B, 1 , T wav Each block consists of:
· Upsampling ConvTranspose1D. · Multi-receptive-field (MRF) residual blocks with (optionally) DAAM gating.
ResBlock with DAAM.
- Leaky ReLU activation.
- Dilated convolution.
- Residual connection.
Decoder hyperparameters.
· Upsample kernels: [3 , 7 , 11 , 15 , 23 , 32]. · Residual blocks: 8 per stage.
Stage 2 Training Objective
Stage 2 optimizes the FSQ quantizer, HiFi-GAN decoder, and JEPA encoder.
].
![Figure 5: HiFi-GAN decoder architecture (Stage 2). Quantized features z q are upsampled through a bottleneck Conv1D followed by 5 decoder blocks. Each block contains ConvTranspose1D upsampling and MRF residual blocks with different kernel sizes (3, 7, 11, 15, 23, 32) to capture multi-scale temporal patterns. SnakeBeta activations provide periodic inductive bias for high-fidelity audio generation [Ziyin et al., 2020].](2512.07168-figure_004.png)
Figure 5: HiFi-GAN decoder architecture (Stage 2). Quantized features z q are upsampled through a bottleneck Conv1D followed by 5 decoder blocks. Each block contains ConvTranspose1D upsampling and MRF residual blocks with different kernel sizes (3, 7, 11, 15, 23, 32) to capture multi-scale temporal patterns. SnakeBeta activations provide periodic inductive bias for high-fidelity audio generation [Ziyin et al., 2020].
Total Loss
- Reconstruction loss (L1).
with spectral convergence and log-magnitude loss
$$
$$
$$
$$
- Multi-resolution STFT loss [Yamamoto et al., 2020].
$$
$$
$$
$$
$$
$$
1. Reconstruction loss (L1).
· FFT sizes: [2048, 1024, 512, 256, 128]. · Hop sizes: [512, 256, 128, 64, 32]. · Window: Hann. 3. GAN loss. We use multi-period and multi-scale discriminators [Kumar et al., 2019]. Generator loss:
$$
$$
Feature matching:
GAN total:
Discriminator loss:
$$
$$
$$
$$
$$
$$
2. Multi-resolution STFT loss citep{Yamamoto2020ParallelWaveGAN
STFT configurations.
· FFT sizes: [2048, 1024, 512, 256, 128]. · Hop sizes: [512, 256, 128, 64, 32]. · Window: Hann. 3. GAN loss. We use multi-period and multi-scale discriminators [Kumar et al., 2019]. Generator loss:
$$
$$
Feature matching:
GAN total:
Discriminator loss:
$$
$$
$$
$$
$$
$$
3. GAN loss.
- Reconstruction loss (L1).
with spectral convergence and log-magnitude loss
$$
$$
$$
$$
- Multi-resolution STFT loss [Yamamoto et al., 2020].
$$
$$
$$
$$
$$
$$
Loss weights and training schedule.
· λ stft = 2 . 0. · λ gan = 0 . 1. · Discriminator warmup: 5000 steps (disc frozen). · After warmup: discriminator updated every step.
Stage 1 hyperparameters.
• Optimizer: AdamW, β 1 = 0 . 8, β 2 = 0 .
• Learning rate: 1 . 5 × 10 - 4 (decoder), 0 . 75 × 10 - 4
(discriminators).
• Weight decay: 10 - 3 . • Batch size: 8. • Training steps: 29 000.
Experimental Setup
Dataset
· LibriLight (large-scale unlabeled English speech corpus) [Kahn et al., 2020]. · Training split: ∼ 9000 hours (combined across the two stages). · Validation: held-out speakers. · Sample rate: 24 kHz. · Max audio length: 15 s.
Data Preprocessing
- Resample to 24 kHz if needed.
- Convert to mono by averaging channels.
- No further preprocessing (normalization handled in-model).
Table 3: Model architecture and parameter efficiency.
Distributed Training
• Hardware: 2x NVIDIA A100 (80GB). • Mixed precision: FP16 for forward/backward, FP32 for critical ops.
- •
- Gradient accumulation: 1 step. • Global batch size:
64 (Stage 1), 16 (Stage 2).
Inference Pipeline
At inference time:
- Raw waveform → JEPA encoder → latent features.
- Latent features → FSQ quantization → discrete tokens.
- Tokens → dequantization → quantized features.
- Quantized features → HiFi-GAN decoder → reconstructed waveform.
Token rate: 47.5 tokens/sec (with G
= 7 packing).
Model Architecture and Efficiency
Parameter Counts
Training Efficiency
Key features:
· Two-stage training: self-supervised pretraining + supervised fine-tuning. · Inference efficiency: 191M parameters (no EMA network).
Table 4: Training efficiency of the two stages.
Evaluation Metrics
We report qualitative evaluations, as all variants were trained under limited computational budgets and this work presents preliminary findings.
Baselines.
- JEPA baseline: JEPA encoder without DAAM gating.
- WavLM-Large [Chen et al., 2021]: pre-trained self-supervised model.
- JEPA+DAAM: JEPA encoder with DAAM gating (ours).
Discussion
Why DAAM Improves JEPA Representations
Integrating Density Adaptive Attention into JEPA provides several advantages.
Comparison to standard attention. Standard softmax-based self-attention computes pairwise correlations between positions, answering 'Which timesteps are similar to this one?' DAAM instead computes statistical salience: 'Which timesteps have unusual or informative statistical properties?' via Gaussian mixture modeling of temporal statistics.
Because it operates on temporal statistics rather than full pairwise similarity matrices, DAAM can capture salient temporal patterns without the quadratic complexity of full self-attention.
Comparison to standard attention.
Advantages:
· Perfect reversibility via modular arithmetic. · Near-optimal compression for given radices. · No learned codebook (unlike VQ-VAE). · Flexible grouping G trading vocabulary size versus token rate. · Integer-only operations, hardware-friendly.
With G = 7 and radix 4, the per-token vocabulary is 4 7 = 16384, comparable to subword vocabularies used in NLP.
$$
$$
Table 2: Frame rate comparison with state-of-the-art neural codecs.
Limitations and Future Work
Current limitations and directions for future work include:
- Fixed masking strategy. Block masking with fixed span distributions may not adapt optimally to varying speech rates or linguistic structure. Future work includes adaptive masking sensitive to acoustic or linguistic boundaries.
- Monolingual evaluation. Experiments are currently limited to English (LibriLight). Generalization to tonal and morphologically rich languages remains open.
- Limited data scale. Pretraining has been conducted on relatively modest amounts of data compared to large-scale SSL systems; conclusions are restricted to emerging capabilities.
- Cross-modal JEPA. Extending to audio-visual or audio-text joint embedding prediction for multimodal representations is a promising direction.
Code Availability
The complete implementation of the JEPA+DAAM framework, including training scripts, model architectures, and data processing pipelines, is available at:
https://github.com/gioannides/Density-Adaptive-JEPA
The repository includes:
· Stage 1 JEPA encoder training with DAAM. · Stage 2 decoder training with the encoder. · FSQ quantization and mixed-radix packing algorithms. · HiFi-GAN decoder with optional DAAM gating. · DeepSpeed integration for distributed training.
Conclusion
We introduce a two-stage self-supervised framework that combines the Joint-Embedding Predictive Architecture (JEPA) with a Density Adaptive Attention Mechanism (DAAM) for learning robust speech representations. Stage 1 uses JEPA with DAAM to learn semantic audio features via masked prediction in latent space, fully decoupled from waveform reconstruction. Stage 2 leverages these representations for efficient tokenization using Finite Scalar Quantization (FSQ) and a mixed-radix packing scheme, followed by high-fidelity waveform reconstruction with a HiFi-GAN decoder. By integrating Gaussian mixture-based density-adaptive gating into the JEPA encoder, the model performs adaptive temporal feature selection and discovers hierarchical speech structure at a low frame rate of 2.5 Hz. The resulting tokens (47.5 tokens/sec) provide a reversible, highly compressed, and language-model-friendly representation that is competitive with, and often more efficient than, existing neural audio codecs.
Key innovation. By integrating Density Adaptive Attention-based gating (Gaussian Mixture gating) (Ioannides2024DAAM) into the JEPA encoder, we achieve adaptive feature selection during self-supervised learning. Combined with a mixed-radix packing scheme, the learned representations capture hierarchical speech structure—due to progressive downsampling from layer to layer—at a low frame rate of 2.5 Hz, enabling efficient speech modeling without labeled data.
Traditional speech codec training couples representation learning with reconstruction objectives, forcing the encoder to prioritize features that minimize waveform-level losses. This conflates two distinct goals:
Learning semantically meaningful representations that capture linguistic and acoustic structure.
Preserving perceptual quality for high-fidelity reconstruction.
JEPA addresses this by separating concerns: the encoder learns to predict masked representations in latent space (Stage 1), then a separate decoder learns to map these representations to audio (Stage 2). This architectural separation enables:
Better representations: the encoder optimizes for semantic content rather than low-level waveform details.
Efficiency: fine-tuning the encoder reduces Stage 2 training cost.
Flexibility: the same encoder can support multiple downstream tasks (text-to-speech, voice conversion, automatic speech recognition, etc.).
Scalability: Stage 1 can leverage large unlabeled datasets.
The integration of DAAM enhances this framework by introducing adaptive attention that learns which temporal regions and features are most informative for prediction, naturally discovering speech-relevant patterns.
The JEPA framework employs block-based temporal masking to create a self-supervised learning objective. For a batch of audio sequences with temporal length TT, binary masks 𝐦∈{0,1}B×T\mathbf{m}\in{0,1}^{B\times T} are generated, where 11 indicates visible (context) regions and 0 indicates masked (target) regions.
Given mask ratio ρ∈[0,1]\rho\in[0,1], minimum span length smins_{\text{min}}, and maximum span length smaxs_{\text{max}}, we construct masks as follows:
Initialize: 𝐦←𝟏B×T\mathbf{m}\leftarrow\mathbf{1}_{B\times T} (all positions visible).
For each sample b∈{1,…,B}b\in{1,\ldots,B}:
Compute target: nmask=⌊ρ⋅T⌋n_{\text{mask}}=\lfloor\rho\cdot T\rfloor.
While nmasked<nmaskn_{\text{masked}}<n_{\text{mask}}:
Sample span length: ℓ∼Uniform(smin,smax)\ell\sim\text{Uniform}(s_{\text{min}},s_{\text{max}}).
Compute end position: tend←min(tstart+ℓ,T)t_{\text{end}}\leftarrow\min(t_{\text{start}}+\ell,T).
Set mask: 𝐦[b,t]←0\mathbf{m}[b,t]\leftarrow 0 for all t∈[tstart,tend)t\in[t_{\text{start}},t_{\text{end}}).
Return: mask tensor 𝐦\mathbf{m}.
This block masking strategy creates contiguous masked spans rather than random individual positions, forcing the model to learn longer-range temporal dependencies and semantic content.
Mask ratio: ρ=0.5\rho=0.5 (50% of timesteps masked).
Minimum span: smin=2s_{\text{min}}=2 frames.
At 2.5 Hz frame rate, this corresponds to variable spans adapted to the sequence length.
The core innovation integrating a stabilized version of the original DAAM into JEPA is the DensityAdaptiveAttention module, which computes adaptive attention gates based on learned Gaussian mixture distributions. Unlike standard self-attention that computes pairwise dot-products between positions, DAAM learns to identify statistically salient temporal regions based on their distributional characteristics.
For input features 𝐱∈ℝB×C×T\mathbf{x}\in\mathbb{R}^{B\times C\times T} (batch size, channels, time), the DAAM module operates along the temporal axis.
For each batch and channel, compute the mean and variance across time:
For KK Gaussian components, we maintain learnable parameters:
Mean offsets: 𝜹=[δ1,…,δK]∈ℝK\bm{\delta}=[\delta_{1},\ldots,\delta_{K}]\in\mathbb{R}^{K}, initialized to δk=0\delta_{k}=0.
The positive scales are computed via softplus:
with ϵ=10−3\epsilon=10^{-3} for numerical stability.
The log-probability density at each timestep is:
To form a mixture of Gaussians:
The final attention gate is
and the output features are
where ⊙\odot denotes element-wise multiplication.
DAAM operates on a learned 1-channel attention projection over time: features are first projected to a single channel, the Gaussian mixture gate is computed on that 1D temporal signal, and the resulting gate scales the full feature tensor.
Variance clamped: var≥10−6\text{var}\geq 10^{-6}.
Softplus ensures positive scales: σ~k>0\tilde{\sigma}_{k}>0.
Number of Gaussians: K=4K=4 across all layers.
The JEPA encoder consists of two parallel pathways that share weights but serve different roles.
Processes the full audio input. Masking is applied later in feature space by replacing hidden timesteps with a learned mask token before the predictor. Parameters are updated via gradient descent.
The input raw waveform [B,1,Twav][B,1,T_{\text{wav}}] passes through Conv1D blocks with stride, progressing through channel dimensions
The total stride is 8×8×5×5×6=96008\times 8\times 5\times 5\times 6=9600 samples/hop at 24 kHz, resulting in a latent representation [B,512,Tz][B,512,T_{z}], where TzT_{z} corresponds to approximately 2.5 Hz frame rate.
We use 8 Conformer layers with 16 attention heads. Each layer comprises self-attention, feedforward, convolution, and layer normalization. DAAM gating is applied in the encoder blocks (after the strided convolutions and residual stacks); there is no DAAM after the Conformer blocks in the current implementation.
After each Conformer block, features pass through GAttnGateG modules that:
Project features to a single channel via 1×11\times 1 convolution.
Compute a DAAM gate from projected features.
Apply learned scaling
where α\alpha (initialized to 0.050.05) controls modulation strength.
The predictor takes context representations and predicts masked regions. It uses two Conformer blocks with 16 attention heads, processing masked context features and outputting predictions for all temporal positions. The predictor only receives context (visible) regions but must predict features at all positions; the mask is applied to the loss.
The JEPA training objective is pure self-supervised prediction in latent space.
ℳ={t:mt=0}\mathcal{M}={t:m_{t}=0} is the set of masked positions,
Nmask=|ℳ|N_{\text{mask}}=|\mathcal{M}|,
The loss is computed only on masked regions by weighting squared differences and normalized by the number of masked tokens times channels.
After each training step, the target encoder parameters are updated via EMA:
with momentum coefficient τ=0.996\tau=0.996.
Optimizer: AdamW with β1=0.8\beta_{1}=0.8, β2=0.99\beta_{2}=0.99.
Learning rate: 1.5×10−41.5\times 10^{-4}.
Weight decay: 10−310^{-3}.
Batch size: 32.
Max audio length: 15 s @ 24 kHz.
Training steps: 24 000.
We monitor (without backpropagation) the standard deviation of predictor outputs across batch and temporal dimensions. If the mean standard deviation falls below 0.010.01, a warning is logged. This does not contribute to the loss.
After Stage 1 completes, the JEPA encoder weights are fine-tuned and used as a feature extractor for Stage 2. Stage 2 introduces quantization and waveform reconstruction.
FSQ provides efficient discrete tokenization without codebook learning (Mentzer2023FSQ). Unlike VQ-VAE, which maintains learnable codebooks, FSQ uses fixed scalar quantization per dimension.
Let 𝐳e∈ℝB×C×T\mathbf{z}_{e}\in\mathbb{R}^{B\times C\times T} be encoder features.
For dimension dd with level LdL_{d}, define boundaries
The quantized value is 𝐳q[d]=qd(𝐳e′[d])\mathbf{z}{q}[d]=q{d}(\mathbf{z}_{e}^{\prime}[d]).
Levels: 𝐋=[4,4,4,4]\mathbf{L}=[4,4,4,4].
Code dimension: C=128C=128.
During backpropagation,
To maximize compression efficiency, we implement a mixed-radix packing algorithm that converts FSQ indices into compact integer tokens (Simon2024MixedRadixArxiv).
Let 𝐢∈ℤB×T×D\mathbf{i}\in\mathbb{Z}^{B\times T\times D} denote FSQ indices, with dimension-specific radices 𝐫=[r1,…,rG]\mathbf{r}=[r_{1},\ldots,r_{G}] for a group of GG dimensions.
Any combination [i1,…,iG][i_{1},\ldots,i_{G}] is encoded as
For G=7G=7 and 𝐫=[4,4,4,4,4,4,4]\mathbf{r}=[4,4,4,4,4,4,4] with 𝐢=[2,1,3,0,2,1,3]\mathbf{i}=[2,1,3,0,2,1,3]:
with maximum value 47−1=163834^{7}-1=16383.
Using Horner’s method (MixedRadixKnuth1997):
implemented right-to-left:
Initialize token=iG\text{token}=i_{G}.
For k=G−1k=G-1 down to 11:
Our FSQ implementation yields D=128D=128 quantized dimensions. We choose group size G=7G=7:
Number of groups: ⌈128/7⌉=19\lceil 128/7\rceil=19.
Padding: 19×7−128=519\times 7-128=5 dimensions with radix 1.
Frame rate:
Groups per frame: 1919. Tokens/sec:
prod=∏j=k+1Grj\text{prod}=\prod_{j=k+1}^{G}r_{j}.
ik=⌊rem/prod⌋i_{k}=\left\lfloor\text{rem}/\text{prod}\right\rfloor.
rem=remmodprod\text{rem}=\text{rem}\bmod\text{prod}.
Advantages:
Perfect reversibility via modular arithmetic.
Near-optimal compression for given radices.
No learned codebook (unlike VQ-VAE).
Flexible grouping GG trading vocabulary size versus token rate.
Integer-only operations, hardware-friendly.
With G=7G=7 and radix 4, the per-token vocabulary is 47=163844^{7}=16384, comparable to subword vocabularies used in NLP.
The compact tokens enable direct training of decoder-only Transformers for speech generation:
Input: discrete token sequence at 47.5 tokens/sec.
Output: next-token prediction over a 16 384-way vocabulary.
Decoding: tokens →\rightarrow FSQ indices →\rightarrow dequantized features →\rightarrow waveform via HiFi-GAN.
The decoder upsamples quantized representations back to waveform using HiFi-GAN with DAAM gating in residual blocks (Kong2020HiFiGAN).
with strides 6,5,5,8,86,5,5,8,8 (total stride 96009600), yielding output waveform [B,1,Twav][B,1,T_{\text{wav}}].
Each block consists of:
Our main contributions are:
Multi-receptive-field (MRF) residual blocks with (optionally) DAAM gating.
Leaky ReLU activation.
Upsample kernels: [3,7,11,15,23,32][3,7,11,15,23,32].
Residual blocks: 8 per stage.
Stage 2 optimizes the FSQ quantizer, HiFi-GAN decoder, and JEPA encoder.
with spectral convergence
and log-magnitude loss
FFT sizes: [2048, 1024, 512, 256, 128].
Window: Hann.
We use multi-period and multi-scale discriminators (Kumar2019MelGAN).
Generator loss:
λstft=2.0\lambda_{\text{stft}}=2.0.
Discriminator warmup: 5000 steps (disc frozen).
LibriLight (large-scale unlabeled English speech corpus) (Kahn2020LibriLight).
Training split: ∼9000\sim 9000 hours (combined across the two stages).
Resample to 24 kHz if needed.
Convert to mono by averaging channels.
No further preprocessing (normalization handled in-model).
Hardware: 2x NVIDIA A100 (80 GB).
Mixed precision: FP16 for forward/backward, FP32 for critical ops.
Global batch size: 64 (Stage 1), 16 (Stage 2).
At inference time:
Latent features →\rightarrow FSQ quantization →\rightarrow discrete tokens.
Inference efficiency: 191M parameters (no EMA network).
We report qualitative evaluations, as all variants were trained under limited computational budgets and this work presents preliminary findings.
JEPA baseline: JEPA encoder without DAAM gating.
WavLM-Large (Chen2021WavLM): pre-trained self-supervised model.
Integrating Density Adaptive Attention into JEPA provides several advantages.
Standard softmax-based self-attention computes pairwise correlations between positions, answering “Which timesteps are similar to this one?” DAAM instead computes statistical salience: “Which timesteps have unusual or informative statistical properties?” via Gaussian mixture modeling of temporal statistics.
Because it operates on temporal statistics rather than full pairwise similarity matrices, DAAM can capture salient temporal patterns without the quadratic complexity of full self-attention.
Current limitations and directions for future work include:
Fixed masking strategy. Block masking with fixed span distributions may not adapt optimally to varying speech rates or linguistic structure. Future work includes adaptive masking sensitive to acoustic or linguistic boundaries.
Monolingual evaluation. Experiments are currently limited to English (LibriLight). Generalization to tonal and morphologically rich languages remains open.
Limited data scale. Pretraining has been conducted on relatively modest amounts of data compared to large-scale SSL systems; conclusions are restricted to emerging capabilities.
Cross-modal JEPA. Extending to audio–visual or audio–text joint embedding prediction for multimodal representations is a promising direction.
The complete implementation of the JEPA+DAAM framework, including training scripts, model architectures, and data processing pipelines, is available at:
https://github.com/gioannides/Density-Adaptive-JEPA
FSQ quantization and mixed-radix packing algorithms.
DeepSpeed integration for distributed training.
We introduced a two-stage self-supervised framework combining Joint-Embedding Predictive Architecture (JEPA) with Density Adaptive Attention Mechanisms (DAAM) for efficient speech representation learning. Stage 1 trains a JEPA encoder with DAAM-based gating to learn robust semantic representations via masked prediction using only MSE loss on masked regions. Stage 2 leverages these representations for reconstruction using L1 loss, multi-resolution STFT loss, and adversarial GAN losses, together with FSQ and HiFi-GAN.
A DAAM-enhanced JEPA encoder that uses Gaussian mixture-based attention for adaptive feature selection during self-supervised learning.
An efficient tokenization scheme based on mixed-radix FSQ packing, achieving 47.5 tokens/sec, substantially lower than many existing neural audio codecs while remaining reversible.
These results show that probabilistic attention mechanisms can improve representation learning by dynamically identifying acoustically salient regions during masked prediction, and that JEPA can serve as a powerful neural tokenizer for speech, suitable for integration with large language models and other sequence models.
Table: S3.T1: Comparison of tokenization approaches.
| Approach | Tokens/sec | Reversible | Notes |
|---|---|---|---|
| No packing (128 dims) | 320 | Yes | Each FSQ dim is a token |
| Mixed-radix (ours, G=7G=7) | 47.5 | Yes | Pack 7 dims/token |
| VQ codebook | Variable | Yes | Requires learned codebook |
Table: S3.T2: Frame rate comparison with state-of-the-art neural codecs.
| Model | Frame Rate | Notes |
|---|---|---|
| Ours (JEPA+FSQ) | 2.5 Hz | Mixed-radix packing (19 groups/frame) |
| U-Codec (Yang2025UCodec) | 5 Hz | Ultra-low for LLM-TTS |
| Mimi (LlamaMimi2025) | 12.5 Hz | Semantic distillation |
| DualCodec (Li2025DualCodec) | 12.5–25 Hz | Dual-stream architecture |
| SoundStream (24 kHz) (Zeghidour2021SoundStream) | 75 Hz | 13.3 ms frames |
| EnCodec (24 kHz) (Defossez2022EnCodec) | 75 Hz | 75 steps/sec @ 24 kHz |
| DAC (44.1 kHz) (DACJAX2024TokenRate) | 86 Hz | Stride 512 @ 44.1 kHz |
Table: S5.T3: Model architecture and parameter efficiency.
| Component | Parameters | Notes |
|---|---|---|
| Stage 1: JEPA encoder training | ||
| Online encoder | 121.7M | Trainable |
| Target encoder (EMA) | 118.5M | Momentum update |
| Predictor network | 3.2M | Trainable |
| Stage 1 total | 240.2M | 121.7M trainable |
| Stage 2: decoder training | ||
| JEPA encoder | 240.2M | Fine-tuned |
| FSQ quantizer | ∼0.01\sim 0.01M | Trainable |
| HiFi-GAN decoder | 69.2M | Trainable |
| Stage 2 total | 309.5M | 69.3M trainable |
| Final model (inference) | ||
| Encoder only | 121.7M | Online encoder only |
| FSQ + decoder | 69.3M | |
| Inference total | 191.0M | Single-pass model |
Table: S5.T4: Training efficiency of the two stages.
| Metric | Stage 1 (JEPA) | Stage 2 (decoder) |
| Trainable parameters | 121.7M (50.7%) | 69.3M (22.4%) |
| Training steps | 24K | 29K |
| Batch size | 32 | 8 |
| Learning rate | 1.5×10−41.5\times 10^{-4} | 1.5×10−41.5\times 10^{-4} |
The input waveform is processed by three parallel pathways: (1) an online encoder (trainable, green) that processes the full audio and feeds into a predictor network (yellow) after feature-space masking with a learned mask token, (2) a target encoder (purple) updated via EMA that also processes the full audio to generate 𝐳target\mathbf{z}{\text{target}}, and (3) a masking strategy module (blue) that generates binary masks. The MSE loss is computed only on masked regions between 𝐳predicted\mathbf{z}{\text{predicted}} and 𝐳target\mathbf{z}_{\text{target}} (stop-gradient), with gradients backpropagating only through the online encoder and predictor. The target encoder provides stable representations without receiving gradients directly (Grill2020BYOL).
JEPA online encoder architecture. Input waveform passes through an initial Conv1D layer followed by 5 encoder blocks, each containing Conv1D with stride, SnakeBeta activation, residual blocks, and Gaussian Adaptive Attention gating. Features are projected through a bottleneck Conv1D layer and processed by 8 Conformer blocks (each with FNN, multi-head attention with 16 heads, depthwise convolution, and a second FNN) to produce the final representation 𝐳\mathbf{z}. The target encoder shares this architecture but is updated via exponential moving average rather than backpropagation.
JEPA predictor network architecture. The predictor takes masked context features 𝐳masked\mathbf{z}{\text{masked}} and processes them through: (1) an expansion Conv1D layer that doubles the channel dimension, (2) two Conformer blocks separated by an intermediate Conv1D for feature refinement, and (3) a projection Conv1D that reduces back to the original dimensionality, producing predicted features 𝐳pred\mathbf{z}{\text{pred}} at all positions including masked regions.
Stage 1 JEPA masked prediction loss (MSE) over training steps. JEPA+DAAM (blue) converges faster and to a lower final loss (∼0.09\sim 0.09) compared to JEPA without DAAM (orange, ∼0.17\sim 0.17), demonstrating that Density Adaptive Attention enables more efficient representation learning. Both models use identical architectures except for DAAM gating.
HiFi-GAN decoder architecture (Stage 2). Quantized features 𝐳q\mathbf{z}_{q} are upsampled through a bottleneck Conv1D followed by 5 decoder blocks. Each block contains ConvTranspose1D upsampling and MRF residual blocks with different kernel sizes (3, 7, 11, 15, 23, 32) to capture multi-scale temporal patterns. SnakeBeta activations provide periodic inductive bias for high-fidelity audio generation (Ziyin2020Snake).
$$ \tilde{\sigma}_k = \text{softplus}(\nu_k) + \epsilon = \log(1 + \exp(\nu_k)) + \epsilon, $$
$$ z_{k,t} = \frac{x_{:,:,t} - (\mu + \delta_k)}{\sigma \cdot \tilde{\sigma}_k + \epsilon}. $$
$$ \log p_k(x_t) = -\frac{1}{2}z_{k,t}^2 - \log \tilde{\sigma}_k - \frac{1}{2}\log(2\pi). $$
$$ \log \mathbf{G}(x_t) = \text{logsumexp}({\log p_1(x_t), \ldots, \log p_K(x_t)}) - \log K. $$
$$ \mathbf{G}(x_{t})=\exp(\log\mathbf{G}(x_{t})), $$ \tag{S2.E7}
$$ \mathbf{y}_t = \mathbf{x}_t \odot \mathbf{G}(x_t), $$
$$ 64 \rightarrow 128 \rightarrow 256 \rightarrow 384 \rightarrow 512 \rightarrow 512. $$
$$ \mathcal{L}{\text{JEPA}} = \frac{1}{N{\text{mask}} \cdot C} \sum_{t \in \mathcal{M}} \left| \mathbf{z}{\text{pred}}^{(t)} - \text{sg}(\mathbf{z}{\text{target}}^{(t)}) \right|^2, $$
$$ \boldsymbol{\theta}{\text{target}} \leftarrow \tau \boldsymbol{\theta}{\text{target}}
- (1-\tau)\boldsymbol{\theta}_{\text{online}}, $$
$$ \mathbf{z}_e' = \tanh(\mathbf{z}_e). $$
$$ B_d = \left{ \frac{2i - L_d + 1}{L_d} : i \in {0,1,\ldots,L_d-1} \right}. $$
$$ q_d(x) = \arg\min_{b \in B_d} |x - b|. $$
$$ \frac{\partial \mathcal{L}}{\partial \mathbf{z}_e}
\frac{\partial \mathcal{L}}{\partial \mathbf{z}_q}. $$
$$ \text{token} = \sum_{k=1}^{G} i_k \prod_{j=k+1}^{G} r_j. $$
$$ \text{token} = i_1 \cdot r_2 \cdots r_G + \cdots + i_{G-1}\cdot r_G + i_G, $$
$$ f = \frac{\text{sample_rate}}{\text{hop}} = \frac{24000}{9600} = 2.5~\text{Hz}. $$
$$ \text{tps} = 2.5 \times 19 = 47.5. $$
$$ 512 \rightarrow 384 \rightarrow 256 \rightarrow 128 \rightarrow 64, $$
$$ \mathcal{L}{\text{total}} = \mathcal{L}{\text{rec}} + \lambda_{\text{stft}}\mathcal{L}{\text{stft}} + \lambda{\text{gan}}\mathcal{L}_{\text{gan}}. $$
$$ \mathcal{L}{\text{rec}} = \frac{1}{T{\text{wav}}} \sum_{t=1}^{T_{\text{wav}}} |\hat{x}_t - x_t|. $$
$$ \mathcal{L}_{\text{sc}}^{(m)} = \frac{\left| |S_m(\hat{x})| - |S_m(x)| \right|_F} {\left| |S_m(x)| \right|_F}, $$
$$ \mathcal{L}{\text{gen}} = \sum{d \in {\text{MPD}, \text{MSD}}} \mathbb{E}[(D_d(\hat{x}) - 1)^2]. $$
$$ \mathcal{L}{\text{gan}} = \mathcal{L}{\text{gen}} + \mathcal{L}_{\text{feat}}. $$
$$ \displaystyle\mu $$
$$ \text{token} &= 2 \cdot 4^6 + 1 \cdot 4^5 + 3 \cdot 4^4 + 0 \cdot 4^3 + 2 \cdot 4^2 + 1 \cdot 4^1 + 3 \cdot 4^0 \ &= 10023, $$
- A DAAM-enhanced JEPA encoder that uses Gaussian mixture-based attention for adaptive feature selection during self-supervised learning.
- An efficient tokenization scheme based on mixed-radix FSQ packing, achieving 47.5 tokens/sec, substantially lower than many existing neural audio codecs while remaining reversible.
- A two-stage training paradigm that cleanly separates representation learning from reconstruction, allowing pure self-supervised pretraining followed by reconstruction-focused fine-tuning.
$$ \mathbf{y} = \mathbf{x} \cdot (1 + \alpha \cdot \text{gate}), $$
$$ \mu &= \frac{1}{T}\sum_{t=1}^T x_{:,:,t} \in \mathbb{R}^{B \times C \times 1}, \ \sigma^2 &= \frac{1}{T}\sum_{t=1}^T (x_{:,:,t} - \mu)^2 \in \mathbb{R}^{B \times C \times 1}. $$
$$ \text{token} = i_k + \text{token} \cdot r_k. $$
| 1 | Hybrid Discrete-Continuous Speech Representations via JEPA with Density | Hybrid Discrete-Continuous Speech Representations via JEPA with Density | Adaptive 2 |
|---|---|---|---|
| Attention | 1.1 | Overview . . . . . . . . . . . . . . | . . . . . . 2 |
| 1.2 | Motivation: Why JEPA for Speech? | . . . . . . 3 | |
| 2 | Stage 1: Self-Supervised JEPA Encoder with DAAM | Stage 1: Self-Supervised JEPA Encoder with DAAM | 3 |
| 2.1 | JEPA Masking Strategy . . . . . . | . . . . . . 3 | |
| 2.2 | Density Adaptive Attention for Temporal Feature Modulation | Density Adaptive Attention for Temporal Feature Modulation | . . . . . . 4 |
| 2.2.1 | Mathematical Formulation | . . . . . . 4 | |
| 2.3 | JEPA Encoder Architecture . . . . | . . . . . . 6 |
| 2.3.1 | Convolutional-Transformer Hybrid Design . . | 6 | |
|---|---|---|---|
| 2.4 | JEPA Predictor Network . | . . . . . . . . . . . . . . . | 6 |
| 2.5 | Stage 1 Training Objective | . . . . . . . . . . . . . . | 6 |
| 2.5.1 . | Loss Function . . . . . . . . . . . . . . . . . | 6 | |
| 2.5.2 | EMA Target Update . . . . . . . . . . . . . . | 7 | |
| 3 Stage 2: Fine-Tuning Encoder + FSQ Quantization + HiFi-GAN Decoder | 3 Stage 2: Fine-Tuning Encoder + FSQ Quantization + HiFi-GAN Decoder | 3 Stage 2: Fine-Tuning Encoder + FSQ Quantization + HiFi-GAN Decoder | 8 |
| 3.1 | Finite Scalar Quantization (FSQ) | . . . . . . . . . . . | 8 |
| 3.1.1 | FSQ Formulation . . . . . . . . . . . . . . . . | 8 | |
| 3.2 | Mixed-Radix Token Packing | . . . . . . . . . . . . . . | 9 |
| 3.2.1 | Mixed-Radix Encoding . . . . . . . . . . . . . | 9 | |
| 3.2.2 | Efficient Iterative Computation . . . . . . . . | 9 | |
| 3.2.3 | Padding and Grouping . . . . . . . . . . . . . | 10 | |
| 3.2.4 . . . | Decoding . . . . . . . . . . . . . . . . . | 10 | |
| 3.2.5 | Comparison to Alternatives . . . . . . . . . . | 10 | |
| 3.2.6 | Integration with Language Models . . . . . . | 11 | |
| 3.2.7 Frame Rate Comparison | with Neural Codecs | 11 | |
| 3.3 | HiFi-GAN Decoder . . . . . . . . . . . . . . | . . . . . | 11 |
| 3.3.1 Decoder | Architecture . . . . . . . . . . . . . | 11 | |
| 3.4 | Stage | 2 Training Objective . . . . . . . . . . . . . . | 11 |
| 3.4.1 . . | Total Loss . . . . . . . . . . . . . . . . . . | 12 | |
| 4 Experimental Setup | 4 Experimental Setup | 4 Experimental Setup | 13 |
| 4.1 | Dataset . . . . . . . . | . . . . . . . . . . . . . . . . . | 13 |
| 4.2 | Data Preprocessing . . . . . . | . . . . . . . . . . . . . | 13 |
| 4.3 | Distributed Training . | . . . . . . . . . . . . . . . . . | 14 |
| 4.4 | Inference Pipeline . . . | . . . . . . . . . . . . . . . . . | 14 |
| 5 Model Architecture and Efficiency | 5 Model Architecture and Efficiency | 5 Model Architecture and Efficiency | 14 |
| 5.1 | Parameter Counts . . | . . . . . . . . . . . . . . . . . | 14 |
| 5.2 | Training Efficiency . . | . . . . . . . . . . . . . . . . . | 14 |
| 6 | Evaluation Metrics | Evaluation Metrics | 15 |
| 7 Discussion | 7 Discussion | 7 Discussion | 15 |
| 7.1 Why | DAAM Improves | JEPA Representations . . . . | 15 |
| 8 Limitations and Future Work | 8 Limitations and Future Work | 8 Limitations and Future Work | 15 |
| 16 | |||
| 9 Code Availability 10 Conclusion | 9 Code Availability 10 Conclusion | 9 Code Availability 10 Conclusion |
| Approach | Tokens/sec | Reversible | Notes |
|---|---|---|---|
| No packing (128 dims) | 320 | Yes | Each FSQ dim is a token |
| Mixed-radix (ours, G = 7) | 47.5 | Yes | Pack 7 dims/token |
| VQ codebook | Variable | Yes | Requires learned codebook |
| Model | Frame Rate | Notes |
|---|---|---|
| Ours (JEPA+FSQ) U-Codec [Yang et al., 2025] Mimi [or Multiple, 2025] DualCodec [Li et al., 2025] SoundStream (24kHz) [Zeghidour et al., EnCodec (24kHz) [D´ efossez et al., 2022] DAC (44.1kHz) [Kumar et al., 2024] | 2.5 Hz 5 Hz | Mixed-radix packing (19 groups/frame) Ultra-low for LLM-TTS Semantic distillation Dual-stream architecture 13.3ms frames 75 steps/sec @ 24kHz |
| 12.5 Hz | ||
| 12.5-25 Hz | ||
| 2021] | 75 Hz | |
| 75 Hz | ||
| 86 Hz | Stride 512 @ 44.1kHz |
| Component | Parameters | Notes |
|---|---|---|
| Stage 1: JEPA encoder training | Stage 1: JEPA encoder training | |
| Online encoder | 121.7M | Trainable |
| Target encoder (EMA) | 118.5M | Momentum update |
| Predictor network | 3.2M | Trainable |
| Stage 1 total | 240.2M | 121.7M trainable |
| Stage 2: decoder training | Stage 2: decoder training | |
| JEPA encoder | 240.2M | Fine-tuned |
| FSQ quantizer | ∼ 0 . 01M | Trainable |
| HiFi-GAN decoder | 69.2M | Trainable |
| Stage 2 total | 309.5M | 69.3M trainable |
| Final model (inference) | Final model (inference) | |
| Encoder only | 121.7M | Online encoder only |
| FSQ + decoder | 69.3M | |
| Inference total | 191.0M | Single-pass model |
| Metric | Stage 1 (JEPA) | Stage 2 (decoder) |
|---|---|---|
| Trainable parameters | 121.7M (50.7%) | 69.3M (22.4%) |
| Training steps | 24K | 29K |
| Batch size | 32 | 8 |
| Learning rate | 1 . 5 × 10 - 4 | 1 . 5 × 10 - 4 |
References
[Assran2023IJEPA] Assran, Mahmoud, Caron, Mathilde, Misra, Ishan, Bojanowski, Piotr, Joulin, Armand, Mairal, Julien, Ballas, Nicolas, Rabbat, Mike, LeCun, Yann, Goyal, Priya. (2023). Self-Supervised Learning From Images With a Joint-Embedding Predictive Architecture. CVPR.
[Li2025DualCodec] Li, Jiaqi, Lin, Xiaolong, Li, Zhekai, Huang, Shixi, Wang, Yuancheng, Wang, Chaoren, Zhan, Zhenpeng, Wu, Zhizheng. (2025). DualCodec: A Low-Frame-Rate, Semantically-Enhanced Neural Audio Codec for Speech Generation. arXiv preprint arXiv:2505.13000.
[assran2023selfsupervisedlearningimagesjointembedding] Mahmoud Assran, Quentin Duval, Ishan Misra, Piotr Bojanowski, Pascal Vincent, Michael Rabbat, Yann LeCun, Nicolas Ballas. (2023). Self-Supervised Learning from Images with a Joint-Embedding Predictive Architecture.
[Huang2025LLMJEPA] Huang, Hai, LeCun, Yann, Balestriero, Randall. (2025). {LLM-JEPA. arXiv preprint arXiv:2509.14252.
[Mo2024CJEPA] Mo, Shanshan, Liu, Yaliang, Wang, Wei, others. (2024). Connecting Joint-Embedding Predictive Architecture with Contrastive Learning. arXiv preprint arXiv:2410.19560.
[Ioannides2024DAAM] Ioannides, Georgios, Chadha, Aman, Elkins, Aaron. (2024). Density Adaptive Attention is All You Need: Robust Parameter-Efficient Fine-Tuning Across Multiple Modalities. arXiv preprint arXiv:2401.11143.
[Ioannides2024DAAMAudio] Ioannides, Georgios, Kieback, Adrian, Chadha, Aman, Elkins, Aaron. (2024). Density Adaptive Attention-based Speech Network: Enhancing Feature Understanding for Mental Health Disorders. arXiv preprint arXiv:2409.00391.
[Mentzer2023FSQ] Mentzer, Fabian, Minnen, David, Agustsson, Eirikur, Tschannen, Michael. (2023). Finite Scalar Quantization: {VQ-VAE. arXiv preprint arXiv:2309.15505.
[Grill2020BYOL] Grill, Jean-Bastien, Strub, Florian, Altch{'e. (2020). Bootstrap Your Own Latent: A New Approach to Self-Supervised Learning. NeurIPS.
[Gulati2020Conformer] Gulati, Anmol, Qin, James, Chiu, Chung-Cheng, Parmar, Niki, Zhang, Yu, others. (2020). Conformer: Convolution-augmented Transformer for Speech Recognition. INTERSPEECH.
[Kong2020HiFiGAN] Kong, Jungil, Kim, Jaehyeon, Bae, Jaekyoung. (2020). {HiFi-GAN. arXiv preprint arXiv:2010.05646.
[Yamamoto2020ParallelWaveGAN] Yamamoto, Ryuichi, Song, Eunwoo, Kim, Jae-Min. (2020). Parallel WaveGAN: A fast waveform generation model based on generative adversarial networks with multi-resolution spectrogram. arXiv preprint arXiv:1910.11480.
[Kumar2019MelGAN] Kumar, Kundan, Kumar, Rithesh, de Boissiere, Thibault, others. (2019). MelGAN: Generative Adversarial Networks for Conditional Waveform Synthesis. arXiv preprint arXiv:1910.06711.
[Zeghidour2021SoundStream] Zeghidour, Neil, Luebs, Alejandro, Omran, Ahmed, Skoglund, Jan, Tagliasacchi, Marco. (2021). SoundStream: An End-to-End Neural Audio Codec. arXiv preprint arXiv:2107.03312.
[Defossez2022EnCodec] D{'e. (2022). High Fidelity Neural Audio Compression. arXiv preprint arXiv:2210.13438.
[DACJAX2024TokenRate] Kumar, Rithesh, others. (2024). {DAC-JAX. arXiv preprint arXiv:2405.11554.
[Chen2021WavLM] Chen, Sanyuan, Wang, Chengyi, Chen, Zhengyang, Wu, Yu, Liu, Shujie, others. (2021). {WavLM. arXiv preprint arXiv:2110.13900.
[Kahn2020LibriLight] Kahn, Jacob, Riviere, Morgane, Zheng, Weiran, Kharitonov, Eugene, others. (2020). Libri-Light: A Benchmark for ASR with Limited or No Supervision. arXiv preprint arXiv:1912.07875.
[Ziyin2020Snake] Ziyin, Liu, Hartwig, Tilman, Ueda, Masahito. (2020). Neural Networks Fail to Learn Periodic Functions and How to Fix It. NeurIPS.
[TSJEPA2025] Ennadir, Soukaina, others. (2025). Joint Embeddings Go Temporal. arXiv preprint arXiv:2509.25449.
[MixedRadixKnuth1997] Knuth, Donald E.. (1997). The Art of Computer Programming, Vol. 2: Seminumerical Algorithms (3rd ed.).
[Simon2024MixedRadixArxiv] Simon, Damien. (2024). Mixed radix numeration bases: Horner's rule, Yang-Baxter equation and Furstenberg's conjecture. arXiv preprint arXiv:2405.19798.
[Yang2025UCodec] Yang, Xuefei, others. (2025). Ultra Low Frame-rate Neural Speech Codec for Fast High-Fidelity Speech Synthesis. arXiv preprint arXiv:2510.16718.
[LlamaMimi2025] Anonymous or Multiple. (2025). Llama-Mimi: Speech Language Models with Interleaved Semantic and Acoustic Tokens. arXiv preprint arXiv:2509.14882.