package cn.bugstack.openai.executor.model.chatgpt;

import cn.bugstack.openai.executor.Executor;
import cn.bugstack.openai.executor.model.chatgpt.config.ChatGPTConfig;
import cn.bugstack.openai.executor.model.chatgpt.valobj.ChatGPTCompletionRequest;
import cn.bugstack.openai.executor.parameter.CompletionRequest;
import cn.bugstack.openai.executor.parameter.Message;
import cn.bugstack.openai.executor.parameter.ParameterHandler;
import cn.bugstack.openai.executor.result.ResultHandler;
import cn.bugstack.openai.session.Configuration;
import com.fasterxml.jackson.databind.ObjectMapper;
import okhttp3.MediaType;
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener;

import java.util.ArrayList;
import java.util.List;

/**
 * ChatGPT 模型执行器 https://openai.apifox.cn/doc-3222729
 *
 * @author 小傅哥，微信：fustack
 */
public class ChatGPTModelExecutor implements Executor, ParameterHandler<ChatGPTCompletionRequest>, ResultHandler {

    /**
     * 配置信息
     */
    private final ChatGPTConfig chatGPTConfig;
    /**
     * 工厂事件
     */
    private final EventSource.Factory factory;

    public ChatGPTModelExecutor(Configuration configuration) {
        this.chatGPTConfig = configuration.getChatGPTConfig();
        this.factory = configuration.createRequestFactory();
    }

    @Override
    public EventSource completions(CompletionRequest completionRequest, EventSourceListener eventSourceListener) throws Exception {
        return completions(null, null, completionRequest, eventSourceListener);
    }

    @Override
    public EventSource completions(String apiHostByUser, String apiKeyByUser, CompletionRequest completionRequest, EventSourceListener eventSourceListener) throws Exception {
        // 1. 核心参数校验；不对用户的传参做更改，只返回错误信息。
        if (!completionRequest.isStream()) {
            throw new RuntimeException("illegal parameter stream is false!");
        }

        // 2. 动态设置 Host、Key，便于用户传递自己的信息
        String apiHost = null == apiHostByUser ? chatGPTConfig.getApiHost() : apiHostByUser;
        String apiKey = null == apiKeyByUser ? chatGPTConfig.getApiKey() : apiKeyByUser;

        // 3. 转换参数信息
        ChatGPTCompletionRequest chatGPTCompletionRequest = getParameterObject(completionRequest);

        // 4. 构建请求信息
        Request request = new Request.Builder()
                .header("Authorization", "Bearer " + apiKey)
                .url(apiHost.concat(chatGPTConfig.getV1_chat_completions()))
                .post(RequestBody.create(MediaType.parse(Configuration.APPLICATION_JSON), new ObjectMapper().writeValueAsString(chatGPTCompletionRequest)))
                .build();

        // 5. 返回事件结果
        return factory.newEventSource(request, eventSourceListener);
    }

    @Override
    public ChatGPTCompletionRequest getParameterObject(CompletionRequest completionRequest) {
        // 转换参数
        List<cn.bugstack.openai.executor.model.chatgpt.valobj.Message> chatGPTMessages = new ArrayList<>();
        List<Message> messages = completionRequest.getMessages();
        for (Message message : messages) {
            cn.bugstack.openai.executor.model.chatgpt.valobj.Message messageVO = new cn.bugstack.openai.executor.model.chatgpt.valobj.Message();
            messageVO.setContent(message.getContent());
            messageVO.setName(message.getName());
            messageVO.setRole(message.getRole());
            chatGPTMessages.add(messageVO);
        }

        // 封装参数
        ChatGPTCompletionRequest chatGPTCompletionRequest = new ChatGPTCompletionRequest();
        chatGPTCompletionRequest.setModel(completionRequest.getModel());
        chatGPTCompletionRequest.setTemperature(completionRequest.getTemperature());
        chatGPTCompletionRequest.setTopP(completionRequest.getTopP());
        chatGPTCompletionRequest.setStream(completionRequest.isStream());
        chatGPTCompletionRequest.setMessages(chatGPTMessages);

        return chatGPTCompletionRequest;
    }

    @Override
    public EventSourceListener eventSourceListener(EventSourceListener eventSourceListener) {
        return eventSourceListener;
    }

}
