package cn.bugstack.openai.executor.model.xunfei;

import cn.bugstack.openai.executor.Executor;
import cn.bugstack.openai.executor.model.xunfei.config.XunFeiConfig;
import cn.bugstack.openai.executor.model.xunfei.utils.URLAuthUtils;
import cn.bugstack.openai.executor.model.xunfei.valobj.*;
import cn.bugstack.openai.executor.model.xunfei.valobj.Usage;
import cn.bugstack.openai.executor.parameter.*;
import cn.bugstack.openai.session.Configuration;
import com.alibaba.fastjson.JSON;
import lombok.extern.slf4j.Slf4j;
import okhttp3.*;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

import java.util.ArrayList;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.CountDownLatch;

/**
 * 讯飞大模型 https://console.xfyun.cn/services/bm3
 *
 * @author 小傅哥，微信：fustack
 */
@Slf4j
public class XunFeiModelExecutor implements Executor, ParameterHandler<XunFeiCompletionRequest> {

    /**
     * 配置信息
     */
    private final XunFeiConfig xunFeiConfig;
    /**
     * 客户端
     */
    private final OkHttpClient okHttpClient;

    public XunFeiModelExecutor(Configuration configuration) {
        this.xunFeiConfig = configuration.getXunFeiConfig();
        this.okHttpClient = configuration.getOkHttpClient();
    }

    @Override
    public EventSource completions(CompletionRequest completionRequest, EventSourceListener eventSourceListener) throws Exception {
        // 1. 转换参数信息
        XunFeiCompletionRequest xunFeiCompletionRequest = getParameterObject(completionRequest);
        // 2. 构建请求信息
        String authURL = URLAuthUtils.getAuthURl(xunFeiConfig.getApiHost(), xunFeiConfig.getApiKey(), xunFeiConfig.getApiSecret());
        Request request = new Request.Builder()
                .url(authURL)
                .build();
        // 3. 调用请求
        WebSocket webSocket = okHttpClient.newWebSocket(request, new BigModelWebSocketListener(xunFeiCompletionRequest, eventSourceListener));
        // 4. 封装结果
        return new EventSource() {

            @NotNull
            @Override
            public Request request() {
                return request;
            }

            @Override
            public void cancel() {
                webSocket.cancel();
            }
        };
    }

    @Override
    public EventSource completions(String apiHostByUser, String apiKeyByUser, CompletionRequest completionRequest, EventSourceListener eventSourceListener) throws Exception {
        // TODO 待实现
        return null;
    }

    protected static class BigModelWebSocketListener extends WebSocketListener {

        private final XunFeiCompletionRequest request;
        private final EventSourceListener eventSourceListener;
        private final CountDownLatch countDownLatch = new CountDownLatch(1);
        private final EventSource eventSource;

        public BigModelWebSocketListener(XunFeiCompletionRequest request, EventSourceListener eventSourceListener) {
            this.request = request;
            this.eventSourceListener = eventSourceListener;
            this.eventSource = new EventSource() {

                @Override
                public Request request() {
                    return this.request();
                }

                @Override
                public void cancel() {
                    this.cancel();
                }
            };
        }

        @Override
        public void onOpen(@NotNull WebSocket webSocket, @NotNull Response response) {
            new Thread(() -> {
                webSocket.send(JSON.toJSONString(request));
                // 等待服务端返回完毕后关闭
                try {
                    countDownLatch.await();
                } catch (InterruptedException e) {
                    throw new RuntimeException(e);
                }
                webSocket.close(1000, "");
            }).start();
        }

        @Override
        public void onMessage(WebSocket webSocket, String text) {
            XunFeiCompletionResponse response = JSON.parseObject(text, XunFeiCompletionResponse.class);
            XunFeiCompletionResponse.Header header = response.getHeader();
            int code = header.getCode();

            // 反馈失败
            if (XunFeiCompletionResponse.Header.Code.SUCCESS.getValue() != code) {
                countDownLatch.countDown();
                return;
            }

            // 封装参数
            CompletionResponse completionResponse = new CompletionResponse();
            List<ChatChoice> chatChoices = new ArrayList<>();
            ChatChoice chatChoice = new ChatChoice();

            XunFeiCompletionResponse.Payload payload = response.getPayload();
            Choices choices = payload.getChoices();
            List<Text> texts = choices.getText();

            for (Text t : texts) {
                chatChoice.setDelta(cn.bugstack.openai.executor.parameter.Message.builder()
                        .name("")
                        .role(CompletionRequest.Role.SYSTEM)
                        .content(t.getContent())
                        .build());
                chatChoices.add(chatChoice);
            }
            completionResponse.setChoices(chatChoices);

            int status = header.getStatus();
            if (XunFeiCompletionResponse.Header.Status.START.getValue() == status) {
                eventSourceListener.onEvent(eventSource, null, null, JSON.toJSONString(completionResponse));
            } else if (XunFeiCompletionResponse.Header.Status.ING.getValue() == status) {
                eventSourceListener.onEvent(eventSource, null, null, JSON.toJSONString(completionResponse));
            } else if (XunFeiCompletionResponse.Header.Status.END.getValue() == status) {
                Usage usage = payload.getUsage();
                Usage.Text usageText = usage.getText();
                cn.bugstack.openai.executor.parameter.Usage openaiUsage = new cn.bugstack.openai.executor.parameter.Usage();
                openaiUsage.setPromptTokens(usageText.getPromptTokens());
                openaiUsage.setCompletionTokens(usageText.getCompletionTokens());
                openaiUsage.setTotalTokens(usageText.getTotalTokens());
                completionResponse.setUsage(openaiUsage);
                completionResponse.setCreated(System.currentTimeMillis());
                chatChoice.setDelta(cn.bugstack.openai.executor.parameter.Message.builder()
                        .name("")
                        .role(CompletionRequest.Role.SYSTEM)
                        .content("stop")
                        .build());
                chatChoices.add(chatChoice);
                eventSourceListener.onEvent(eventSource, null, null, JSON.toJSONString(completionResponse));
            }

            countDownLatch.countDown();
        }

        @Override
        public void onClosed(WebSocket webSocket, int code, String reason) {
            eventSourceListener.onClosed(eventSource);
        }

        @Override
        public void onFailure(WebSocket webSocket, Throwable t, @Nullable Response response) {
            eventSourceListener.onFailure(eventSource, t, response);
        }
    }

    @Override
    public XunFeiCompletionRequest getParameterObject(CompletionRequest completionRequest) {
        // 头信息
        XunFeiCompletionRequest.Header header = XunFeiCompletionRequest.Header.builder()
                .appid(xunFeiConfig.getAppid())
                .uid(UUID.randomUUID().toString().substring(0, 10))
                .build();
        // 模型
        XunFeiCompletionRequest.Parameter parameter = XunFeiCompletionRequest.Parameter.builder().chat(Chat.builder()
                .domain("generalv2")
                .temperature(completionRequest.getTemperature())
                .maxTokens(completionRequest.getMaxTokens())
                .build()).build();
        // 内容
        List<Text> texts = new ArrayList<>();
        List<cn.bugstack.openai.executor.parameter.Message> messages = completionRequest.getMessages();
        for (cn.bugstack.openai.executor.parameter.Message message : messages) {
            texts.add(Text.builder()
                    .role(Text.Role.USER.getName())
                    .content(message.getContent())
                    .build());
        }

        XunFeiCompletionRequest.Payload payload = XunFeiCompletionRequest.Payload.builder()
                .message(cn.bugstack.openai.executor.model.xunfei.valobj.Message.builder().text(texts).build())
                .build();

        return XunFeiCompletionRequest.builder()
                .header(header)
                .parameter(parameter)
                .payload(payload)
                .build();
    }

}
