/*
 * Decompiled with CFR 0.152.
 */
package cn.bugstack.openai.executor.model.baidu;

import cn.bugstack.openai.executor.Executor;
import cn.bugstack.openai.executor.model.baidu.config.BaiduConfig;
import cn.bugstack.openai.executor.model.baidu.utils.AccessTokenUtils;
import cn.bugstack.openai.executor.model.baidu.valobj.BaiduCompletionRequest;
import cn.bugstack.openai.executor.model.baidu.valobj.BaiduCompletionResponse;
import cn.bugstack.openai.executor.model.baidu.valobj.BaiduImageRequest;
import cn.bugstack.openai.executor.model.baidu.valobj.Message;
import cn.bugstack.openai.executor.model.baidu.valobj.Usage;
import cn.bugstack.openai.executor.parameter.ChatChoice;
import cn.bugstack.openai.executor.parameter.CompletionRequest;
import cn.bugstack.openai.executor.parameter.CompletionResponse;
import cn.bugstack.openai.executor.parameter.ImageRequest;
import cn.bugstack.openai.executor.parameter.ImageResponse;
import cn.bugstack.openai.executor.parameter.ParameterHandler;
import cn.bugstack.openai.executor.parameter.PictureRequest;
import cn.bugstack.openai.executor.result.ResultHandler;
import cn.bugstack.openai.session.Configuration;
import com.alibaba.fastjson.JSON;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import javax.annotation.Nullable;
import okhttp3.Call;
import okhttp3.MediaType;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.Response;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class BaiduModelExecutor
implements Executor,
ParameterHandler<BaiduCompletionRequest>,
ResultHandler {
    private static final Logger log = LoggerFactory.getLogger(BaiduModelExecutor.class);
    private final BaiduConfig baiduConfig;
    private final EventSource.Factory factory;
    private final OkHttpClient okHttpClient;

    public BaiduModelExecutor(Configuration configuration) {
        this.baiduConfig = configuration.getBaiduConfig();
        this.okHttpClient = configuration.getOkHttpClient();
        this.factory = configuration.createRequestFactory();
    }

    @Override
    public EventSource completions(CompletionRequest completionRequest, EventSourceListener eventSourceListener) throws Exception {
        return this.completions(null, null, completionRequest, eventSourceListener);
    }

    @Override
    public EventSource completions(String apiHostByUser, String apiKeyByUser, CompletionRequest completionRequest, EventSourceListener eventSourceListener) throws Exception {
        if (!completionRequest.isStream()) {
            throw new RuntimeException("illegal parameter stream is false!");
        }
        String apiHost = null == apiHostByUser ? this.baiduConfig.getApiHost() : apiHostByUser;
        String apiKey = this.baiduConfig.getApiKey();
        String apiSecret = this.baiduConfig.getApiSecret();
        String authHost = this.baiduConfig.getAuthHost();
        if (apiKeyByUser != null) {
            String[] apiKeySecret = apiKeyByUser.split("\\.");
            apiKey = apiKeySecret[0];
            apiSecret = apiKeySecret[1];
        }
        String accessToken = AccessTokenUtils.getAccessToken(this.okHttpClient, apiKey, apiSecret, authHost);
        BaiduCompletionRequest baiduCompletionRequest = this.getParameterObject(completionRequest);
        Request request = new Request.Builder().addHeader("Content-Type", "application/json").url(apiHost.concat(BaiduConfig.CompletionsUrl.valueOf(completionRequest.getModel()).getUrl()).concat("?access_token=").concat(accessToken)).post(RequestBody.create((MediaType)MediaType.parse((String)"application/json"), (String)new ObjectMapper().writeValueAsString((Object)baiduCompletionRequest))).build();
        return this.factory.newEventSource(request, this.eventSourceListener(eventSourceListener));
    }

    @Override
    public ImageResponse genImages(ImageRequest imageRequest) throws Exception {
        return this.genImages(null, null, imageRequest);
    }

    @Override
    public ImageResponse genImages(String apiHostByUser, String apiKeyByUser, ImageRequest imageRequest) throws IOException {
        BaiduImageRequest baiduImageRequest = BaiduImageRequest.builder().n(imageRequest.getN()).size(imageRequest.getSize()).prompt(imageRequest.getPrompt()).build();
        String apiHost = null == apiHostByUser ? this.baiduConfig.getApiHost() : apiHostByUser;
        String apiKey = this.baiduConfig.getApiKey();
        String apiSecret = this.baiduConfig.getApiSecret();
        String authHost = this.baiduConfig.getAuthHost();
        if (apiKeyByUser != null) {
            String[] apiKeySecret = apiKeyByUser.split("\\.");
            apiKey = apiKeySecret[0];
            apiSecret = apiKeySecret[1];
        }
        String accessToken = AccessTokenUtils.getAccessToken(this.okHttpClient, apiKey, apiSecret, authHost);
        String url = BaiduConfig.CompletionsUrl.valueOf(imageRequest.getModel()).getUrl();
        Request request = new Request.Builder().addHeader("Content-Type", "application/json").url(apiHost.concat(url).concat("?access_token=").concat(accessToken)).post(RequestBody.create((MediaType)MediaType.parse((String)"application/json"), (String)JSON.toJSONString((Object)baiduImageRequest))).build();
        Call call = this.okHttpClient.newCall(request);
        Response response = call.execute();
        if (response.isSuccessful() && response.body() != null) {
            return (ImageResponse)JSON.parseObject((String)response.body().string(), ImageResponse.class);
        }
        throw new IOException("Failed to get image response");
    }

    @Override
    public EventSource pictureUnderstanding(PictureRequest pictureRequest, EventSourceListener eventSourceListener) throws Exception {
        return null;
    }

    @Override
    public EventSource pictureUnderstanding(String apiHostByUser, String apiKeyByUser, PictureRequest pictureRequest, EventSourceListener eventSourceListener) throws Exception {
        return null;
    }

    @Override
    public BaiduCompletionRequest getParameterObject(CompletionRequest completionRequest) {
        ArrayList<Message> wenXinMessages = new ArrayList<Message>();
        List<cn.bugstack.openai.executor.parameter.Message> messages = completionRequest.getMessages();
        for (cn.bugstack.openai.executor.parameter.Message message : messages) {
            Message messageVo = new Message();
            messageVo.setRole(message.getRole());
            messageVo.setContent(message.getContent());
            wenXinMessages.add(messageVo);
        }
        BaiduCompletionRequest baiduCompletionRequest = new BaiduCompletionRequest();
        baiduCompletionRequest.setStream(completionRequest.isStream());
        baiduCompletionRequest.setTopP(completionRequest.getTopP());
        baiduCompletionRequest.setTemperature(completionRequest.getTemperature());
        baiduCompletionRequest.setMessages(wenXinMessages);
        return baiduCompletionRequest;
    }

    @Override
    public EventSourceListener eventSourceListener(final EventSourceListener eventSourceListener) {
        return new EventSourceListener(){

            public void onEvent(EventSource eventSource, @Nullable String id, @Nullable String type, String data) {
                BaiduCompletionResponse response = (BaiduCompletionResponse)JSON.parseObject((String)data, BaiduCompletionResponse.class);
                CompletionResponse completionResponse = new CompletionResponse();
                ArrayList<ChatChoice> choices = new ArrayList<ChatChoice>();
                ChatChoice chatChoice = new ChatChoice();
                chatChoice.setDelta(cn.bugstack.openai.executor.parameter.Message.builder().name("").role(CompletionRequest.Role.SYSTEM).content(response.getResult()).build());
                choices.add(chatChoice);
                completionResponse.setChoices(choices);
                if (!response.getIsEnd().booleanValue()) {
                    eventSourceListener.onEvent(eventSource, id, type, JSON.toJSONString((Object)completionResponse));
                } else {
                    Usage usage = response.getUsage();
                    cn.bugstack.openai.executor.parameter.Usage openaiUsage = new cn.bugstack.openai.executor.parameter.Usage();
                    openaiUsage.setPromptTokens(usage.getPromptTokens());
                    openaiUsage.setCompletionTokens(usage.getCompletionTokens());
                    openaiUsage.setTotalTokens(usage.getTotalTokens());
                    chatChoice.setFinishReason("stop");
                    chatChoice.setDelta(cn.bugstack.openai.executor.parameter.Message.builder().name("").role(CompletionRequest.Role.SYSTEM).content(response.getResult()).build());
                    choices.add(chatChoice);
                    completionResponse.setUsage(openaiUsage);
                    completionResponse.setCreated(System.currentTimeMillis());
                    eventSourceListener.onEvent(eventSource, null, null, JSON.toJSONString((Object)completionResponse));
                }
            }

            public void onOpen(EventSource eventSource, Response response) {
                eventSourceListener.onOpen(eventSource, response);
            }

            public void onClosed(EventSource eventSource) {
                eventSourceListener.onClosed(eventSource);
            }

            public void onFailure(EventSource eventSource, @Nullable Throwable t, @Nullable Response response) {
                eventSourceListener.onFailure(eventSource, t, response);
            }
        };
    }
}

