device_vulkan.c
cpp
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// A example of setting up the the vulkan driver.
#include <stddef.h>
#include "iree/base/api.h"
#include "iree/hal/api.h"
#include "iree/hal/drivers/vulkan/registration/driver_module.h"
// Compiled module embedded here to avoid file IO:
#include "samples/simple_embedding/simple_embedding_test_bytecode_module_vulkan_c.h"
iree_status_t create_sample_device(iree_allocator_t host_allocator,
iree_hal_device_t** out_device) {
// Only register the Vulkan HAL driver.
IREE_RETURN_IF_ERROR(iree_hal_vulkan_driver_module_register(
iree_hal_driver_registry_default()));
// Create the HAL driver from the name.
iree_hal_driver_t* driver = NULL;
iree_string_view_t identifier = iree_make_cstring_view("vulkan");
iree_status_t status = iree_hal_driver_registry_try_create(
iree_hal_driver_registry_default(), identifier, host_allocator, &driver);
// Create the default device (primary GPU).
if (iree_status_is_ok(status)) {
status = iree_hal_driver_create_default_device(driver, host_allocator,
out_device);
}
iree_hal_driver_release(driver);
return iree_ok_status();
}
iree_status_t read_file(const char* path, iree_allocator_t allocator,
void** out_data, size_t* out_size) {
FILE* file = fopen(path, "rb");
if (!file) {
return iree_make_status(IREE_STATUS_NOT_FOUND, "failed to open file '%s'", path);
}
if (fseek(file, 0, SEEK_END) != 0) {
fclose(file);
return iree_make_status(IREE_STATUS_DATA_LOSS, "fseek failed");
}
long size = ftell(file);
if (size < 0) {
fclose(file);
return iree_make_status(IREE_STATUS_DATA_LOSS, "ftell failed");
}
if (fseek(file, 0, SEEK_SET) != 0) {
fclose(file);
return iree_make_status(IREE_STATUS_DATA_LOSS, "rewind failed");
}
void* data = NULL;
iree_status_t status = iree_allocator_malloc(allocator, size, &data);
if (!iree_status_is_ok(status)) {
fclose(file);
return status;
}
size_t bytes_read = fread(data, 1, size, file);
fclose(file);
if (bytes_read != (size_t)size) {
iree_allocator_free(allocator, data);
return iree_make_status(IREE_STATUS_DATA_LOSS, "incomplete read");
}
*out_data = data;
*out_size = (size_t)size;
return iree_ok_status();
}
const iree_const_byte_span_t load_bytecode_module_data() {
const struct iree_file_toc_t* module_file_toc =
iree_samples_simple_embedding_test_module_vulkan_create();
// === 3. Load .vmfb from file ===
const char* model_path = "/data/local/tmp/qwen25_05b/qwen2_5_05b_android_gpu.vmfb";
void* model_data = NULL;
size_t model_size = 0;
read_file(model_path, iree_allocator_system(), &model_data, &model_size);
printf("module_file_toc size: %ld, model_size:%ld\n", module_file_toc->size, model_size);
return iree_make_const_byte_span(model_data,
model_size);
}
simple_embedding.c
cpp
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// A example of setting up the HAL module to run simple pointwise array
// multiplication with the device implemented by different backends via
// create_sample_driver().
//
// NOTE: this file does not properly handle error cases and will leak on
// failure. Applications that are just going to exit()/abort() on failure can
// probably get away with the same thing but really should prefer not to.
#include <stdio.h>
#include <time.h>
#include "iree/base/api.h"
#include "iree/hal/api.h"
#include "iree/modules/hal/module.h"
#include "iree/vm/api.h"
#include "iree/vm/bytecode/module.h"
#define MAX_NEW_TOKENS 100
#define NUM_LAYERS 24
#define NUM_HEADS 2
#define HEAD_DIM 64
#define PROMPT_LEN 6 // 演示 prompt 长度,后面写死
static const int64_t prompt_ids[PROMPT_LEN] = {14880,109432,104455,103949,103168,1773};
// A function to create the HAL device from the different backend targets.
// The HAL device is returned based on the implementation, and it must be
// released by the caller.
extern iree_status_t create_sample_device(iree_allocator_t host_allocator,
iree_hal_device_t** out_device);
// A function to load the vm bytecode module from the different backend targets.
// The bytecode module is generated for the specific backend and platform.
extern const iree_const_byte_span_t load_bytecode_module_data();
iree_status_t Run() {
iree_vm_instance_t* instance = NULL;
IREE_RETURN_IF_ERROR(iree_vm_instance_create(
IREE_VM_TYPE_CAPACITY_DEFAULT, iree_allocator_system(), &instance));
IREE_RETURN_IF_ERROR(iree_hal_module_register_all_types(instance));
iree_hal_device_t* device = NULL;
IREE_RETURN_IF_ERROR(create_sample_device(iree_allocator_system(), &device),
"create device");
iree_vm_module_t* hal_module = NULL;
IREE_RETURN_IF_ERROR(iree_hal_module_create(
instance, iree_hal_module_device_policy_default(), /*device_count=*/1,
&device, IREE_HAL_MODULE_FLAG_SYNCHRONOUS,
iree_hal_module_debug_sink_stdio(stderr), iree_allocator_system(),
&hal_module));
printf("1111\n");
// Load bytecode module from the embedded data.
const iree_const_byte_span_t module_data = load_bytecode_module_data();
iree_vm_module_t* bytecode_module = NULL;
IREE_RETURN_IF_ERROR(iree_vm_bytecode_module_create(
instance, module_data, iree_allocator_null(), iree_allocator_system(),
&bytecode_module));
// Allocate a context that will hold the module state across invocations.
iree_vm_context_t* context = NULL;
iree_vm_module_t* modules[] = {hal_module, bytecode_module};
IREE_RETURN_IF_ERROR(iree_vm_context_create_with_modules(
instance, IREE_VM_CONTEXT_FLAG_NONE, IREE_ARRAYSIZE(modules), &modules[0],
iree_allocator_system(), &context));
iree_vm_module_release(hal_module);
iree_vm_module_release(bytecode_module);
printf("2222\n");
// Lookup the entry point function.
// Note that we use the synchronous variant which operates on pure type/shape
// erased buffers.
const char kMainFunctionName[] = "module.main_graph";
iree_vm_function_t main_function;
IREE_RETURN_IF_ERROR(iree_vm_context_resolve_function(
context, iree_make_cstring_view(kMainFunctionName), &main_function));
// Initial buffer contents for 4 * 2 = 8.
const int64_t prompt_ids[] = {14880, 109432, 104455, 103949, 103168, 1773};
const int64_t attention_mask[] = {1, 1, 1, 1, 1, 1};
const int64_t position_ids[] = {0, 1, 2, 3, 4, 5};
printf("prompt_ids size: %lu\n", sizeof(prompt_ids));
// Allocate buffers in device-local memory so that if the device has an
// independent address space they live on the fast side of the fence.
iree_hal_dim_t shape[2] = {1, PROMPT_LEN};
iree_hal_buffer_view_t* arg0_buffer_view = NULL;
iree_hal_buffer_view_t* arg1_buffer_view = NULL;
iree_hal_buffer_view_t* arg2_buffer_view = NULL;
IREE_RETURN_IF_ERROR(iree_hal_buffer_view_allocate_buffer_copy(
device, iree_hal_device_allocator(device), IREE_ARRAYSIZE(shape), shape,
IREE_HAL_ELEMENT_TYPE_SINT_64, IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR,
(iree_hal_buffer_params_t){
.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL,
.usage = IREE_HAL_BUFFER_USAGE_DEFAULT,
},
iree_make_const_byte_span(prompt_ids, sizeof(prompt_ids)), &arg0_buffer_view));
IREE_RETURN_IF_ERROR(iree_hal_buffer_view_allocate_buffer_copy(
device, iree_hal_device_allocator(device), IREE_ARRAYSIZE(shape), shape,
IREE_HAL_ELEMENT_TYPE_SINT_64, IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR,
(iree_hal_buffer_params_t){
.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL,
.usage = IREE_HAL_BUFFER_USAGE_DEFAULT,
},
iree_make_const_byte_span(attention_mask, sizeof(attention_mask)), &arg1_buffer_view));
IREE_RETURN_IF_ERROR(iree_hal_buffer_view_allocate_buffer_copy(
device, iree_hal_device_allocator(device), IREE_ARRAYSIZE(shape), shape,
IREE_HAL_ELEMENT_TYPE_SINT_64, IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR,
(iree_hal_buffer_params_t){
.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL,
.usage = IREE_HAL_BUFFER_USAGE_DEFAULT,
},
iree_make_const_byte_span(position_ids, sizeof(position_ids)), &arg2_buffer_view));
printf("4444\n");
iree_hal_buffer_view_t* pos_view[NUM_LAYERS*2];
iree_hal_dim_t kv_shape[4] = {1, NUM_HEADS, 6, HEAD_DIM};
for (int i=0;i<NUM_LAYERS*2;++i) {
pos_view[i] = NULL;
IREE_RETURN_IF_ERROR(iree_hal_buffer_view_allocate_buffer_copy(
device, iree_hal_device_allocator(device), IREE_ARRAYSIZE(kv_shape), kv_shape,
IREE_HAL_ELEMENT_TYPE_FLOAT_32, IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR,
(iree_hal_buffer_params_t){
.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL,
.usage = IREE_HAL_BUFFER_USAGE_DEFAULT,
},
iree_make_const_byte_span(NULL, 0), &pos_view[i]));
}
// Setup call inputs with our buffers.
iree_vm_list_t* inputs = NULL;
IREE_RETURN_IF_ERROR(
iree_vm_list_create(iree_vm_make_undefined_type_def(),
/*capacity=*/2, iree_allocator_system(), &inputs),
"can't allocate input vm list");
iree_vm_ref_t arg0_buffer_view_ref =
iree_hal_buffer_view_move_ref(arg0_buffer_view);
iree_vm_ref_t arg1_buffer_view_ref =
iree_hal_buffer_view_move_ref(arg1_buffer_view);
iree_vm_ref_t arg2_buffer_view_ref =
iree_hal_buffer_view_move_ref(arg2_buffer_view);
IREE_RETURN_IF_ERROR(
iree_vm_list_push_ref_move(inputs, &arg0_buffer_view_ref));
IREE_RETURN_IF_ERROR(
iree_vm_list_push_ref_move(inputs, &arg1_buffer_view_ref));
IREE_RETURN_IF_ERROR(
iree_vm_list_push_ref_move(inputs, &arg2_buffer_view_ref));
iree_vm_ref_t pos_view_ref[NUM_LAYERS*2];
for (int i=0;i<NUM_LAYERS*2;++i) {
pos_view_ref[i] =
iree_hal_buffer_view_move_ref(pos_view[i]);
IREE_RETURN_IF_ERROR(
iree_vm_list_push_ref_move(inputs, &pos_view_ref[i]));
}
printf("5555\n");
// Prepare outputs list to accept the results from the invocation.
// The output vm list is allocated statically.
iree_vm_list_t* outputs = NULL;
IREE_RETURN_IF_ERROR(
iree_vm_list_create(iree_vm_make_undefined_type_def(),
/*capacity=*/1 + NUM_LAYERS * 2, iree_allocator_system(), &outputs),
"can't allocate output vm list");
printf("6666\n");
clock_t start_time, end_time;
start_time = clock();
// Synchronously invoke the function.
IREE_RETURN_IF_ERROR(iree_vm_invoke(
context, main_function, IREE_VM_INVOCATION_FLAG_NONE,
/*policy=*/NULL, inputs, outputs, iree_allocator_system()));
end_time = clock();
double dur = ((double)(end_time - start_time))/CLOCKS_PER_SEC;
printf("8888 %f\n", dur);
// Get the result buffers from the invocation.
iree_hal_buffer_view_t* ret_buffer_view =
iree_vm_list_get_buffer_view_assign(outputs, 0);
if (ret_buffer_view == NULL) {
return iree_make_status(IREE_STATUS_NOT_FOUND,
"can't find return buffer view");
}
/*
// Read back the results and ensure we got the right values.
float results[] = {0.0f, 0.0f, 0.0f, 0.0f};
IREE_RETURN_IF_ERROR(iree_hal_device_transfer_d2h(
device, iree_hal_buffer_view_buffer(ret_buffer_view), 0, results,
sizeof(results), IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT,
iree_infinite_timeout()));
for (iree_host_size_t i = 0; i < IREE_ARRAYSIZE(results); ++i) {
if (results[i] != 8.0f) {
printf("%f, ", results[i]);
//return iree_make_status(IREE_STATUS_UNKNOWN, "result mismatches");
}
}
*/
iree_vm_list_release(inputs);
iree_vm_list_release(outputs);
iree_hal_device_release(device);
iree_vm_context_release(context);
iree_vm_instance_release(instance);
return iree_ok_status();
}
int main() {
printf("start\n");
const iree_status_t result = Run();
int ret = (int)iree_status_code(result);
if (!iree_status_is_ok(result)) {
iree_status_fprint(stderr, result);
iree_status_free(result);
}
fprintf(stdout, "simple_embedding done2\n");
return ret;
}