package com.codeupsoft.interceptor.xss.core;

import com.codeupsoft.interceptor.xss.handler.XssClearHandler;
import jakarta.servlet.ReadListener;
import jakarta.servlet.ServletInputStream;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletRequestWrapper;
import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import org.apache.commons.lang3.StringUtils;

/**
 * XSS防护请求包装类.
 *
 * <p>对HttpServletRequest进行包装，对参数进行XSS过滤.
 *
 * @author codeupsoft
 */
public class XssHttpServletRequestWrapper extends HttpServletRequestWrapper {

  private final XssClearHandler xssClearHandler;

  /**
   * 构造函数.
   *
   * @param request 原始请求对象
   * @param xssClearHandler XSS清理处理器
   */
  public XssHttpServletRequestWrapper(HttpServletRequest request, XssClearHandler xssClearHandler) {
    super(request);
    this.xssClearHandler = xssClearHandler;
  }

  /**
   * 清理XSS攻击字符.
   *
   * @param value 待清理的字符串
   * @return 清理后的字符串
   */
  private String cleanXss(String value) {
    return xssClearHandler.clearXss(value);
  }

  /**
   * 获取指定参数名的值，并对其进行XSS过滤.
   *
   * @param name 参数名称
   * @return 过滤后的参数值
   */
  @Override
  public String getParameter(String name) {
    String value = super.getParameter(name);
    if (StringUtils.isNotBlank(value)) {
      return cleanXss(value);
    }
    return value;
  }

  /**
   * 获取指定参数名的所有值，并对其进行XSS过滤.
   *
   * @param name 参数名称
   * @return 过滤后的参数值数组
   */
  @Override
  public String[] getParameterValues(String name) {
    String[] values = super.getParameterValues(name);
    if (values != null && values.length > 0) {
      String[] newValues = new String[values.length];
      for (int i = 0; i < values.length; i++) {
        newValues[i] = cleanXss(values[i]);
      }
      return newValues;
    }
    return values;
  }

  /**
   * 获取所有请求参数，并对键和值都进行XSS过滤.
   *
   * @return 过滤后的参数映射表
   */
  @Override
  public Map<String, String[]> getParameterMap() {
    Map<String, String[]> paramMap = super.getParameterMap();
    if (Objects.isNull(paramMap) || paramMap.isEmpty()) {
      return Map.of();
    }
    Map<String, String[]> newParamMap = HashMap.newHashMap(paramMap.size());
    for (Map.Entry<String, String[]> entry : paramMap.entrySet()) {
      String[] values = new String[entry.getValue().length];
      for (int i = 0; i < entry.getValue().length; i++) {
        values[i] = cleanXss(entry.getValue()[i]);
      }
      newParamMap.put(cleanXss(entry.getKey()), values);
    }
    return newParamMap;
  }

  /**
   * 获取请求头信息，并对其进行XSS过滤.
   *
   * @param name 请求头名称
   * @return 过滤后的请求头值
   */
  @Override
  public String getHeader(String name) {
    String value = super.getHeader(name);
    if (StringUtils.isNotBlank(value)) {
      return cleanXss(value);
    }
    return value;
  }

  /**
   * 获取请求输入流，并对其进行XSS过滤.
   *
   * @return 过滤后的ServletInputStream
   * @throws IOException IO异常
   */
  @Override
  public ServletInputStream getInputStream() throws IOException {
    final ByteArrayInputStream stream =
        new ByteArrayInputStream(inputHandlers(super.getInputStream()).getBytes());

    return new ServletInputStream() {

      @Override
      public int read() throws IOException {
        return stream.read();
      }

      @Override
      public boolean isFinished() {
        return false;
      }

      @Override
      public boolean isReady() {
        return false;
      }

      @Override
      public void setReadListener(ReadListener readListener) {
        throw new UnsupportedOperationException();
      }
    };
  }

  /**
   * 处理输入流数据，读取并清理其中的XSS攻击字符.
   *
   * @param inputStream 输入流
   * @return 清理后的内容
   * @throws IOException IO异常
   */
  public String inputHandlers(ServletInputStream inputStream) throws IOException {
    StringBuilder sb = new StringBuilder();
    try (final BufferedReader reader =
        new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8))) {
      String line;
      while ((line = reader.readLine()) != null) {
        sb.append(line);
      }
    } finally {
      if (inputStream != null) {
        inputStream.close();
      }
    }
    return cleanXss(sb.toString());
  }
}
