Skip to content

Commit 405ff93

Browse files
authored
Fix: guard double-resume in AudioPlaybackManager playPcmData (#453)
* fix: guard double-resume in AudioPlaybackManager playPcmData * fix: make audioTrack volatile for cross-thread visibility * fix: fix AudioPlaybackManager bit depth rejection and stop() completion path * fix: address AudioPlaybackManager race condition and malformed WAV header loop
1 parent cfb7784 commit 405ff93

1 file changed

Lines changed: 94 additions & 71 deletions

File tree

sdk/runanywhere-kotlin/src/androidMain/kotlin/com/runanywhere/sdk/features/tts/AudioPlaybackManager.kt

Lines changed: 94 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,27 @@ import com.runanywhere.sdk.foundation.SDKLogger
77
import kotlinx.coroutines.Dispatchers
88
import kotlinx.coroutines.suspendCancellableCoroutine
99
import kotlinx.coroutines.withContext
10+
import java.util.concurrent.atomic.AtomicBoolean
11+
import java.util.concurrent.atomic.AtomicInteger
1012
import kotlin.coroutines.resume
1113
import kotlin.coroutines.resumeWithException
1214

1315
/**
1416
* Manages audio playback for TTS services on Android.
15-
* Plays WAV audio data (16-bit PCM format) generated by TTS synthesis.
17+
* Plays WAV audio data (8-bit or 16-bit PCM format) generated by TTS synthesis.
1618
*
1719
* Matches iOS AudioPlaybackManager behavior.
1820
*/
1921
class AudioPlaybackManager {
2022
private val logger = SDKLogger.tts
23+
private val playbackIdGenerator = AtomicInteger(0)
2124

25+
@Volatile
2226
private var audioTrack: AudioTrack? = null
2327

28+
@Volatile
29+
private var interruptPlayback: (() -> Unit)? = null
30+
2431
@Volatile
2532
var isPlaying: Boolean = false
2633
private set
@@ -36,72 +43,90 @@ class AudioPlaybackManager {
3643
throw AudioPlaybackException.EmptyAudioData
3744
}
3845

46+
val playbackId = playbackIdGenerator.incrementAndGet()
47+
logger.debug("[playbackId=$playbackId] play() start: totalBytes=${audioData.size}")
48+
3949
withContext(Dispatchers.IO) {
4050
try {
41-
// Parse WAV header to get audio parameters
4251
val wavInfo = parseWavHeader(audioData)
43-
logger.info("Playing audio: ${audioData.size} bytes, ${wavInfo.sampleRate}Hz, ${wavInfo.channels}ch")
52+
logger.info("[playbackId=$playbackId] Playing audio: ${audioData.size} bytes, ${wavInfo.sampleRate}Hz, ${wavInfo.channels}ch, ${wavInfo.bitsPerSample}bit")
4453

45-
// Get PCM data (skip WAV header)
4654
val pcmData = audioData.copyOfRange(wavInfo.dataOffset, audioData.size)
4755

48-
playPcmData(pcmData, wavInfo.sampleRate, wavInfo.channels, wavInfo.bitsPerSample)
56+
playPcmData(playbackId, pcmData, wavInfo.sampleRate, wavInfo.channels, wavInfo.bitsPerSample)
4957

50-
logger.info("Playback completed")
58+
logger.info("[playbackId=$playbackId] play() completed")
5159
} catch (e: Exception) {
52-
logger.error("Playback failed: ${e.message}")
60+
logger.error("[playbackId=$playbackId] play() failed: ${e.message}")
5361
throw if (e is AudioPlaybackException) e else AudioPlaybackException.PlaybackFailed(e.message)
5462
}
5563
}
5664
}
5765

5866
/**
59-
* Stop current playback.
67+
* Stop current playback. If play() is suspended, it will resume with PlaybackInterrupted.
6068
*/
6169
fun stop() {
62-
if (!isPlaying) return
63-
64-
try {
65-
audioTrack?.stop()
66-
audioTrack?.release()
67-
audioTrack = null
68-
} catch (e: Exception) {
69-
logger.error("Error stopping playback: ${e.message}")
70-
}
71-
72-
isPlaying = false
73-
logger.info("Playback stopped")
70+
logger.debug("stop() called: isPlaying=$isPlaying, hasTrack=${audioTrack != null}")
71+
interruptPlayback?.invoke()
7472
}
7573

7674
private suspend fun playPcmData(
75+
playbackId: Int,
7776
pcmData: ByteArray,
7877
sampleRate: Int,
7978
channels: Int,
8079
bitsPerSample: Int,
8180
) = suspendCancellableCoroutine { continuation ->
81+
val resumed = AtomicBoolean(false)
82+
83+
fun cleanup(track: AudioTrack?) {
84+
isPlaying = false
85+
try { track?.stop() } catch (_: Exception) {}
86+
try { track?.release() } catch (_: Exception) {}
87+
if (audioTrack === track) audioTrack = null
88+
interruptPlayback = null
89+
}
90+
91+
fun succeed(track: AudioTrack?) {
92+
val casWon = resumed.compareAndSet(false, true)
93+
logger.debug("[playbackId=$playbackId] completion path=success casWon=$casWon")
94+
if (casWon) {
95+
cleanup(track)
96+
continuation.resume(Unit)
97+
}
98+
}
99+
100+
fun fail(track: AudioTrack?, throwable: Throwable) {
101+
val casWon = resumed.compareAndSet(false, true)
102+
logger.debug("[playbackId=$playbackId] completion path=${throwable::class.simpleName} casWon=$casWon")
103+
if (casWon) {
104+
cleanup(track)
105+
continuation.resumeWithException(throwable)
106+
}
107+
}
108+
82109
try {
110+
logger.debug("[playbackId=$playbackId] playPcmData() enter: pcmBytes=${pcmData.size}, sampleRate=$sampleRate, channels=$channels, bitsPerSample=$bitsPerSample")
111+
83112
val channelConfig =
84-
if (channels == 1) {
85-
AudioFormat.CHANNEL_OUT_MONO
86-
} else {
87-
AudioFormat.CHANNEL_OUT_STEREO
88-
}
113+
if (channels == 1) AudioFormat.CHANNEL_OUT_MONO else AudioFormat.CHANNEL_OUT_STEREO
89114

90115
val audioFormat =
91-
if (bitsPerSample == 16) {
92-
AudioFormat.ENCODING_PCM_16BIT
93-
} else {
94-
AudioFormat.ENCODING_PCM_8BIT
116+
when (bitsPerSample) {
117+
8 -> AudioFormat.ENCODING_PCM_8BIT
118+
16 -> AudioFormat.ENCODING_PCM_16BIT
119+
else -> {
120+
logger.debug("[playbackId=$playbackId] completion path=unsupported_bit_depth bitsPerSample=$bitsPerSample")
121+
continuation.resumeWithException(AudioPlaybackException.InvalidAudioFormat)
122+
return@suspendCancellableCoroutine
123+
}
95124
}
96125

97-
val minBufferSize =
98-
AudioTrack.getMinBufferSize(
99-
sampleRate,
100-
channelConfig,
101-
audioFormat,
102-
)
126+
val minBufferSize = AudioTrack.getMinBufferSize(sampleRate, channelConfig, audioFormat)
103127

104128
if (minBufferSize == AudioTrack.ERROR || minBufferSize == AudioTrack.ERROR_BAD_VALUE) {
129+
logger.warn("[playbackId=$playbackId] Invalid minBufferSize=$minBufferSize")
105130
continuation.resumeWithException(AudioPlaybackException.InvalidAudioFormat)
106131
return@suspendCancellableCoroutine
107132
}
@@ -128,70 +153,57 @@ class AudioPlaybackManager {
128153
.setTransferMode(AudioTrack.MODE_STATIC)
129154
.build()
130155

156+
interruptPlayback = { fail(track, AudioPlaybackException.PlaybackInterrupted) }
131157
audioTrack = track
132158
isPlaying = true
133159

134-
// Write all data
135160
val bytesWritten = track.write(pcmData, 0, pcmData.size)
136161
if (bytesWritten < 0) {
137-
track.release()
138-
audioTrack = null
139-
isPlaying = false
140-
continuation.resumeWithException(AudioPlaybackException.PlaybackFailed("Write failed: $bytesWritten"))
162+
logger.debug("[playbackId=$playbackId] completion path=write_error bytesWritten=$bytesWritten")
163+
fail(track, AudioPlaybackException.PlaybackFailed("Write failed: $bytesWritten"))
164+
return@suspendCancellableCoroutine
165+
}
166+
167+
val bytesPerSample = bitsPerSample / 8
168+
val frameSize = bytesPerSample * channels
169+
if (frameSize <= 0 || bytesWritten % frameSize != 0) {
170+
logger.debug("[playbackId=$playbackId] completion path=invalid_frame frameSize=$frameSize bytesWritten=$bytesWritten")
171+
fail(track, AudioPlaybackException.InvalidAudioFormat)
141172
return@suspendCancellableCoroutine
142173
}
143174

144-
// Set notification marker at end
145-
track.notificationMarkerPosition = pcmData.size / (bitsPerSample / 8 * channels)
175+
track.notificationMarkerPosition = bytesWritten / frameSize
146176
track.setPlaybackPositionUpdateListener(
147177
object : AudioTrack.OnPlaybackPositionUpdateListener {
148178
override fun onMarkerReached(track: AudioTrack?) {
149-
isPlaying = false
150-
track?.stop()
151-
track?.release()
152-
audioTrack = null
153-
continuation.resume(Unit)
179+
succeed(track)
154180
}
155181

156182
override fun onPeriodicNotification(track: AudioTrack?) {
157-
// Not used
158183
}
159184
},
160185
)
161186

162-
// Handle cancellation
163187
continuation.invokeOnCancellation {
164-
stop()
188+
fail(track, AudioPlaybackException.PlaybackInterrupted)
165189
}
166190

167-
// Start playback
191+
logger.debug("[playbackId=$playbackId] track.play()")
168192
track.play()
169193
} catch (e: Exception) {
170-
isPlaying = false
171-
audioTrack?.release()
172-
audioTrack = null
173-
continuation.resumeWithException(e)
194+
fail(audioTrack, e)
174195
}
175196
}
176197

177198
private fun parseWavHeader(data: ByteArray): WavInfo {
178-
if (data.size < 44) {
179-
throw AudioPlaybackException.InvalidAudioFormat
180-
}
199+
if (data.size < 44) throw AudioPlaybackException.InvalidAudioFormat
181200

182-
// Check RIFF header
183201
val riff = String(data.copyOfRange(0, 4))
184-
if (riff != "RIFF") {
185-
throw AudioPlaybackException.InvalidAudioFormat
186-
}
202+
if (riff != "RIFF") throw AudioPlaybackException.InvalidAudioFormat
187203

188-
// Check WAVE format
189204
val wave = String(data.copyOfRange(8, 12))
190-
if (wave != "WAVE") {
191-
throw AudioPlaybackException.InvalidAudioFormat
192-
}
205+
if (wave != "WAVE") throw AudioPlaybackException.InvalidAudioFormat
193206

194-
// Parse fmt chunk
195207
val channels = (data[22].toInt() and 0xFF) or ((data[23].toInt() and 0xFF) shl 8)
196208
val sampleRate =
197209
(data[24].toInt() and 0xFF) or
@@ -200,7 +212,10 @@ class AudioPlaybackManager {
200212
((data[27].toInt() and 0xFF) shl 24)
201213
val bitsPerSample = (data[34].toInt() and 0xFF) or ((data[35].toInt() and 0xFF) shl 8)
202214

203-
// Find data chunk (usually at offset 44 but can vary)
215+
if (channels !in 1..2) throw AudioPlaybackException.InvalidAudioFormat
216+
if (sampleRate <= 0) throw AudioPlaybackException.InvalidAudioFormat
217+
if (bitsPerSample !in setOf(8, 16)) throw AudioPlaybackException.InvalidAudioFormat
218+
204219
var dataOffset = 12
205220
while (dataOffset < data.size - 8) {
206221
val chunkId = String(data.copyOfRange(dataOffset, dataOffset + 4))
@@ -211,13 +226,21 @@ class AudioPlaybackManager {
211226
((data[dataOffset + 7].toInt() and 0xFF) shl 24)
212227

213228
if (chunkId == "data") {
214-
dataOffset += 8 // Skip chunk header
229+
dataOffset += 8
215230
break
216231
}
217232

218-
dataOffset += 8 + chunkSize
233+
if (chunkSize < 0) throw AudioPlaybackException.InvalidAudioFormat
234+
val paddedChunkSize = chunkSize + (chunkSize and 1)
235+
val nextOffset = dataOffset.toLong() + 8L + paddedChunkSize.toLong()
236+
if (nextOffset <= dataOffset.toLong() || nextOffset > data.size.toLong()) {
237+
throw AudioPlaybackException.InvalidAudioFormat
238+
}
239+
dataOffset = nextOffset.toInt()
219240
}
220241

242+
if (dataOffset >= data.size) throw AudioPlaybackException.InvalidAudioFormat
243+
221244
return WavInfo(
222245
sampleRate = sampleRate,
223246
channels = channels,
@@ -262,4 +285,4 @@ sealed class AudioPlaybackException : Exception() {
262285

263286
override val message: String = "Invalid audio format"
264287
}
265-
}
288+
}

0 commit comments

Comments
 (0)