Defined in File matx_util.hpp
-
struct MatxUtil
-
Public Static Functions
-
static std::shared_ptr<rmm::device_buffer> cast(const DevMemInfo &input, TypeId output_type)
Convert one device_buffer type to another.
- Parameters
input –
output_type –
- Returns
std::shared_ptr<rmm::device_buffer>
-
static std::shared_ptr<rmm::device_buffer> create_seg_ids(size_t row_count, size_t fea_len, TypeId output_type)
Builds a Nx3 segment ID matrix.
- Parameters
row_count –
fea_len –
output_type –
- Returns
std::shared_ptr<rmm::device_buffer>
-
static std::shared_ptr<rmm::device_buffer> logits(const DevMemInfo &input)
Calculate logits on device_buffer.
- Parameters
input –
- Returns
std::shared_ptr<rmm::device_buffer>
-
static std::shared_ptr<rmm::device_buffer> transpose(const DevMemInfo &input)
Perform transpose.
- Parameters
input –
- Returns
std::shared_ptr<rmm::device_buffer>
-
static std::shared_ptr<rmm::device_buffer> threshold(const DevMemInfo &input, double thresh_val, bool by_row)
Returns an array of boolean where x[i,j] >= thresh_val, when by_row is true an Nx1 array will be returned with a true if any value in the row is above the threshold.
- Parameters
input –
thresh_val –
by_row –
- Returns
std::shared_ptr<rmm::device_buffer>
-
static std::shared_ptr<rmm::device_buffer> reduce_max(const DevMemInfo &input, const std::vector<int32_t> &seq_ids, size_t seq_id_offset, const std::vector<int64_t> &output_shape)
Returns a buffer with
output_shape
containing the max value from values ininput
mapped according toseq_ids
. Ex given a hypothetical input of:input = [5, 2, 8, 9, 8, 2, 1] seq_ids = [0, 0, 0, 1, 2, 3, 3]
Will return: [8, 9, 8, 2]
- Parameters
input –
seq_ids –
seq_id_offset –
output_shape –
- Returns
std::shared_ptr<rmm::device_buffer>
-
static std::shared_ptr<rmm::device_buffer> cast(const DevMemInfo &input, TypeId output_type)