ONNX Runtime generate() Java API
Note: this API is in preview and is subject to change.
- Install and import
- SimpleGenAI Class
- Exception Class
- Model class
- Tokenizer class
- TokenizerStream class
- Tensor Class
- GeneratorParams class
- Generator class
- Sequences Class
Install and import
The Java API is delivered by the ai.onnxruntime.genai Java package. Package publication is pending. To build the package from source, see the build from source guide.
import ai.onnxruntime.genai.*;
SimpleGenAI Class
The SimpleGenAI
class provides a simple usage example of the GenAI API. It works with a model that generates text based on a prompt, processing a single prompt at a time. Usage:
Create an instance of the class with the path to the model. The path should also contain the GenAI configuration files.
SimpleGenAI genAI = new SimpleGenAI(folderPath);
Call createGeneratorParams with the prompt text. Set any other search options via the GeneratorParams object as needed using setSearchOption
.
GeneratorParams generatorParams = genAI.createGeneratorParams(promptText);
// .. set additional generator params before calling generate()
Call generate with the GeneratorParams object and an optional listener.
String fullResponse = genAI.generate(generatorParams, listener);
The listener is used as a callback mechanism so that tokens can be used as they are generated. Create a class that implements the Consumer<String>
interface and provide an instance of that class as the listener
argument.
Constructor
public SimpleGenAI(String modelPath) throws GenAIException
Throws
GenAIException
- on failure.
Generate Method
Generate text based on the prompt and settings in GeneratorParams.
NOTE: This only handles a single sequence of input (i.e. a single prompt which equates to batch size of 1).
public String generate(GeneratorParams generatorParams, Consumer<String> listener) throws GenAIException
Parameters
generatorParams
: the prompt and settings to run the model with.listener
: optional callback for tokens to be provided as they are generated.
NOTE: Token generation will be blocked until the listener’s accept
method returns.
Throws
GenAIException
- on failure.
Returns
The generated text.
Example
SimpleGenAI generator = new SimpleGenAI(modelPath);
GeneratorParams params = generator.createGeneratorParams("What's 6 times 7?");
Consumer<String> listener = token -> logger.info("onTokenGenerate: " + token);
String result = generator.generate(params, listener);
logger.info("Result: " + result);
createGenerateParams Method
Create the generator parameters and add the prompt text. The user can set other search options via the GeneratorParams object prior to running generate
.
public GeneratorParams createGeneratorParams(String prompt) throws GenAIException
Parameters
prompt
: the prompt text to encode.
Throws
GenAIException
- on failure.
Returns
The generator parameters.
Exception Class
An exception which contains the error message and code produced by the native layer.
Constructor
public GenAIException(String message)
Example
catch (GenAIException e) {
throw new GenAIException("Token generation loop failed.", e);
}
Model class
Constructor
Model(String modelPath)
Create Tokenizer Method
Creates a Tokenizer instance for this model. The model contains the configuration information that determines the tokenizer to use.
public Tokenizer createTokenizer() throws GenAIException
Throws
GenAIException
- if the call to the GenAI native API fails
Returns
The tokenizer instance.
Generate Method
public Sequences generate(GeneratorParams generatorParams) throws GenAIException
Parameters
generatorParams
: the generator parameters.
Throws
GenAIException
- if the call to the GenAI native API fails.
Returns
The generated sequences.
Example
Sequences output = model.generate(generatorParams);
createGeneratorParams Method
Creates a GeneratorParams instance for executing the model.
NOTE: GeneratorParams internally uses the Model, so the Model instance must remain valid.
public GeneratorParams createGeneratorParams() throws GenAIException
Throws
GenAIException
- if the call to the GenAI native API fails.
Returns
The GeneratorParams instance.
Example
GeneratorParams params = generator.createGeneratorParams("What's 6 times 7?");
Tokenizer class
Encode Method
Encodes a string into a sequence of token ids.
public Sequences encode(String string) throws GenAIException
Parameters
string
: text to encode as token ids.
Throws
GenAIException
- if the call to the GenAI native API fails.
Returns
A Sequences object with a single sequence in it.
Example
Sequences encodedPrompt = tokenizer.encode(prompt);
Decode Method
Decodes a sequence of token ids into text.
public String decode(int[] sequence) throws GenAIException
Parameters
sequence
: collection of token ids to decode to text
Throws
GenAIException
- if the call to the GenAI native API fails.
Returns
The text representation of the sequence.
Example
String result = tokenizer.decode(output_ids);
encodeBatch Method
Encodes an array of strings into a sequence of token ids for each input.
public Sequences encodeBatch(String[] strings) throws GenAIException
Parameters
strings
: collection of strings to encode as token ids.
Throws
GenAIException
- if the call to the GenAI native API fails.
Returns
A Sequences object with one sequence per input string.
Example
Sequences encoded = tokenizer.encodeBatch(inputs);
decodeBatch Method
Decodes a batch of sequences of token ids into text.
public String[] decodeBatch(Sequences sequences) throws GenAIException
Parameters
sequences
: a Sequences object with one or more sequences of token ids.
Throws
GenAIException
- if the call to the GenAI native API fails.
Returns
An array of strings with the text representation of each sequence.
Example
String[] decoded = tokenizer.decodeBatch(encoded);
createStream Method
Creates a TokenizerStream object for streaming tokenization. This is used with Generator class to provide each token as it is generated.
public TokenizerStream createStream() throws GenAIException
Throws
GenAIException
- if the call to the GenAI native API fails.
Returns
The new TokenizerStream instance.
TokenizerStream class
This class is used to convert individual tokens when using Generator.generateNextToken.
Decode method
public String decode(int token) throws GenAIException
Parameters
token
: int value for token
Throws
GenAIException
Tensor Class
Constructs a Tensor with the given data, shape, and element type.
public Tensor(ByteBuffer data, long[] shape, ElementType elementType) throws GenAIException
Parameters
data
: the data for the Tensor. Must be a direct ByteBuffer.shape
: the shape of the Tensor.elementType
: the Type of elements in the Tensor.
Throws
GenAIException
Example
Create a 2x2 Tensor with 32-bit float data.
long[] shape = {2, 2};
ByteBuffer data = ByteBuffer.allocateDirect(4 * Float.BYTES);
FloatBuffer floatBuffer = data.asFloatBuffer();
floatBuffer.put(new float[] {1.0f, 2.0f, 3.0f, 4.0f});
Tensor tensor = new Tensor(data, shape, Tensor.ElementType.float32);
GeneratorParams class
The GeneratorParams
class represents the parameters used for generating sequences with a model. Set the prompt using setInput, and any other search options using setSearchOption.
Create a Generator Params object
GeneratorParams params = new GeneratorParams(model);
setSearchOption Method
public void setSearchOption(String optionName, double value) throws GenAIException
Throws
GenAIException
Example
Set search option to limit the model generation length.
generatorParams.setSearchOption("max_length", 10);
setSearchOption Method
public void setSearchOption(String optionName, boolean value) throws GenAIException
Throws
GenAIException
Example
generatorParams.setSearchOption("early_stopping", true);
setInput Method
Sets the prompt/s for model execution. The sequences
are created by using Tokenizer.Encode or EncodeBatch.
public void setInput(Sequences sequences) throws GenAIException
Parameters
sequences
: sequences containing the encoded prompt.
Throws
GenAIException
- if the call to the GenAI native API fails.
Example
generatorParams.setInput(encodedPrompt);
setInput Method
Sets the prompt/s token ids for model execution. The tokenIds
are the encoded parameters.
public void setInput(int[] tokenIds, int sequenceLength, int batchSize)
throws GenAIException
Parameters
tokenIds
: the token ids of the encoded prompt/ssequenceLength
: the length of each sequence.batchSize
: size of the batch.
Throws
GenAIException
- if the call to the GenAI native API fails.
NOTE: all sequences in the batch must be the same length.
Example
generatorParams.setInput(tokenIds, sequenceLength, batchSize);
Generator class
The Generator class generates output using a model and generator parameters. The expected usage is to loop until isDone returns false. Within the loop, call computeLogits followed by generateNextToken.
The newly generated token can be retrieved with getLastTokenInSequence and decoded with TokenizerStream.Decode.
After the generation process is done, GetSequence can be used to retrieve the complete generated sequence if needed.
Create a Generator
Constructs a Generator object with the given model and generator parameters.
Generator(Model model, GeneratorParams generatorParams)
Parameters
model
: the model.params
: the generator parameters.
Throws
GenAIException
- if the call to the GenAI native API fails.
isDone Method
Checks if the generation process is done.
public boolean isDone()
Returns
Returns true if the generation process is done, false otherwise.
computeLogits Method
Computes the logits for the next token in the sequence.
public void computeLogits() throws GenAIException
Throws
GenAIException
- if the call to the GenAI native API fails.
getSequence Method
Retrieves a sequence of token ids for the specified sequence index.
public int[] getSequence(long sequenceIndex) throws GenAIException
Parameters
sequenceIndex
: the index of the sequence.
Throws
GenAIException
- if the call to the GenAI native API fails.
Returns
An array of integers with the sequence of token ids.
Exmaple
int[] outputIds = output.getSequence(i);
generateNextToken Method
Generates the next token in the sequence.
public void generateNextToken() throws GenAIException
Throws
GenAIException
- if the call to the GenAI native API fails.
getLastTokenInSequence Method
Retrieves the last token in the sequence for the specified sequence index.
public int getLastTokenInSequence(long sequenceIndex) throws GenAIException
Parameters
sequenceIndex
: the index of the sequence.
Throws
GenAIException
- if the call to the GenAI native API fails.
Returns
The last token in the sequence.
Sequences Class
Represents a collection of encoded prompts/responses.
numSequences Method
Gets the number of sequences in the collection. This is equivalent to the batch size.
public long numSequences()
Returns
The number of sequences.
Example
int numSequences = (int) sequences.numSequences();
getSequence Method
Gets the sequence at the specified index.
public int[] getSequence(long sequenceIndex)
Parameters
sequenceIndex
: The index of the sequence.
Returns
The sequence as an array of integers.