Skip to main content
LangChain.js 支持与 AWS SageMaker 托管的端点集成。查看 Amazon SageMaker JumpStart 以获取可用模型列表,以及如何部署您自己的模型。

设置

您需要安装官方的 SageMaker SDK 作为对等依赖项:
npm
npm install @aws-sdk/client-sagemaker-runtime
有关安装 LangChain 包的通用说明,请参阅此部分
npm
npm install @langchain/community @langchain/core

使用方式

import {
  SageMakerEndpoint,
  SageMakerLLMContentHandler,
} from "@langchain/community/llms/sagemaker_endpoint";

interface ResponseJsonInterface {
  generation: {
    content: string;
  };
}

// 根据您将使用的模型进行自定义
class LLama213BHandler implements SageMakerLLMContentHandler {
  contentType = "application/json";

  accepts = "application/json";

  async transformInput(
    prompt: string,
    modelKwargs: Record<string, unknown>
  ): Promise<Uint8Array> {
    const payload = {
      inputs: [[{ role: "user", content: prompt }]],
      parameters: modelKwargs,
    };

    const stringifiedPayload = JSON.stringify(payload);

    return new TextEncoder().encode(stringifiedPayload);
  }

  async transformOutput(output: Uint8Array): Promise<string> {
    const response_json = JSON.parse(
      new TextDecoder("utf-8").decode(output)
    ) as ResponseJsonInterface[];
    const content = response_json[0]?.generation.content ?? "";
    return content;
  }
}

const contentHandler = new LLama213BHandler();

const model = new SageMakerEndpoint({
  endpointName: "aws-llama-2-13b-chat",
  modelKwargs: {
    temperature: 0.5,
    max_new_tokens: 700,
    top_p: 0.9,
  },
  endpointKwargs: {
    CustomAttributes: "accept_eula=true",
  },
  contentHandler,
  clientOptions: {
    region: "您的 AWS 端点区域",
    credentials: {
      accessKeyId: "您的 AWS 访问密钥 ID",
      secretAccessKey: "您的 AWS 秘密访问密钥",
    },
  },
});

const res = await model.invoke(
  "你好,我的名字是 John Doe,给我讲一个关于羊驼的笑话 "
);

console.log(res);

/*
  [
    {
      content: "你好,John Doe!这里有一个关于羊驼的笑话:
        为什么羊驼成了园丁?
        因为它擅长羊驼景观设计!"
    }
  ]
 */

相关链接