trait Generate extends AnyRef
- Alphabetic
- By Inheritance
- Generate
- AnyRef
- Any
- Hide All
- Show All
- Public
- Protected
Abstract Value Members
- abstract def getModelOutput(encoderInputIds: Seq[Array[Int]], decoderInputIds: Seq[Array[Int]], decoderEncoderStateTensors: Either[Tensor, OnnxTensor], encoderAttentionMaskTensors: Either[Tensor, OnnxTensor], maxLength: Int, session: Either[Session, (OrtEnvironment, OrtSession)], ovInferRequest: Option[InferRequest] = None): Array[Array[Float]]
Calls the model and returns the output logits.
Calls the model and returns the output logits.
- encoderInputIds
Input IDs for the Encoder
- decoderInputIds
Input IDs for the Decoder
- decoderEncoderStateTensors
Tensor of encoded input for the decoder
- encoderAttentionMaskTensors
Tensor for encoder attention mask
- maxLength
Max length of the input
- session
Tensorflow Session
- returns
Logits for the input
Concrete Value Members
- final def !=(arg0: Any): Boolean
- Definition Classes
- AnyRef → Any
- final def ##: Int
- Definition Classes
- AnyRef → Any
- final def ==(arg0: Any): Boolean
- Definition Classes
- AnyRef → Any
- final def asInstanceOf[T0]: T0
- Definition Classes
- Any
- def beamSearch(encoderInputIdsVals: Seq[Array[Int]], inputIdsVal: Seq[Array[Int]], decoderEncoderStateTensors: Either[Tensor, OnnxTensor], encoderAttentionMaskTensors: Either[Tensor, OnnxTensor], beamScorer: BeamScorer, logitProcessor: LogitProcessorList, maxLength: Int, padTokenId: Int, eosTokenId: Int, doSample: Boolean, randomSeed: Option[Long], session: Either[Session, (OrtEnvironment, OrtSession)], applySoftmax: Boolean, ovInferRequest: Option[InferRequest] = None, stopTokenIds: Array[Int] = Array()): Array[Array[Int]]
Beam Search for text generation
Beam Search for text generation
- encoderInputIdsVals
encoder input ids vals
- inputIdsVal
input ids val
- decoderEncoderStateTensors
decoder encoder state tensors
- encoderAttentionMaskTensors
encoder attention mask tensors
- beamScorer
beam scorer
- logitProcessor
logit processor
- maxLength
max length
- padTokenId
pad token id
- eosTokenId
eos token id
- doSample
do sample
- randomSeed
random seed
- session
session
- def clone(): AnyRef
- Attributes
- protected[lang]
- Definition Classes
- AnyRef
- Annotations
- @throws(classOf[java.lang.CloneNotSupportedException]) @HotSpotIntrinsicCandidate() @native()
- final def eq(arg0: AnyRef): Boolean
- Definition Classes
- AnyRef
- def equals(arg0: AnyRef): Boolean
- Definition Classes
- AnyRef → Any
- def generate(inputIds: Seq[Array[Int]], decoderEncoderStateTensors: Either[Tensor, OnnxTensor], encoderAttentionMaskTensors: Either[Tensor, OnnxTensor], decoderInputs: Array[Array[Int]], maxOutputLength: Int, minOutputLength: Int, doSample: Boolean, beamSize: Int, numReturnSequences: Int, temperature: Double, topK: Int, topP: Double, repetitionPenalty: Double, noRepeatNgramSize: Int, vocabSize: Int, eosTokenId: Int, paddingTokenId: Int, randomSeed: Option[Long], ignoreTokenIds: Array[Int] = Array(), session: Either[Session, (OrtEnvironment, OrtSession)], applySoftmax: Boolean = true, ovInferRequest: Option[InferRequest] = None, stopTokenIds: Array[Int] = Array()): Array[Array[Int]]
Text Generation using Beam Search
Text Generation using Beam Search
- inputIds
input ids
- decoderEncoderStateTensors
decoder encoder state tensors
- encoderAttentionMaskTensors
encoder attention mask tensors
- decoderInputs
decoder inputs
- maxOutputLength
max output length
- minOutputLength
min output length
- doSample
do sample
- beamSize
beam size
- numReturnSequences
num return sequences
- temperature
temperature
- topK
top K
- topP
top P
- repetitionPenalty
repetition penalty
- noRepeatNgramSize
no repeat ngram size
- vocabSize
vocab size
- eosTokenId
eos token id
- paddingTokenId
padding token id
- randomSeed
random seed
- ignoreTokenIds
ignore token ids
- session
session
- returns
Array of generated sequences
- def getCDF(probs: Array[Float]): Array[Float]
- final def getClass(): Class[_ <: AnyRef]
- Definition Classes
- AnyRef → Any
- Annotations
- @HotSpotIntrinsicCandidate() @native()
- def hashCode(): Int
- Definition Classes
- AnyRef → Any
- Annotations
- @HotSpotIntrinsicCandidate() @native()
- final def isInstanceOf[T0]: Boolean
- Definition Classes
- Any
- def logSoftmax(values: Array[Float]): Array[Float]
- def multinomialSampling(logitValues: Array[Float], k: Int, seed: Option[Long]): Array[Int]
Samples from a multinomial distribution using the provided logits.
Samples from a multinomial distribution using the provided logits.
- logitValues
The logits to sample from
- k
The number of samples to draw
- seed
The random seed to use
- returns
The sampled indices
- final def ne(arg0: AnyRef): Boolean
- Definition Classes
- AnyRef
- final def notify(): Unit
- Definition Classes
- AnyRef
- Annotations
- @HotSpotIntrinsicCandidate() @native()
- final def notifyAll(): Unit
- Definition Classes
- AnyRef
- Annotations
- @HotSpotIntrinsicCandidate() @native()
- def reshapeArray(inputArray: Array[Array[Float]], numRows: Int, numCols: Int): Array[Array[Float]]
Reshapes a 1D array into a 2D array with the specified number of rows and columns.
Reshapes a 1D array into a 2D array with the specified number of rows and columns.
- inputArray
The input array to reshape
- numRows
The number of rows in the output array
- numCols
The number of columns in the output array
- returns
The reshaped array
- def sample(logits: Seq[Float], k: Int, seed: Long = 42): Array[Int]
Samples from a multinomial distribution using the provided logits.
Samples from a multinomial distribution using the provided logits.
- logits
The logits to sample from
- k
The number of samples to draw
- seed
The random seed to use
- returns
The sampled indices
- def softmax(logitValues: Array[Float]): Array[Float]
- final def synchronized[T0](arg0: => T0): T0
- Definition Classes
- AnyRef
- def toString(): String
- Definition Classes
- AnyRef → Any
- final def wait(arg0: Long, arg1: Int): Unit
- Definition Classes
- AnyRef
- Annotations
- @throws(classOf[java.lang.InterruptedException])
- final def wait(arg0: Long): Unit
- Definition Classes
- AnyRef
- Annotations
- @throws(classOf[java.lang.InterruptedException]) @native()
- final def wait(): Unit
- Definition Classes
- AnyRef
- Annotations
- @throws(classOf[java.lang.InterruptedException])
Deprecated Value Members
- def finalize(): Unit
- Attributes
- protected[lang]
- Definition Classes
- AnyRef
- Annotations
- @throws(classOf[java.lang.Throwable]) @Deprecated
- Deprecated
(Since version 9)