TFLiteでのカスタム演算子の記述— GPU
GPUでカスタムオペレーターを実行するために必要なのは、シェーダーコードを記述し、ここに新しい操作があることをTFLiteに通知することだけです。
私たちのすべての作業はtensorflow/tensorflow/lite/delegates/gpu
、テンソルフローリポジトリのパスにあります。
また、このプロセスで作成したすべての新しいファイルを追跡し、BUILD
適切な場所でファイルに追加する必要があります。また、必要に応じて適切なヘッダーファイルをインポートします。
オペレーションの登録
を探しcustom_resitry.cc
ます。そこに空の関数がありますRegisterCustomOps
。これはregistry.cc
、組み込みのopsを登録した後に呼び出されます。関数に渡されるシェーダーのハッシュマップにカスタム操作を追加する必要があります。この行を関数に追加します。
(*shaders)["sin"].push_back(NewSinNodeShader())
TFLiteにopの属性と出力サイズ(メモリ割り当て用)の読み取り方法を知らせるには、opパーサーを作成する必要があります。これはに座り込みtensorflow/tensorflow/lite/delegates/gpu/common/default
ます。
ヘッダーファイルを作成し、op属性をキャプチャする構造体とを継承するクラスを定義しますTFLiteOperationParser
。
struct SinAttributes {
float frequency;
float phase;
}
class SinOperationParser: public TFLiteOperationParser {
// 2 functions to be overriden
IsSupported(some args);
Parse(some args);
}
// data = tflite_node->data
// where tflite_node is a parameter of the function
auto cast_data = reinterpreted_cast<const uint8_t*>(data)
const flexbuffers::Map map = flexbuffers::GetRoot(cast_data, data_size).AsMap();
SinAttributes attr;
// freq is the parameter name in python api of the op
attr.frequency = map["freq"].AsFloat();
attr.phase = map["phase"].AsFloat();
// Inside NewCustomOperationParser()
if (op_name == "sin") {
return std::make_unique<SinOperationParser>();
}
すべてのGPUオペレーションは、クラスを継承し、関数NodeShader
をオーバーライドする必要があります。このGenerateCode
関数は、GLSLで記述されたシェーダーコードが存在する場所です。
これは座り込みますtensorflow/tensorflow/lite/delegates/gpu/common/gl
最初のパラメータはcontext
、操作の属性、入力および出力の形状などの重要な情報を提供するものです。属性はタイプstd::any
であるため、データにアクセスする前にカスタム構造体にキャストする必要があります。入力形状と出力形状はベクトルのベクトルです。
2番目のパラメーターはGenerateCode
、ポインターとして渡される構造体であり、関数本体で指定する必要があります。以下のメンバーがいます
kernel_size
パラメータ—glslシェーダーでハードコーディングされるなどの単純なデータのキーと値のペア- オブジェクト—カーネルの重みなどの入力テンソル以外の読み取り専用データのキーと値のペア。glslシェーダーで均一なバッファーオブジェクト*としてバインドされます。
- shared_variables —読み取り/書き込みされるが、出力テンソルではないデータのキーと値のペア。おそらく中間の計算結果を格納するために、glslシェーダーでSSBO*としてバインドされます。
- ワークロード—起動するスレッド数に関する3Dintベクトル
workload.x / workgroup.x
ワークグループ—各ワークグループのスレッド数を示す3D整数ベクトル。通常、ワークロードとワークグループは、ワークグループごとに1つのスレッドを持つオペレーターの出力形状になります- source_code —実際のGLSLソースコード
- input、output —入力テンソルと出力テンソルにアクセスする方法を指定する列挙型。各スレッドに対応する要素の読み取り/書き込みのみが必要な要素ごとの操作を記述している場合は、を使用します
value_n = op(value_n)
。入力テンソルの任意の要素にアクセスする場合は、を使用します。output_data_n = op(input_data_n[x,y], input_data_n[x+dx,y+dy])
ここで、x
およびはスレッドy
を使用して取得されます。これが入出力テンソルのインデックスですInvocationID
。n
const parameters = {
"frequency": ctx.attr.frequency,
"phase": ctx.attr.phase,
}
*generated_code = {
parameters,
{}, // objects
{}, // shared variables
uint3(), // workload (0,0,0) means one thread for one output element
uint3(), // workgroup
R"value_0 = sin(value_0/$frequency$ + $phase$)", // GLSL code
IOStructure::AUTO, // input access type
IOStructure::AUTO // output access type
}