asc.language.basic.set_vector_mask

asc.language.basic.set_vector_mask(length: int, dtype: DataType, mode: MaskMode) None
asc.language.basic.set_vector_mask(mask_high: int, mask_low: int, dtype: DataType, mode: MaskMode) None

用于在矢量计算时设置mask。使用前需要先调用 set_mask_count/set_mask_norm 设置 mask 模式。 在不同模式下,mask的含义不同:

  • Normal 模式

    mask参数用来控制单次迭代内参与计算的元素个数。此时又可以划分为如下两种模式:

    • 连续模式(len):表示单次迭代中前面连续多少个元素参与计算。取值范围和操作数的数据类型有关,数据类型不同,每次迭代内能够处理的元素个数最大值不同。

      • 操作数为16位时:mask ∈ [1, 128]

      • 操作数为32位时:mask ∈ [1, 64]

      • 操作数为64位时:mask ∈ [1, 32]

    • 逐比特模式(mask_high / mask_low):按位控制参与计算的元素,bit位的值为1表示参与计算,0表示不参与。

      分为mask_high(高位mask)和mask_low(低位mask)。参数取值范围和操作数的数据类型有关,数据类型不同,每次迭代内能够处理的元素个数最大值不同。

      • 操作数为16位时:mask_low、mask_high ∈ [0, 2⁶⁴-1],并且不同时为 0

      • 操作数为32位时:mask_high = 0,mask_low ∈ (0, 2⁶⁴-1]

      • 操作数为64位时:mask_high = 0,mask_low ∈ (0, 2³²-1]

  • Counter 模式

    mask参数表示整个矢量计算参与计算的元素个数。

对应的Ascend C函数原型

template <typename T, MaskMode mode = MaskMode::NORMAL>
__aicore__ static inline void SetVectorMask(const uint64_t maskHigh, const uint64_t maskLow);
template <typename T, MaskMode mode = MaskMode::NORMAL>
__aicore__ static inline void SetVectorMask(int32_t len);

参数说明

  • mask_high

    • Normal模式:对应Normal模式下的逐比特模式,可以按位控制哪些元素参与计算。传入高位mask值。

    • Counter模式:需要置0,本入参不生效。

  • mask_low

    • Normal模式:对应Normal模式下的逐比特模式,可以按位控制哪些元素参与计算。传入低位mask值。

    • Counter模式:整个矢量计算过程中,参与计算的元素个数。

  • len

    • Normal模式:对应Normal模式下的mask连续模式,表示单次迭代内表示前面连续的多少个元素参与计算。

    • Counter模式:整个矢量计算过程中,参与计算的元素个数。

  • dtype:矢量计算操作数的数据类型,由 Python 前端显式指定,用于推导 C++ 模板参数 T。

  • mode: mask 模式,类型为 MaskMode 枚举值 - asc.MaskMode.NORMAL:Normal 模式,支持连续模式与逐比特模式。 - asc.MaskMode.COUNTER:Counter 模式,mask 参数表示整个矢量计算参与的总元素个数。

约束说明

该接口仅在矢量计算API的isSetMask模板参数为false时生效,使用完成后需要使用ResetMask将mask恢复为默认值。

调用示例

  • Counter 模式:整个计算中参与 128 个元素

    len = 128
    asc.set_mask_count()
    asc.set_vector_mask(len, dtype=asc.float16, mode=asc.MaskMode.COUNTER)
    asc.reset_mask()
    
  • Normal 模式(逐bit模式):使用 bitmask 控制参与计算的元素

    mask_high = 2**64 - 1
    mask_low = 2**64 - 1
    asc.set_mask_norm()
    asc.set_vector_mask(mask_high, mask_low, dtype=asc.float16, mode=asc.MaskMode.NORMAL)
    asc.reset_mask()
    
  • Normal 模式(连续模式):前 64 个元素参与每次迭代计算

    len = 64
    asc.set_mask_norm()
    asc.set_vector_mask(len, dtype=asc.float32, mode=asc.MaskMode.NORMAL)
    asc.reset_mask()