@@ -7,20 +7,27 @@ import com.runanywhere.sdk.foundation.SDKLogger
77import kotlinx.coroutines.Dispatchers
88import kotlinx.coroutines.suspendCancellableCoroutine
99import kotlinx.coroutines.withContext
10+ import java.util.concurrent.atomic.AtomicBoolean
11+ import java.util.concurrent.atomic.AtomicInteger
1012import kotlin.coroutines.resume
1113import kotlin.coroutines.resumeWithException
1214
1315/* *
1416 * Manages audio playback for TTS services on Android.
15- * Plays WAV audio data (16-bit PCM format) generated by TTS synthesis.
17+ * Plays WAV audio data (8-bit or 16-bit PCM format) generated by TTS synthesis.
1618 *
1719 * Matches iOS AudioPlaybackManager behavior.
1820 */
1921class AudioPlaybackManager {
2022 private val logger = SDKLogger .tts
23+ private val playbackIdGenerator = AtomicInteger (0 )
2124
25+ @Volatile
2226 private var audioTrack: AudioTrack ? = null
2327
28+ @Volatile
29+ private var interruptPlayback: (() -> Unit )? = null
30+
2431 @Volatile
2532 var isPlaying: Boolean = false
2633 private set
@@ -36,72 +43,90 @@ class AudioPlaybackManager {
3643 throw AudioPlaybackException .EmptyAudioData
3744 }
3845
46+ val playbackId = playbackIdGenerator.incrementAndGet()
47+ logger.debug(" [playbackId=$playbackId ] play() start: totalBytes=${audioData.size} " )
48+
3949 withContext(Dispatchers .IO ) {
4050 try {
41- // Parse WAV header to get audio parameters
4251 val wavInfo = parseWavHeader(audioData)
43- logger.info(" Playing audio: ${audioData.size} bytes, ${wavInfo.sampleRate} Hz, ${wavInfo.channels} ch" )
52+ logger.info(" [playbackId= $playbackId ] Playing audio: ${audioData.size} bytes, ${wavInfo.sampleRate} Hz, ${wavInfo.channels} ch, ${wavInfo.bitsPerSample} bit " )
4453
45- // Get PCM data (skip WAV header)
4654 val pcmData = audioData.copyOfRange(wavInfo.dataOffset, audioData.size)
4755
48- playPcmData(pcmData, wavInfo.sampleRate, wavInfo.channels, wavInfo.bitsPerSample)
56+ playPcmData(playbackId, pcmData, wavInfo.sampleRate, wavInfo.channels, wavInfo.bitsPerSample)
4957
50- logger.info(" Playback completed" )
58+ logger.info(" [playbackId= $playbackId ] play() completed" )
5159 } catch (e: Exception ) {
52- logger.error(" Playback failed: ${e.message} " )
60+ logger.error(" [playbackId= $playbackId ] play() failed: ${e.message} " )
5361 throw if (e is AudioPlaybackException ) e else AudioPlaybackException .PlaybackFailed (e.message)
5462 }
5563 }
5664 }
5765
5866 /* *
59- * Stop current playback.
67+ * Stop current playback. If play() is suspended, it will resume with PlaybackInterrupted.
6068 */
6169 fun stop () {
62- if (! isPlaying) return
63-
64- try {
65- audioTrack?.stop()
66- audioTrack?.release()
67- audioTrack = null
68- } catch (e: Exception ) {
69- logger.error(" Error stopping playback: ${e.message} " )
70- }
71-
72- isPlaying = false
73- logger.info(" Playback stopped" )
70+ logger.debug(" stop() called: isPlaying=$isPlaying , hasTrack=${audioTrack != null } " )
71+ interruptPlayback?.invoke()
7472 }
7573
7674 private suspend fun playPcmData (
75+ playbackId : Int ,
7776 pcmData : ByteArray ,
7877 sampleRate : Int ,
7978 channels : Int ,
8079 bitsPerSample : Int ,
8180 ) = suspendCancellableCoroutine { continuation ->
81+ val resumed = AtomicBoolean (false )
82+
83+ fun cleanup (track : AudioTrack ? ) {
84+ isPlaying = false
85+ try { track?.stop() } catch (_: Exception ) {}
86+ try { track?.release() } catch (_: Exception ) {}
87+ if (audioTrack == = track) audioTrack = null
88+ interruptPlayback = null
89+ }
90+
91+ fun succeed (track : AudioTrack ? ) {
92+ val casWon = resumed.compareAndSet(false , true )
93+ logger.debug(" [playbackId=$playbackId ] completion path=success casWon=$casWon " )
94+ if (casWon) {
95+ cleanup(track)
96+ continuation.resume(Unit )
97+ }
98+ }
99+
100+ fun fail (track : AudioTrack ? , throwable : Throwable ) {
101+ val casWon = resumed.compareAndSet(false , true )
102+ logger.debug(" [playbackId=$playbackId ] completion path=${throwable::class .simpleName} casWon=$casWon " )
103+ if (casWon) {
104+ cleanup(track)
105+ continuation.resumeWithException(throwable)
106+ }
107+ }
108+
82109 try {
110+ logger.debug(" [playbackId=$playbackId ] playPcmData() enter: pcmBytes=${pcmData.size} , sampleRate=$sampleRate , channels=$channels , bitsPerSample=$bitsPerSample " )
111+
83112 val channelConfig =
84- if (channels == 1 ) {
85- AudioFormat .CHANNEL_OUT_MONO
86- } else {
87- AudioFormat .CHANNEL_OUT_STEREO
88- }
113+ if (channels == 1 ) AudioFormat .CHANNEL_OUT_MONO else AudioFormat .CHANNEL_OUT_STEREO
89114
90115 val audioFormat =
91- if (bitsPerSample == 16 ) {
92- AudioFormat .ENCODING_PCM_16BIT
93- } else {
94- AudioFormat .ENCODING_PCM_8BIT
116+ when (bitsPerSample) {
117+ 8 -> AudioFormat .ENCODING_PCM_8BIT
118+ 16 -> AudioFormat .ENCODING_PCM_16BIT
119+ else -> {
120+ logger.debug(" [playbackId=$playbackId ] completion path=unsupported_bit_depth bitsPerSample=$bitsPerSample " )
121+ continuation.resumeWithException(AudioPlaybackException .InvalidAudioFormat )
122+ return @suspendCancellableCoroutine
123+ }
95124 }
96125
97- val minBufferSize =
98- AudioTrack .getMinBufferSize(
99- sampleRate,
100- channelConfig,
101- audioFormat,
102- )
126+ val minBufferSize = AudioTrack .getMinBufferSize(sampleRate, channelConfig, audioFormat)
103127
104128 if (minBufferSize == AudioTrack .ERROR || minBufferSize == AudioTrack .ERROR_BAD_VALUE ) {
129+ logger.warn(" [playbackId=$playbackId ] Invalid minBufferSize=$minBufferSize " )
105130 continuation.resumeWithException(AudioPlaybackException .InvalidAudioFormat )
106131 return @suspendCancellableCoroutine
107132 }
@@ -128,70 +153,57 @@ class AudioPlaybackManager {
128153 .setTransferMode(AudioTrack .MODE_STATIC )
129154 .build()
130155
156+ interruptPlayback = { fail(track, AudioPlaybackException .PlaybackInterrupted ) }
131157 audioTrack = track
132158 isPlaying = true
133159
134- // Write all data
135160 val bytesWritten = track.write(pcmData, 0 , pcmData.size)
136161 if (bytesWritten < 0 ) {
137- track.release()
138- audioTrack = null
139- isPlaying = false
140- continuation.resumeWithException(AudioPlaybackException .PlaybackFailed (" Write failed: $bytesWritten " ))
162+ logger.debug(" [playbackId=$playbackId ] completion path=write_error bytesWritten=$bytesWritten " )
163+ fail(track, AudioPlaybackException .PlaybackFailed (" Write failed: $bytesWritten " ))
164+ return @suspendCancellableCoroutine
165+ }
166+
167+ val bytesPerSample = bitsPerSample / 8
168+ val frameSize = bytesPerSample * channels
169+ if (frameSize <= 0 || bytesWritten % frameSize != 0 ) {
170+ logger.debug(" [playbackId=$playbackId ] completion path=invalid_frame frameSize=$frameSize bytesWritten=$bytesWritten " )
171+ fail(track, AudioPlaybackException .InvalidAudioFormat )
141172 return @suspendCancellableCoroutine
142173 }
143174
144- // Set notification marker at end
145- track.notificationMarkerPosition = pcmData.size / (bitsPerSample / 8 * channels)
175+ track.notificationMarkerPosition = bytesWritten / frameSize
146176 track.setPlaybackPositionUpdateListener(
147177 object : AudioTrack .OnPlaybackPositionUpdateListener {
148178 override fun onMarkerReached (track : AudioTrack ? ) {
149- isPlaying = false
150- track?.stop()
151- track?.release()
152- audioTrack = null
153- continuation.resume(Unit )
179+ succeed(track)
154180 }
155181
156182 override fun onPeriodicNotification (track : AudioTrack ? ) {
157- // Not used
158183 }
159184 },
160185 )
161186
162- // Handle cancellation
163187 continuation.invokeOnCancellation {
164- stop( )
188+ fail(track, AudioPlaybackException . PlaybackInterrupted )
165189 }
166190
167- // Start playback
191+ logger.debug( " [playbackId= $playbackId ] track.play() " )
168192 track.play()
169193 } catch (e: Exception ) {
170- isPlaying = false
171- audioTrack?.release()
172- audioTrack = null
173- continuation.resumeWithException(e)
194+ fail(audioTrack, e)
174195 }
175196 }
176197
177198 private fun parseWavHeader (data : ByteArray ): WavInfo {
178- if (data.size < 44 ) {
179- throw AudioPlaybackException .InvalidAudioFormat
180- }
199+ if (data.size < 44 ) throw AudioPlaybackException .InvalidAudioFormat
181200
182- // Check RIFF header
183201 val riff = String (data.copyOfRange(0 , 4 ))
184- if (riff != " RIFF" ) {
185- throw AudioPlaybackException .InvalidAudioFormat
186- }
202+ if (riff != " RIFF" ) throw AudioPlaybackException .InvalidAudioFormat
187203
188- // Check WAVE format
189204 val wave = String (data.copyOfRange(8 , 12 ))
190- if (wave != " WAVE" ) {
191- throw AudioPlaybackException .InvalidAudioFormat
192- }
205+ if (wave != " WAVE" ) throw AudioPlaybackException .InvalidAudioFormat
193206
194- // Parse fmt chunk
195207 val channels = (data[22 ].toInt() and 0xFF ) or ((data[23 ].toInt() and 0xFF ) shl 8 )
196208 val sampleRate =
197209 (data[24 ].toInt() and 0xFF ) or
@@ -200,7 +212,10 @@ class AudioPlaybackManager {
200212 ((data[27 ].toInt() and 0xFF ) shl 24 )
201213 val bitsPerSample = (data[34 ].toInt() and 0xFF ) or ((data[35 ].toInt() and 0xFF ) shl 8 )
202214
203- // Find data chunk (usually at offset 44 but can vary)
215+ if (channels !in 1 .. 2 ) throw AudioPlaybackException .InvalidAudioFormat
216+ if (sampleRate <= 0 ) throw AudioPlaybackException .InvalidAudioFormat
217+ if (bitsPerSample !in setOf (8 , 16 )) throw AudioPlaybackException .InvalidAudioFormat
218+
204219 var dataOffset = 12
205220 while (dataOffset < data.size - 8 ) {
206221 val chunkId = String (data.copyOfRange(dataOffset, dataOffset + 4 ))
@@ -211,13 +226,21 @@ class AudioPlaybackManager {
211226 ((data[dataOffset + 7 ].toInt() and 0xFF ) shl 24 )
212227
213228 if (chunkId == " data" ) {
214- dataOffset + = 8 // Skip chunk header
229+ dataOffset + = 8
215230 break
216231 }
217232
218- dataOffset + = 8 + chunkSize
233+ if (chunkSize < 0 ) throw AudioPlaybackException .InvalidAudioFormat
234+ val paddedChunkSize = chunkSize + (chunkSize and 1 )
235+ val nextOffset = dataOffset.toLong() + 8L + paddedChunkSize.toLong()
236+ if (nextOffset <= dataOffset.toLong() || nextOffset > data.size.toLong()) {
237+ throw AudioPlaybackException .InvalidAudioFormat
238+ }
239+ dataOffset = nextOffset.toInt()
219240 }
220241
242+ if (dataOffset >= data.size) throw AudioPlaybackException .InvalidAudioFormat
243+
221244 return WavInfo (
222245 sampleRate = sampleRate,
223246 channels = channels,
@@ -262,4 +285,4 @@ sealed class AudioPlaybackException : Exception() {
262285
263286 override val message: String = " Invalid audio format"
264287 }
265- }
288+ }
0 commit comments