diff --git a/logic/models/edit.go b/logic/models/edit.go index 63cb37f..527e710 100644 --- a/logic/models/edit.go +++ b/logic/models/edit.go @@ -107,22 +107,40 @@ func handleEdit(handle *Handle) { Accuracy float64 } - def_rows, err := c.Db.Query("select id, status, epoch, epoch_progress, accuracy from model_definition where model_id=$1 order by created_on asc", model.Id) - if err != nil { - return c.Error500(err) - } - defer def_rows.Close() - defs := []defrow{} - for def_rows.Next() { - var def defrow - err = def_rows.Scan(&def.Id, &def.Status, &def.Epoch, &def.EpochProgress, &def.Accuracy) - if err != nil { - return c.Error500(err) - } - defs = append(defs, def) - } + if model.Type == 2 { + def_rows, err := c.Db.Query("select md.id, md.status, md.epoch, h.epoch_progress, h.accuracy from model_definition as md inner join exp_model_head as h on h.def_id = md.id where md.model_id=$1 order by md.created_on asc", model.Id) + if err != nil { + return c.Error500(err) + } + defer def_rows.Close() + + for def_rows.Next() { + var def defrow + err = def_rows.Scan(&def.Id, &def.Status, &def.Epoch, &def.EpochProgress, &def.Accuracy) + if err != nil { + return c.Error500(err) + } + defs = append(defs, def) + } + } else { + def_rows, err := c.Db.Query("select id, status, epoch, epoch_progress, accuracy from model_definition where model_id=$1 order by created_on asc", model.Id) + if err != nil { + return c.Error500(err) + } + defer def_rows.Close() + + for def_rows.Next() { + var def defrow + err = def_rows.Scan(&def.Id, &def.Status, &def.Epoch, &def.EpochProgress, &def.Accuracy) + if err != nil { + return c.Error500(err) + } + defs = append(defs, def) + } + } + type layerdef struct { id string @@ -164,39 +182,11 @@ func handleEdit(handle *Handle) { return c.Error500(err) } - c.Logger.Info("res", "id", lastLayer.Id, "start", lastLayer.Range_start, "end", lastLayer.Range_end, "shape", fmt.Sprintf("%d, 1", lastLayer.Range_end-lastLayer.Range_start + 1)) - layers = append(layers, layerdef{ id: lastLayer.Id, LayerType: LAYER_DENSE, Shape: fmt.Sprintf("%d, 1", lastLayer.Range_end-lastLayer.Range_start + 1), }) - - /* - lastLayer, err := c.Db.Query("select id, range_start, range_end from exp_model_head where def_id=$1 and status=3;", def.Id) - if err != nil { - return c.Error500(err) - } - defer lastLayer.Close() - - if !lastLayer.Next() { - c.Logger.Info("Could not find the model head for", "def_id", def.Id) - continue - } - - head_id, range_start, range_end := "", 0, 0 - err = lastLayer.Scan(&head_id, &range_start, &range_end) - - if err != nil { - return c.Error500(err) - } - - layers = append(layers, layerdef{ - id: head_id, - LayerType: LAYER_DENSE, - Shape: fmt.Sprintf("%d, 1", range_end-range_start), - }) - */ } break