Struct MatxUtil#

Struct Documentation#

struct MatxUtil#

Public Static Functions

static std::shared_ptr<rmm::device_buffer> cast(
const DevMemInfo &input,
TypeId output_type
)#

Convert one device_buffer type to another.

Parameters:
  • input

  • output_type

Returns:

std::shared_ptr<rmm::device_buffer>

static std::shared_ptr<rmm::device_buffer> create_seq_ids(
TensorIndex row_count,
TensorIndex fea_len,
TypeId output_type,
std::shared_ptr<MemoryDescriptor> md,
TensorIndex start_idx = 0
)#

Builds a Nx3 segment ID matrix.

Parameters:
  • row_count

  • fea_len

  • output_type

  • start_idx

Returns:

std::shared_ptr<rmm::device_buffer>

static void offset_seq_ids(
const DevMemInfo &input,
TensorIndex offset
)#

Adds a constant offset to a seg_ids tensor.

Parameters:
  • input

  • offset

static std::shared_ptr<rmm::device_buffer> logits(
const DevMemInfo &input
)#

Calculate logits on device_buffer.

Parameters:

input

Returns:

std::shared_ptr<rmm::device_buffer>

static std::shared_ptr<rmm::device_buffer> transpose(
const DevMemInfo &input
)#

Perform transpose.

Parameters:

input

Returns:

std::shared_ptr<rmm::device_buffer>

static std::shared_ptr<rmm::device_buffer> threshold(
const DevMemInfo &input,
double thresh_val,
bool by_row
)#

Returns an array of boolean where x[i,j] >= thresh_val, when by_row is true an Nx1 array will be returned with a true if any value in the row is above the threshold.

Parameters:
  • input

  • thresh_val

  • by_row

Returns:

std::shared_ptr<rmm::device_buffer>

static std::shared_ptr<rmm::device_buffer> reduce_max(
const DevMemInfo &input,
const ShapeType &seq_ids,
TensorIndex seq_id_offset,
const ShapeType &output_shape
)#

Returns a buffer with output_shape containing the max value from values in input mapped according to seq_ids. Ex given a hypothetical input of:

input =   [5, 2, 8, 9, 8, 2, 1]
seq_ids = [0, 0, 0, 1, 2, 3, 3]
Will return: [8, 9, 8, 2]

Parameters:
  • input

  • seq_ids

  • seq_id_offset

  • output_shape

Returns:

std::shared_ptr<rmm::device_buffer>