141 lines
No EOL
5 KiB
Kotlin
141 lines
No EOL
5 KiB
Kotlin
package com.mistral.chat.api
|
|
|
|
import com.google.gson.JsonObject
|
|
import com.google.gson.Gson
|
|
import com.google.gson.JsonArray
|
|
import com.mistral.chat.data.Message
|
|
import kotlinx.coroutines.Dispatchers
|
|
import kotlinx.coroutines.withContext
|
|
import okhttp3.MediaType.Companion.toMediaType
|
|
import okhttp3.OkHttpClient
|
|
import okhttp3.Request
|
|
import okhttp3.RequestBody.Companion.toRequestBody
|
|
import java.util.concurrent.TimeUnit
|
|
|
|
class MistralClient(private val apiKey: String) {
|
|
|
|
private val client = OkHttpClient.Builder()
|
|
.connectTimeout(60, TimeUnit.SECONDS)
|
|
.readTimeout(120, TimeUnit.SECONDS)
|
|
.writeTimeout(60, TimeUnit.SECONDS)
|
|
.build()
|
|
|
|
private val gson = Gson()
|
|
private val jsonMediaType = "application/json".toMediaType()
|
|
|
|
companion object {
|
|
private const val BASE_URL = "https://api.mistral.ai/v1"
|
|
|
|
val AVAILABLE_MODELS = listOf(
|
|
"mistral-small-latest" to "Mistral Small",
|
|
"mistral-medium-latest" to "Mistral Medium",
|
|
"mistral-large-latest" to "Mistral Large",
|
|
"codestral-latest" to "Codestral",
|
|
"pixtral-large-latest" to "Pixtral Large"
|
|
)
|
|
}
|
|
|
|
suspend fun getModels(): Result<List<Pair<String, String>>> = withContext(Dispatchers.IO) {
|
|
try {
|
|
val request = Request.Builder()
|
|
.url("$BASE_URL/models")
|
|
.addHeader("Authorization", "Bearer $apiKey")
|
|
.get()
|
|
.build()
|
|
|
|
val response = client.newCall(request).execute()
|
|
|
|
if (!response.isSuccessful) {
|
|
return@withContext Result.failure(Exception("API error: ${response.code}"))
|
|
}
|
|
|
|
val responseBody = response.body?.string() ?: ""
|
|
val responseJson = gson.fromJson(responseBody, JsonObject::class.java)
|
|
|
|
val models = responseJson
|
|
.getAsJsonArray("data")
|
|
?.mapNotNull { obj ->
|
|
val jsonObj = obj.asJsonObject
|
|
val id = jsonObj.get("id")?.asString
|
|
val created = jsonObj.get("created")?.asLong ?: 0L
|
|
if (id != null && created > 0 && id.endsWith("-latest")) {
|
|
val displayName = id
|
|
.replace("-latest", "")
|
|
.replace("-", " ")
|
|
.split(" ")
|
|
.joinToString(" ") { it.replaceFirstChar { c -> c.uppercase() } }
|
|
id to displayName
|
|
} else null
|
|
} ?: emptyList()
|
|
|
|
Result.success(models)
|
|
} catch (e: Exception) {
|
|
Result.failure(e)
|
|
}
|
|
}
|
|
|
|
suspend fun chat(
|
|
model: String,
|
|
messages: List<Message>,
|
|
onChunk: ((String) -> Unit)? = null
|
|
): Result<Pair<String, String>> = withContext(Dispatchers.IO) {
|
|
try {
|
|
val jsonObject = JsonObject()
|
|
jsonObject.addProperty("model", model)
|
|
jsonObject.addProperty("temperature", 0.7)
|
|
jsonObject.addProperty("stream", onChunk != null)
|
|
|
|
val messagesArray = JsonArray()
|
|
messages.forEach { msg ->
|
|
val msgObj = JsonObject()
|
|
msgObj.addProperty("role", if (msg.isUser) "user" else "assistant")
|
|
msgObj.addProperty("content", msg.content)
|
|
messagesArray.add(msgObj)
|
|
}
|
|
jsonObject.add("messages", messagesArray)
|
|
|
|
val json = gson.toJson(jsonObject)
|
|
val body = json.toRequestBody(jsonMediaType)
|
|
|
|
val request = Request.Builder()
|
|
.url("$BASE_URL/chat/completions")
|
|
.addHeader("Authorization", "Bearer $apiKey")
|
|
.addHeader("Content-Type", "application/json")
|
|
.post(body)
|
|
.build()
|
|
|
|
val response = client.newCall(request).execute()
|
|
|
|
if (!response.isSuccessful) {
|
|
val errorBody = response.body?.string() ?: "Unknown error"
|
|
return@withContext Result.failure(Exception("API error: ${response.code} - $errorBody"))
|
|
}
|
|
|
|
val responseBody = response.body?.string() ?: ""
|
|
|
|
if (onChunk != null) {
|
|
onChunk(responseBody)
|
|
}
|
|
|
|
val responseJson = gson.fromJson(responseBody, JsonObject::class.java)
|
|
|
|
val choices = responseJson.getAsJsonArray("choices")
|
|
if (choices == null || choices.size() == 0) {
|
|
return@withContext Result.failure(Exception("No response from API"))
|
|
}
|
|
|
|
val content = choices
|
|
.get(0)
|
|
?.asJsonObject
|
|
?.getAsJsonObject("message")
|
|
?.get("content")
|
|
?.asString ?: ""
|
|
|
|
val usedModel = responseJson.get("model")?.asString ?: model
|
|
|
|
Result.success(content to usedModel)
|
|
} catch (e: Exception) {
|
|
Result.failure(e)
|
|
}
|
|
}
|
|
} |