Struct MatxUtil

struct MatxUtil

Public Static Functions

static std::shared_ptr<rmm::device_buffer> cast(const DevMemInfo &input, TypeId output_type)

Convert one device_buffer type to another.

Parameters
  • input

  • output_type

Returns

std::shared_ptr<rmm::device_buffer>

static std::shared_ptr<rmm::device_buffer> create_seq_ids(TensorIndex row_count, TensorIndex fea_len, TypeId output_type, std::shared_ptr<MemoryDescriptor> md, TensorIndex start_idx = 0)

Builds a Nx3 segment ID matrix.

Parameters
  • row_count

  • fea_len

  • output_type

  • start_idx

Returns

std::shared_ptr<rmm::device_buffer>

static void offset_seq_ids(const DevMemInfo &input, TensorIndex offset)

Adds a constant offset to a seg_ids tensor.

Parameters
  • input

  • offset

static std::shared_ptr<rmm::device_buffer> logits(const DevMemInfo &input)

Calculate logits on device_buffer.

Parameters

input

Returns

std::shared_ptr<rmm::device_buffer>

static std::shared_ptr<rmm::device_buffer> transpose(const DevMemInfo &input)

Perform transpose.

Parameters

input

Returns

std::shared_ptr<rmm::device_buffer>

static std::shared_ptr<rmm::device_buffer> threshold(const DevMemInfo &input, double thresh_val, bool by_row)

Returns an array of boolean where x[i,j] >= thresh_val, when by_row is true an Nx1 array will be returned with a true if any value in the row is above the threshold.

Parameters
  • input

  • thresh_val

  • by_row

Returns

std::shared_ptr<rmm::device_buffer>

static std::shared_ptr<rmm::device_buffer> reduce_max(const DevMemInfo &input, const ShapeType &seq_ids, TensorIndex seq_id_offset, const ShapeType &output_shape)

Returns a buffer with output_shape containing the max value from values in input mapped according to seq_ids. Ex given a hypothetical input of:

Copy
Copied!
            

input = [5, 2, 8, 9, 8, 2, 1] seq_ids = [0, 0, 0, 1, 2, 3, 3]

Will return: [8, 9, 8, 2]

Parameters
  • input

  • seq_ids

  • seq_id_offset

  • output_shape

Returns

std::shared_ptr<rmm::device_buffer>

© Copyright 2023, NVIDIA. Last updated on Apr 11, 2023.