/*
 * Decompiled with CFR 0.152.
 */
package com.tokenwatcher;

import com.tokenwatcher.BaseMonitoredClient;
import com.tokenwatcher.Event;
import com.tokenwatcher.TokenWatcherConfig;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class MonitoredOpenAI
extends BaseMonitoredClient {
    private static final Logger logger = LoggerFactory.getLogger(MonitoredOpenAI.class);
    private static final String PROVIDER = "openai";
    private final Object wrappedClient;

    private MonitoredOpenAI(Object client, TokenWatcherConfig config, Map<String, Object> metadata) {
        super(config, metadata);
        this.wrappedClient = client;
    }

    public static <T> T wrap(T client, TokenWatcherConfig config) {
        return MonitoredOpenAI.wrap(client, config, null);
    }

    public static <T> T wrap(T client, TokenWatcherConfig config, Map<String, Object> metadata) {
        MonitoredOpenAI monitor = new MonitoredOpenAI(client, config, metadata);
        return (T)Proxy.newProxyInstance(client.getClass().getClassLoader(), client.getClass().getInterfaces(), (InvocationHandler)new MonitoringInvocationHandler(monitor, client, PROVIDER));
    }

    private static class MonitoringInvocationHandler
    implements InvocationHandler {
        private final MonitoredOpenAI monitor;
        private final Object target;
        private final String provider;

        MonitoringInvocationHandler(MonitoredOpenAI monitor, Object target, String provider) {
            this.monitor = monitor;
            this.target = target;
            this.provider = provider;
        }

        @Override
        public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
            if (!this.shouldMonitor(method)) {
                return method.invoke(this.target, args);
            }
            long startTime = System.currentTimeMillis();
            String model = this.extractModel(args);
            String context = this.monitor.getContext();
            try {
                Object result = method.invoke(this.target, args);
                long latencyMs = System.currentTimeMillis() - startTime;
                Integer inputTokens = this.extractInputTokens(result);
                Integer outputTokens = this.extractOutputTokens(result);
                Event event = this.monitor.createSuccessEvent(this.provider, model != null ? model : "unknown", inputTokens, outputTokens, latencyMs, context);
                this.monitor.buffer.addEvent(event);
                return result;
            }
            catch (Exception e) {
                long latencyMs = System.currentTimeMillis() - startTime;
                Event event = this.monitor.createErrorEvent(this.provider, model != null ? model : "unknown", latencyMs, e.getCause() != null ? e.getCause() : e, context);
                this.monitor.buffer.addEvent(event);
                throw e.getCause() != null ? e.getCause() : e;
            }
        }

        private boolean shouldMonitor(Method method) {
            String name = method.getName();
            return name.contains("create") || name.contains("complete") || name.contains("chat") || name.contains("embedding");
        }

        private String extractModel(Object[] args) {
            if (args == null || args.length == 0) {
                return null;
            }
            for (Object arg : args) {
                if (arg == null) continue;
                try {
                    Method getModel = arg.getClass().getMethod("getModel", new Class[0]);
                    Object model = getModel.invoke(arg, new Object[0]);
                    if (model == null) continue;
                    return model.toString();
                }
                catch (Exception exception) {
                    // empty catch block
                }
            }
            return null;
        }

        private Integer extractInputTokens(Object result) {
            return this.extractTokenCount(result, "getPromptTokens", "promptTokens");
        }

        private Integer extractOutputTokens(Object result) {
            return this.extractTokenCount(result, "getCompletionTokens", "completionTokens");
        }

        private Integer extractTokenCount(Object result, String ... methodNames) {
            block6: {
                if (result == null) {
                    return null;
                }
                try {
                    Method getUsage = result.getClass().getMethod("getUsage", new Class[0]);
                    Object usage = getUsage.invoke(result, new Object[0]);
                    if (usage == null) break block6;
                    for (String methodName : methodNames) {
                        try {
                            Method getTokens = usage.getClass().getMethod(methodName, new Class[0]);
                            Object tokens = getTokens.invoke(usage, new Object[0]);
                            if (!(tokens instanceof Number)) continue;
                            return ((Number)tokens).intValue();
                        }
                        catch (Exception exception) {
                            // empty catch block
                        }
                    }
                }
                catch (Exception exception) {
                    // empty catch block
                }
            }
            return null;
        }
    }
}

