扩展LangChain4j,使其支持Deepseek-R1模型,能输出推理内容
| Java
评论 0 | 点赞 0 | 浏览 640

公司现有一个AI项目,使用的是 LangChain4j  + openai4j 来调用AI,一开始仅使用OpenAI,现在有接入deepseek的需求。

尽管 deepseek 的api兼容了openai的格式,但是deepseek-r1模型与openai还是有几点不同:

  • 返回参数中多了推理内容 reasoning_content 字段;
  • 在使用上下文记忆的时候,上下文中不需要包含 reasoning_content  中的内容

查了一下 openai4j 的源码,发现封装的返回体中没有 reasoning_content  字段,所以针对deepseek r1,不可以直接用openai4j来调用。


然后我又去网上找了一下,发现了开源的deepseek4j,于是我思考能不能把LangChain4j和deepseek4j组合起来。

首先要搞明白 LangChain4j 和 deepseek4j 各自的作用。

LangChain4j 并不直接调用api,他是对api调用前后的封装,增加了上下文记忆、多模型统一调用、tool调用等过程,简化AI业务的开发;

deepseek4j 则是真正调用api的库。

搞清楚这些,思路就有了,只需要把之前 LangChain4j 调用openai4j的地方改成调用deepseek4j 即可,为此,需要先看LangChain4j 源码,搞清楚他的调用链。

源码里找到DefaultAiServices.class,可以看到有这么一段代码:

    public T build() {
        //......

        Object proxyInstance = Proxy.newProxyInstance(this.context.aiServiceClass.getClassLoader(), new Class[]{this.context.aiServiceClass}, new InvocationHandler() {
            private final ExecutorService executor = Executors.newCachedThreadPool();

            public Object invoke(Object proxy, Method method, Object[] args) throws Exception {
                if (method.getDeclaringClass() == Object.class) {
                    return method.invoke(this, args);
                } else {
                    //.....

                    if (streaming) {
                        TokenStream tokenStream = new AiServiceTokenStream((List)messages, (List)toolSpecifications, (Map)toolExecutors, augmentationResult != null ? augmentationResult.contents() : null, DefaultAiServices.this.context, memoryId);
                        return returnType == TokenStream.class ? tokenStream : this.adapt(tokenStream, returnType);
                    } else {
                        Response response;
                        if (supportsJsonSchema && jsonSchema.isPresent()) {
                            ChatRequest chatRequest = ChatRequest.builder().messages((List)messages).toolSpecifications((List)toolSpecifications).responseFormat(ResponseFormat.builder().type(ResponseFormatType.JSON).jsonSchema((JsonSchema)jsonSchema.get()).build()).build();
                            ChatResponse chatResponse = DefaultAiServices.this.context.chatModel.chat(chatRequest);
                            response = new Response(chatResponse.aiMessage(), chatResponse.tokenUsage(), chatResponse.finishReason());
                        } else {
                            response = toolSpecifications == null ? DefaultAiServices.this.context.chatModel.generate((List)messages) : DefaultAiServices.this.context.chatModel.generate((List)messages, (List)toolSpecifications);
                        }// ....

    }

这里我省略了很多非关键代码,可以看到,对于流式输出调用,LangChain4j 是调用的StreamingChatLanguageModel的generate方法。

接下来看一下源码中的OpenAiStreamingChatModel,他是StreamingChatLanguageModel的实现类:

public class OpenAiStreamingChatModel implements StreamingChatLanguageModel, TokenCountEstimator {
    
    private void generate(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications, ToolSpecification toolThatMustBeExecuted, StreamingResponseHandler<AiMessage> handler) {

        this.client.chatCompletion(request).onPartialResponse((partialResponse) -> {
            handle(partialResponse, handler);
        }).onComplete(() -> {
           listener.onResponse(responseContext);
            handler.onComplete(response);
        }).onError((error) -> {
            listener.onError(errorContext);
            handler.onError(error);
        }).execute();
    }

    private static void handle(ChatCompletionResponse partialResponse, StreamingResponseHandler<AiMessage> handler) {
        List<ChatCompletionChoice> choices = partialResponse.choices();
        if (choices != null && !choices.isEmpty()) {
            Delta delta = ((ChatCompletionChoice)choices.get(0)).delta();
            String content = delta.content();
            if (content != null) {
                handler.onNext(content);
            }

        }
    }
}


方法参数中的StreamingResponseHandler<AiMessage> handler是负责流式输出的,每次调用handler.onNext('')方法,都会流式的输出一段内容;每次有数据流到达后,都会调用onPartialResponse()方法,当数据流全部返回后,会调用onComplete()方法;因此我们只需要重写一个类实现StreamingChatLanguageModel接口,把AI接口调用方式从openai换成deepseek,并且重写onPartialResponse和onComplete使其适配我们的业务即可。

具体的实现

1、引入LangChain4j和deepseek4j的maven依赖

<langchain4j.version>0.35.0</langchain4j.version>     
<dependency>
   <groupId>dev.langchain4j</groupId>
   <artifactId>langchain4j</artifactId>
 <version>${langchain4j.version}</version>
</dependency>
<dependency>
   <groupId>dev.langchain4j</groupId>
   <artifactId>langchain4j-reactor</artifactId>
   <version>${langchain4j.version}</version>
</dependency>
<dependency>
   <groupId>io.github.pig-mesh.ai</groupId>
   <artifactId>deepseek4j-core</artifactId>
   <version>1.4.5</version>
</dependency>

注意,这里只需要引入deepseek4j-core,而不需要引入deepseek-spring-boot-starter 。

2、模仿OpenAiStreamingChatModel,重写一个DeepSeekStreamingChatModel


import com.alibaba.fastjson.JSON;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.Tokenizer;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.chat.TokenCountEstimator;
import dev.langchain4j.model.chat.listener.*;
import dev.langchain4j.model.openai.OpenAiTokenizer;
import dev.langchain4j.model.output.Response;
import io.github.pigmesh.ai.deepseek.core.DeepSeekClient;
import io.github.pigmesh.ai.deepseek.core.chat.*;
import io.github.pigmesh.ai.deepseek.core.shared.StreamOptions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.net.Proxy;
import java.time.Duration;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicReference;

public class DeepSeekStreamingChatModel implements StreamingChatLanguageModel, TokenCountEstimator {
    private static final Logger log = LoggerFactory.getLogger(dev.langchain4j.model.openai.OpenAiStreamingChatModel.class);
    private final DeepSeekClient client;
    private final String modelName;
    private final Double temperature;
    private final Double topP;
    private final List<String> stop;
    private final Integer maxTokens;
    private final Integer maxCompletionTokens;
    private final Double presencePenalty;
    private final Double frequencyPenalty;
    private final Map<String, Integer> logitBias;
    private final ResponseFormat responseFormat;
    private final Integer seed;
    private final String user;
    private final Boolean strictTools;
    private final Boolean parallelToolCalls;
    private final Tokenizer tokenizer;
    private final List<ChatModelListener> listeners;
    private final String baseUrl;
<br/>    //这里是deepseek-r1推理内容的前缀,用来和content做区分
    public static final String REASON_CONTENT_PREFIX = "reasoning^&Content:";

    public DeepSeekStreamingChatModel(String baseUrl, String apiKey, String organizationId, String modelName, Double temperature, Double topP, List<String> stop, Integer maxTokens, Integer maxCompletionTokens, Double presencePenalty, Double frequencyPenalty, Map<String, Integer> logitBias, String responseFormat, Integer seed, String user, Boolean strictTools, Boolean parallelToolCalls, Duration timeout, Proxy proxy, Boolean logRequests, Boolean logResponses, Tokenizer tokenizer, Map<String, String> customHeaders, List<ChatModelListener> listeners) {
        timeout = (Duration)Utils.getOrDefault(timeout, Duration.ofSeconds(60L));
        this.baseUrl = baseUrl;
        this.client = DeepSeekClient.builder().baseUrl((String)Utils.getOrDefault(baseUrl, "https://api.openai.com/v1")).openAiApiKey(apiKey).organizationId(organizationId).callTimeout(timeout).connectTimeout(timeout).readTimeout(timeout).writeTimeout(timeout).proxy(proxy).logRequests(logRequests).logStreamingResponses(logResponses).userAgent("langchain4j-openai").customHeaders(customHeaders).build();
        this.modelName = (String)Utils.getOrDefault(modelName, "gpt-3.5-turbo");
        this.temperature = (Double)Utils.getOrDefault(temperature, 0.7);
        this.topP = topP;
        this.stop = stop;
        this.maxTokens = maxTokens;
        this.maxCompletionTokens = maxCompletionTokens;
        this.presencePenalty = presencePenalty;
        this.frequencyPenalty = frequencyPenalty;
        this.logitBias = logitBias;
        this.responseFormat = responseFormat == null ? null : ResponseFormat.builder().type(ResponseFormatType.valueOf(responseFormat.toUpperCase(Locale.ROOT))).build();
        this.seed = seed;
        this.user = user;
        this.strictTools = (Boolean)Utils.getOrDefault(strictTools, false);
        this.parallelToolCalls = parallelToolCalls;
        this.tokenizer = (Tokenizer)Utils.getOrDefault(tokenizer, OpenAiTokenizer::new);
        this.listeners = (List)(listeners == null ? Collections.emptyList() : new ArrayList(listeners));
    }

    public String modelName() {
        return this.modelName;
    }

    public void generate(List<ChatMessage> messages, StreamingResponseHandler<AiMessage> handler) {
        this.generate(messages, (List)null, (ToolSpecification)null, handler);
    }

    public void generate(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications, StreamingResponseHandler<AiMessage> handler) {
        this.generate(messages, toolSpecifications, (ToolSpecification)null, handler);
    }

    public void generate(List<ChatMessage> messages, ToolSpecification toolSpecification, StreamingResponseHandler<AiMessage> handler) {
        this.generate(messages, (List)null, toolSpecification, handler);
    }

    private void generate(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications, ToolSpecification toolThatMustBeExecuted, StreamingResponseHandler<AiMessage> handler) {
        ChatCompletionRequest.Builder requestBuilder = ChatCompletionRequest.builder().stream(true).streamOptions(StreamOptions.builder().includeUsage(true).build()).model(this.modelName).messages(InternalDeepSeekHelper.toOpenAiMessages(messages)).temperature(this.temperature).topP(this.topP).stop(this.stop).maxTokens(this.maxTokens).maxCompletionTokens(this.maxCompletionTokens).presencePenalty(this.presencePenalty).frequencyPenalty(this.frequencyPenalty).logitBias(this.logitBias).responseFormat(this.responseFormat).seed(this.seed).user(this.user).parallelToolCalls(this.parallelToolCalls);
        if (toolThatMustBeExecuted != null) {
            requestBuilder.tools(InternalDeepSeekHelper.toTools(Collections.singletonList(toolThatMustBeExecuted), this.strictTools));
            requestBuilder.toolChoice(toolThatMustBeExecuted.name());
        } else if (!Utils.isNullOrEmpty(toolSpecifications)) {
            requestBuilder.tools(InternalDeepSeekHelper.toTools(toolSpecifications, this.strictTools));
        }

        ChatCompletionRequest request = requestBuilder.build();
        ChatModelRequest modelListenerRequest = InternalDeepSeekHelper.createModelListenerRequest(request, messages, toolSpecifications);
        Map<Object, Object> attributes = new ConcurrentHashMap();
        ChatModelRequestContext requestContext = new ChatModelRequestContext(modelListenerRequest, attributes);
        this.listeners.forEach((listener) -> {
            try {
                listener.onRequest(requestContext);
            } catch (Exception var3) {
                log.warn("Exception while calling model listener", var3);
            }

        });
        DeepSeekStreamingResponseBuilder responseBuilder = new DeepSeekStreamingResponseBuilder();
        AtomicReference<String> responseId = new AtomicReference();
        AtomicReference<String> responseModel = new AtomicReference();
        this.client.chatCompletion(request).onPartialResponse((partialResponse) -> {
            responseBuilder.append(partialResponse);
            handle(partialResponse, handler);
            if (!Utils.isNullOrBlank(partialResponse.id())) {
                responseId.set(partialResponse.id());
            }

            if (!Utils.isNullOrBlank(partialResponse.model())) {
                responseModel.set(partialResponse.model());
            }

        }).onComplete(() -> {
            Response<AiMessage> response = responseBuilder.build();
            ChatModelResponse modelListenerResponse = InternalDeepSeekHelper.createModelListenerResponse((String)responseId.get(), (String)responseModel.get(), response);
            ChatModelResponseContext responseContext = new ChatModelResponseContext(modelListenerResponse, modelListenerRequest, attributes);
            this.listeners.forEach((listener) -> {
                try {
                    listener.onResponse(responseContext);
                } catch (Exception var3) {
                    log.warn("Exception while calling model listener", var3);
                }

            });
            handler.onComplete(response);
        }).onError((error) -> {
            Response<AiMessage> response = responseBuilder.build();
            ChatModelResponse modelListenerPartialResponse = InternalDeepSeekHelper.createModelListenerResponse((String)responseId.get(), (String)responseModel.get(), response);
            ChatModelErrorContext errorContext = new ChatModelErrorContext(error, modelListenerRequest, modelListenerPartialResponse, attributes);
            this.listeners.forEach((listener) -> {
                try {
                    listener.onError(errorContext);
                } catch (Exception var3) {
                    log.warn("Exception while calling model listener", var3);
                }

            });
            handler.onError(error);
        }).execute();
    }

    private static void handle(ChatCompletionResponse partialResponse, StreamingResponseHandler<AiMessage> handler) {
        List<ChatCompletionChoice> choices = partialResponse.choices();
        if (choices != null && !choices.isEmpty()) {
            Delta delta = ((ChatCompletionChoice)choices.get(0)).delta();
            String content = delta.content();
            String reasoningContent = delta.reasoningContent();
            StringBuilder builder = new StringBuilder();

            if (reasoningContent != null) {
                builder.append(REASON_CONTENT_PREFIX).append(reasoningContent);
            }
            if (content != null) {
                builder.append(content);
            }
            if (reasoningContent != null || content != null) {
                handler.onNext(builder.toString());
            }
        }
    }

    public int estimateTokenCount(List<ChatMessage> messages) {
        return this.tokenizer.estimateTokenCountInMessages(messages);
    }
//
//    /** @deprecated */
//    @Deprecated
//    public static DeepSeekStreamingChatModel withApiKey(String apiKey) {
//        return builder().apiKey(apiKey).build();
//    }
//
    public static DeepSeekStreamingChatModelBuilder builder() {
        return new DeepSeekStreamingChatModel.DeepSeekStreamingChatModelBuilder();
    }


    public static class DeepSeekStreamingChatModelBuilder {
        private String baseUrl;
        private String apiKey;
        private String organizationId;
        private String modelName;
        private Double temperature;
        private Double topP;
        private List<String> stop;
        private Integer maxTokens;
        private Integer maxCompletionTokens;
        private Double presencePenalty;
        private Double frequencyPenalty;
        private Map<String, Integer> logitBias;
        private String responseFormat;
        private Integer seed;
        private String user;
        private Boolean strictTools;
        private Boolean parallelToolCalls;
        private Duration timeout;
        private Proxy proxy;
        private Boolean logRequests;
        private Boolean logResponses;
        private Tokenizer tokenizer;
        private Map<String, String> customHeaders;
        private List<ChatModelListener> listeners;

        public DeepSeekStreamingChatModelBuilder() {
        }

        public DeepSeekStreamingChatModel.DeepSeekStreamingChatModelBuilder modelName(String modelName) {
            this.modelName = modelName;
            return this;
        }

        public DeepSeekStreamingChatModel.DeepSeekStreamingChatModelBuilder baseUrl(String baseUrl) {
            this.baseUrl = baseUrl;
            return this;
        }

        public DeepSeekStreamingChatModel.DeepSeekStreamingChatModelBuilder apiKey(String apiKey) {
            this.apiKey = apiKey;
            return this;
        }

        public DeepSeekStreamingChatModel.DeepSeekStreamingChatModelBuilder organizationId(String organizationId) {
            this.organizationId = organizationId;
            return this;
        }

        public DeepSeekStreamingChatModel.DeepSeekStreamingChatModelBuilder temperature(Double temperature) {
            this.temperature = temperature;
            return this;
        }

        public DeepSeekStreamingChatModel.DeepSeekStreamingChatModelBuilder topP(Double topP) {
            this.topP = topP;
            return this;
        }

        public DeepSeekStreamingChatModel.DeepSeekStreamingChatModelBuilder stop(List<String> stop) {
            this.stop = stop;
            return this;
        }

        public DeepSeekStreamingChatModel.DeepSeekStreamingChatModelBuilder maxTokens(Integer maxTokens) {
            this.maxTokens = maxTokens;
            return this;
        }

        public DeepSeekStreamingChatModel.DeepSeekStreamingChatModelBuilder maxCompletionTokens(Integer maxCompletionTokens) {
            this.maxCompletionTokens = maxCompletionTokens;
            return this;
        }

        public DeepSeekStreamingChatModel.DeepSeekStreamingChatModelBuilder presencePenalty(Double presencePenalty) {
            this.presencePenalty = presencePenalty;
            return this;
        }

        public DeepSeekStreamingChatModel.DeepSeekStreamingChatModelBuilder frequencyPenalty(Double frequencyPenalty) {
            this.frequencyPenalty = frequencyPenalty;
            return this;
        }

        public DeepSeekStreamingChatModel.DeepSeekStreamingChatModelBuilder logitBias(Map<String, Integer> logitBias) {
            this.logitBias = logitBias;
            return this;
        }

        public DeepSeekStreamingChatModel.DeepSeekStreamingChatModelBuilder responseFormat(String responseFormat) {
            this.responseFormat = responseFormat;
            return this;
        }

        public DeepSeekStreamingChatModel.DeepSeekStreamingChatModelBuilder seed(Integer seed) {
            this.seed = seed;
            return this;
        }

        public DeepSeekStreamingChatModel.DeepSeekStreamingChatModelBuilder user(String user) {
            this.user = user;
            return this;
        }

        public DeepSeekStreamingChatModel.DeepSeekStreamingChatModelBuilder strictTools(Boolean strictTools) {
            this.strictTools = strictTools;
            return this;
        }

        public DeepSeekStreamingChatModel.DeepSeekStreamingChatModelBuilder parallelToolCalls(Boolean parallelToolCalls) {
            this.parallelToolCalls = parallelToolCalls;
            return this;
        }

        public DeepSeekStreamingChatModel.DeepSeekStreamingChatModelBuilder timeout(Duration timeout) {
            this.timeout = timeout;
            return this;
        }

        public DeepSeekStreamingChatModel.DeepSeekStreamingChatModelBuilder proxy(Proxy proxy) {
            this.proxy = proxy;
            return this;
        }

        public DeepSeekStreamingChatModel.DeepSeekStreamingChatModelBuilder logRequests(Boolean logRequests) {
            this.logRequests = logRequests;
            return this;
        }

        public DeepSeekStreamingChatModel.DeepSeekStreamingChatModelBuilder logResponses(Boolean logResponses) {
            this.logResponses = logResponses;
            return this;
        }

        public DeepSeekStreamingChatModel.DeepSeekStreamingChatModelBuilder tokenizer(Tokenizer tokenizer) {
            this.tokenizer = tokenizer;
            return this;
        }

        public DeepSeekStreamingChatModel.DeepSeekStreamingChatModelBuilder customHeaders(Map<String, String> customHeaders) {
            this.customHeaders = customHeaders;
            return this;
        }

        public DeepSeekStreamingChatModel.DeepSeekStreamingChatModelBuilder listeners(List<ChatModelListener> listeners) {
            this.listeners = listeners;
            return this;
        }

        public DeepSeekStreamingChatModel build() {
            return new DeepSeekStreamingChatModel(this.baseUrl, this.apiKey, this.organizationId, this.modelName, this.temperature, this.topP, this.stop, this.maxTokens, this.maxCompletionTokens, this.presencePenalty, this.frequencyPenalty, this.logitBias, this.responseFormat, this.seed, this.user, this.strictTools, this.parallelToolCalls, this.timeout, this.proxy, this.logRequests, this.logResponses, this.tokenizer, this.customHeaders, this.listeners);
        }

        public String toString() {
            return "OpenAiStreamingChatModel.OpenAiStreamingChatModelBuilder(baseUrl=" + this.baseUrl + ", apiKey=" + this.apiKey + ", organizationId=" + this.organizationId + ", modelName=" + this.modelName + ", temperature=" + this.temperature + ", topP=" + this.topP + ", stop=" + this.stop + ", maxTokens=" + this.maxTokens + ", maxCompletionTokens=" + this.maxCompletionTokens + ", presencePenalty=" + this.presencePenalty + ", frequencyPenalty=" + this.frequencyPenalty + ", logitBias=" + this.logitBias + ", responseFormat=" + this.responseFormat + ", seed=" + this.seed + ", user=" + this.user + ", strictTools=" + this.strictTools + ", parallelToolCalls=" + this.parallelToolCalls + ", timeout=" + this.timeout + ", proxy=" + this.proxy + ", logRequests=" + this.logRequests + ", logResponses=" + this.logResponses + ", tokenizer=" + this.tokenizer + ", customHeaders=" + this.customHeaders + ", listeners=" + this.listeners + ")";
        }
    }
}
package com.gowining.gowiningai.knowledge.support.handler.model;

import io.github.pigmesh.ai.deepseek.core.chat.ChatCompletionChoice;
import io.github.pigmesh.ai.deepseek.core.chat.ChatCompletionResponse;
import io.github.pigmesh.ai.deepseek.core.chat.Delta;
import io.github.pigmesh.ai.deepseek.core.chat.FunctionCall;
import io.github.pigmesh.ai.deepseek.core.chat.ToolCall;
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.data.message.AiMessage;
import com.gowining.gowiningai.knowledge.support.handler.model.InternalDeepSeekHelper;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import io.github.pigmesh.ai.deepseek.core.completion.CompletionChoice;
import io.github.pigmesh.ai.deepseek.core.completion.CompletionResponse;
import io.github.pigmesh.ai.deepseek.core.shared.Usage;

import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;

public class DeepSeekStreamingResponseBuilder {
    private final StringBuffer contentBuilder = new StringBuffer();
    private final StringBuffer reasoningContentBuilder = new StringBuffer();
    private final StringBuffer toolNameBuilder = new StringBuffer();
    private final StringBuffer toolArgumentsBuilder = new StringBuffer();
    private final Map<Integer, DeepSeekStreamingResponseBuilder.ToolExecutionRequestBuilder> indexToToolExecutionRequestBuilder = new ConcurrentHashMap();
    private volatile TokenUsage tokenUsage;
    private volatile FinishReason finishReason;


    public DeepSeekStreamingResponseBuilder() {
    }

    public void append(ChatCompletionResponse partialResponse) {
        if (partialResponse != null) {
            Usage usage = partialResponse.usage();
            if (usage != null) {
                this.tokenUsage = InternalDeepSeekHelper.tokenUsageFrom(usage);
            }

            List<ChatCompletionChoice> choices = partialResponse.choices();
            if (choices != null && !choices.isEmpty()) {
                ChatCompletionChoice chatCompletionChoice = (ChatCompletionChoice)choices.get(0);
                if (chatCompletionChoice != null) {
                    String finishReason = chatCompletionChoice.finishReason();
                    if (finishReason != null) {
                        this.finishReason = InternalDeepSeekHelper.finishReasonFrom(finishReason);
                    }

                    Delta delta = chatCompletionChoice.delta();
                    if (delta != null) {
                        String content = delta.content();
                        String reasoningContent = delta.reasoningContent();
                        if (reasoningContent != null) {
                            this.reasoningContentBuilder.append(reasoningContent);
                        }
                        if (content != null) {
                            this.contentBuilder.append(content);
                        } else {
                            if (delta.functionCall() != null) {
                                FunctionCall functionCall = delta.functionCall();
                                if (functionCall.name() != null) {
                                    this.toolNameBuilder.append(functionCall.name());
                                }

                                if (functionCall.arguments() != null) {
                                    this.toolArgumentsBuilder.append(functionCall.arguments());
                                }
                            }

                            if (delta.toolCalls() != null && !delta.toolCalls().isEmpty()) {
                                ToolCall toolCall = (ToolCall)delta.toolCalls().get(0);
                                DeepSeekStreamingResponseBuilder.ToolExecutionRequestBuilder toolExecutionRequestBuilder = (DeepSeekStreamingResponseBuilder.ToolExecutionRequestBuilder)this.indexToToolExecutionRequestBuilder.computeIfAbsent(toolCall.index(), (idx) -> {
                                    return new DeepSeekStreamingResponseBuilder.ToolExecutionRequestBuilder();
                                });
                                if (toolCall.id() != null) {
                                    toolExecutionRequestBuilder.idBuilder.append(toolCall.id());
                                }

                                FunctionCall functionCall = toolCall.function();
                                if (functionCall.name() != null) {
                                    toolExecutionRequestBuilder.nameBuilder.append(functionCall.name());
                                }

                                if (functionCall.arguments() != null) {
                                    toolExecutionRequestBuilder.argumentsBuilder.append(functionCall.arguments());
                                }
                            }

                        }
                    }
                }
            }
        }
    }

    public void append(CompletionResponse partialResponse) {
        if (partialResponse != null) {
            Usage usage = partialResponse.usage();
            if (usage != null) {
                this.tokenUsage = InternalDeepSeekHelper.tokenUsageFrom(usage);
            }

            List<CompletionChoice> choices = partialResponse.choices();
            if (choices != null && !choices.isEmpty()) {
                CompletionChoice completionChoice = (CompletionChoice)choices.get(0);
                if (completionChoice != null) {
                    String finishReason = completionChoice.finishReason();
                    if (finishReason != null) {
                        this.finishReason = InternalDeepSeekHelper.finishReasonFrom(finishReason);
                    }

                    String token = completionChoice.text();
                    if (token != null) {
                        this.contentBuilder.append(token);
                    }

                }
            }
        }
    }

    public Response<AiMessage> build() {
        String content = this.contentBuilder.toString();
        if (!content.isEmpty()) {
            return Response.from(AiMessage.from(content), this.tokenUsage, this.finishReason);
        } else {
            String toolName = this.toolNameBuilder.toString();
            if (!toolName.isEmpty()) {
                ToolExecutionRequest toolExecutionRequest = ToolExecutionRequest.builder().name(toolName).arguments(this.toolArgumentsBuilder.toString()).build();
                return Response.from(AiMessage.from(new ToolExecutionRequest[]{toolExecutionRequest}), this.tokenUsage, this.finishReason);
            } else if (!this.indexToToolExecutionRequestBuilder.isEmpty()) {
                List<ToolExecutionRequest> toolExecutionRequests = (List)this.indexToToolExecutionRequestBuilder.values().stream().map((it) -> {
                    return ToolExecutionRequest.builder().id(it.idBuilder.toString()).name(it.nameBuilder.toString()).arguments(it.argumentsBuilder.toString()).build();
                }).collect(Collectors.toList());
                return Response.from(AiMessage.from(toolExecutionRequests), this.tokenUsage, this.finishReason);
            } else {
                return null;
            }
        }
    }

    private static class ToolExecutionRequestBuilder {
        private final StringBuffer idBuilder;
        private final StringBuffer nameBuilder;
        private final StringBuffer argumentsBuilder;

        private ToolExecutionRequestBuilder() {
            this.idBuilder = new StringBuffer();
            this.nameBuilder = new StringBuffer();
            this.argumentsBuilder = new StringBuffer();
        }
    }
}

package com.gowining.gowiningai.knowledge.support.handler.model;

import io.github.pigmesh.ai.deepseek.core.chat.AssistantMessage;
import io.github.pigmesh.ai.deepseek.core.chat.ChatCompletionChoice;
import io.github.pigmesh.ai.deepseek.core.chat.ChatCompletionRequest;
import io.github.pigmesh.ai.deepseek.core.chat.ChatCompletionResponse;
import io.github.pigmesh.ai.deepseek.core.chat.Content;
import io.github.pigmesh.ai.deepseek.core.chat.ContentType;
import io.github.pigmesh.ai.deepseek.core.chat.Function;
import io.github.pigmesh.ai.deepseek.core.chat.FunctionCall;
import io.github.pigmesh.ai.deepseek.core.chat.FunctionMessage;
import io.github.pigmesh.ai.deepseek.core.chat.ImageDetail;
import io.github.pigmesh.ai.deepseek.core.chat.ImageUrl;
import io.github.pigmesh.ai.deepseek.core.chat.JsonArraySchema;
import io.github.pigmesh.ai.deepseek.core.chat.JsonBooleanSchema;
import io.github.pigmesh.ai.deepseek.core.chat.JsonEnumSchema;
import io.github.pigmesh.ai.deepseek.core.chat.JsonIntegerSchema;
import io.github.pigmesh.ai.deepseek.core.chat.JsonNumberSchema;
import io.github.pigmesh.ai.deepseek.core.chat.JsonObjectSchema;
import io.github.pigmesh.ai.deepseek.core.chat.JsonSchemaElement;
import io.github.pigmesh.ai.deepseek.core.chat.JsonStringSchema;
import io.github.pigmesh.ai.deepseek.core.chat.Message;
import io.github.pigmesh.ai.deepseek.core.chat.ResponseFormat;
import io.github.pigmesh.ai.deepseek.core.chat.Tool;
import io.github.pigmesh.ai.deepseek.core.chat.ToolCall;
import io.github.pigmesh.ai.deepseek.core.chat.ToolMessage;
import io.github.pigmesh.ai.deepseek.core.chat.ToolType;
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolParameters;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.image.Image;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ImageContent;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.TextContent;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.internal.Exceptions;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.model.chat.listener.ChatModelRequest;
import dev.langchain4j.model.chat.listener.ChatModelResponse;
import dev.langchain4j.model.chat.request.ResponseFormatType;
import dev.langchain4j.model.chat.request.json.JsonSchema;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import io.github.pigmesh.ai.deepseek.core.shared.Usage;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

public class InternalDeepSeekHelper {
    static final String OPENAI_URL = "https://api.openai.com/v1";
    static final String OPENAI_DEMO_API_KEY = "demo";
    static final String OPENAI_DEMO_URL = "http://langchain4j.dev/demo/openai/v1";
    static final String DEFAULT_USER_AGENT = "langchain4j-openai";

    public InternalDeepSeekHelper() {
    }

    public static List<Message> toOpenAiMessages(List<ChatMessage> messages) {
        return (List)messages.stream().map(dev.langchain4j.model.openai.InternalOpenAiHelper::toOpenAiMessage).collect(Collectors.toList());
    }

    public static Message toOpenAiMessage(ChatMessage message) {
        if (message instanceof SystemMessage) {
            return io.github.pigmesh.ai.deepseek.core.chat.SystemMessage.from(((SystemMessage)message).text());
        } else if (message instanceof UserMessage) {
            UserMessage userMessage = (UserMessage)message;
            return userMessage.hasSingleText() ? io.github.pigmesh.ai.deepseek.core.chat.UserMessage.builder().content(userMessage.text()).name(userMessage.name()).build() : io.github.pigmesh.ai.deepseek.core.chat.UserMessage.builder().content((List)userMessage.contents().stream().map(com.gowining.gowiningai.knowledge.support.handler.model.InternalDeepSeekHelper::toOpenAiContent).collect(Collectors.toList())).name(userMessage.name()).build();
        } else if (message instanceof AiMessage) {
            AiMessage aiMessage = (AiMessage)message;
            if (!aiMessage.hasToolExecutionRequests()) {
                return AssistantMessage.from(aiMessage.text());
            } else {
                ToolExecutionRequest toolExecutionRequest = (ToolExecutionRequest)aiMessage.toolExecutionRequests().get(0);
                if (toolExecutionRequest.id() == null) {
                    FunctionCall functionCall = FunctionCall.builder().name(toolExecutionRequest.name()).arguments(toolExecutionRequest.arguments()).build();
                    return AssistantMessage.builder().functionCall(functionCall).build();
                } else {
                    List<ToolCall> toolCalls = (List)aiMessage.toolExecutionRequests().stream().map((it) -> {
                        return ToolCall.builder().id(it.id()).type(ToolType.FUNCTION).function(FunctionCall.builder().name(it.name()).arguments(it.arguments()).build()).build();
                    }).collect(Collectors.toList());
                    return AssistantMessage.builder().toolCalls(toolCalls).build();
                }
            }
        } else if (message instanceof ToolExecutionResultMessage) {
            ToolExecutionResultMessage toolExecutionResultMessage = (ToolExecutionResultMessage)message;
            return (Message)(toolExecutionResultMessage.id() == null ? FunctionMessage.from(toolExecutionResultMessage.toolName(), toolExecutionResultMessage.text()) : ToolMessage.from(toolExecutionResultMessage.id(), toolExecutionResultMessage.text()));
        } else {
            throw Exceptions.illegalArgument("Unknown message type: " + message.type(), new Object[0]);
        }
    }

    private static Content toOpenAiContent(dev.langchain4j.data.message.Content content) {
        if (content instanceof TextContent) {
            return toOpenAiContent((TextContent)content);
        } else if (content instanceof ImageContent) {
            return toOpenAiContent((ImageContent)content);
        } else {
            throw Exceptions.illegalArgument("Unknown content type: " + content, new Object[0]);
        }
    }

    private static Content toOpenAiContent(TextContent content) {
        return Content.builder().type(ContentType.TEXT).text(content.text()).build();
    }

    private static Content toOpenAiContent(ImageContent content) {
        return Content.builder().type(ContentType.IMAGE_URL).imageUrl(ImageUrl.builder().url(toUrl(content.image())).detail(toDetail(content.detailLevel())).build()).build();
    }

    private static String toUrl(Image image) {
        return image.url() != null ? image.url().toString() : String.format("data:%s;base64,%s", image.mimeType(), image.base64Data());
    }

    private static ImageDetail toDetail(ImageContent.DetailLevel detailLevel) {
        return detailLevel == null ? null : ImageDetail.valueOf(detailLevel.name());
    }

    public static List<Tool> toTools(Collection<ToolSpecification> toolSpecifications, boolean strict) {
        return (List)toolSpecifications.stream().map((toolSpecification) -> {
            return toTool(toolSpecification, strict);
        }).collect(Collectors.toList());
    }

    private static Tool toTool(ToolSpecification toolSpecification, boolean strict) {
        Function.Builder functionBuilder = Function.builder().name(toolSpecification.name()).description(toolSpecification.description()).parameters(toOpenAiParameters(toolSpecification.parameters(), strict));
        if (strict) {
            functionBuilder.strict(true);
        }

        Function function = functionBuilder.build();
        return Tool.from(function);
    }

    /** @deprecated */
    @Deprecated
    public static List<Function> toFunctions(Collection<ToolSpecification> toolSpecifications) {
        return (List)toolSpecifications.stream().map(com.gowining.gowiningai.knowledge.support.handler.model.InternalDeepSeekHelper::toFunction).collect(Collectors.toList());
    }

    /** @deprecated */
    @Deprecated
    private static Function toFunction(ToolSpecification toolSpecification) {
        return Function.builder().name(toolSpecification.name()).description(toolSpecification.description()).parameters(toOpenAiParameters(toolSpecification.parameters(), false)).build();
    }

    private static JsonObjectSchema toOpenAiParameters(ToolParameters toolParameters, boolean strict) {
        JsonObjectSchema.Builder builder;
        if (toolParameters == null) {
            builder = JsonObjectSchema.builder();
            if (strict) {
                builder.additionalProperties(false);
            }

            return builder.build();
        } else {
            builder = JsonObjectSchema.builder().properties(toOpenAiProperties(toolParameters.properties(), strict)).required(toolParameters.required());
            if (strict) {
                builder.required(new ArrayList(toolParameters.properties().keySet())).additionalProperties(false);
            }

            return builder.build();
        }
    }

    private static Map<String, JsonSchemaElement> toOpenAiProperties(Map<String, ?> properties, boolean strict) {
        Map<String, JsonSchemaElement> openAiProperties = new LinkedHashMap();
        properties.forEach((key, value) -> {
            openAiProperties.put(key, toOpenAiJsonSchemaElement((Map)value, strict));
        });
        return openAiProperties;
    }

    private static JsonSchemaElement toOpenAiJsonSchemaElement(Map<String, ?> properties, boolean strict) {
        Object type = properties.get("type");
        String description = (String)properties.get("description");
        if ("object".equals(type)) {
            List<String> required = (List)properties.get("required");
            JsonObjectSchema.Builder builder = JsonObjectSchema.builder().description(description).properties(toOpenAiProperties((Map)properties.get("properties"), strict));
            if (required != null) {
                builder.required(required);
            }

            if (strict) {
                builder.required(new ArrayList(((Map)properties.get("properties")).keySet())).additionalProperties(false);
            }

            return builder.build();
        } else if ("array".equals(type)) {
            return JsonArraySchema.builder().description(description).items(toOpenAiJsonSchemaElement((Map)properties.get("items"), strict)).build();
        } else if (properties.get("enum") != null) {
            return JsonEnumSchema.builder().description(description).enumValues((List)properties.get("enum")).build();
        } else if ("string".equals(type)) {
            return JsonStringSchema.builder().description(description).build();
        } else if ("integer".equals(type)) {
            return JsonIntegerSchema.builder().description(description).build();
        } else if ("number".equals(type)) {
            return JsonNumberSchema.builder().description(description).build();
        } else if ("boolean".equals(type)) {
            return JsonBooleanSchema.builder().description(description).build();
        } else {
            throw new IllegalArgumentException("Unknown type " + type);
        }
    }

    public static AiMessage aiMessageFrom(ChatCompletionResponse response) {
        AssistantMessage assistantMessage = ((ChatCompletionChoice)response.choices().get(0)).message();
        String text = assistantMessage.content();
        List<ToolCall> toolCalls = assistantMessage.toolCalls();
        if (!Utils.isNullOrEmpty(toolCalls)) {
            List<ToolExecutionRequest> toolExecutionRequests = (List)toolCalls.stream().filter((toolCall) -> {
                return toolCall.type() == ToolType.FUNCTION;
            }).map(com.gowining.gowiningai.knowledge.support.handler.model.InternalDeepSeekHelper::toToolExecutionRequest).collect(Collectors.toList());
            return Utils.isNullOrBlank(text) ? AiMessage.from(toolExecutionRequests) : AiMessage.from(text, toolExecutionRequests);
        } else {
            FunctionCall functionCall = assistantMessage.functionCall();
            if (functionCall != null) {
                ToolExecutionRequest toolExecutionRequest = ToolExecutionRequest.builder().name(functionCall.name()).arguments(functionCall.arguments()).build();
                return Utils.isNullOrBlank(text) ? AiMessage.from(new ToolExecutionRequest[]{toolExecutionRequest}) : AiMessage.from(text, Collections.singletonList(toolExecutionRequest));
            } else {
                return AiMessage.from(text);
            }
        }
    }

    private static ToolExecutionRequest toToolExecutionRequest(ToolCall toolCall) {
        FunctionCall functionCall = toolCall.function();
        return ToolExecutionRequest.builder().id(toolCall.id()).name(functionCall.name()).arguments(functionCall.arguments()).build();
    }

    public static TokenUsage tokenUsageFrom(Usage openAiUsage) {
        return openAiUsage == null ? null : new TokenUsage(openAiUsage.promptTokens(), openAiUsage.completionTokens(), openAiUsage.totalTokens());
    }

    public static FinishReason finishReasonFrom(String openAiFinishReason) {
        if (openAiFinishReason == null) {
            return null;
        } else {
            switch (openAiFinishReason) {
                case "stop":
                    return FinishReason.STOP;
                case "length":
                    return FinishReason.LENGTH;
                case "tool_calls":
                case "function_call":
                    return FinishReason.TOOL_EXECUTION;
                case "content_filter":
                    return FinishReason.CONTENT_FILTER;
                default:
                    return null;
            }
        }
    }

    static ChatModelRequest createModelListenerRequest(ChatCompletionRequest request, List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) {
        return ChatModelRequest.builder().model(request.model()).temperature(request.temperature()).topP(request.topP()).maxTokens((Integer)Utils.getOrDefault(request.maxCompletionTokens(), request.maxTokens())).messages(messages).toolSpecifications(toolSpecifications).build();
    }

    static ChatModelResponse createModelListenerResponse(String responseId, String responseModel, Response<AiMessage> response) {
        return response == null ? null : ChatModelResponse.builder().id(responseId).model(responseModel).tokenUsage(response.tokenUsage()).finishReason(response.finishReason()).aiMessage((AiMessage)response.content()).build();
    }

    static ResponseFormat toOpenAiResponseFormat(dev.langchain4j.model.chat.request.ResponseFormat responseFormat, Boolean strict) {
        if (responseFormat != null && responseFormat.type() != ResponseFormatType.TEXT) {
            JsonSchema jsonSchema = responseFormat.jsonSchema();
            if (jsonSchema == null) {
                return ResponseFormat.builder().type(io.github.pigmesh.ai.deepseek.core.chat.ResponseFormatType.JSON_OBJECT).build();
            } else if (!(jsonSchema.rootElement() instanceof dev.langchain4j.model.chat.request.json.JsonObjectSchema)) {
                throw new IllegalArgumentException("For OpenAI, the root element of the JSON Schema must be a JsonObjectSchema, but it was: " + jsonSchema.rootElement().getClass());
            } else {
                io.github.pigmesh.ai.deepseek.core.chat.JsonSchema openAiJsonSchema = io.github.pigmesh.ai.deepseek.core.chat.JsonSchema.builder().name(jsonSchema.name()).strict(strict).schema((JsonObjectSchema)toOpenAiJsonSchemaElement(jsonSchema.rootElement())).build();
                return ResponseFormat.builder().type(io.github.pigmesh.ai.deepseek.core.chat.ResponseFormatType.JSON_SCHEMA).jsonSchema(openAiJsonSchema).build();
            }
        } else {
            return null;
        }
    }

    private static JsonSchemaElement toOpenAiJsonSchemaElement(dev.langchain4j.model.chat.request.json.JsonSchemaElement jsonSchemaElement) {
        if (jsonSchemaElement instanceof dev.langchain4j.model.chat.request.json.JsonStringSchema) {
            return JsonStringSchema.builder().description(((dev.langchain4j.model.chat.request.json.JsonStringSchema)jsonSchemaElement).description()).build();
        } else if (jsonSchemaElement instanceof dev.langchain4j.model.chat.request.json.JsonIntegerSchema) {
            return JsonIntegerSchema.builder().description(((dev.langchain4j.model.chat.request.json.JsonIntegerSchema)jsonSchemaElement).description()).build();
        } else if (jsonSchemaElement instanceof dev.langchain4j.model.chat.request.json.JsonNumberSchema) {
            return JsonNumberSchema.builder().description(((dev.langchain4j.model.chat.request.json.JsonNumberSchema)jsonSchemaElement).description()).build();
        } else if (jsonSchemaElement instanceof dev.langchain4j.model.chat.request.json.JsonBooleanSchema) {
            return JsonBooleanSchema.builder().description(((dev.langchain4j.model.chat.request.json.JsonBooleanSchema)jsonSchemaElement).description()).build();
        } else if (jsonSchemaElement instanceof dev.langchain4j.model.chat.request.json.JsonEnumSchema) {
            return JsonEnumSchema.builder().description(((dev.langchain4j.model.chat.request.json.JsonEnumSchema)jsonSchemaElement).description()).enumValues(((dev.langchain4j.model.chat.request.json.JsonEnumSchema)jsonSchemaElement).enumValues()).build();
        } else if (jsonSchemaElement instanceof dev.langchain4j.model.chat.request.json.JsonArraySchema) {
            return JsonArraySchema.builder().description(((dev.langchain4j.model.chat.request.json.JsonArraySchema)jsonSchemaElement).description()).items(toOpenAiJsonSchemaElement(((dev.langchain4j.model.chat.request.json.JsonArraySchema)jsonSchemaElement).items())).build();
        } else if (jsonSchemaElement instanceof dev.langchain4j.model.chat.request.json.JsonObjectSchema) {
            Map<String, dev.langchain4j.model.chat.request.json.JsonSchemaElement> properties = ((dev.langchain4j.model.chat.request.json.JsonObjectSchema)jsonSchemaElement).properties();
            Map<String, JsonSchemaElement> openAiProperties = new LinkedHashMap();
            properties.forEach((key, value) -> {
                openAiProperties.put(key, toOpenAiJsonSchemaElement(value));
            });
            return JsonObjectSchema.builder().description(((dev.langchain4j.model.chat.request.json.JsonObjectSchema)jsonSchemaElement).description()).properties(openAiProperties).required(((dev.langchain4j.model.chat.request.json.JsonObjectSchema)jsonSchemaElement).required()).additionalProperties(((dev.langchain4j.model.chat.request.json.JsonObjectSchema)jsonSchemaElement).additionalProperties()).build();
        } else {
            throw new IllegalArgumentException("Unknown type: " + jsonSchemaElement);
        }
    }
}

3、调用

构建DeepSeekStreamingChatModel,流式调用返回flux

DeepSeekStreamingChatModel streamingChatLanguageModel = DeepSeekStreamingChatModel.builder()//
                .apiKey(aiModel.getApiKey())//
                .baseUrl(aiModel.getBaseUrl())//
                .modelName(aiModel.getModelName())//
                .temperature(aiModel.getTemperature())//
                .topP(aiModel.getTopP())//
                .maxTokens(aiModel.getResponseLimit())
                .proxy(getProxy(aiModel.getBaseUrl()))
                .listeners(chatModelListenerList)//
                .timeout(Duration.ofSeconds(aiKnowledgeProperties.getConnectTimeout()))
                .logRequests(aiKnowledgeProperties.isShowLog())//
                .logResponses(aiKnowledgeProperties.isShowLog())//
                .build();DeepSeekR1AssistantService deepSeekR1AssistantService = AiServices.builder(DeepSeekR1AssistantService.class)
                        .streamingChatLanguageModel(streamingChatLanguageModel)
                        .chatMemoryProvider(chatMemoryAdvisorProvider)
                        .build();Flux<String> chatResponseFlux = deepSeekR1AssistantService.chat(chatMessageDTO.getConversationId(),
                chatMessageDTO.getContent());return chatResponseFlux.map(AiMessageResultDTO::new);

AiMessageResultDTO的构造方法里要区分是推理内容还是回答内容

public AiMessageResultDTO(String message) {
		if (message != null && message.startsWith(REASON_CONTENT_PREFIX)) {
			reasonContent = message.substring(REASON_CONTENT_PREFIX.length());
			this.message = "";
		} else {
			this.message = message;
		}
	}



本文作者:不是好驴
本文链接:https://www.baddonkey.cn/detail/55
版权声明:原创文章,允许转载,转载请注明出处

高谈阔论

留言列表