feat: closes #24
This commit is contained in:
parent
f68f0e2444
commit
84b9c40a72
@ -110,7 +110,7 @@ func handleRun(handle *Handle) {
|
|||||||
|
|
||||||
root := tg.NewRoot()
|
root := tg.NewRoot()
|
||||||
|
|
||||||
tf_img := ReadPNG(root, path.Join(run_path, "img.png"), 3)
|
tf_img := ReadPNG(root, path.Join(run_path, "img.png"), int64(model.ImageMode))
|
||||||
|
|
||||||
exec_results := tg.Exec(root, []tf.Output{tf_img.Value()}, nil, &tf.SessionOptions{})
|
exec_results := tg.Exec(root, []tf.Output{tf_img.Value()}, nil, &tf.SessionOptions{})
|
||||||
inputImage, err:= tf.NewTensor(exec_results[0].Value())
|
inputImage, err:= tf.NewTensor(exec_results[0].Value())
|
||||||
@ -130,8 +130,6 @@ func handleRun(handle *Handle) {
|
|||||||
vi := 0
|
vi := 0
|
||||||
var predictions = results[0].Value().([][]float32)[0]
|
var predictions = results[0].Value().([][]float32)[0]
|
||||||
|
|
||||||
fmt.Println(predictions)
|
|
||||||
|
|
||||||
for i, v := range predictions {
|
for i, v := range predictions {
|
||||||
if v > vmax {
|
if v > vmax {
|
||||||
vi = i
|
vi = i
|
||||||
@ -155,7 +153,6 @@ func handleRun(handle *Handle) {
|
|||||||
var name string
|
var name string
|
||||||
if err = rows.Scan(&name); err != nil { return nil }
|
if err = rows.Scan(&name); err != nil { return nil }
|
||||||
|
|
||||||
|
|
||||||
LoadDefineTemplate(w, "/models/edit.html", "run-model-card", c.AddMap(AnyMap{
|
LoadDefineTemplate(w, "/models/edit.html", "run-model-card", c.AddMap(AnyMap{
|
||||||
"Model": model,
|
"Model": model,
|
||||||
"Result": name,
|
"Result": name,
|
||||||
|
@ -19,27 +19,12 @@ import (
|
|||||||
|
|
||||||
func MakeDefenition(db *sql.DB, model_id string, target_accuracy int) (id string, err error) {
|
func MakeDefenition(db *sql.DB, model_id string, target_accuracy int) (id string, err error) {
|
||||||
id = ""
|
id = ""
|
||||||
_, err = db.Exec("insert into model_definition (model_id, target_accuracy) values ($1, $2);", model_id, target_accuracy)
|
rows, err := db.Query("insert into model_definition (model_id, target_accuracy) values ($1, $2) returning id;", model_id, target_accuracy)
|
||||||
if err != nil {
|
if err != nil { return }
|
||||||
return
|
defer rows.Close()
|
||||||
}
|
if !rows.Next() { return id, errors.New("Something wrong!") }
|
||||||
|
|
||||||
rows, err := db.Query("select id from model_definition where model_id=$1 order by created_on DESC;", model_id)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
if !rows.Next() {
|
|
||||||
return id, errors.New("Something wrong!")
|
|
||||||
}
|
|
||||||
|
|
||||||
err = rows.Scan(&id)
|
err = rows.Scan(&id)
|
||||||
if err != nil {
|
return
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type ModelDefinitionStatus int
|
type ModelDefinitionStatus int
|
||||||
@ -103,7 +88,7 @@ func generateCvs(handle *Handle, run_path string, model_id string) (count int,
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func trainDefinition(handle *Handle, model_id string, definition_id string) (accuracy float64, err error) {
|
func trainDefinition(handle *Handle, model *BaseModel, definition_id string) (accuracy float64, err error) {
|
||||||
accuracy = 0
|
accuracy = 0
|
||||||
layers, err := handle.Db.Query("select layer_type, shape from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id)
|
layers, err := handle.Db.Query("select layer_type, shape from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -128,14 +113,14 @@ func trainDefinition(handle *Handle, model_id string, definition_id string) (acc
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Generate run folder
|
// Generate run folder
|
||||||
run_path := path.Join("/tmp", model_id, "defs", definition_id)
|
run_path := path.Join("/tmp", model.Id, "defs", definition_id)
|
||||||
|
|
||||||
err = os.MkdirAll(run_path, os.ModePerm)
|
err = os.MkdirAll(run_path, os.ModePerm)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = generateCvs(handle, run_path, model_id)
|
_, err = generateCvs(handle, run_path, model.Id)
|
||||||
if err != nil { return }
|
if err != nil { return }
|
||||||
|
|
||||||
// Create python script
|
// Create python script
|
||||||
@ -153,8 +138,9 @@ func trainDefinition(handle *Handle, model_id string, definition_id string) (acc
|
|||||||
if err = tmpl.Execute(f, AnyMap{
|
if err = tmpl.Execute(f, AnyMap{
|
||||||
"Layers": got,
|
"Layers": got,
|
||||||
"Size": got[0].Shape,
|
"Size": got[0].Shape,
|
||||||
"DataDir": path.Join(getDir(), "savedData", model_id, "data"),
|
"DataDir": path.Join(getDir(), "savedData", model.Id, "data"),
|
||||||
"RunPath": run_path,
|
"RunPath": run_path,
|
||||||
|
"ColorMode": model.ImageMode,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -167,7 +153,7 @@ func trainDefinition(handle *Handle, model_id string, definition_id string) (acc
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Copy result around
|
// Copy result around
|
||||||
result_path := path.Join("savedData", model_id, "defs", definition_id)
|
result_path := path.Join("savedData", model.Id, "defs", definition_id)
|
||||||
|
|
||||||
if err = os.MkdirAll(result_path, os.ModePerm); err != nil {
|
if err = os.MkdirAll(result_path, os.ModePerm); err != nil {
|
||||||
return
|
return
|
||||||
@ -235,7 +221,7 @@ func trainModel(handle *Handle, model *BaseModel) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, def := range definitions {
|
for _, def := range definitions {
|
||||||
accuracy, err := trainDefinition(handle, model.Id, def.id)
|
accuracy, err := trainDefinition(handle, model, def.id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("Failed to train definition!Err:\n")
|
fmt.Printf("Failed to train definition!Err:\n")
|
||||||
fmt.Println(err)
|
fmt.Println(err)
|
||||||
|
@ -10,8 +10,9 @@ type BaseModel struct {
|
|||||||
Status int
|
Status int
|
||||||
Id string
|
Id string
|
||||||
|
|
||||||
Width int
|
ImageMode int
|
||||||
Height int
|
Width int
|
||||||
|
Height int
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -30,7 +31,7 @@ const (
|
|||||||
var ModelNotFoundError = errors.New("Model not found error")
|
var ModelNotFoundError = errors.New("Model not found error")
|
||||||
|
|
||||||
func GetBaseModel(db *sql.DB, id string) (base *BaseModel, err error) {
|
func GetBaseModel(db *sql.DB, id string) (base *BaseModel, err error) {
|
||||||
rows, err := db.Query("select name, status, id, width, height from models where id=$1;", id)
|
rows, err := db.Query("select name, status, id, width, height, color_mode from models where id=$1;", id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -41,10 +42,17 @@ func GetBaseModel(db *sql.DB, id string) (base *BaseModel, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
base = &BaseModel{}
|
base = &BaseModel{}
|
||||||
err = rows.Scan(&base.Name, &base.Status, &base.Id, &base.Width, &base.Height)
|
var colorMode string
|
||||||
|
err = rows.Scan(&base.Name, &base.Status, &base.Id, &base.Width, &base.Height, &colorMode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
switch colorMode {
|
||||||
|
case "greyscale":
|
||||||
|
base.ImageMode = 1
|
||||||
|
default:
|
||||||
|
panic("unkown color mode")
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -32,10 +32,7 @@ def pathToLabel(path):
|
|||||||
#return tf.strings.as_string([path])
|
#return tf.strings.as_string([path])
|
||||||
|
|
||||||
def decode_image(img):
|
def decode_image(img):
|
||||||
# channels were reduced to 1 since image is grayscale
|
img = tf.io.decode_png(img, channels={{.ColorMode}})
|
||||||
# TODO chnage channel number based if grayscale
|
|
||||||
img = tf.io.decode_png(img, channels=3)
|
|
||||||
|
|
||||||
return tf.image.resize(img, image_size)
|
return tf.image.resize(img, image_size)
|
||||||
|
|
||||||
def process_path(path):
|
def process_path(path):
|
||||||
|
Loading…
Reference in New Issue
Block a user