From 00369640a87e2adf348d9eb6dde59d920652f7db Mon Sep 17 00:00:00 2001 From: Andre Henriques Date: Mon, 8 Apr 2024 17:45:32 +0100 Subject: [PATCH] feat: improve zip processing speed closes #71 --- config.toml | 3 + logic/models/data.go | 225 +++++++++++++++++++++++++----------- logic/models/train/train.go | 8 +- logic/utils/config.go | 10 +- 4 files changed, 171 insertions(+), 75 deletions(-) diff --git a/config.toml b/config.toml index fc52843..f6b16db 100644 --- a/config.toml +++ b/config.toml @@ -1,2 +1,5 @@ PORT=5002 + HOSTNAME="https://testing.andr3h3nriqu3s.com" + +NUMBER_OF_WORKERS=20 diff --git a/logic/models/data.go b/logic/models/data.go index a6c1287..0212af4 100644 --- a/logic/models/data.go +++ b/logic/models/data.go @@ -28,13 +28,95 @@ func InsertIfNotPresent(ss []string, s string) []string { return ss } +/* +This function will process a single file from the uploaded zip file +*/ +func fileProcessor( + c *Context, + model *BaseModel, + reader *zip.ReadCloser, + ids map[string]string, + base_path string, + index int, + file_chan chan *zip.File, + back_channel chan int, +) { + defer func() { + if r := recover(); r != nil { + c.Logger.Error("Recovered in file processor", "processor id", index, "due to", r) + } + }() + + for file := range file_chan { + c.Logger.Debug("Processing File", "file", file.Name) + data, err := reader.Open(file.Name) + if err != nil { + c.Logger.Error("Could not open file in zip %s\n", "file name", file.Name, "err", err) + back_channel <- index + continue + } + defer data.Close() + file_data, err := io.ReadAll(data) + if err != nil { + c.Logger.Error("Could not open file in zip %s\n", "file name", file.Name, "err", err) + back_channel <- index + continue + } + + // TODO check if the file is a valid photo that matched the defined photo on the database + + parts := strings.Split(file.Name, "/") + + mode := model_classes.DATA_POINT_MODE_TRAINING + if parts[0] == "testing" { + mode = model_classes.DATA_POINT_MODE_TESTING + } + + data_point_id, err := model_classes.AddDataPoint(c.Db, ids[parts[1]], "id://", mode) + if err != nil { + c.Logger.Error("Failed to add datapoint", "model", model.Id, "file name", file.Name, "err", err) + back_channel <- -index - 1 + return + } + + file_path := path.Join(base_path, data_point_id+"."+model.Format) + f, err := os.Create(file_path) + if err != nil { + c.Logger.Error("Failed to add datapoint", "model", model.Id, "file name", file.Name, "err", err) + back_channel <- -index - 1 + return + } + defer f.Close() + f.Write(file_data) + + if !testImgForModel(c, model, file_path) { + c.Logger.Errorf("Image did not have valid format for model %s (in zip: %s)!", file_path, file.Name) + c.Logger.Warn("Not failling updating data point to status -1") + message := "Image did not have valid format for the model" + if err = model_classes.UpdateDataPointStatus(c.Db, data_point_id, -1, &message); err != nil { + c.Logger.Error("Failed to update data point", "model", model.Id, "file name", file.Name, "err", err) + back_channel <- -index - 1 + return + } + } + + back_channel <- index + } +} + func processZipFile(c *Context, model *BaseModel) { + + + var err error + + failed := func(msg string) { + c.Logger.Error(msg, "err", err) + ModelUpdateStatus(c, model.Id, FAILED_PREPARING_ZIP_FILE) + } + reader, err := zip.OpenReader(path.Join("savedData", model.Id, "base_data.zip")) if err != nil { - // TODO add msg to error - ModelUpdateStatus(c, model.Id, FAILED_PREPARING_ZIP_FILE) - fmt.Printf("Faield to proccess zip file failed to open reader\n") - fmt.Println(err) + failed("Failed to proccess zip file failed to open reader") return } defer reader.Close() @@ -51,8 +133,7 @@ func processZipFile(c *Context, model *BaseModel) { } if paths[0] != "training" && paths[0] != "testing" { - fmt.Printf("Invalid file '%s' TODO add msg to response!!!\n", file.Name) - ModelUpdateStatus(c, model.Id, FAILED_PREPARING_ZIP_FILE) + failed(fmt.Sprintf("Invalid file '%s'!", file.Name)) return } @@ -64,90 +145,95 @@ func processZipFile(c *Context, model *BaseModel) { } if !reflect.DeepEqual(testing, training) { - fmt.Printf("testing and training are diferent\n") - fmt.Println(testing) - fmt.Println(training) - ModelUpdateStatus(c, model.Id, FAILED_PREPARING_ZIP_FILE) + c.Logger.Info("Diff", "testing", testing, "training", training) + failed("Testing and Training datesets are diferent") return } base_path := path.Join("savedData", model.Id, "data") if err = os.MkdirAll(base_path, os.ModePerm); err != nil { - fmt.Printf("Failed to create base_path dir\n") - ModelUpdateStatus(c, model.Id, FAILED_PREPARING_ZIP_FILE) + failed("Failed to create base_path dir\n") return } + c.Logger.Info("File Structure looks good to append data", "model", model.Id) + ids := map[string]string{} for i, name := range training { id, err := model_classes.CreateClass(c.Db, model.Id, i, name) if err != nil { - fmt.Printf("Failed to create class '%s' on db\n", name) - ModelUpdateStatus(c, id, FAILED_PREPARING_ZIP_FILE) + failed(fmt.Sprintf("Failed to create the class '%s'", name)) return } ids[name] = id } + back_channel := make(chan int, c.Handle.Config.NumberOfWorkers) + + file_chans := make([]chan *zip.File, c.Handle.Config.NumberOfWorkers) + + for i := 0; i < c.Handle.Config.NumberOfWorkers; i++ { + file_chans[i] = make(chan *zip.File, 2) + go fileProcessor(c, model, reader, ids, base_path, i, file_chans[i], back_channel) + } + + clean_up_channels := func() { + for i := 0; i < c.Handle.Config.NumberOfWorkers; i++ { + close(file_chans[i]) + } + for i := 0; i < c.Handle.Config.NumberOfWorkers - 1; i++ { + _ = <- back_channel + } + close(back_channel) + } + + first_round := true + + channel_to_send := 0 + + // Parelalize this + for _, file := range reader.Reader.File { + // Skip if dir if file.Name[len(file.Name)-1] == '/' { continue } - data, err := reader.Open(file.Name) - if err != nil { - fmt.Printf("Could not open file in zip %s\n", file.Name) - ModelUpdateStatus(c, model.Id, FAILED_PREPARING_ZIP_FILE) - return - } - defer data.Close() - file_data, err := io.ReadAll(data) - if err != nil { - fmt.Printf("Could not read file file in zip %s\n", file.Name) - ModelUpdateStatus(c, model.Id, FAILED_PREPARING_ZIP_FILE) - return - } + file_chans[channel_to_send] <- file - // TODO check if the file is a valid photo that matched the defined photo on the database - parts := strings.Split(file.Name, "/") + if first_round { + channel_to_send += 1 + if c.Handle.Config.NumberOfWorkers == channel_to_send { + first_round = false + } + } + + // Can not do else if because need to handle the case where the value changes in + // previous if + if !first_round { + new_id, ok := <- back_channel + if !ok { + c.Logger.Fatal("Something is very wrong please check as this line should be unreachable") + } - mode := model_classes.DATA_POINT_MODE_TRAINING - if parts[0] == "testing" { - mode = model_classes.DATA_POINT_MODE_TESTING - } + if new_id < 0 { + c.Logger.Error("Worker failed", "worker id", -(new_id + 1)) + clean_up_channels() + failed("One of the workers failed due to db error") + return + } - data_point_id, err := model_classes.AddDataPoint(c.Db, ids[parts[1]], "id://", mode) - if err != nil { - fmt.Printf("Failed to add data point for %s\n", model.Id) - fmt.Println(err) - ModelUpdateStatus(c, model.Id, FAILED_PREPARING_ZIP_FILE) - return - } + channel_to_send = new_id + } - file_path := path.Join(base_path, data_point_id+"."+model.Format) - f, err := os.Create(file_path) - if err != nil { - fmt.Printf("Could not create file %s\n", file_path) - ModelUpdateStatus(c, model.Id, FAILED_PREPARING_ZIP_FILE) - return - } - defer f.Close() - f.Write(file_data) - - if !testImgForModel(c, model, file_path) { - c.Logger.Errorf("Image did not have valid format for model %s (in zip: %s)!", file_path, file.Name) - c.Logger.Warn("Not failling updating data point to status -1") - message := "Image did not have valid format for the model" - if err = model_classes.UpdateDataPointStatus(c.Db, data_point_id, -1, &message); err != nil { - c.Logger.Error("Failed to update data point status") - ModelUpdateStatus(c, model.Id, FAILED_PREPARING_ZIP_FILE) - } - } } - fmt.Printf("Added data to model '%s'!\n", model.Id) + clean_up_channels() + + c.Logger.Info("Added data to model", "model", model.Id) + ModelUpdateStatus(c, model.Id, CONFIRM_PRE_TRAINING) } @@ -202,19 +288,20 @@ func processZipFileExpand(c *Context, model *BaseModel) { ids := map[string]string{} - var baseOrder struct { - Order int `db:"class_order"` - } + var baseOrder struct { + Order int `db:"class_order"` + } - err = GetDBOnce(c, &baseOrder, "model_classes where model_id=$1 order by class_order desc;", model.Id) - if err != nil { - failed("Failed to get the last class_order") - } + err = GetDBOnce(c, &baseOrder, "model_classes where model_id=$1 order by class_order desc;", model.Id) + if err != nil { + failed("Failed to get the last class_order") + } - base := baseOrder.Order + 1 + base := baseOrder.Order + 1 for i, name := range training { - id, err := model_classes.CreateClass(c.Db, model.Id, base + i, name) + id, _err := model_classes.CreateClass(c.Db, model.Id, base+i, name) + err = _err if err != nil { failed(fmt.Sprintf("Failed to create class '%s' on db\n", name)) return diff --git a/logic/models/train/train.go b/logic/models/train/train.go index 0707349..e2abaf2 100644 --- a/logic/models/train/train.go +++ b/logic/models/train/train.go @@ -1903,12 +1903,14 @@ func handleTrain(handle *Handle) { return c.JsonBadRequest("Definition not on status 3(training)") } - c.Logger.Info("Updated model_definition!", "model", model_id, "progress", epoch, "accuracy", accuracy) + c.Logger.Debug("Updated model_definition!", "model", model_id, "progress", epoch, "accuracy", accuracy) _, err = c.Db.Exec("update model_definition set epoch_progress=$1, accuracy=$2 where id=$3", epoch, accuracy, def_id) if err != nil { return c.Error500(err) } + + c.ShowMessage = false return nil }) @@ -1951,12 +1953,14 @@ func handleTrain(handle *Handle) { return c.JsonBadRequest("Head not on status 3(training)") } - c.Logger.Info("Updated model_head!", "head", head_id, "progress", epoch, "accuracy", accuracy) + c.Logger.Debug("Updated model_head!", "head", head_id, "progress", epoch, "accuracy", accuracy) _, err = c.Db.Exec("update exp_model_head set epoch_progress=$1, accuracy=$2 where id=$3", epoch, accuracy, head_id) if err != nil { return c.Error500(err) } + + c.ShowMessage = false return nil }) } diff --git a/logic/utils/config.go b/logic/utils/config.go index a5257b7..aeaa846 100644 --- a/logic/utils/config.go +++ b/logic/utils/config.go @@ -8,8 +8,9 @@ import ( ) type Config struct { - Hostname string - Port int + Hostname string + Port int + NumberOfWorkers int `toml:"number_of_workers"` } func LoadConfig() Config { @@ -21,8 +22,9 @@ func LoadConfig() Config { log.Error("Failed to load config file", "err", err) // Use default values return Config{ - Hostname: "localhost", - Port: 8000, + Hostname: "localhost", + Port: 8000, + NumberOfWorkers: 10, } }